#!/usr/bin/env python

import os
import sys
import pwd
try:
    import asyncore
except:
    import py_asyncore as asyncore
import socket
import errno
from struct import unpack
import time
from exabgp.util.od import od


class BGPHandler(asyncore.dispatcher_with_send):
    wire = not not os.environ.get('wire', '')
    update = True

    keepalive = chr(0xFF) * 16 + chr(0x0) + chr(0x13) + chr(0x4)

    _name = {
        chr(1): 'OPEN',
        chr(2): 'UPDATE',
        chr(3): 'NOTIFICATION',
        chr(4): 'KEEPALIVE',
    }

    def isupdate(self, header):
        return header[18] == chr(2)

    def name(self, header):
        return self._name.get(header[18], 'SOME WEIRD RFC PACKET')

    def routes(self, body):
        len_w = unpack('!H', body[0:2])[0]
        prefixes = [ord(_) for _ in body[2 : 2 + len_w :]]

        if not prefixes:
            yield 'no ipv4 withdrawal'

        while prefixes:
            l = prefixes.pop(0)
            r = [0, 0, 0, 0]
            for index in range(4):
                if index * 8 >= l:
                    break
                r[index] = prefixes.pop(0)
            yield 'withdraw ' + '.'.join(str(_) for _ in r) + '/' + str(l)

        len_a = unpack('!H', body[2 + len_w : 2 + len_w + 2])[0]
        prefixes = [ord(_) for _ in body[2 + len_w + 2 + len_a :]]

        if not prefixes:
            yield 'no ipv4 announcement'

        while prefixes:
            l = prefixes.pop(0)
            r = [0, 0, 0, 0]
            for index in range(4):
                if index * 8 >= l:
                    break
                r[index] = prefixes.pop(0)
            yield 'announce ' + '.'.join(str(_) for _ in r) + '/' + str(l)

    def announce(self, *args):
        print self.ip, self.port, ' '.join(str(_) for _ in args) if len(args) > 1 else args[0]

    def setup(self, record, ip, port):
        self.ip = ip
        self.port = port
        now = time.strftime("%a-%d-%b-%Y-%H:%M:%S", time.gmtime())
        self.record = open("%s-%s" % ('bgp', now), 'w') if record else None
        self.handle_read = self.handle_open
        self.update_count = 0
        self.time = time.time()
        return self

    def read_message(self):
        header = ''
        while len(header) != 19:
            try:
                left = 19 - len(header)
                header += self.recv(left)
                if self.wire:
                    self.announce("HEADER ", od(header))
                if self.wire and len(header) != 19:
                    self.announce("left", 19 - len(header))
                if left == 19 - len(header):  # ugly
                    # the TCP session is gone.
                    self.announce("TCP connection closed")
                    self.close()
                    return None, None
            except socket.error, exc:
                if exc.args[0] in (errno.EWOULDBLOCK, errno.EAGAIN):
                    continue
                raise exc

        self.announce("read", self.name(header))

        length = unpack('!H', header[16:18])[0] - 19
        if self.wire:
            self.announce("waiting for", length, "bytes")

        if length > 4096 - 19:
            print "packet"
            print od(header)
            print "Invalid length for packet", length
            sys.exit(1)

        body = ''
        left = length
        while len(body) != length:
            try:
                body += self.recv(left)
                left = length - len(body)
                if self.wire:
                    self.announce("BODY   ", od(body))
                if self.wire and left:
                    self.announce("missing", left)
            except socket.error, exc:
                if exc.args[0] in (errno.EWOULDBLOCK, errno.EAGAIN):
                    continue
                raise exc

        self.update_count += 1

        if self.record:
            self.record.write(header + body)

        elif self.isupdate(header):
            self.announce(
                "received %-6d updates (%6d/sec) " % (self.update_count, self.update_count / (time.time() - self.time)),
                ', '.join(self.routes(body)),
            )

        return header, body

    def handle_open(self):
        # reply with a IBGP response with the same capability (just changing routerID)
        header, body = self.read_message()
        routerid = chr((ord(body[8]) + 1) & 0xFF)
        o = header + body[:8] + routerid + body[9:]
        self.announce("sending open")
        self.send(o)
        self.announce("sending keepalive")
        self.send(self.keepalive)
        self.handle_read = self.handle_keepalive

    def handle_keepalive(self):
        header, body = self.read_message()
        if header is not None:
            self.announce("sending keepalive")
            self.send(self.keepalive)


class BGPServer(asyncore.dispatcher):
    def __init__(self, host, port, record):
        self.record = record
        asyncore.dispatcher.__init__(self)
        self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
        self.set_reuse_addr()
        self.bind((host, port))
        self.listen(5)

    def handle_accept(self):
        pair = self.accept()
        if pair is not None:
            # The if prevent invalid unpacking
            sock, addr = pair  # pylint: disable=W0633
            print "new BGP connection from", addr
            BGPHandler(sock).setup(self.record, *addr)


def drop():
    uid = os.getuid()
    gid = os.getgid()

    if uid and gid:
        return

    for name in [
        'nobody',
    ]:
        try:
            user = pwd.getpwnam(name)
            nuid = int(user.pw_uid)
            ngid = int(user.pw_uid)
        except KeyError:
            pass

    if not gid:
        os.setgid(ngid)
    if not uid:
        os.setuid(nuid)


try:
    if os.environ.get('exabgp.tcp.port', '').isdigit():
        port = int(os.environ.get('exabgp.tcp.port'))
    elif os.environ.get('exabgp_tcp_port', '').isdigit():
        port = int(os.environ.get('exabgp_tcp_port'))
    else:
        port = 179

    bind = os.environ.get('exabgp.tcp.bind', os.environ.get('exabgp_tcp_bind', 'localhost'))
    record = bool(os.environ.get('exabgp.wire.record', os.environ.get('exabgp.wire.record', False)))
    server = BGPServer(bind, port, record)
    drop()
    asyncore.loop()
except socket.error:
    print 'need root right to bind to port 179'
