NDOJ/event_socket_server/helpers.py
2020-01-21 15:35:58 +09:00

125 lines
4.1 KiB
Python

import struct
import zlib
from judge.utils.unicode import utf8text
from .handler import Handler
size_pack = struct.Struct('!I')
class SizedPacketHandler(Handler):
def __init__(self, server, socket):
super(SizedPacketHandler, self).__init__(server, socket)
self._buffer = b''
self._packetlen = 0
def _packet(self, data):
raise NotImplementedError()
def _format_send(self, data):
return data
def _recv_data(self, data):
self._buffer += data
while len(self._buffer) >= self._packetlen if self._packetlen else len(self._buffer) >= size_pack.size:
if self._packetlen:
data = self._buffer[:self._packetlen]
self._buffer = self._buffer[self._packetlen:]
self._packetlen = 0
self._packet(data)
else:
data = self._buffer[:size_pack.size]
self._buffer = self._buffer[size_pack.size:]
self._packetlen = size_pack.unpack(data)[0]
def send(self, data, callback=None):
data = self._format_send(data)
self._send(size_pack.pack(len(data)) + data, callback)
class ZlibPacketHandler(SizedPacketHandler):
def _format_send(self, data):
return zlib.compress(data.encode('utf-8'))
def packet(self, data):
raise NotImplementedError()
def _packet(self, data):
try:
self.packet(zlib.decompress(data).decode('utf-8'))
except zlib.error as e:
self.malformed_packet(e)
def malformed_packet(self, exception):
self.close()
class ProxyProtocolMixin(object):
__UNKNOWN_TYPE = 0
__PROXY1 = 1
__PROXY2 = 2
__DATA = 3
__HEADER2 = b'\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A'
__HEADER2_LEN = len(__HEADER2)
_REAL_IP_SET = None
@classmethod
def with_proxy_set(cls, ranges):
from netaddr import IPSet, IPGlob
from itertools import chain
globs = []
addrs = []
for item in ranges:
if '*' in item or '-' in item:
globs.append(IPGlob(item))
else:
addrs.append(item)
ipset = IPSet(chain(chain.from_iterable(globs), addrs))
return type(cls.__name__, (cls,), {'_REAL_IP_SET': ipset})
def __init__(self, server, socket):
super(ProxyProtocolMixin, self).__init__(server, socket)
self.__buffer = b''
self.__type = (self.__UNKNOWN_TYPE if self._REAL_IP_SET and
self.client_address[0] in self._REAL_IP_SET else self.__DATA)
def __parse_proxy1(self, data):
self.__buffer += data
index = self.__buffer.find(b'\r\n')
if 0 <= index < 106:
proxy = data[:index].split()
if len(proxy) < 2:
return self.close()
if proxy[1] == b'TCP4':
if len(proxy) != 6:
return self.close()
self.client_address = (utf8text(proxy[2]), utf8text(proxy[4]))
self.server_address = (utf8text(proxy[3]), utf8text(proxy[5]))
elif proxy[1] == b'TCP6':
self.client_address = (utf8text(proxy[2]), utf8text(proxy[4]), 0, 0)
self.server_address = (utf8text(proxy[3]), utf8text(proxy[5]), 0, 0)
elif proxy[1] != b'UNKNOWN':
return self.close()
self.__type = self.__DATA
super(ProxyProtocolMixin, self)._recv_data(data[index + 2:])
elif len(self.__buffer) > 107 or index > 105:
self.close()
def _recv_data(self, data):
if self.__type == self.__DATA:
super(ProxyProtocolMixin, self)._recv_data(data)
elif self.__type == self.__UNKNOWN_TYPE:
if len(data) >= 16 and data[:self.__HEADER2_LEN] == self.__HEADER2:
self.close()
elif len(data) >= 8 and data[:5] == b'PROXY':
self.__type = self.__PROXY1
self.__parse_proxy1(data)
else:
self.__type = self.__DATA
super(ProxyProtocolMixin, self)._recv_data(data)
else:
self.__parse_proxy1(data)