#!/usr/bin/env python3
# encoding: utf-8
"""
cache.py

Created by Thomas Mangin
Copyright (c) 2013-2017 Exa Networks. All rights reserved.
License: 3-clause BSD. (See the COPYRIGHT file)
"""

import os
import re
import sys
import glob
import time
import signal
import argparse
import itertools
import subprocess

PROGRAM = os.path.realpath(__file__)
ROOT = os.path.abspath(os.path.join(os.path.dirname(PROGRAM), os.path.join('..', '..')))
LIBRARY = os.path.join(ROOT, 'src')

EXPLAIN = """
ExaBGP command line
=======================================================

%(client)s


bgp daemon command line
=======================================================

%(server)s


The following extra configuration options could be used
=======================================================

export exabgp_debug_rotate=true
export exabgp_debug_defensive=true
"""


class Color(object):
    NONE = '\033[0m' + '\033[0m' + ' '  # NONE
    STARTING = '\033[0m' + '\033[96m' + '~'  # LIGHT BLUE
    READY = '\033[0m' + '\033[94m' + '='  # PENDING
    FAIL = '\033[0m' + '\033[91m' + '-'  # RED
    SUCCESS = '\033[1m' + '\033[92m' + '+'  # GREEN


class Identifier(dict):
    _listing = '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
    _from_name = {}
    _from_nick = {}
    _next = 0
    _nl = 3

    @classmethod
    def get(cls, name):
        letter = cls._listing[cls._next]
        cls._from_name[name] = letter
        cls._from_nick[letter] = name
        cls._next += 1
        return letter

    @classmethod
    def identifiers(cls):
        for n in range(0, cls._next):
            yield cls._listing[n], not (n + 1) % cls._nl

    @classmethod
    def nick(cls, name):
        return cls._from_name[name]

    @classmethod
    def name(cls, nick):
        return cls._from_nick[nick]


class Port(object):
    base = 1790

    @classmethod
    def get(cls):
        current = cls.base
        cls.base += 1
        return current


class Path(object):
    ETC = os.path.join(ROOT, 'etc', 'exabgp')
    EXABGP = os.path.join(ROOT, 'sbin', 'exabgp')
    BGP = os.path.join(ROOT, 'qa', 'sbin', 'bgp')
    CI = os.path.join(os.path.join(ROOT, 'qa', 'ci'))
    ALL_CI = glob.glob(os.path.join(CI, '*.ci'))
    ALL_CI.sort()

    @classmethod
    def validate(cls):
        if not os.path.isdir(cls.ETC):
            sys.exit('could not find etc folder')

        if not os.path.isdir(cls.CI):
            sys.exit('could not find tests in the qa/ci folder')

        if not os.path.isfile(cls.EXABGP):
            sys.exit('could not find exabgp')

        if not os.path.isfile(cls.BGP):
            sys.exit('could not find the sequence daemon')


class CI(dict):
    API = re.compile(r'^\s*run\s+(.*)\s*?;\s*?$')
    _content = {}
    _status = {}
    _tests = []

    @classmethod
    def make(cls):
        for filename in Path.ALL_CI:
            name, extension = os.path.splitext(filename.split('/')[-1])
            nick = Identifier.get(name)
            with open(filename, 'r') as reader:
                content = reader.readline()
                cls._content[nick] = {
                    'name': name,
                    'confs': [os.path.join(Path.ETC, _) for _ in content.split()],
                    'ci': os.path.join(Path.CI, name) + '.ci',
                    'msg': os.path.join(Path.CI, name) + '.msg',
                    'port': Port.get(),
                }
        cls._tests.extend(sorted(cls._content.keys()))

    @classmethod
    def get(cls, k):
        return cls._content.get(k, None)

    @classmethod
    def state(cls, name):
        if name not in cls._status:
            cls._status[name] = Color.NONE
        elif cls._status[name] == Color.NONE:
            cls._status[name] = Color.STARTING
        elif cls._status[name] == Color.STARTING:
            cls._status[name] = Color.READY

    @classmethod
    def color(cls, name):
        return cls._status.get(name, Color.NONE)

    @classmethod
    def reset(cls, name):
        cls._status[name] = Color.NONE

    @classmethod
    def passed(cls, name):
        cls._status[name] = Color.SUCCESS

    @classmethod
    def failed(cls, name):
        cls._status[name] = Color.FAIL

    @classmethod
    def files(cls, k):
        test = cls._content.get(k, None)
        if not test:
            return []
        files = [
            test['msg'],
        ]
        for f in test['confs']:
            files.append(f)
            with open(f) as reader:
                for line in reader:
                    found = cls.API.match(line)
                    if not found:
                        continue
                    name = found.group(1)
                    if not name.startswith('/'):
                        name = os.path.abspath(os.path.join(Path.ETC, name))
                    if name not in files:
                        files.append(name)
        return [f for f in files if os.path.isfile(f)]

    @classmethod
    def display(cls):
        # sys.stdout.write('\r')
        for k in cls._tests:
            sys.stdout.write('%s%s ' % (CI.color(k), k))
        sys.stdout.write(Color.NONE)
        # same line printing now buggy
        sys.stdout.write('\r')
        sys.stdout.flush()

    @classmethod
    def listing(cls):
        sys.stdout.write('\n')
        sys.stdout.write('The available functional tests are:\n')
        sys.stdout.write('\n')
        for index, nl in Identifier.identifiers():
            name = cls._content[index]['name']
            sys.stdout.write(' %-2s %s%s' % (index, name, ' ' * (25 - len(name))))
            sys.stdout.write('\n' if nl else '')
        sys.stdout.write('\n')
        sys.stdout.write('\n')
        sys.stdout.write('\n')
        sys.stdout.write('checking\n')
        sys.stdout.write('\n')
        sys.stdout.flush()


Path.validate()
CI.make()
# CI.display()


class Alarm(Exception):
    pass


def alarm_handler(number, frame):  # pylint: disable=W0613
    raise Alarm()


class Process(object):
    _running = {}
    _result = {}

    @classmethod
    def add(cls, name, side, process):
        cls._running.setdefault(name, {})[side] = process
        for std in ('in', 'out'):
            cls._result.setdefault(name, {}).setdefault(side, {})[std] = b''

    @classmethod
    def success(cls, name):
        return b'successful' in cls._result[name]['server']['out']

    @classmethod
    def _ready(cls, side, name):
        try:
            signal.alarm(1)
            polled = cls._running[side][name].poll()
            signal.alarm(0)
        except Alarm:
            return False
        except (IOError, OSError, ValueError):
            return True
        if polled is None:
            return False
        return True

    @classmethod
    def collect(cls, name, side):
        try:
            signal.alarm(1)
            stdout, stderr = cls._running[name][side].communicate()
            signal.alarm(0)
            cls._result[name][side]['out'] = stdout
            cls._result[name][side]['err'] = stderr
        except ValueError:  # I/O operation on closed file
            pass
        except Alarm:
            pass

    @classmethod
    def output(cls, name, side):
        return cls._result[name][side]['out']

    def error(cls, name, side):
        return cls._result[name][side]['err']

    @classmethod
    def _terminate(cls, name, side):
        try:
            cls._running[name][side].send_signal(signal.SIGTERM)
        except OSError:  # No such process, Errno 3
            pass

    @classmethod
    def terminate(cls):
        for name in cls._running:
            for side in cls._running[name]:
                if cls.output(name, side) != '' or cls.error(name, side) != '':
                    continue
                cls._terminate(name, side)
                cls.collect(name, side)


class Command(dict):
    @staticmethod
    def execute(cmd):
        print('starting: %s' % ' '.join(cmd))
        popen = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True)
        for line in itertools.chain(iter(popen.stdout.readline, ''), iter(popen.stderr.readline, '')):
            yield line
        popen.stdout.close()
        return_code = popen.wait()
        if return_code:
            raise subprocess.CalledProcessError(return_code, cmd)

    @classmethod
    def explain(cls, index):
        print(
            EXPLAIN % {'client': cls.client(index), 'server': cls.server(index),}
        )
        sys.exit(1)

    @staticmethod
    def client(index):
        test = CI.get(index)
        if not test:
            sys.exit("can not find any test called '%s'" % index)

        if os.getuid() and os.getgid() and test['port'] <= 1024:
            sys.exit('you need to have root privileges to bind to port 79')

        config = {
            'env': ' \\\n  '.join(
                [
                    'exabgp_tcp_once=true',
                    'exabgp_api_cli=false',
                    'exabgp_debug_rotate=true',
                    'exabgp_debug_configuration=true',
                    'exabgp_tcp_bind=\'\'',
                    'exabgp_tcp_port=%d' % test['port'],
                    'INTERPRETER=%s ' % os.environ.get('__PYVENV_LAUNCHER__', sys.executable),
                ]
            ),
            'exabgp': Path.EXABGP,
            'confs': ' \\\n    '.join(test['confs']),
        }
        return 'env \\\n  %(env)s \\\n %(exabgp)s -d -p \\\n    %(confs)s' % config

    @staticmethod
    def server(index):
        test = CI.get(index)

        if not test:
            sys.exit("can not find any test called '%s'" % index)

        if os.getuid() and os.getgid() and test['port'] <= 1024:
            sys.exit('you need to have root privileges to bind to port 79')

        config = {
            'env': ' \\\n  '.join(['exabgp_tcp_port=%d' % test['port'],]),
            'interpreter': os.environ.get('__PYVENV_LAUNCHER__', sys.executable),
            'bgp': Path.BGP,
            'msg': test['msg'],
        }

        return 'env \\\n  %(env)s \\\n  %(interpreter)s %(bgp)s \\\n    %(msg)s' % config

    @staticmethod
    def dispatch(running, timeout):
        completed = True
        names = []
        for name in running:
            if CI.get(name) is None:
                sys.exit("can not find any test called '%s'" % name)
            CI.state(name)
            names.append(name)

        for side in ['server', 'client']:
            for name in running:
                process = subprocess.Popen(
                    [sys.argv[0], side, name, '--port', str(CI.get(name)['port'])],
                    stdout=subprocess.PIPE,
                    stderr=subprocess.PIPE,
                )
                Process.add(name, side, process)
                CI.state(name)
                CI.display()
                time.sleep(0.02)

        exit_time = time.time() + timeout

        while names and time.time() < exit_time:
            CI.display()
            for name in list(names):
                for side in ('server', 'client'):
                    if not Process._ready(name, side):
                        continue

                    Process.collect(name, side)

                    if side == 'server':
                        names.remove(name)

                    if Process.success(name):
                        CI.passed(name)
                    else:
                        CI.failed(name)
                        completed = False

                    CI.display()
            time.sleep(0.2)

        Process.terminate()

        for name in names:
            print('\server stderr\n------\n%s' % str(Process.output(name, 'server')).replace('\\n', '\n'))
            print('\client stdout\n------\n%s' % str(Process.output(name, 'client')).replace('\\n', '\n'))

        CI.display()
        return completed


def _run(to_run, chunk, timeout):
    success = True
    while to_run and success:
        running, to_run = to_run[:chunk], to_run[chunk:]
        success = Command.dispatch(running, timeout)
    sys.stdout.write('\n')

    sys.exit(0 if success else 1)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='The BGP swiss army knife of networking functional testing tool')
    subparsers = parser.add_subparsers()

    def all(parsed):
        to_run = [index for index, _ in Identifier.identifiers()]
        chunk = 1
        _run(to_run, chunk, parsed.timeout)

    sub = subparsers.add_parser('all', help='run all available test')
    sub.add_argument('--timeout', help='timeout for test failure', type=int, default=60)
    sub.add_argument('--port', help='base port to use', type=int, default=1790)
    sub.set_defaults(func=all)

    def run(parsed):
        Port.base = parsed.port
        if parsed.test:
            to_run = [
                parsed.test,
            ]
        else:
            to_run = [index for index, _ in Identifier.identifiers()]
        chunk = len(to_run) if not parsed.steps else parsed.steps
        _run(to_run, chunk, parsed.timeout)

    sub = subparsers.add_parser('run', help='run a particular test')
    sub.add_argument('test', help='name of the test to run', nargs='?', default=None)
    sub.add_argument('--timeout', help='timeout for test failure', type=int, default=60)
    sub.add_argument('--port', help='base port to use', type=int, default=1790)
    sub.add_argument('--steps', help='number of test to run simultaneously', type=int, default=0)
    sub.set_defaults(func=run)

    def client(parsed):
        command = Command.client(parsed.test)
        print(f'> {command}')
        if not parsed.dry:
            sys.exit(os.system(command))
        sys.exit(0)

    sub = subparsers.add_parser('client', help='start the client for a specific test')
    sub.add_argument('test', help='name of the test to run')
    sub.add_argument('-d', '--dry', help='show what command would be run but does nothing', action='store_true')
    sub.add_argument('--timeout', help='timeout for test failure', type=int, default=60)
    sub.add_argument('--port', help='base port to use', type=int, default=1790)
    sub.set_defaults(func=client)

    def server(parsed):
        command = Command.server(parsed.test)
        print(f'> {command}')
        if not parsed.dry:
            sys.exit(os.system(command))
        sys.exit(0)

    sub = subparsers.add_parser('server', help='start the server for a specific test')
    sub.add_argument('test', help='name of the test to run')
    sub.add_argument('-d', '--dry', help='show what command would be run but does nothing', action='store_true')
    sub.add_argument('--timeout', help='timeout for test failure', type=int, default=60)
    sub.add_argument('--port', help='base port to use', type=int, default=1790)
    sub.set_defaults(func=server)

    def explain(parsed):
        Command.explain(parsed.test)
        sys.exit(0)

    sub = subparsers.add_parser('explain', help='show what command for a test are run')
    sub.add_argument('test', help='name of the test to explain')
    sub.add_argument('--timeout', help='timeout for test failure', type=int, default=60)
    sub.add_argument('--port', help='base port to use', type=int, default=1790)
    sub.set_defaults(func=explain)

    def edit(parsed):
        files = CI.files(parsed.test)
        if not files:
            sys.exit('no such test')
        editor = os.environ.get('EDITOR', 'vi')
        os.system('%s %s' % (editor, ' '.join(files)))
        sys.exit(0)

    sub = subparsers.add_parser('edit', help='start $EDITOR to edit a specific test')
    sub.add_argument('test', help='name of the test to edit')
    sub.set_defaults(func=edit)

    def decode(parsed):
        test = CI.get(parsed.test)
        command = '%s decode %s  "%s"' % (Path.EXABGP, test['confs'][0], ''.join(parsed.payload))
        print('> %s' % command)
        os.system(command)
        sys.exit(0)

    sub = subparsers.add_parser('decode', help='use the test configuration to decode a packet')
    sub.add_argument('test', help='name of the test to use to know the BGP configuration')
    sub.add_argument('payload', nargs='+', help='the hexadecimal representation of the packet')
    sub.set_defaults(func=decode)

    sub = subparsers.add_parser('listing', help='list all functional test available')
    sub.set_defaults(func=lambda _: CI.listing())

    parsed = parser.parse_args()
    if vars(parsed):
        parsed.func(parsed)
    else:
        parser.print_help()