Move dirty logic to admin

This commit is contained in:
cuom1999 2023-10-12 08:56:53 -05:00
parent e1a38d42c3
commit 11bc57f2b1
4 changed files with 38 additions and 31 deletions

View file

@ -33,6 +33,7 @@ from judge.widgets import (
CheckboxSelectMultipleWithSelectAll,
HeavyPreviewAdminPageDownWidget,
)
from judge.utils.problems import user_editable_ids, user_tester_ids
MEMORY_UNITS = (("KB", "KB"), ("MB", "MB"))
@ -359,12 +360,31 @@ class ProblemAdmin(CompareVersionAdmin):
self._rescore(request, obj.id)
def save_related(self, request, form, formsets, change):
editors = set()
testers = set()
if "curators" in form.changed_data or "authors" in form.changed_data:
editors = set(form.instance.editor_ids)
if "testers" in form.changed_data:
testers = set(form.instance.tester_ids)
super().save_related(request, form, formsets, change)
# Only rescored if we did not already do so in `save_model`
obj = form.instance
obj.curators.add(request.profile)
obj.is_organization_private = obj.organizations.count() > 0
obj.save()
if "curators" in form.changed_data or "authors" in form.changed_data:
del obj.editor_ids
editors = editors.union(set(obj.editor_ids))
if "testers" in form.changed_data:
del obj.tester_ids
testers = testers.union(set(obj.tester_ids))
for editor in editors:
user_editable_ids.dirty(editor)
for tester in testers:
user_tester_ids.dirty(tester)
# Create notification
if "is_public" in form.changed_data or "organizations" in form.changed_data:
users = set(obj.authors.all())

View file

@ -40,12 +40,10 @@ def cache_wrapper(prefix, timeout=None):
def _get(key):
if not l0_cache:
return cache.get(key)
print("GET", key, l0_cache.get(key))
return l0_cache.get(key) or cache.get(key)
def _set_l0(key, value):
if l0_cache:
print("SET", key, value)
l0_cache.set(key, value, 30)
def _set(key, value, timeout):

View file

@ -436,15 +436,23 @@ class Problem(models.Model, PageVotable, Bookmarkable):
@cached_property
def author_ids(self):
return self.authors.values_list("id", flat=True)
return Problem.authors.through.objects.filter(problem=self).values_list(
"profile_id", flat=True
)
@cached_property
def editor_ids(self):
return self.author_ids | self.curators.values_list("id", flat=True)
return self.author_ids.union(
Problem.curators.through.objects.filter(problem=self).values_list(
"profile_id", flat=True
)
)
@cached_property
def tester_ids(self):
return self.testers.values_list("id", flat=True)
return Problem.testers.through.objects.filter(problem=self).values_list(
"profile_id", flat=True
)
@cached_property
def usable_common_names(self):

View file

@ -10,8 +10,6 @@ from django.db.models import Case, Count, ExpressionWrapper, F, Max, Q, When
from django.db.models.fields import FloatField
from django.utils import timezone
from django.utils.translation import gettext as _, gettext_noop
from django.db.models.signals import pre_save
from django.dispatch import receiver
from judge.models import Problem, Submission
from judge.ml.collab_filter import CollabFilter
@ -29,9 +27,9 @@ __all__ = [
@cache_wrapper(prefix="user_tester")
def user_tester_ids(profile):
return set(
Problem.testers.through.objects.filter(profile=profile).values_list(
"problem_id", flat=True
)
Problem.testers.through.objects.filter(profile=profile)
.values_list("problem_id", flat=True)
.distinct()
)
@ -41,7 +39,9 @@ def user_editable_ids(profile):
(
Problem.objects.filter(authors=profile)
| Problem.objects.filter(curators=profile)
).values_list("id", flat=True)
)
.values_list("id", flat=True)
.distinct()
)
return result
@ -249,22 +249,3 @@ def finished_submission(sub):
keys += ["contest_complete:%d" % participation.id]
keys += ["contest_attempted:%d" % participation.id]
cache.delete_many(keys)
@receiver([pre_save], sender=Problem)
def on_problem_save(sender, instance, **kwargs):
if instance.id is None:
return
prev_editors = list(sender.objects.get(id=instance.id).editor_ids)
new_editors = list(instance.editor_ids)
if prev_editors != new_editors:
all_editors = set(prev_editors + new_editors)
for profile_id in all_editors:
user_editable_ids.dirty(profile_id)
prev_testers = list(sender.objects.get(id=instance.id).tester_ids)
new_testers = list(instance.tester_ids)
if prev_testers != new_testers:
all_testers = set(prev_testers + new_testers)
for profile_id in all_testers:
user_tester_ids.dirty(profile_id)