import logging import socket import struct import zlib from itertools import chain from netaddr import IPGlob, IPSet from judge.utils.unicode import utf8text logger = logging.getLogger('judge.bridge') size_pack = struct.Struct('!I') assert size_pack.size == 4 MAX_ALLOWED_PACKET_SIZE = 8 * 1024 * 1024 def proxy_list(human_readable): globs = [] addrs = [] for item in human_readable: if '*' in item or '-' in item: globs.append(IPGlob(item)) else: addrs.append(item) return IPSet(chain(chain.from_iterable(globs), addrs)) class Disconnect(Exception): pass # socketserver.BaseRequestHandler does all the handling in __init__, # making it impossible to inherit __init__ sanely. While it lets you # use setup(), most tools will complain about uninitialized variables. # This metaclass will allow sane __init__ behaviour while also magically # calling the methods that handles the request. class RequestHandlerMeta(type): def __call__(cls, *args, **kwargs): handler = super().__call__(*args, **kwargs) handler.on_connect() try: handler.handle() except BaseException: logger.exception('Error in base packet handling') raise finally: handler.on_disconnect() class ZlibPacketHandler(metaclass=RequestHandlerMeta): proxies = [] def __init__(self, request, client_address, server): self.request = request self.server = server self.client_address = client_address self.server_address = server.server_address self._initial_tag = None self._got_packet = False @property def timeout(self): return self.request.gettimeout() @timeout.setter def timeout(self, timeout): self.request.settimeout(timeout or None) def read_sized_packet(self, size, initial=None): if size > MAX_ALLOWED_PACKET_SIZE: logger.log(logging.WARNING if self._got_packet else logging.INFO, 'Disconnecting client due to too-large message size (%d bytes): %s', size, self.client_address) raise Disconnect() buffer = [] remainder = size if initial: buffer.append(initial) remainder -= len(initial) assert remainder >= 0 while remainder: data = self.request.recv(remainder) remainder -= len(data) buffer.append(data) self._on_packet(b''.join(buffer)) def parse_proxy_protocol(self, line): words = line.split() if len(words) < 2: raise Disconnect() if words[1] == b'TCP4': if len(words) != 6: raise Disconnect() self.client_address = (utf8text(words[2]), utf8text(words[4])) self.server_address = (utf8text(words[3]), utf8text(words[5])) elif words[1] == b'TCP6': self.client_address = (utf8text(words[2]), utf8text(words[4]), 0, 0) self.server_address = (utf8text(words[3]), utf8text(words[5]), 0, 0) elif words[1] != b'UNKNOWN': raise Disconnect() def read_size(self, buffer=b''): while len(buffer) < size_pack.size: recv = self.request.recv(size_pack.size - len(buffer)) if not recv: raise Disconnect() buffer += recv return size_pack.unpack(buffer)[0] def read_proxy_header(self, buffer=b''): # Max line length for PROXY protocol is 107, and we received 4 already. while b'\r\n' not in buffer: if len(buffer) > 107: raise Disconnect() data = self.request.recv(107) if not data: raise Disconnect() buffer += data return buffer def _on_packet(self, data): decompressed = zlib.decompress(data).decode('utf-8') self._got_packet = True self.on_packet(decompressed) def on_packet(self, data): raise NotImplementedError() def on_connect(self): pass def on_disconnect(self): pass def on_timeout(self): pass def handle(self): try: tag = self.read_size() self._initial_tag = size_pack.pack(tag) if self.client_address[0] in self.proxies and self._initial_tag == b'PROX': proxy, _, remainder = self.read_proxy_header(self._initial_tag).partition(b'\r\n') self.parse_proxy_protocol(proxy) while remainder: while len(remainder) < size_pack.size: self.read_sized_packet(self.read_size(remainder)) break size = size_pack.unpack(remainder[:size_pack.size])[0] remainder = remainder[size_pack.size:] if len(remainder) <= size: self.read_sized_packet(size, remainder) break self._on_packet(remainder[:size]) remainder = remainder[size:] else: self.read_sized_packet(tag) while True: self.read_sized_packet(self.read_size()) except Disconnect: return except zlib.error: if self._got_packet: logger.warning('Encountered zlib error during packet handling, disconnecting client: %s', self.client_address, exc_info=True) else: logger.info('Potentially wrong protocol (zlib error): %s: %r', self.client_address, self._initial_tag, exc_info=True) except socket.timeout: if self._got_packet: logger.info('Socket timed out: %s', self.client_address) self.on_timeout() else: logger.info('Potentially wrong protocol: %s: %r', self.client_address, self._initial_tag) except socket.error as e: # When a gevent socket is shutdown, gevent cancels all waits, causing recv to raise cancel_wait_ex. if e.__class__.__name__ == 'cancel_wait_ex': return raise def send(self, data): compressed = zlib.compress(data.encode('utf-8')) self.request.sendall(size_pack.pack(len(compressed)) + compressed) def close(self): self.request.shutdown(socket.SHUT_RDWR)