NDOJ/judge/bridge/judge_handler.py
2023-03-09 22:31:55 -06:00

914 lines
31 KiB
Python

import hmac
import json
import logging
import threading
import time
import os
from collections import deque, namedtuple
from operator import itemgetter
from django import db
from django.conf import settings
from django.utils import timezone
from django.db.models import F
from django.core.cache import cache
from judge import event_poster as event
from judge.bridge.base_handler import ZlibPacketHandler, proxy_list
from judge.caching import finished_submission
from judge.models import (
Judge,
Language,
LanguageLimit,
Problem,
RuntimeVersion,
Submission,
SubmissionTestCase,
)
logger = logging.getLogger("judge.bridge")
json_log = logging.getLogger("judge.json.bridge")
UPDATE_RATE_LIMIT = 5
UPDATE_RATE_TIME = 0.5
SubmissionData = namedtuple(
"SubmissionData",
"time memory short_circuit pretests_only contest_no attempt_no user_id",
)
def _ensure_connection():
db.connection.close_if_unusable_or_obsolete()
class JudgeHandler(ZlibPacketHandler):
proxies = proxy_list(settings.BRIDGED_JUDGE_PROXIES or [])
def __init__(self, request, client_address, server, judges):
super().__init__(request, client_address, server)
self.judges = judges
self.handlers = {
"grading-begin": self.on_grading_begin,
"grading-end": self.on_grading_end,
"compile-error": self.on_compile_error,
"compile-message": self.on_compile_message,
"batch-begin": self.on_batch_begin,
"batch-end": self.on_batch_end,
"test-case-status": self.on_test_case,
"internal-error": self.on_internal_error,
"submission-terminated": self.on_submission_terminated,
"submission-acknowledged": self.on_submission_acknowledged,
"ping-response": self.on_ping_response,
"supported-problems": self.on_supported_problems,
"handshake": self.on_handshake,
}
self._working = False
self._no_response_job = None
self._problems = []
self.executors = {}
self.problems = {}
self.latency = None
self.time_delta = None
self.load = 1e100
self.name = None
self.batch_id = None
self.in_batch = False
self._stop_ping = threading.Event()
self._ping_average = deque(maxlen=6) # 1 minute average, just like load
self._time_delta = deque(maxlen=6)
# each value is (updates, last reset)
self.update_counter = {}
self.judge = None
self.judge_address = None
self._submission_cache_id = None
self._submission_cache = {}
def on_connect(self):
self.timeout = 15
logger.info("Judge connected from: %s", self.client_address)
json_log.info(self._make_json_log(action="connect"))
def on_disconnect(self):
self._stop_ping.set()
if self._working:
logger.error(
"Judge %s disconnected while handling submission %s",
self.name,
self._working,
)
self.judges.remove(self)
if self.name is not None:
self._disconnected()
logger.info(
"Judge disconnected from: %s with name %s", self.client_address, self.name
)
json_log.info(
self._make_json_log(action="disconnect", info="judge disconnected")
)
if self._working:
Submission.objects.filter(id=self._working).update(
status="IE", result="IE", error=""
)
json_log.error(
self._make_json_log(
sub=self._working,
action="close",
info="IE due to shutdown on grading",
)
)
def _authenticate(self, id, key):
try:
judge = Judge.objects.get(name=id, is_blocked=False)
except Judge.DoesNotExist:
result = False
else:
result = hmac.compare_digest(judge.auth_key, key)
if not result:
json_log.warning(
self._make_json_log(
action="auth", judge=id, info="judge failed authentication"
)
)
return result
def _connected(self):
judge = self.judge = Judge.objects.get(name=self.name)
judge.start_time = timezone.now()
judge.online = True
judge.problems.set(Problem.objects.filter(code__in=list(self.problems.keys())))
judge.runtimes.set(Language.objects.filter(key__in=list(self.executors.keys())))
# Delete now in case we somehow crashed and left some over from the last connection
RuntimeVersion.objects.filter(judge=judge).delete()
versions = []
for lang in judge.runtimes.all():
versions += [
RuntimeVersion(
language=lang,
name=name,
version=".".join(map(str, version)),
priority=idx,
judge=judge,
)
for idx, (name, version) in enumerate(self.executors[lang.key])
]
RuntimeVersion.objects.bulk_create(versions)
judge.last_ip = self.client_address[0]
judge.save()
self.judge_address = "[%s]:%s" % (
self.client_address[0],
self.client_address[1],
)
json_log.info(
self._make_json_log(
action="auth",
info="judge successfully authenticated",
executors=list(self.executors.keys()),
)
)
def _disconnected(self):
Judge.objects.filter(id=self.judge.id).update(online=False)
RuntimeVersion.objects.filter(judge=self.judge).delete()
def _update_ping(self):
try:
Judge.objects.filter(name=self.name).update(
ping=self.latency, load=self.load
)
except Exception as e:
# What can I do? I don't want to tie this to MySQL.
if (
e.__class__.__name__ == "OperationalError"
and e.__module__ == "_mysql_exceptions"
and e.args[0] == 2006
):
db.connection.close()
def send(self, data):
super().send(json.dumps(data, separators=(",", ":")))
def on_handshake(self, packet):
if "id" not in packet or "key" not in packet:
logger.warning("Malformed handshake: %s", self.client_address)
self.close()
return
if not self._authenticate(packet["id"], packet["key"]):
logger.warning("Authentication failure: %s", self.client_address)
self.close()
return
self.timeout = 60
self._problems = packet["problems"]
self.problems = dict(self._problems)
self.executors = packet["executors"]
self.name = packet["id"]
self.send({"name": "handshake-success"})
logger.info("Judge authenticated: %s (%s)", self.client_address, packet["id"])
self.judges.register(self)
threading.Thread(target=self._ping_thread).start()
self._connected()
def can_judge(self, problem, executor, judge_id=None):
return (
problem in self.problems
and executor in self.executors
and (not judge_id or self.name == judge_id)
)
@property
def working(self):
return bool(self._working)
def get_related_submission_data(self, submission):
_ensure_connection()
try:
(
pid,
time,
memory,
short_circuit,
lid,
is_pretested,
sub_date,
uid,
part_virtual,
part_id,
) = (
Submission.objects.filter(id=submission).values_list(
"problem__id",
"problem__time_limit",
"problem__memory_limit",
"problem__short_circuit",
"language__id",
"is_pretested",
"date",
"user__id",
"contest__participation__virtual",
"contest__participation__id",
)
).get()
except Submission.DoesNotExist:
logger.error("Submission vanished: %s", submission)
json_log.error(
self._make_json_log(
sub=self._working,
action="request",
info="submission vanished when fetching info",
)
)
return
attempt_no = (
Submission.objects.filter(
problem__id=pid,
contest__participation__id=part_id,
user__id=uid,
date__lt=sub_date,
)
.exclude(status__in=("CE", "IE"))
.count()
+ 1
)
try:
time, memory = (
LanguageLimit.objects.filter(problem__id=pid, language__id=lid)
.values_list("time_limit", "memory_limit")
.get()
)
except LanguageLimit.DoesNotExist:
pass
return SubmissionData(
time=time,
memory=memory,
short_circuit=short_circuit,
pretests_only=is_pretested,
contest_no=part_virtual,
attempt_no=attempt_no,
user_id=uid,
)
def disconnect(self, force=False):
if force:
# Yank the power out.
self.close()
else:
self.send({"name": "disconnect"})
def submit(self, id, problem, language, source):
data = self.get_related_submission_data(id)
self._working = id
self._no_response_job = threading.Timer(20, self._kill_if_no_response)
self.send(
{
"name": "submission-request",
"submission-id": id,
"problem-id": problem,
"language": language,
"source": source,
"time-limit": data.time,
"memory-limit": data.memory,
"short-circuit": data.short_circuit,
"meta": {
"pretests-only": data.pretests_only,
"in-contest": data.contest_no,
"attempt-no": data.attempt_no,
"user": data.user_id,
},
}
)
def _kill_if_no_response(self):
logger.error(
"Judge failed to acknowledge submission: %s: %s", self.name, self._working
)
self.close()
def on_timeout(self):
if self.name:
logger.warning("Judge seems dead: %s: %s", self.name, self._working)
def malformed_packet(self, exception):
logger.exception("Judge sent malformed packet: %s", self.name)
super(JudgeHandler, self).malformed_packet(exception)
def on_submission_processing(self, packet):
_ensure_connection()
id = packet["submission-id"]
if Submission.objects.filter(id=id).update(status="P", judged_on=self.judge):
event.post("sub_%s" % Submission.get_id_secret(id), {"type": "processing"})
self._post_update_submission(id, "processing")
json_log.info(self._make_json_log(packet, action="processing"))
else:
logger.warning("Unknown submission: %s", id)
json_log.error(
self._make_json_log(
packet, action="processing", info="unknown submission"
)
)
def on_submission_wrong_acknowledge(self, packet, expected, got):
json_log.error(
self._make_json_log(
packet, action="processing", info="wrong-acknowledge", expected=expected
)
)
Submission.objects.filter(id=expected).update(
status="IE", result="IE", error=None
)
Submission.objects.filter(id=got, status="QU").update(
status="IE", result="IE", error=None
)
def on_submission_acknowledged(self, packet):
if not packet.get("submission-id", None) == self._working:
logger.error(
"Wrong acknowledgement: %s: %s, expected: %s",
self.name,
packet.get("submission-id", None),
self._working,
)
self.on_submission_wrong_acknowledge(
packet, self._working, packet.get("submission-id", None)
)
self.close()
logger.info("Submission acknowledged: %d", self._working)
if self._no_response_job:
self._no_response_job.cancel()
self._no_response_job = None
self.on_submission_processing(packet)
def abort(self):
self.send({"name": "terminate-submission"})
def get_current_submission(self):
return self._working or None
def ping(self):
self.send({"name": "ping", "when": time.time()})
def on_packet(self, data):
try:
try:
data = json.loads(data)
if "name" not in data:
raise ValueError
except ValueError:
self.on_malformed(data)
else:
handler = self.handlers.get(data["name"], self.on_malformed)
handler(data)
except Exception:
logger.exception("Error in packet handling (Judge-side): %s", self.name)
self._packet_exception()
# You can't crash here because you aren't so sure about the judges
# not being malicious or simply malforms. THIS IS A SERVER!
def _packet_exception(self):
json_log.exception(
self._make_json_log(sub=self._working, info="packet processing exception")
)
def _submission_is_batch(self, id):
if not Submission.objects.filter(id=id).update(batch=True):
logger.warning("Unknown submission: %s", id)
def on_supported_problems(self, packet):
logger.info("%s: Updated problem list", self.name)
self._problems = packet["problems"]
self.problems = dict(self._problems)
if not self.working:
self.judges.update_problems(self)
self.judge.problems.set(
Problem.objects.filter(code__in=list(self.problems.keys()))
)
json_log.info(
self._make_json_log(action="update-problems", count=len(self.problems))
)
def on_grading_begin(self, packet):
logger.info("%s: Grading has begun on: %s", self.name, packet["submission-id"])
self.batch_id = None
if Submission.objects.filter(id=packet["submission-id"]).update(
status="G",
is_pretested=packet["pretested"],
current_testcase=1,
points=0,
batch=False,
judged_date=timezone.now(),
):
SubmissionTestCase.objects.filter(
submission_id=packet["submission-id"]
).delete()
event.post(
"sub_%s" % Submission.get_id_secret(packet["submission-id"]),
{"type": "grading-begin"},
)
self._post_update_submission(packet["submission-id"], "grading-begin")
json_log.info(self._make_json_log(packet, action="grading-begin"))
else:
logger.warning("Unknown submission: %s", packet["submission-id"])
json_log.error(
self._make_json_log(
packet, action="grading-begin", info="unknown submission"
)
)
def on_grading_end(self, packet):
logger.info("%s: Grading has ended on: %s", self.name, packet["submission-id"])
self._free_self(packet)
self.batch_id = None
try:
submission = Submission.objects.get(id=packet["submission-id"])
except Submission.DoesNotExist:
logger.warning("Unknown submission: %s", packet["submission-id"])
json_log.error(
self._make_json_log(
packet, action="grading-end", info="unknown submission"
)
)
return
time = 0
memory = 0
points = 0.0
total = 0
status = 0
status_codes = ["SC", "AC", "WA", "MLE", "TLE", "IR", "RTE", "OLE"]
batches = {} # batch number: (points, total)
for case in SubmissionTestCase.objects.filter(submission=submission):
time += case.time
if not case.batch:
points += case.points
total += case.total
else:
if case.batch in batches:
batches[case.batch][0] += case.points
batches[case.batch][1] += case.total
else:
batches[case.batch] = [case.points, case.total]
memory = max(memory, case.memory)
i = status_codes.index(case.status)
if i > status:
status = i
for i in batches:
points += batches[i][0]
total += batches[i][1]
points = points
total = total
submission.case_points = points
submission.case_total = total
problem = submission.problem
sub_points = round(points / total * problem.points if total > 0 else 0, 3)
if not problem.partial and sub_points != problem.points:
sub_points = 0
submission.status = "D"
submission.time = time
submission.memory = memory
submission.points = sub_points
submission.result = status_codes[status]
submission.save()
json_log.info(
self._make_json_log(
packet,
action="grading-end",
time=time,
memory=memory,
points=sub_points,
total=problem.points,
result=submission.result,
case_points=points,
case_total=total,
user=submission.user_id,
problem=problem.code,
finish=True,
)
)
submission.user._updating_stats_only = True
submission.user.calculate_points()
problem._updating_stats_only = True
problem.update_stats()
submission.update_contest()
finished_submission(submission)
event.post(
"sub_%s" % submission.id_secret,
{
"type": "grading-end",
"time": time,
"memory": memory,
"points": float(points),
"total": float(problem.points),
"result": submission.result,
},
)
if hasattr(submission, "contest"):
participation = submission.contest.participation
event.post("contest_%d" % participation.contest_id, {"type": "update"})
self._post_update_submission(submission.id, "grading-end", done=True)
# Clean up submission source file (if any)
source_file = cache.get(f"submission_source_file:{submission.id}")
if source_file:
filepath = os.path.join(settings.DMOJ_SUBMISSION_ROOT, source_file)
if os.path.exists(filepath):
os.remove(filepath)
def on_compile_error(self, packet):
logger.info(
"%s: Submission failed to compile: %s", self.name, packet["submission-id"]
)
self._free_self(packet)
if Submission.objects.filter(id=packet["submission-id"]).update(
status="CE", result="CE", error=packet["log"]
):
event.post(
"sub_%s" % Submission.get_id_secret(packet["submission-id"]),
{
"type": "compile-error",
"log": packet["log"],
},
)
self._post_update_submission(
packet["submission-id"], "compile-error", done=True
)
json_log.info(
self._make_json_log(
packet,
action="compile-error",
log=packet["log"],
finish=True,
result="CE",
)
)
else:
logger.warning("Unknown submission: %s", packet["submission-id"])
json_log.error(
self._make_json_log(
packet,
action="compile-error",
info="unknown submission",
log=packet["log"],
finish=True,
result="CE",
)
)
def on_compile_message(self, packet):
logger.info(
"%s: Submission generated compiler messages: %s",
self.name,
packet["submission-id"],
)
if Submission.objects.filter(id=packet["submission-id"]).update(
error=packet["log"]
):
event.post(
"sub_%s" % Submission.get_id_secret(packet["submission-id"]),
{"type": "compile-message"},
)
json_log.info(
self._make_json_log(packet, action="compile-message", log=packet["log"])
)
else:
logger.warning("Unknown submission: %s", packet["submission-id"])
json_log.error(
self._make_json_log(
packet,
action="compile-message",
info="unknown submission",
log=packet["log"],
)
)
def on_internal_error(self, packet):
try:
raise ValueError("\n\n" + packet["message"])
except ValueError:
logger.exception(
"Judge %s failed while handling submission %s",
self.name,
packet["submission-id"],
)
self._free_self(packet)
id = packet["submission-id"]
if Submission.objects.filter(id=id).update(
status="IE", result="IE", error=packet["message"]
):
event.post(
"sub_%s" % Submission.get_id_secret(id), {"type": "internal-error"}
)
self._post_update_submission(id, "internal-error", done=True)
json_log.info(
self._make_json_log(
packet,
action="internal-error",
message=packet["message"],
finish=True,
result="IE",
)
)
else:
logger.warning("Unknown submission: %s", id)
json_log.error(
self._make_json_log(
packet,
action="internal-error",
info="unknown submission",
message=packet["message"],
finish=True,
result="IE",
)
)
def on_submission_terminated(self, packet):
logger.info("%s: Submission aborted: %s", self.name, packet["submission-id"])
self._free_self(packet)
if Submission.objects.filter(id=packet["submission-id"]).update(
status="AB", result="AB"
):
event.post(
"sub_%s" % Submission.get_id_secret(packet["submission-id"]),
{"type": "aborted-submission"},
)
self._post_update_submission(
packet["submission-id"], "terminated", done=True
)
json_log.info(
self._make_json_log(packet, action="aborted", finish=True, result="AB")
)
else:
logger.warning("Unknown submission: %s", packet["submission-id"])
json_log.error(
self._make_json_log(
packet,
action="aborted",
info="unknown submission",
finish=True,
result="AB",
)
)
def on_batch_begin(self, packet):
logger.info("%s: Batch began on: %s", self.name, packet["submission-id"])
self.in_batch = True
if self.batch_id is None:
self.batch_id = 0
self._submission_is_batch(packet["submission-id"])
self.batch_id += 1
json_log.info(
self._make_json_log(packet, action="batch-begin", batch=self.batch_id)
)
def on_batch_end(self, packet):
self.in_batch = False
logger.info("%s: Batch ended on: %s", self.name, packet["submission-id"])
json_log.info(
self._make_json_log(packet, action="batch-end", batch=self.batch_id)
)
def on_test_case(
self,
packet,
max_feedback=SubmissionTestCase._meta.get_field("feedback").max_length,
):
logger.info(
"%s: %d test case(s) completed on: %s",
self.name,
len(packet["cases"]),
packet["submission-id"],
)
id = packet["submission-id"]
updates = packet["cases"]
max_position = max(map(itemgetter("position"), updates))
sum_points = sum(map(itemgetter("points"), updates))
if not Submission.objects.filter(id=id).update(
current_testcase=max_position + 1, points=F("points") + sum_points
):
logger.warning("Unknown submission: %s", id)
json_log.error(
self._make_json_log(
packet, action="test-case", info="unknown submission"
)
)
return
bulk_test_case_updates = []
for result in updates:
test_case = SubmissionTestCase(submission_id=id, case=result["position"])
status = result["status"]
if status & 4:
test_case.status = "TLE"
elif status & 8:
test_case.status = "MLE"
elif status & 64:
test_case.status = "OLE"
elif status & 2:
test_case.status = "RTE"
elif status & 16:
test_case.status = "IR"
elif status & 1:
test_case.status = "WA"
elif status & 32:
test_case.status = "SC"
else:
test_case.status = "AC"
test_case.time = result["time"]
test_case.memory = result["memory"]
test_case.points = result["points"]
test_case.total = result["total-points"]
test_case.batch = self.batch_id if self.in_batch else None
test_case.feedback = (result.get("feedback") or "")[:max_feedback]
test_case.extended_feedback = result.get("extended-feedback") or ""
test_case.output = result["output"]
bulk_test_case_updates.append(test_case)
json_log.info(
self._make_json_log(
packet,
action="test-case",
case=test_case.case,
batch=test_case.batch,
time=test_case.time,
memory=test_case.memory,
feedback=test_case.feedback,
extended_feedback=test_case.extended_feedback,
output=test_case.output,
points=test_case.points,
total=test_case.total,
status=test_case.status,
)
)
do_post = True
if id in self.update_counter:
cnt, reset = self.update_counter[id]
cnt += 1
if time.monotonic() - reset > UPDATE_RATE_TIME:
del self.update_counter[id]
else:
self.update_counter[id] = (cnt, reset)
if cnt > UPDATE_RATE_LIMIT:
do_post = False
if id not in self.update_counter:
self.update_counter[id] = (1, time.monotonic())
if do_post:
event.post(
"sub_%s" % Submission.get_id_secret(id),
{
"type": "test-case",
"id": max_position,
},
)
self._post_update_submission(id, state="test-case")
SubmissionTestCase.objects.bulk_create(bulk_test_case_updates)
def on_malformed(self, packet):
logger.error("%s: Malformed packet: %s", self.name, packet)
json_log.exception(
self._make_json_log(sub=self._working, info="malformed json packet")
)
def on_ping_response(self, packet):
end = time.time()
self._ping_average.append(end - packet["when"])
self._time_delta.append((end + packet["when"]) / 2 - packet["time"])
self.latency = sum(self._ping_average) / len(self._ping_average)
self.time_delta = sum(self._time_delta) / len(self._time_delta)
self.load = packet["load"]
self._update_ping()
def _free_self(self, packet):
self.judges.on_judge_free(self, packet["submission-id"])
def _ping_thread(self):
try:
while True:
self.ping()
if self._stop_ping.wait(10):
break
except Exception:
logger.exception("Ping error in %s", self.name)
self.close()
raise
def _make_json_log(self, packet=None, sub=None, **kwargs):
data = {
"judge": self.name,
"address": self.judge_address,
}
if sub is None and packet is not None:
sub = packet.get("submission-id")
if sub is not None:
data["submission"] = sub
data.update(kwargs)
return json.dumps(data)
def _post_update_submission(self, id, state, done=False):
if self._submission_cache_id == id:
data = self._submission_cache
else:
self._submission_cache = data = (
Submission.objects.filter(id=id)
.values(
"problem__is_public",
"contest_object__key",
"user_id",
"problem_id",
"status",
"language__key",
)
.get()
)
self._submission_cache_id = id
if data["problem__is_public"]:
event.post(
"submissions",
{
"type": "done-submission" if done else "update-submission",
"state": state,
"id": id,
"contest": data["contest_object__key"],
"user": data["user_id"],
"problem": data["problem_id"],
"status": data["status"],
"language": data["language__key"],
},
)
def on_cleanup(self):
db.connection.close()