from datetime import timedelta

from django.core.exceptions import ValidationError
from django.db import connection
from django.template.defaultfilters import floatformat
from django.urls import reverse
from django.utils.html import format_html
from django.utils.safestring import mark_safe
from django.utils.translation import gettext_lazy

from judge.contest_format.default import DefaultContestFormat
from judge.contest_format.registry import register_contest_format
from judge.timezone import from_database_time
from judge.utils.timedelta import nice_repr


@register_contest_format("icpc")
class ICPCContestFormat(DefaultContestFormat):
    name = gettext_lazy("ICPC")
    config_defaults = {"penalty": 20}
    config_validators = {"penalty": lambda x: x >= 0}
    """
        penalty: Number of penalty minutes each incorrect submission adds. Defaults to 20.
    """

    @classmethod
    def validate(cls, config):
        if config is None:
            return

        if not isinstance(config, dict):
            raise ValidationError(
                "ICPC-styled contest expects no config or dict as config"
            )

        for key, value in config.items():
            if key not in cls.config_defaults:
                raise ValidationError('unknown config key "%s"' % key)
            if not isinstance(value, type(cls.config_defaults[key])):
                raise ValidationError('invalid type for config key "%s"' % key)
            if not cls.config_validators[key](value):
                raise ValidationError(
                    'invalid value "%s" for config key "%s"' % (value, key)
                )

    def __init__(self, contest, config):
        self.config = self.config_defaults.copy()
        self.config.update(config or {})
        self.contest = contest

    def update_participation(self, participation):
        cumtime = 0
        last = 0
        penalty = 0
        score = 0
        format_data = {}

        with connection.cursor() as cursor:
            cursor.execute(
                """
                SELECT MAX(cs.points) as `points`, (
                    SELECT MIN(csub.date)
                        FROM judge_contestsubmission ccs LEFT OUTER JOIN
                             judge_submission csub ON (csub.id = ccs.submission_id)
                        WHERE ccs.problem_id = cp.id AND ccs.participation_id = %s AND ccs.points = MAX(cs.points)
                ) AS `time`, cp.id AS `prob`
                FROM judge_contestproblem cp INNER JOIN
                     judge_contestsubmission cs ON (cs.problem_id = cp.id AND cs.participation_id = %s) LEFT OUTER JOIN
                     judge_submission sub ON (sub.id = cs.submission_id)
                GROUP BY cp.id
            """,
                (participation.id, participation.id),
            )

            for points, time, prob in cursor.fetchall():
                time = from_database_time(time)
                dt = (time - participation.start).total_seconds()

                # Compute penalty
                if self.config["penalty"]:
                    # An IE can have a submission result of `None`
                    subs = (
                        participation.submissions.exclude(
                            submission__result__isnull=True
                        )
                        .exclude(submission__result__in=["IE", "CE"])
                        .filter(problem_id=prob)
                    )
                    if points:
                        prev = subs.filter(submission__date__lte=time).count() - 1
                        penalty += prev * self.config["penalty"] * 60
                    else:
                        # We should always display the penalty, even if the user has a score of 0
                        prev = subs.count()
                else:
                    prev = 0

                if points:
                    cumtime += dt
                    last = max(last, dt)

                format_data[str(prob)] = {"time": dt, "points": points, "penalty": prev}
                score += points

        participation.cumtime = max(0, cumtime + penalty)
        participation.score = score
        participation.tiebreaker = last  # field is sorted from least to greatest
        participation.format_data = format_data
        participation.save()

    def display_user_problem(self, participation, contest_problem):
        format_data = (participation.format_data or {}).get(str(contest_problem.id))
        if format_data:
            penalty = (
                format_html(
                    '<small style="color:red"> ({penalty})</small>',
                    penalty=floatformat(format_data["penalty"]),
                )
                if format_data["penalty"]
                else ""
            )
            return format_html(
                '<td class="{state}"><a data-featherlight="{url}" href="#">{points}{penalty}<div class="solving-time">{time}</div></a></td>',
                state=(
                    (
                        "pretest-"
                        if self.contest.run_pretests_only
                        and contest_problem.is_pretested
                        else ""
                    )
                    + self.best_solution_state(
                        format_data["points"], contest_problem.points
                    )
                ),
                url=reverse(
                    "contest_user_submissions_ajax",
                    args=[
                        self.contest.key,
                        participation.user.user.username,
                        contest_problem.problem.code,
                    ],
                ),
                points=floatformat(format_data["points"]),
                penalty=penalty,
                time=nice_repr(timedelta(seconds=format_data["time"]), "noday"),
            )
        else:
            return mark_safe("<td></td>")

    def get_contest_problem_label_script(self):
        return """
            function(n)
                n = n + 1
                ret = ""
                while n > 0 do
                    ret = string.char((n - 1) % 26 + 65) .. ret
                    n = math.floor((n - 1) / 26)
                end
                return ret
            end
        """