Reformat using black

This commit is contained in:
cuom1999 2022-05-14 12:57:27 -05:00
parent efee4ad081
commit a87fb49918
221 changed files with 19127 additions and 7310 deletions

View file

@ -8,9 +8,9 @@ from netaddr import IPGlob, IPSet
from judge.utils.unicode import utf8text
logger = logging.getLogger('judge.bridge')
logger = logging.getLogger("judge.bridge")
size_pack = struct.Struct('!I')
size_pack = struct.Struct("!I")
assert size_pack.size == 4
MAX_ALLOWED_PACKET_SIZE = 8 * 1024 * 1024
@ -20,7 +20,7 @@ def proxy_list(human_readable):
globs = []
addrs = []
for item in human_readable:
if '*' in item or '-' in item:
if "*" in item or "-" in item:
globs.append(IPGlob(item))
else:
addrs.append(item)
@ -43,7 +43,7 @@ class RequestHandlerMeta(type):
try:
handler.handle()
except BaseException:
logger.exception('Error in base packet handling')
logger.exception("Error in base packet handling")
raise
finally:
handler.on_disconnect()
@ -70,8 +70,12 @@ class ZlibPacketHandler(metaclass=RequestHandlerMeta):
def read_sized_packet(self, size, initial=None):
if size > MAX_ALLOWED_PACKET_SIZE:
logger.log(logging.WARNING if self._got_packet else logging.INFO,
'Disconnecting client due to too-large message size (%d bytes): %s', size, self.client_address)
logger.log(
logging.WARNING if self._got_packet else logging.INFO,
"Disconnecting client due to too-large message size (%d bytes): %s",
size,
self.client_address,
)
raise Disconnect()
buffer = []
@ -86,7 +90,7 @@ class ZlibPacketHandler(metaclass=RequestHandlerMeta):
data = self.request.recv(remainder)
remainder -= len(data)
buffer.append(data)
self._on_packet(b''.join(buffer))
self._on_packet(b"".join(buffer))
def parse_proxy_protocol(self, line):
words = line.split()
@ -94,18 +98,18 @@ class ZlibPacketHandler(metaclass=RequestHandlerMeta):
if len(words) < 2:
raise Disconnect()
if words[1] == b'TCP4':
if words[1] == b"TCP4":
if len(words) != 6:
raise Disconnect()
self.client_address = (utf8text(words[2]), utf8text(words[4]))
self.server_address = (utf8text(words[3]), utf8text(words[5]))
elif words[1] == b'TCP6':
elif words[1] == b"TCP6":
self.client_address = (utf8text(words[2]), utf8text(words[4]), 0, 0)
self.server_address = (utf8text(words[3]), utf8text(words[5]), 0, 0)
elif words[1] != b'UNKNOWN':
elif words[1] != b"UNKNOWN":
raise Disconnect()
def read_size(self, buffer=b''):
def read_size(self, buffer=b""):
while len(buffer) < size_pack.size:
recv = self.request.recv(size_pack.size - len(buffer))
if not recv:
@ -113,9 +117,9 @@ class ZlibPacketHandler(metaclass=RequestHandlerMeta):
buffer += recv
return size_pack.unpack(buffer)[0]
def read_proxy_header(self, buffer=b''):
def read_proxy_header(self, buffer=b""):
# Max line length for PROXY protocol is 107, and we received 4 already.
while b'\r\n' not in buffer:
while b"\r\n" not in buffer:
if len(buffer) > 107:
raise Disconnect()
data = self.request.recv(107)
@ -125,7 +129,7 @@ class ZlibPacketHandler(metaclass=RequestHandlerMeta):
return buffer
def _on_packet(self, data):
decompressed = zlib.decompress(data).decode('utf-8')
decompressed = zlib.decompress(data).decode("utf-8")
self._got_packet = True
self.on_packet(decompressed)
@ -145,8 +149,10 @@ class ZlibPacketHandler(metaclass=RequestHandlerMeta):
try:
tag = self.read_size()
self._initial_tag = size_pack.pack(tag)
if self.client_address[0] in self.proxies and self._initial_tag == b'PROX':
proxy, _, remainder = self.read_proxy_header(self._initial_tag).partition(b'\r\n')
if self.client_address[0] in self.proxies and self._initial_tag == b"PROX":
proxy, _, remainder = self.read_proxy_header(
self._initial_tag
).partition(b"\r\n")
self.parse_proxy_protocol(proxy)
while remainder:
@ -154,8 +160,8 @@ class ZlibPacketHandler(metaclass=RequestHandlerMeta):
self.read_sized_packet(self.read_size(remainder))
break
size = size_pack.unpack(remainder[:size_pack.size])[0]
remainder = remainder[size_pack.size:]
size = size_pack.unpack(remainder[: size_pack.size])[0]
remainder = remainder[size_pack.size :]
if len(remainder) <= size:
self.read_sized_packet(size, remainder)
break
@ -171,25 +177,36 @@ class ZlibPacketHandler(metaclass=RequestHandlerMeta):
return
except zlib.error:
if self._got_packet:
logger.warning('Encountered zlib error during packet handling, disconnecting client: %s',
self.client_address, exc_info=True)
logger.warning(
"Encountered zlib error during packet handling, disconnecting client: %s",
self.client_address,
exc_info=True,
)
else:
logger.info('Potentially wrong protocol (zlib error): %s: %r', self.client_address, self._initial_tag,
exc_info=True)
logger.info(
"Potentially wrong protocol (zlib error): %s: %r",
self.client_address,
self._initial_tag,
exc_info=True,
)
except socket.timeout:
if self._got_packet:
logger.info('Socket timed out: %s', self.client_address)
logger.info("Socket timed out: %s", self.client_address)
self.on_timeout()
else:
logger.info('Potentially wrong protocol: %s: %r', self.client_address, self._initial_tag)
logger.info(
"Potentially wrong protocol: %s: %r",
self.client_address,
self._initial_tag,
)
except socket.error as e:
# When a gevent socket is shutdown, gevent cancels all waits, causing recv to raise cancel_wait_ex.
if e.__class__.__name__ == 'cancel_wait_ex':
if e.__class__.__name__ == "cancel_wait_ex":
return
raise
def send(self, data):
compressed = zlib.compress(data.encode('utf-8'))
compressed = zlib.compress(data.encode("utf-8"))
self.request.sendall(size_pack.pack(len(compressed)) + compressed)
def close(self):

View file

@ -11,7 +11,7 @@ from judge.bridge.judge_list import JudgeList
from judge.bridge.server import Server
from judge.models import Judge, Submission
logger = logging.getLogger('judge.bridge')
logger = logging.getLogger("judge.bridge")
def reset_judges():
@ -20,12 +20,17 @@ def reset_judges():
def judge_daemon():
reset_judges()
Submission.objects.filter(status__in=Submission.IN_PROGRESS_GRADING_STATUS) \
.update(status='IE', result='IE', error=None)
Submission.objects.filter(status__in=Submission.IN_PROGRESS_GRADING_STATUS).update(
status="IE", result="IE", error=None
)
judges = JudgeList()
judge_server = Server(settings.BRIDGED_JUDGE_ADDRESS, partial(JudgeHandler, judges=judges))
django_server = Server(settings.BRIDGED_DJANGO_ADDRESS, partial(DjangoHandler, judges=judges))
judge_server = Server(
settings.BRIDGED_JUDGE_ADDRESS, partial(JudgeHandler, judges=judges)
)
django_server = Server(
settings.BRIDGED_DJANGO_ADDRESS, partial(DjangoHandler, judges=judges)
)
threading.Thread(target=django_server.serve_forever).start()
threading.Thread(target=judge_server.serve_forever).start()
@ -33,7 +38,7 @@ def judge_daemon():
stop = threading.Event()
def signal_handler(signum, _):
logger.info('Exiting due to %s', signal.Signals(signum).name)
logger.info("Exiting due to %s", signal.Signals(signum).name)
stop.set()
signal.signal(signal.SIGINT, signal_handler)

View file

@ -4,8 +4,8 @@ import struct
from judge.bridge.base_handler import Disconnect, ZlibPacketHandler
logger = logging.getLogger('judge.bridge')
size_pack = struct.Struct('!I')
logger = logging.getLogger("judge.bridge")
size_pack = struct.Struct("!I")
class DjangoHandler(ZlibPacketHandler):
@ -13,47 +13,52 @@ class DjangoHandler(ZlibPacketHandler):
super().__init__(request, client_address, server)
self.handlers = {
'submission-request': self.on_submission,
'terminate-submission': self.on_termination,
'disconnect-judge': self.on_disconnect_request,
"submission-request": self.on_submission,
"terminate-submission": self.on_termination,
"disconnect-judge": self.on_disconnect_request,
}
self.judges = judges
def send(self, data):
super().send(json.dumps(data, separators=(',', ':')))
super().send(json.dumps(data, separators=(",", ":")))
def on_packet(self, packet):
packet = json.loads(packet)
try:
result = self.handlers.get(packet.get('name', None), self.on_malformed)(packet)
result = self.handlers.get(packet.get("name", None), self.on_malformed)(
packet
)
except Exception:
logger.exception('Error in packet handling (Django-facing)')
result = {'name': 'bad-request'}
logger.exception("Error in packet handling (Django-facing)")
result = {"name": "bad-request"}
self.send(result)
raise Disconnect()
def on_submission(self, data):
id = data['submission-id']
problem = data['problem-id']
language = data['language']
source = data['source']
judge_id = data['judge-id']
priority = data['priority']
id = data["submission-id"]
problem = data["problem-id"]
language = data["language"]
source = data["source"]
judge_id = data["judge-id"]
priority = data["priority"]
if not self.judges.check_priority(priority):
return {'name': 'bad-request'}
return {"name": "bad-request"}
self.judges.judge(id, problem, language, source, judge_id, priority)
return {'name': 'submission-received', 'submission-id': id}
return {"name": "submission-received", "submission-id": id}
def on_termination(self, data):
return {'name': 'submission-received', 'judge-aborted': self.judges.abort(data['submission-id'])}
return {
"name": "submission-received",
"judge-aborted": self.judges.abort(data["submission-id"]),
}
def on_disconnect_request(self, data):
judge_id = data['judge-id']
force = data['force']
judge_id = data["judge-id"]
force = data["force"]
self.judges.disconnect(judge_id, force=force)
def on_malformed(self, packet):
logger.error('Malformed packet: %s', packet)
logger.error("Malformed packet: %s", packet)
def on_close(self):
self._to_kill = False

View file

@ -4,7 +4,7 @@ import struct
import time
import zlib
size_pack = struct.Struct('!I')
size_pack = struct.Struct("!I")
def open_connection():
@ -13,69 +13,70 @@ def open_connection():
def zlibify(data):
data = zlib.compress(data.encode('utf-8'))
data = zlib.compress(data.encode("utf-8"))
return size_pack.pack(len(data)) + data
def dezlibify(data, skip_head=True):
if skip_head:
data = data[size_pack.size:]
return zlib.decompress(data).decode('utf-8')
data = data[size_pack.size :]
return zlib.decompress(data).decode("utf-8")
def main():
global host, port
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('-l', '--host', default='localhost')
parser.add_argument('-p', '--port', default=9999, type=int)
parser.add_argument("-l", "--host", default="localhost")
parser.add_argument("-p", "--port", default=9999, type=int)
args = parser.parse_args()
host, port = args.host, args.port
print('Opening idle connection:', end=' ')
print("Opening idle connection:", end=" ")
s1 = open_connection()
print('Success')
print('Opening hello world connection:', end=' ')
print("Success")
print("Opening hello world connection:", end=" ")
s2 = open_connection()
print('Success')
print('Sending Hello, World!', end=' ')
s2.sendall(zlibify('Hello, World!'))
print('Success')
print('Testing blank connection:', end=' ')
print("Success")
print("Sending Hello, World!", end=" ")
s2.sendall(zlibify("Hello, World!"))
print("Success")
print("Testing blank connection:", end=" ")
s3 = open_connection()
s3.close()
print('Success')
print("Success")
result = dezlibify(s2.recv(1024))
assert result == 'Hello, World!'
assert result == "Hello, World!"
print(result)
s2.close()
print('Large random data test:', end=' ')
print("Large random data test:", end=" ")
s4 = open_connection()
data = os.urandom(1000000).decode('iso-8859-1')
print('Generated', end=' ')
data = os.urandom(1000000).decode("iso-8859-1")
print("Generated", end=" ")
s4.sendall(zlibify(data))
print('Sent', end=' ')
result = b''
print("Sent", end=" ")
result = b""
while len(result) < size_pack.size:
result += s4.recv(1024)
size = size_pack.unpack(result[:size_pack.size])[0]
result = result[size_pack.size:]
size = size_pack.unpack(result[: size_pack.size])[0]
result = result[size_pack.size :]
while len(result) < size:
result += s4.recv(1024)
print('Received', end=' ')
print("Received", end=" ")
assert dezlibify(result, False) == data
print('Success')
print("Success")
s4.close()
print('Test malformed connection:', end=' ')
print("Test malformed connection:", end=" ")
s5 = open_connection()
s5.sendall(data[:100000].encode('utf-8'))
s5.sendall(data[:100000].encode("utf-8"))
s5.close()
print('Success')
print('Waiting for timeout to close idle connection:', end=' ')
print("Success")
print("Waiting for timeout to close idle connection:", end=" ")
time.sleep(6)
print('Done')
print("Done")
s1.close()
if __name__ == '__main__':
if __name__ == "__main__":
main()

View file

@ -3,19 +3,22 @@ from judge.bridge.base_handler import ZlibPacketHandler
class EchoPacketHandler(ZlibPacketHandler):
def on_connect(self):
print('New client:', self.client_address)
print("New client:", self.client_address)
self.timeout = 5
def on_timeout(self):
print('Inactive client:', self.client_address)
print("Inactive client:", self.client_address)
def on_packet(self, data):
self.timeout = None
print('Data from %s: %r' % (self.client_address, data[:30] if len(data) > 30 else data))
print(
"Data from %s: %r"
% (self.client_address, data[:30] if len(data) > 30 else data)
)
self.send(data)
def on_disconnect(self):
print('Closed client:', self.client_address)
print("Closed client:", self.client_address)
def main():
@ -23,9 +26,9 @@ def main():
from judge.bridge.server import Server
parser = argparse.ArgumentParser()
parser.add_argument('-l', '--host', action='append')
parser.add_argument('-p', '--port', type=int, action='append')
parser.add_argument('-P', '--proxy', action='append')
parser.add_argument("-l", "--host", action="append")
parser.add_argument("-p", "--port", type=int, action="append")
parser.add_argument("-P", "--proxy", action="append")
args = parser.parse_args()
class Handler(EchoPacketHandler):
@ -35,5 +38,5 @@ def main():
server.serve_forever()
if __name__ == '__main__':
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load diff

View file

@ -8,9 +8,9 @@ try:
except ImportError:
from pyllist import dllist
logger = logging.getLogger('judge.bridge')
logger = logging.getLogger("judge.bridge")
PriorityMarker = namedtuple('PriorityMarker', 'priority')
PriorityMarker = namedtuple("PriorityMarker", "priority")
class JudgeList(object):
@ -18,7 +18,9 @@ class JudgeList(object):
def __init__(self):
self.queue = dllist()
self.priority = [self.queue.append(PriorityMarker(i)) for i in range(self.priorities)]
self.priority = [
self.queue.append(PriorityMarker(i)) for i in range(self.priorities)
]
self.judges = set()
self.node_map = {}
self.submission_map = {}
@ -32,11 +34,19 @@ class JudgeList(object):
id, problem, language, source, judge_id = node.value
if judge.can_judge(problem, language, judge_id):
self.submission_map[id] = judge
logger.info('Dispatched queued submission %d: %s', id, judge.name)
logger.info(
"Dispatched queued submission %d: %s", id, judge.name
)
try:
judge.submit(id, problem, language, source)
except Exception:
logger.exception('Failed to dispatch %d (%s, %s) to %s', id, problem, language, judge.name)
logger.exception(
"Failed to dispatch %d (%s, %s) to %s",
id,
problem,
language,
judge.name,
)
self.judges.remove(judge)
return
self.queue.remove(node)
@ -76,14 +86,14 @@ class JudgeList(object):
def on_judge_free(self, judge, submission):
with self.lock:
logger.info('Judge available after grading %d: %s', submission, judge.name)
logger.info("Judge available after grading %d: %s", submission, judge.name)
del self.submission_map[submission]
judge._working = False
self._handle_free_judge(judge)
def abort(self, submission):
with self.lock:
logger.info('Abort request: %d', submission)
logger.info("Abort request: %d", submission)
try:
self.submission_map[submission].abort()
return True
@ -108,21 +118,33 @@ class JudgeList(object):
return
candidates = [
judge for judge in self.judges if not judge.working and judge.can_judge(problem, language, judge_id)
judge
for judge in self.judges
if not judge.working and judge.can_judge(problem, language, judge_id)
]
if judge_id:
logger.info('Specified judge %s is%savailable', judge_id, ' ' if candidates else ' not ')
logger.info(
"Specified judge %s is%savailable",
judge_id,
" " if candidates else " not ",
)
else:
logger.info('Free judges: %d', len(candidates))
logger.info("Free judges: %d", len(candidates))
if candidates:
# Schedule the submission on the judge reporting least load.
judge = min(candidates, key=attrgetter('load'))
logger.info('Dispatched submission %d to: %s', id, judge.name)
judge = min(candidates, key=attrgetter("load"))
logger.info("Dispatched submission %d to: %s", id, judge.name)
self.submission_map[id] = judge
try:
judge.submit(id, problem, language, source)
except Exception:
logger.exception('Failed to dispatch %d (%s, %s) to %s', id, problem, language, judge.name)
logger.exception(
"Failed to dispatch %d (%s, %s) to %s",
id,
problem,
language,
judge.name,
)
self.judges.discard(judge)
return self.judge(id, problem, language, source, judge_id, priority)
else:
@ -130,4 +152,4 @@ class JudgeList(object):
(id, problem, language, source, judge_id),
self.priority[priority],
)
logger.info('Queued submission: %d', id)
logger.info("Queued submission: %d", id)

View file

@ -12,7 +12,9 @@ class Server:
self._shutdown = threading.Event()
def serve_forever(self):
threads = [threading.Thread(target=server.serve_forever) for server in self.servers]
threads = [
threading.Thread(target=server.serve_forever) for server in self.servers
]
for thread in threads:
thread.daemon = True
thread.start()