NDOJ/judge/bridge/base_handler.py

197 lines
6.3 KiB
Python
Raw Permalink Normal View History

2020-07-19 21:27:14 +00:00
import logging
import socket
import struct
import zlib
from itertools import chain
from netaddr import IPGlob, IPSet
from judge.utils.unicode import utf8text
logger = logging.getLogger('judge.bridge')
size_pack = struct.Struct('!I')
assert size_pack.size == 4
MAX_ALLOWED_PACKET_SIZE = 8 * 1024 * 1024
def proxy_list(human_readable):
globs = []
addrs = []
for item in human_readable:
if '*' in item or '-' in item:
globs.append(IPGlob(item))
else:
addrs.append(item)
return IPSet(chain(chain.from_iterable(globs), addrs))
class Disconnect(Exception):
pass
# socketserver.BaseRequestHandler does all the handling in __init__,
# making it impossible to inherit __init__ sanely. While it lets you
# use setup(), most tools will complain about uninitialized variables.
# This metaclass will allow sane __init__ behaviour while also magically
# calling the methods that handles the request.
class RequestHandlerMeta(type):
def __call__(cls, *args, **kwargs):
handler = super().__call__(*args, **kwargs)
handler.on_connect()
try:
handler.handle()
except BaseException:
logger.exception('Error in base packet handling')
raise
finally:
handler.on_disconnect()
class ZlibPacketHandler(metaclass=RequestHandlerMeta):
proxies = []
def __init__(self, request, client_address, server):
self.request = request
self.server = server
self.client_address = client_address
self.server_address = server.server_address
self._initial_tag = None
self._got_packet = False
@property
def timeout(self):
return self.request.gettimeout()
@timeout.setter
def timeout(self, timeout):
self.request.settimeout(timeout or None)
def read_sized_packet(self, size, initial=None):
if size > MAX_ALLOWED_PACKET_SIZE:
logger.log(logging.WARNING if self._got_packet else logging.INFO,
'Disconnecting client due to too-large message size (%d bytes): %s', size, self.client_address)
raise Disconnect()
buffer = []
remainder = size
if initial:
buffer.append(initial)
remainder -= len(initial)
assert remainder >= 0
while remainder:
data = self.request.recv(remainder)
remainder -= len(data)
buffer.append(data)
self._on_packet(b''.join(buffer))
def parse_proxy_protocol(self, line):
words = line.split()
if len(words) < 2:
raise Disconnect()
if words[1] == b'TCP4':
if len(words) != 6:
raise Disconnect()
self.client_address = (utf8text(words[2]), utf8text(words[4]))
self.server_address = (utf8text(words[3]), utf8text(words[5]))
elif words[1] == b'TCP6':
self.client_address = (utf8text(words[2]), utf8text(words[4]), 0, 0)
self.server_address = (utf8text(words[3]), utf8text(words[5]), 0, 0)
elif words[1] != b'UNKNOWN':
raise Disconnect()
def read_size(self, buffer=b''):
while len(buffer) < size_pack.size:
recv = self.request.recv(size_pack.size - len(buffer))
if not recv:
raise Disconnect()
buffer += recv
return size_pack.unpack(buffer)[0]
def read_proxy_header(self, buffer=b''):
# Max line length for PROXY protocol is 107, and we received 4 already.
while b'\r\n' not in buffer:
if len(buffer) > 107:
raise Disconnect()
data = self.request.recv(107)
if not data:
raise Disconnect()
buffer += data
return buffer
def _on_packet(self, data):
decompressed = zlib.decompress(data).decode('utf-8')
self._got_packet = True
self.on_packet(decompressed)
def on_packet(self, data):
raise NotImplementedError()
def on_connect(self):
pass
def on_disconnect(self):
pass
def on_timeout(self):
pass
def handle(self):
try:
tag = self.read_size()
self._initial_tag = size_pack.pack(tag)
if self.client_address[0] in self.proxies and self._initial_tag == b'PROX':
proxy, _, remainder = self.read_proxy_header(self._initial_tag).partition(b'\r\n')
self.parse_proxy_protocol(proxy)
while remainder:
while len(remainder) < size_pack.size:
self.read_sized_packet(self.read_size(remainder))
break
size = size_pack.unpack(remainder[:size_pack.size])[0]
remainder = remainder[size_pack.size:]
if len(remainder) <= size:
self.read_sized_packet(size, remainder)
break
self._on_packet(remainder[:size])
remainder = remainder[size:]
else:
self.read_sized_packet(tag)
while True:
self.read_sized_packet(self.read_size())
except Disconnect:
return
except zlib.error:
if self._got_packet:
logger.warning('Encountered zlib error during packet handling, disconnecting client: %s',
self.client_address, exc_info=True)
else:
logger.info('Potentially wrong protocol (zlib error): %s: %r', self.client_address, self._initial_tag,
exc_info=True)
except socket.timeout:
if self._got_packet:
logger.info('Socket timed out: %s', self.client_address)
self.on_timeout()
else:
logger.info('Potentially wrong protocol: %s: %r', self.client_address, self._initial_tag)
except socket.error as e:
# When a gevent socket is shutdown, gevent cancels all waits, causing recv to raise cancel_wait_ex.
if e.__class__.__name__ == 'cancel_wait_ex':
return
raise
def send(self, data):
compressed = zlib.compress(data.encode('utf-8'))
self.request.sendall(size_pack.pack(len(compressed)) + compressed)
def close(self):
self.request.shutdown(socket.SHUT_RDWR)