import logging import socket import struct import zlib from itertools import chain from netaddr import IPGlob, IPSet from judge.utils.unicode import utf8text logger = logging.getLogger("judge.bridge") size_pack = struct.Struct("!I") assert size_pack.size == 4 MAX_ALLOWED_PACKET_SIZE = 8 * 1024 * 1024 def proxy_list(human_readable): globs = [] addrs = [] for item in human_readable: if "*" in item or "-" in item: globs.append(IPGlob(item)) else: addrs.append(item) return IPSet(chain(chain.from_iterable(globs), addrs)) class Disconnect(Exception): pass # socketserver.BaseRequestHandler does all the handling in __init__, # making it impossible to inherit __init__ sanely. While it lets you # use setup(), most tools will complain about uninitialized variables. # This metaclass will allow sane __init__ behaviour while also magically # calling the methods that handles the request. class RequestHandlerMeta(type): def __call__(cls, *args, **kwargs): handler = super().__call__(*args, **kwargs) handler.on_connect() try: handler.handle() except BaseException: logger.exception("Error in base packet handling") raise finally: handler.on_disconnect() class ZlibPacketHandler(metaclass=RequestHandlerMeta): proxies = [] def __init__(self, request, client_address, server): self.request = request self.server = server self.client_address = client_address self.server_address = server.server_address self._initial_tag = None self._got_packet = False @property def timeout(self): return self.request.gettimeout() @timeout.setter def timeout(self, timeout): self.request.settimeout(timeout or None) def read_sized_packet(self, size, initial=None): if size > MAX_ALLOWED_PACKET_SIZE: logger.log( logging.WARNING if self._got_packet else logging.INFO, "Disconnecting client due to too-large message size (%d bytes): %s", size, self.client_address, ) raise Disconnect() buffer = [] remainder = size if initial: buffer.append(initial) remainder -= len(initial) assert remainder >= 0 while remainder: data = self.request.recv(remainder) remainder -= len(data) buffer.append(data) self._on_packet(b"".join(buffer)) def parse_proxy_protocol(self, line): words = line.split() if len(words) < 2: raise Disconnect() if words[1] == b"TCP4": if len(words) != 6: raise Disconnect() self.client_address = (utf8text(words[2]), utf8text(words[4])) self.server_address = (utf8text(words[3]), utf8text(words[5])) elif words[1] == b"TCP6": self.client_address = (utf8text(words[2]), utf8text(words[4]), 0, 0) self.server_address = (utf8text(words[3]), utf8text(words[5]), 0, 0) elif words[1] != b"UNKNOWN": raise Disconnect() def read_size(self, buffer=b""): while len(buffer) < size_pack.size: recv = self.request.recv(size_pack.size - len(buffer)) if not recv: raise Disconnect() buffer += recv return size_pack.unpack(buffer)[0] def read_proxy_header(self, buffer=b""): # Max line length for PROXY protocol is 107, and we received 4 already. while b"\r\n" not in buffer: if len(buffer) > 107: raise Disconnect() data = self.request.recv(107) if not data: raise Disconnect() buffer += data return buffer def _on_packet(self, data): decompressed = zlib.decompress(data).decode("utf-8") self._got_packet = True self.on_packet(decompressed) def on_packet(self, data): raise NotImplementedError() def on_connect(self): pass def on_disconnect(self): pass def on_timeout(self): pass def handle(self): try: tag = self.read_size() self._initial_tag = size_pack.pack(tag) if self.client_address[0] in self.proxies and self._initial_tag == b"PROX": proxy, _, remainder = self.read_proxy_header( self._initial_tag ).partition(b"\r\n") self.parse_proxy_protocol(proxy) while remainder: while len(remainder) < size_pack.size: self.read_sized_packet(self.read_size(remainder)) break size = size_pack.unpack(remainder[: size_pack.size])[0] remainder = remainder[size_pack.size :] if len(remainder) <= size: self.read_sized_packet(size, remainder) break self._on_packet(remainder[:size]) remainder = remainder[size:] else: self.read_sized_packet(tag) while True: self.read_sized_packet(self.read_size()) except Disconnect: return except zlib.error: if self._got_packet: logger.warning( "Encountered zlib error during packet handling, disconnecting client: %s", self.client_address, exc_info=True, ) else: logger.info( "Potentially wrong protocol (zlib error): %s: %r", self.client_address, self._initial_tag, exc_info=True, ) except socket.timeout: if self._got_packet: logger.info("Socket timed out: %s", self.client_address) self.on_timeout() else: logger.info( "Potentially wrong protocol: %s: %r", self.client_address, self._initial_tag, ) except socket.error as e: # When a gevent socket is shutdown, gevent cancels all waits, causing recv to raise cancel_wait_ex. if e.__class__.__name__ == "cancel_wait_ex": return raise def send(self, data): compressed = zlib.compress(data.encode("utf-8")) self.request.sendall(size_pack.pack(len(compressed)) + compressed) def close(self): self.request.shutdown(socket.SHUT_RDWR)