import math
from operator import attrgetter, itemgetter

from django.db import migrations, models
from django.db.models import Count, OuterRef, Subquery
from django.db.models.functions import Coalesce
from django.utils import timezone


def tie_ranker(iterable, key=attrgetter('points')):
    rank = 0
    delta = 1
    last = None
    buf = []
    for item in iterable:
        new = key(item)
        if new != last:
            for _ in buf:
                yield rank + (delta - 1) / 2.0
            rank += delta
            delta = 0
            buf = []
        delta += 1
        buf.append(item)
        last = key(item)
    for _ in buf:
        yield rank + (delta - 1) / 2.0


def rational_approximation(t):
    # Abramowitz and Stegun formula 26.2.23.
    # The absolute value of the error should be less than 4.5 e-4.
    c = [2.515517, 0.802853, 0.010328]
    d = [1.432788, 0.189269, 0.001308]
    numerator = (c[2] * t + c[1]) * t + c[0]
    denominator = ((d[2] * t + d[1]) * t + d[0]) * t + 1.0
    return t - numerator / denominator


def normal_CDF_inverse(p):
    assert 0.0 < p < 1

    # See article above for explanation of this section.
    if p < 0.5:
        # F^-1(p) = - G^-1(p)
        return -rational_approximation(math.sqrt(-2.0 * math.log(p)))
    else:
        # F^-1(p) = G^-1(1-p)
        return rational_approximation(math.sqrt(-2.0 * math.log(1.0 - p)))


def WP(RA, RB, VA, VB):
    return (math.erf((RB - RA) / math.sqrt(2 * (VA * VA + VB * VB))) + 1) / 2.0


def recalculate_ratings(old_rating, old_volatility, actual_rank, times_rated, is_disqualified):
    # actual_rank: 1 is first place, N is last place
    # if there are ties, use the average of places (if places 2, 3, 4, 5 tie, use 3.5 for all of them)

    N = len(old_rating)
    new_rating = old_rating[:]
    new_volatility = old_volatility[:]
    if N <= 1:
        return new_rating, new_volatility

    ranking = list(range(N))
    ranking.sort(key=old_rating.__getitem__, reverse=True)

    ave_rating = float(sum(old_rating)) / N
    sum1 = sum(i * i for i in old_volatility) / N
    sum2 = sum((i - ave_rating) ** 2 for i in old_rating) / (N - 1)
    CF = math.sqrt(sum1 + sum2)

    for i in range(N):
        ERank = 0.5
        for j in range(N):
            ERank += WP(old_rating[i], old_rating[j], old_volatility[i], old_volatility[j])

        EPerf = -normal_CDF_inverse((ERank - 0.5) / N)
        APerf = -normal_CDF_inverse((actual_rank[i] - 0.5) / N)
        PerfAs = old_rating[i] + CF * (APerf - EPerf)
        Weight = 1.0 / (1 - (0.42 / (times_rated[i] + 1) + 0.18)) - 1.0
        if old_rating[i] > 2500:
            Weight *= 0.8
        elif old_rating[i] >= 2000:
            Weight *= 0.9

        Cap = 150.0 + 1500.0 / (times_rated[i] + 2)

        new_rating[i] = (old_rating[i] + Weight * PerfAs) / (1.0 + Weight)

        if abs(old_rating[i] - new_rating[i]) > Cap:
            if old_rating[i] < new_rating[i]:
                new_rating[i] = old_rating[i] + Cap
            else:
                new_rating[i] = old_rating[i] - Cap

        if times_rated[i] == 0:
            new_volatility[i] = 385
        else:
            new_volatility[i] = math.sqrt(((new_rating[i] - old_rating[i]) ** 2) / Weight +
                                          (old_volatility[i] ** 2) / (Weight + 1))

        if is_disqualified[i]:
            # DQed users can manipulate TopCoder ratings to get higher volatility in order to increase their rating
            # later on, prohibit this by ensuring their volatility never increases in this situation
            new_volatility[i] = min(new_volatility[i], old_volatility[i])

    # try to keep the sum of ratings constant
    adjust = float(sum(old_rating) - sum(new_rating)) / N
    new_rating = list(map(adjust.__add__, new_rating))
    # inflate a little if we have to so people who placed first don't lose rating
    best_rank = min(actual_rank)
    for i in range(N):
        if abs(actual_rank[i] - best_rank) <= 1e-3 and new_rating[i] < old_rating[i] + 1:
            new_rating[i] = old_rating[i] + 1
    return list(map(int, map(round, new_rating))), list(map(int, map(round, new_volatility)))


def tc_rate_contest(contest, Rating, Profile):
    rating_subquery = Rating.objects.filter(user=OuterRef('user'))
    rating_sorted = rating_subquery.order_by('-contest__end_time')
    users = contest.users.order_by('is_disqualified', '-score', 'cumtime', 'tiebreaker') \
        .annotate(submissions=Count('submission'),
                  last_rating=Coalesce(Subquery(rating_sorted.values('rating')[:1]), 1200),
                  volatility=Coalesce(Subquery(rating_sorted.values('volatility')[:1]), 535),
                  times=Coalesce(Subquery(rating_subquery.order_by().values('user_id')
                                          .annotate(count=Count('id')).values('count')), 0)) \
        .exclude(user_id__in=contest.rate_exclude.all()) \
        .filter(virtual=0).values('id', 'user_id', 'score', 'cumtime', 'tiebreaker', 'is_disqualified',
                                  'last_rating', 'volatility', 'times')
    if not contest.rate_all:
        users = users.filter(submissions__gt=0)
    if contest.rating_floor is not None:
        users = users.exclude(last_rating__lt=contest.rating_floor)
    if contest.rating_ceiling is not None:
        users = users.exclude(last_rating__gt=contest.rating_ceiling)

    users = list(users)
    participation_ids = list(map(itemgetter('id'), users))
    user_ids = list(map(itemgetter('user_id'), users))
    is_disqualified = list(map(itemgetter('is_disqualified'), users))
    ranking = list(tie_ranker(users, key=itemgetter('score', 'cumtime', 'tiebreaker')))
    old_rating = list(map(itemgetter('last_rating'), users))
    old_volatility = list(map(itemgetter('volatility'), users))
    times_ranked = list(map(itemgetter('times'), users))
    rating, volatility = recalculate_ratings(old_rating, old_volatility, ranking, times_ranked, is_disqualified)

    now = timezone.now()
    ratings = [Rating(user_id=i, contest=contest, rating=r, volatility=v, last_rated=now, participation_id=p, rank=z)
               for i, p, r, v, z in zip(user_ids, participation_ids, rating, volatility, ranking)]

    Rating.objects.bulk_create(ratings)

    Profile.objects.filter(contest_history__contest=contest, contest_history__virtual=0).update(
        rating=Subquery(Rating.objects.filter(user=OuterRef('id'))
                        .order_by('-contest__end_time').values('rating')[:1]))


# inspired by rate_all_view
def rate_tc(apps, schema_editor):
    Contest = apps.get_model('judge', 'Contest')
    Rating = apps.get_model('judge', 'Rating')
    Profile = apps.get_model('judge', 'Profile')

    with schema_editor.connection.cursor() as cursor:
        cursor.execute('TRUNCATE TABLE `%s`' % Rating._meta.db_table)
    Profile.objects.update(rating=None)
    for contest in Contest.objects.filter(is_rated=True, end_time__lte=timezone.now()).order_by('end_time'):
        tc_rate_contest(contest, Rating, Profile)


# inspired by rate_all_view
def rate_elo_mmr(apps, schema_editor):
    Rating = apps.get_model('judge', 'Rating')
    Profile = apps.get_model('judge', 'Profile')

    with schema_editor.connection.cursor() as cursor:
        cursor.execute('TRUNCATE TABLE `%s`' % Rating._meta.db_table)
    Profile.objects.update(rating=None)
    # Don't populate Rating


class Migration(migrations.Migration):

    dependencies = [
        ('judge', '0117_auto_20211209_0612'),
    ]

    operations = [
        migrations.RunPython(migrations.RunPython.noop, rate_tc, atomic=True),
        migrations.AddField(
            model_name='rating',
            name='mean',
            field=models.FloatField(verbose_name='raw rating'),
        ),
        migrations.AddField(
            model_name='rating',
            name='performance',
            field=models.FloatField(verbose_name='contest performance'),
        ),
        migrations.RemoveField(
            model_name='rating',
            name='volatility',
            field=models.IntegerField(verbose_name='volatility'),
        ),
        migrations.RunPython(rate_elo_mmr, migrations.RunPython.noop, atomic=True),
    ]