Cloned DMOJ
This commit is contained in:
parent
f623974b58
commit
49dc9ff10c
513 changed files with 132349 additions and 39 deletions
11
event_socket_server/__init__.py
Normal file
11
event_socket_server/__init__.py
Normal file
|
@ -0,0 +1,11 @@
|
|||
from .base_server import BaseServer
|
||||
from .engines import *
|
||||
from .handler import Handler
|
||||
from .helpers import ProxyProtocolMixin, SizedPacketHandler, ZlibPacketHandler
|
||||
|
||||
|
||||
def get_preferred_engine(choices=('epoll', 'poll', 'select')):
|
||||
for choice in choices:
|
||||
if choice in engines:
|
||||
return engines[choice]
|
||||
return engines['select']
|
169
event_socket_server/base_server.py
Normal file
169
event_socket_server/base_server.py
Normal file
|
@ -0,0 +1,169 @@
|
|||
import logging
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
from collections import defaultdict, deque
|
||||
from functools import total_ordering
|
||||
from heapq import heappop, heappush
|
||||
|
||||
logger = logging.getLogger('event_socket_server')
|
||||
|
||||
|
||||
class SendMessage(object):
|
||||
__slots__ = ('data', 'callback')
|
||||
|
||||
def __init__(self, data, callback):
|
||||
self.data = data
|
||||
self.callback = callback
|
||||
|
||||
|
||||
@total_ordering
|
||||
class ScheduledJob(object):
|
||||
__slots__ = ('time', 'func', 'args', 'kwargs', 'cancel', 'dispatched')
|
||||
|
||||
def __init__(self, time, func, args, kwargs):
|
||||
self.time = time
|
||||
self.func = func
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
self.cancel = False
|
||||
self.dispatched = False
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.time == other.time
|
||||
|
||||
def __lt__(self, other):
|
||||
return self.time < other.time
|
||||
|
||||
|
||||
class BaseServer(object):
|
||||
def __init__(self, addresses, client):
|
||||
self._servers = set()
|
||||
for address, port in addresses:
|
||||
info = socket.getaddrinfo(address, port, socket.AF_UNSPEC, socket.SOCK_STREAM)
|
||||
for af, socktype, proto, canonname, sa in info:
|
||||
sock = socket.socket(af, socktype, proto)
|
||||
sock.setblocking(0)
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
sock.bind(sa)
|
||||
self._servers.add(sock)
|
||||
|
||||
self._stop = threading.Event()
|
||||
self._clients = set()
|
||||
self._ClientClass = client
|
||||
self._send_queue = defaultdict(deque)
|
||||
self._job_queue = []
|
||||
self._job_queue_lock = threading.Lock()
|
||||
|
||||
def _serve(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def _accept(self, sock):
|
||||
conn, address = sock.accept()
|
||||
conn.setblocking(0)
|
||||
client = self._ClientClass(self, conn)
|
||||
self._clients.add(client)
|
||||
return client
|
||||
|
||||
def schedule(self, delay, func, *args, **kwargs):
|
||||
with self._job_queue_lock:
|
||||
job = ScheduledJob(time.time() + delay, func, args, kwargs)
|
||||
heappush(self._job_queue, job)
|
||||
return job
|
||||
|
||||
def unschedule(self, job):
|
||||
with self._job_queue_lock:
|
||||
if job.dispatched or job.cancel:
|
||||
return False
|
||||
job.cancel = True
|
||||
return True
|
||||
|
||||
def _register_write(self, client):
|
||||
raise NotImplementedError()
|
||||
|
||||
def _register_read(self, client):
|
||||
raise NotImplementedError()
|
||||
|
||||
def _clean_up_client(self, client, finalize=False):
|
||||
try:
|
||||
del self._send_queue[client.fileno()]
|
||||
except KeyError:
|
||||
pass
|
||||
client.on_close()
|
||||
client._socket.close()
|
||||
if not finalize:
|
||||
self._clients.remove(client)
|
||||
|
||||
def _dispatch_event(self):
|
||||
t = time.time()
|
||||
tasks = []
|
||||
with self._job_queue_lock:
|
||||
while True:
|
||||
dt = self._job_queue[0].time - t if self._job_queue else 1
|
||||
if dt > 0:
|
||||
break
|
||||
task = heappop(self._job_queue)
|
||||
task.dispatched = True
|
||||
if not task.cancel:
|
||||
tasks.append(task)
|
||||
for task in tasks:
|
||||
logger.debug('Dispatching event: %r(*%r, **%r)', task.func, task.args, task.kwargs)
|
||||
task.func(*task.args, **task.kwargs)
|
||||
if not self._job_queue or dt > 1:
|
||||
dt = 1
|
||||
return dt
|
||||
|
||||
def _nonblock_read(self, client):
|
||||
try:
|
||||
data = client._socket.recv(1024)
|
||||
except socket.error:
|
||||
self._clean_up_client(client)
|
||||
else:
|
||||
logger.debug('Read from %s: %d bytes', client.client_address, len(data))
|
||||
if not data:
|
||||
self._clean_up_client(client)
|
||||
else:
|
||||
try:
|
||||
client._recv_data(data)
|
||||
except Exception:
|
||||
logger.exception('Client recv_data failure')
|
||||
self._clean_up_client(client)
|
||||
|
||||
def _nonblock_write(self, client):
|
||||
fd = client.fileno()
|
||||
queue = self._send_queue[fd]
|
||||
try:
|
||||
top = queue[0]
|
||||
cb = client._socket.send(top.data)
|
||||
top.data = top.data[cb:]
|
||||
logger.debug('Send to %s: %d bytes', client.client_address, cb)
|
||||
if not top.data:
|
||||
logger.debug('Finished sending: %s', client.client_address)
|
||||
if top.callback is not None:
|
||||
logger.debug('Calling callback: %s: %r', client.client_address, top.callback)
|
||||
try:
|
||||
top.callback()
|
||||
except Exception:
|
||||
logger.exception('Client write callback failure')
|
||||
self._clean_up_client(client)
|
||||
return
|
||||
queue.popleft()
|
||||
if not queue:
|
||||
self._register_read(client)
|
||||
del self._send_queue[fd]
|
||||
except socket.error:
|
||||
self._clean_up_client(client)
|
||||
|
||||
def send(self, client, data, callback=None):
|
||||
logger.debug('Writing %d bytes to client %s, callback: %s', len(data), client.client_address, callback)
|
||||
self._send_queue[client.fileno()].append(SendMessage(data, callback))
|
||||
self._register_write(client)
|
||||
|
||||
def stop(self):
|
||||
self._stop.set()
|
||||
|
||||
def serve_forever(self):
|
||||
self._serve()
|
||||
|
||||
def on_shutdown(self):
|
||||
pass
|
17
event_socket_server/engines/__init__.py
Normal file
17
event_socket_server/engines/__init__.py
Normal file
|
@ -0,0 +1,17 @@
|
|||
import select
|
||||
|
||||
__author__ = 'Quantum'
|
||||
engines = {}
|
||||
|
||||
from .select_server import SelectServer # noqa: E402, import not at top for consistency
|
||||
engines['select'] = SelectServer
|
||||
|
||||
if hasattr(select, 'poll'):
|
||||
from .poll_server import PollServer
|
||||
engines['poll'] = PollServer
|
||||
|
||||
if hasattr(select, 'epoll'):
|
||||
from .epoll_server import EpollServer
|
||||
engines['epoll'] = EpollServer
|
||||
|
||||
del select
|
17
event_socket_server/engines/epoll_server.py
Normal file
17
event_socket_server/engines/epoll_server.py
Normal file
|
@ -0,0 +1,17 @@
|
|||
import select
|
||||
__author__ = 'Quantum'
|
||||
|
||||
if not hasattr(select, 'epoll'):
|
||||
raise ImportError('System does not support epoll')
|
||||
|
||||
from .poll_server import PollServer # noqa: E402, must be imported here
|
||||
|
||||
|
||||
class EpollServer(PollServer):
|
||||
poll = select.epoll
|
||||
WRITE = select.EPOLLIN | select.EPOLLOUT | select.EPOLLERR | select.EPOLLHUP
|
||||
READ = select.EPOLLIN | select.EPOLLERR | select.EPOLLHUP
|
||||
POLLIN = select.EPOLLIN
|
||||
POLLOUT = select.EPOLLOUT
|
||||
POLL_CLOSE = select.EPOLLHUP | select.EPOLLERR
|
||||
NEED_CLOSE = True
|
97
event_socket_server/engines/poll_server.py
Normal file
97
event_socket_server/engines/poll_server.py
Normal file
|
@ -0,0 +1,97 @@
|
|||
import errno
|
||||
import logging
|
||||
import select
|
||||
import threading
|
||||
|
||||
from ..base_server import BaseServer
|
||||
|
||||
logger = logging.getLogger('event_socket_server')
|
||||
|
||||
if not hasattr(select, 'poll'):
|
||||
raise ImportError('System does not support poll')
|
||||
|
||||
|
||||
class PollServer(BaseServer):
|
||||
poll = select.poll
|
||||
WRITE = select.POLLIN | select.POLLOUT | select.POLLERR | select.POLLHUP
|
||||
READ = select.POLLIN | select.POLLERR | select.POLLHUP
|
||||
POLLIN = select.POLLIN
|
||||
POLLOUT = select.POLLOUT
|
||||
POLL_CLOSE = select.POLLERR | select.POLLHUP
|
||||
NEED_CLOSE = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(PollServer, self).__init__(*args, **kwargs)
|
||||
self._poll = self.poll()
|
||||
self._fdmap = {}
|
||||
self._server_fds = {sock.fileno(): sock for sock in self._servers}
|
||||
self._close_lock = threading.RLock()
|
||||
|
||||
def _register_write(self, client):
|
||||
logger.debug('On write mode: %s', client.client_address)
|
||||
self._poll.modify(client.fileno(), self.WRITE)
|
||||
|
||||
def _register_read(self, client):
|
||||
logger.debug('On read mode: %s', client.client_address)
|
||||
self._poll.modify(client.fileno(), self.READ)
|
||||
|
||||
def _clean_up_client(self, client, finalize=False):
|
||||
logger.debug('Taking close lock: cleanup')
|
||||
with self._close_lock:
|
||||
logger.debug('Cleaning up client: %s, finalize: %d', client.client_address, finalize)
|
||||
fd = client.fileno()
|
||||
try:
|
||||
self._poll.unregister(fd)
|
||||
except IOError as e:
|
||||
if e.errno != errno.ENOENT:
|
||||
raise
|
||||
except KeyError:
|
||||
pass
|
||||
del self._fdmap[fd]
|
||||
super(PollServer, self)._clean_up_client(client, finalize)
|
||||
|
||||
def _serve(self):
|
||||
for fd, sock in self._server_fds.items():
|
||||
self._poll.register(fd, self.POLLIN)
|
||||
sock.listen(16)
|
||||
try:
|
||||
while not self._stop.is_set():
|
||||
for fd, event in self._poll.poll(self._dispatch_event()):
|
||||
if fd in self._server_fds:
|
||||
client = self._accept(self._server_fds[fd])
|
||||
logger.debug('Accepting: %s', client.client_address)
|
||||
fd = client.fileno()
|
||||
self._poll.register(fd, self.READ)
|
||||
self._fdmap[fd] = client
|
||||
elif event & self.POLL_CLOSE:
|
||||
logger.debug('Client closed: %s', self._fdmap[fd].client_address)
|
||||
self._clean_up_client(self._fdmap[fd])
|
||||
else:
|
||||
logger.debug('Taking close lock: event loop')
|
||||
with self._close_lock:
|
||||
try:
|
||||
client = self._fdmap[fd]
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
logger.debug('Client active: %s, read: %d, write: %d',
|
||||
client.client_address,
|
||||
event & self.POLLIN,
|
||||
event & self.POLLOUT)
|
||||
if event & self.POLLIN:
|
||||
logger.debug('Non-blocking read on client: %s', client.client_address)
|
||||
self._nonblock_read(client)
|
||||
# Might be closed in the read handler.
|
||||
if event & self.POLLOUT and fd in self._fdmap:
|
||||
logger.debug('Non-blocking write on client: %s', client.client_address)
|
||||
self._nonblock_write(client)
|
||||
finally:
|
||||
logger.info('Shutting down server')
|
||||
self.on_shutdown()
|
||||
for client in self._clients:
|
||||
self._clean_up_client(client, True)
|
||||
for fd, sock in self._server_fds.items():
|
||||
self._poll.unregister(fd)
|
||||
sock.close()
|
||||
if self.NEED_CLOSE:
|
||||
self._poll.close()
|
49
event_socket_server/engines/select_server.py
Normal file
49
event_socket_server/engines/select_server.py
Normal file
|
@ -0,0 +1,49 @@
|
|||
import select
|
||||
|
||||
from ..base_server import BaseServer
|
||||
|
||||
|
||||
class SelectServer(BaseServer):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(SelectServer, self).__init__(*args, **kwargs)
|
||||
self._reads = set(self._servers)
|
||||
self._writes = set()
|
||||
|
||||
def _register_write(self, client):
|
||||
self._writes.add(client)
|
||||
|
||||
def _register_read(self, client):
|
||||
self._writes.remove(client)
|
||||
|
||||
def _clean_up_client(self, client, finalize=False):
|
||||
self._writes.discard(client)
|
||||
self._reads.remove(client)
|
||||
super(SelectServer, self)._clean_up_client(client, finalize)
|
||||
|
||||
def _serve(self, select=select.select):
|
||||
for server in self._servers:
|
||||
server.listen(16)
|
||||
try:
|
||||
while not self._stop.is_set():
|
||||
r, w, x = select(self._reads, self._writes, self._reads, self._dispatch_event())
|
||||
for s in r:
|
||||
if s in self._servers:
|
||||
self._reads.add(self._accept(s))
|
||||
else:
|
||||
self._nonblock_read(s)
|
||||
|
||||
for client in w:
|
||||
self._nonblock_write(client)
|
||||
|
||||
for s in x:
|
||||
s.close()
|
||||
if s in self._servers:
|
||||
raise RuntimeError('Server is in exceptional condition')
|
||||
else:
|
||||
self._clean_up_client(s)
|
||||
finally:
|
||||
self.on_shutdown()
|
||||
for client in self._clients:
|
||||
self._clean_up_client(client, True)
|
||||
for server in self._servers:
|
||||
server.close()
|
27
event_socket_server/handler.py
Normal file
27
event_socket_server/handler.py
Normal file
|
@ -0,0 +1,27 @@
|
|||
__author__ = 'Quantum'
|
||||
|
||||
|
||||
class Handler(object):
|
||||
def __init__(self, server, socket):
|
||||
self._socket = socket
|
||||
self.server = server
|
||||
self.client_address = socket.getpeername()
|
||||
|
||||
def fileno(self):
|
||||
return self._socket.fileno()
|
||||
|
||||
def _recv_data(self, data):
|
||||
raise NotImplementedError
|
||||
|
||||
def _send(self, data, callback=None):
|
||||
return self.server.send(self, data, callback)
|
||||
|
||||
def close(self):
|
||||
self.server._clean_up_client(self)
|
||||
|
||||
def on_close(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
def socket(self):
|
||||
return self._socket
|
125
event_socket_server/helpers.py
Normal file
125
event_socket_server/helpers.py
Normal file
|
@ -0,0 +1,125 @@
|
|||
import struct
|
||||
import zlib
|
||||
|
||||
from judge.utils.unicode import utf8text
|
||||
from .handler import Handler
|
||||
|
||||
size_pack = struct.Struct('!I')
|
||||
|
||||
|
||||
class SizedPacketHandler(Handler):
|
||||
def __init__(self, server, socket):
|
||||
super(SizedPacketHandler, self).__init__(server, socket)
|
||||
self._buffer = b''
|
||||
self._packetlen = 0
|
||||
|
||||
def _packet(self, data):
|
||||
raise NotImplementedError()
|
||||
|
||||
def _format_send(self, data):
|
||||
return data
|
||||
|
||||
def _recv_data(self, data):
|
||||
self._buffer += data
|
||||
while len(self._buffer) >= self._packetlen if self._packetlen else len(self._buffer) >= size_pack.size:
|
||||
if self._packetlen:
|
||||
data = self._buffer[:self._packetlen]
|
||||
self._buffer = self._buffer[self._packetlen:]
|
||||
self._packetlen = 0
|
||||
self._packet(data)
|
||||
else:
|
||||
data = self._buffer[:size_pack.size]
|
||||
self._buffer = self._buffer[size_pack.size:]
|
||||
self._packetlen = size_pack.unpack(data)[0]
|
||||
|
||||
def send(self, data, callback=None):
|
||||
data = self._format_send(data)
|
||||
self._send(size_pack.pack(len(data)) + data, callback)
|
||||
|
||||
|
||||
class ZlibPacketHandler(SizedPacketHandler):
|
||||
def _format_send(self, data):
|
||||
return zlib.compress(data.encode('utf-8'))
|
||||
|
||||
def packet(self, data):
|
||||
raise NotImplementedError()
|
||||
|
||||
def _packet(self, data):
|
||||
try:
|
||||
self.packet(zlib.decompress(data).decode('utf-8'))
|
||||
except zlib.error as e:
|
||||
self.malformed_packet(e)
|
||||
|
||||
def malformed_packet(self, exception):
|
||||
self.close()
|
||||
|
||||
|
||||
class ProxyProtocolMixin(object):
|
||||
__UNKNOWN_TYPE = 0
|
||||
__PROXY1 = 1
|
||||
__PROXY2 = 2
|
||||
__DATA = 3
|
||||
|
||||
__HEADER2 = b'\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A'
|
||||
__HEADER2_LEN = len(__HEADER2)
|
||||
|
||||
_REAL_IP_SET = None
|
||||
|
||||
@classmethod
|
||||
def with_proxy_set(cls, ranges):
|
||||
from netaddr import IPSet, IPGlob
|
||||
from itertools import chain
|
||||
|
||||
globs = []
|
||||
addrs = []
|
||||
for item in ranges:
|
||||
if '*' in item or '-' in item:
|
||||
globs.append(IPGlob(item))
|
||||
else:
|
||||
addrs.append(item)
|
||||
ipset = IPSet(chain(chain.from_iterable(globs), addrs))
|
||||
return type(cls.__name__, (cls,), {'_REAL_IP_SET': ipset})
|
||||
|
||||
def __init__(self, server, socket):
|
||||
super(ProxyProtocolMixin, self).__init__(server, socket)
|
||||
self.__buffer = b''
|
||||
self.__type = (self.__UNKNOWN_TYPE if self._REAL_IP_SET and
|
||||
self.client_address[0] in self._REAL_IP_SET else self.__DATA)
|
||||
|
||||
def __parse_proxy1(self, data):
|
||||
self.__buffer += data
|
||||
index = self.__buffer.find(b'\r\n')
|
||||
if 0 <= index < 106:
|
||||
proxy = data[:index].split()
|
||||
if len(proxy) < 2:
|
||||
return self.close()
|
||||
if proxy[1] == b'TCP4':
|
||||
if len(proxy) != 6:
|
||||
return self.close()
|
||||
self.client_address = (utf8text(proxy[2]), utf8text(proxy[4]))
|
||||
self.server_address = (utf8text(proxy[3]), utf8text(proxy[5]))
|
||||
elif proxy[1] == b'TCP6':
|
||||
self.client_address = (utf8text(proxy[2]), utf8text(proxy[4]), 0, 0)
|
||||
self.server_address = (utf8text(proxy[3]), utf8text(proxy[5]), 0, 0)
|
||||
elif proxy[1] != b'UNKNOWN':
|
||||
return self.close()
|
||||
|
||||
self.__type = self.__DATA
|
||||
super(ProxyProtocolMixin, self)._recv_data(data[index + 2:])
|
||||
elif len(self.__buffer) > 107 or index > 105:
|
||||
self.close()
|
||||
|
||||
def _recv_data(self, data):
|
||||
if self.__type == self.__DATA:
|
||||
super(ProxyProtocolMixin, self)._recv_data(data)
|
||||
elif self.__type == self.__UNKNOWN_TYPE:
|
||||
if len(data) >= 16 and data[:self.__HEADER2_LEN] == self.__HEADER2:
|
||||
self.close()
|
||||
elif len(data) >= 8 and data[:5] == b'PROXY':
|
||||
self.__type = self.__PROXY1
|
||||
self.__parse_proxy1(data)
|
||||
else:
|
||||
self.__type = self.__DATA
|
||||
super(ProxyProtocolMixin, self)._recv_data(data)
|
||||
else:
|
||||
self.__parse_proxy1(data)
|
94
event_socket_server/test_client.py
Normal file
94
event_socket_server/test_client.py
Normal file
|
@ -0,0 +1,94 @@
|
|||
import ctypes
|
||||
import socket
|
||||
import struct
|
||||
import time
|
||||
import zlib
|
||||
|
||||
size_pack = struct.Struct('!I')
|
||||
try:
|
||||
RtlGenRandom = ctypes.windll.advapi32.SystemFunction036
|
||||
except AttributeError:
|
||||
RtlGenRandom = None
|
||||
|
||||
|
||||
def open_connection():
|
||||
sock = socket.create_connection((host, port))
|
||||
return sock
|
||||
|
||||
|
||||
def zlibify(data):
|
||||
data = zlib.compress(data.encode('utf-8'))
|
||||
return size_pack.pack(len(data)) + data
|
||||
|
||||
|
||||
def dezlibify(data, skip_head=True):
|
||||
if skip_head:
|
||||
data = data[size_pack.size:]
|
||||
return zlib.decompress(data).decode('utf-8')
|
||||
|
||||
|
||||
def random(length):
|
||||
if RtlGenRandom is None:
|
||||
with open('/dev/urandom') as f:
|
||||
return f.read(length)
|
||||
buf = ctypes.create_string_buffer(length)
|
||||
RtlGenRandom(buf, length)
|
||||
return buf.raw
|
||||
|
||||
|
||||
def main():
|
||||
global host, port
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-l', '--host', default='localhost')
|
||||
parser.add_argument('-p', '--port', default=9999, type=int)
|
||||
args = parser.parse_args()
|
||||
host, port = args.host, args.port
|
||||
|
||||
print('Opening idle connection:', end=' ')
|
||||
s1 = open_connection()
|
||||
print('Success')
|
||||
print('Opening hello world connection:', end=' ')
|
||||
s2 = open_connection()
|
||||
print('Success')
|
||||
print('Sending Hello, World!', end=' ')
|
||||
s2.sendall(zlibify('Hello, World!'))
|
||||
print('Success')
|
||||
print('Testing blank connection:', end=' ')
|
||||
s3 = open_connection()
|
||||
s3.close()
|
||||
print('Success')
|
||||
result = dezlibify(s2.recv(1024))
|
||||
assert result == 'Hello, World!'
|
||||
print(result)
|
||||
s2.close()
|
||||
print('Large random data test:', end=' ')
|
||||
s4 = open_connection()
|
||||
data = random(1000000)
|
||||
print('Generated', end=' ')
|
||||
s4.sendall(zlibify(data))
|
||||
print('Sent', end=' ')
|
||||
result = ''
|
||||
while len(result) < size_pack.size:
|
||||
result += s4.recv(1024)
|
||||
size = size_pack.unpack(result[:size_pack.size])[0]
|
||||
result = result[size_pack.size:]
|
||||
while len(result) < size:
|
||||
result += s4.recv(1024)
|
||||
print('Received', end=' ')
|
||||
assert dezlibify(result, False) == data
|
||||
print('Success')
|
||||
s4.close()
|
||||
print('Test malformed connection:', end=' ')
|
||||
s5 = open_connection()
|
||||
s5.sendall(data[:100000])
|
||||
s5.close()
|
||||
print('Success')
|
||||
print('Waiting for timeout to close idle connection:', end=' ')
|
||||
time.sleep(6)
|
||||
print('Done')
|
||||
s1.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
54
event_socket_server/test_server.py
Normal file
54
event_socket_server/test_server.py
Normal file
|
@ -0,0 +1,54 @@
|
|||
from .engines import engines
|
||||
from .helpers import ProxyProtocolMixin, ZlibPacketHandler
|
||||
|
||||
|
||||
class EchoPacketHandler(ProxyProtocolMixin, ZlibPacketHandler):
|
||||
def __init__(self, server, socket):
|
||||
super(EchoPacketHandler, self).__init__(server, socket)
|
||||
self._gotdata = False
|
||||
self.server.schedule(5, self._kill_if_no_data)
|
||||
|
||||
def _kill_if_no_data(self):
|
||||
if not self._gotdata:
|
||||
print('Inactive client:', self._socket.getpeername())
|
||||
self.close()
|
||||
|
||||
def packet(self, data):
|
||||
self._gotdata = True
|
||||
print('Data from %s: %r' % (self._socket.getpeername(), data[:30] if len(data) > 30 else data))
|
||||
self.send(data)
|
||||
|
||||
def on_close(self):
|
||||
self._gotdata = True
|
||||
print('Closed client:', self._socket.getpeername())
|
||||
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-l', '--host', action='append')
|
||||
parser.add_argument('-p', '--port', type=int, action='append')
|
||||
parser.add_argument('-e', '--engine', default='select', choices=sorted(engines.keys()))
|
||||
try:
|
||||
import netaddr
|
||||
except ImportError:
|
||||
netaddr = None
|
||||
else:
|
||||
parser.add_argument('-P', '--proxy', action='append')
|
||||
args = parser.parse_args()
|
||||
|
||||
class TestServer(engines[args.engine]):
|
||||
def _accept(self, sock):
|
||||
client = super(TestServer, self)._accept(sock)
|
||||
print('New connection:', client.socket.getpeername())
|
||||
return client
|
||||
|
||||
handler = EchoPacketHandler
|
||||
if netaddr is not None and args.proxy:
|
||||
handler = handler.with_proxy_set(args.proxy)
|
||||
server = TestServer(list(zip(args.host, args.port)), handler)
|
||||
server.serve_forever()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Add table
Add a link
Reference in a new issue