import logging
import socket
import struct
import zlib
from itertools import chain

from netaddr import IPGlob, IPSet

from judge.utils.unicode import utf8text

logger = logging.getLogger('judge.bridge')

size_pack = struct.Struct('!I')
assert size_pack.size == 4

MAX_ALLOWED_PACKET_SIZE = 8 * 1024 * 1024


def proxy_list(human_readable):
    globs = []
    addrs = []
    for item in human_readable:
        if '*' in item or '-' in item:
            globs.append(IPGlob(item))
        else:
            addrs.append(item)
    return IPSet(chain(chain.from_iterable(globs), addrs))


class Disconnect(Exception):
    pass


# socketserver.BaseRequestHandler does all the handling in __init__,
# making it impossible to inherit __init__ sanely. While it lets you
# use setup(), most tools will complain about uninitialized variables.
# This metaclass will allow sane __init__ behaviour while also magically
# calling the methods that handles the request.
class RequestHandlerMeta(type):
    def __call__(cls, *args, **kwargs):
        handler = super().__call__(*args, **kwargs)
        handler.on_connect()
        try:
            handler.handle()
        except BaseException:
            logger.exception('Error in base packet handling')
            raise
        finally:
            handler.on_disconnect()


class ZlibPacketHandler(metaclass=RequestHandlerMeta):
    proxies = []

    def __init__(self, request, client_address, server):
        self.request = request
        self.server = server
        self.client_address = client_address
        self.server_address = server.server_address
        self._initial_tag = None
        self._got_packet = False

    @property
    def timeout(self):
        return self.request.gettimeout()

    @timeout.setter
    def timeout(self, timeout):
        self.request.settimeout(timeout or None)

    def read_sized_packet(self, size, initial=None):
        if size > MAX_ALLOWED_PACKET_SIZE:
            logger.log(logging.WARNING if self._got_packet else logging.INFO,
                       'Disconnecting client due to too-large message size (%d bytes): %s', size, self.client_address)
            raise Disconnect()

        buffer = []
        remainder = size

        if initial:
            buffer.append(initial)
            remainder -= len(initial)
            assert remainder >= 0

        while remainder:
            data = self.request.recv(remainder)
            remainder -= len(data)
            buffer.append(data)
        self._on_packet(b''.join(buffer))

    def parse_proxy_protocol(self, line):
        words = line.split()

        if len(words) < 2:
            raise Disconnect()

        if words[1] == b'TCP4':
            if len(words) != 6:
                raise Disconnect()
            self.client_address = (utf8text(words[2]), utf8text(words[4]))
            self.server_address = (utf8text(words[3]), utf8text(words[5]))
        elif words[1] == b'TCP6':
            self.client_address = (utf8text(words[2]), utf8text(words[4]), 0, 0)
            self.server_address = (utf8text(words[3]), utf8text(words[5]), 0, 0)
        elif words[1] != b'UNKNOWN':
            raise Disconnect()

    def read_size(self, buffer=b''):
        while len(buffer) < size_pack.size:
            recv = self.request.recv(size_pack.size - len(buffer))
            if not recv:
                raise Disconnect()
            buffer += recv
        return size_pack.unpack(buffer)[0]

    def read_proxy_header(self, buffer=b''):
        # Max line length for PROXY protocol is 107, and we received 4 already.
        while b'\r\n' not in buffer:
            if len(buffer) > 107:
                raise Disconnect()
            data = self.request.recv(107)
            if not data:
                raise Disconnect()
            buffer += data
        return buffer

    def _on_packet(self, data):
        decompressed = zlib.decompress(data).decode('utf-8')
        self._got_packet = True
        self.on_packet(decompressed)

    def on_packet(self, data):
        raise NotImplementedError()

    def on_connect(self):
        pass

    def on_disconnect(self):
        pass

    def on_timeout(self):
        pass

    def handle(self):
        try:
            tag = self.read_size()
            self._initial_tag = size_pack.pack(tag)
            if self.client_address[0] in self.proxies and self._initial_tag == b'PROX':
                proxy, _, remainder = self.read_proxy_header(self._initial_tag).partition(b'\r\n')
                self.parse_proxy_protocol(proxy)

                while remainder:
                    while len(remainder) < size_pack.size:
                        self.read_sized_packet(self.read_size(remainder))
                        break

                    size = size_pack.unpack(remainder[:size_pack.size])[0]
                    remainder = remainder[size_pack.size:]
                    if len(remainder) <= size:
                        self.read_sized_packet(size, remainder)
                        break

                    self._on_packet(remainder[:size])
                    remainder = remainder[size:]
            else:
                self.read_sized_packet(tag)

            while True:
                self.read_sized_packet(self.read_size())
        except Disconnect:
            return
        except zlib.error:
            if self._got_packet:
                logger.warning('Encountered zlib error during packet handling, disconnecting client: %s',
                               self.client_address, exc_info=True)
            else:
                logger.info('Potentially wrong protocol (zlib error): %s: %r', self.client_address, self._initial_tag,
                            exc_info=True)
        except socket.timeout:
            if self._got_packet:
                logger.info('Socket timed out: %s', self.client_address)
                self.on_timeout()
            else:
                logger.info('Potentially wrong protocol: %s: %r', self.client_address, self._initial_tag)
        except socket.error as e:
            # When a gevent socket is shutdown, gevent cancels all waits, causing recv to raise cancel_wait_ex.
            if e.__class__.__name__ == 'cancel_wait_ex':
                return
            raise

    def send(self, data):
        compressed = zlib.compress(data.encode('utf-8'))
        self.request.sendall(size_pack.pack(len(compressed)) + compressed)

    def close(self):
        self.request.shutdown(socket.SHUT_RDWR)