from telnetlib import Telnet
from xmljson import BadgerFish, Parker as X2J
from xml.etree.ElementTree import fromstring
import json
from time import sleep
from falcon_rest.logger import log
from datetime import datetime, timedelta
from lxml import etree

import threading


class CallApiException(Exception):
    pass


READ_TIMEOUT = 30
STATS_CACHE_DELAY_MINUTES = 2
_sessions = {}
_connection = {}
_system_call_stat = {}
_system_peak_stat = {}
_license_limit = {}
_disk_info = {}
_sip_profile = {}


class DnlSwitchSession(object):
    def connect(self):
        if hasattr(self, 'conn'):
            try:
                self.conn.close()
            except:
                pass
        self.conn = Telnet(self.ip, self.port, self.timeout)
        # CLAS6-10749 - API for call simulation is very slow
        # Fix: first read didn't return expected response
        # self.conn.read_until(b"Escape character is '^]'.", self.timeout * 2)
        sleep(0.1)

    def __init__(self, ip, port, timeout=READ_TIMEOUT):
        self.conn = None
        log.debug('DnlSwitch try connect host {} port {} '.format(ip, port))
        self.ip = ip
        self.port = port
        _sessions[self.addr()] = self
        self.timeout = timeout
        self.lock = threading.RLock()
        self.connect()
        # self.conn.interact()
        # sleep(1)

    def addr(self):
        return self.ip, self.port

    def __del__(self):
        try:
            if hasattr(self, 'conn'):
                self.logged_in = False
                self.conn.close()
        except Exception as e:
            log.error('DnlSwitch: {}'.format(e))

    def _api_write(self, command):
        log.debug('DnlSwitch host: {} port: {} command: {}'.format(self.conn.host, self.conn.port, command))
        try:
            self.conn.write(bytes(command, encoding='ascii') + b"\r\n")
        except:
            log.debug(
                'DnlSwitch NOT CONNECTED host: {} port: {} command: {}'.format(self.conn.host, self.conn.port, command))
            self.connect()
            sleep(0.1)
            self.login()
            sleep(0.1)
            self.conn.write(bytes(command, encoding='ascii') + b"\r\n")

    def _api_read(self):
        response = self.conn.read_very_eager()
        log.debug('DnlSwitch response: {}'.format(response))
        try:
            result = str(response).strip('\n')
        except Exception as e:
            raise CallApiException('Bad response. {}'.format(e))

        return result

    def _api_call(self, command, expect=[b'\r\n']):
        self.lock.acquire()
        try:
            self._api_write(command)
            sleep(0.1)
            result = self.conn.expect(expect, READ_TIMEOUT)[2]
        except EOFError:
            self.conn = Telnet(self.ip, self.port)
            self.connect()
            self.login()
            sleep(0.5)
            self._api_write(command)
            sleep(0.1)
            result = self.conn.expect(expect, READ_TIMEOUT)[2]
            log.debug('DnlSwitch response: {}'.format(result))
        finally:
            self.lock.release()
        return result

    def api_call(self, command):
        if 'get_disk_info' in command:
            return self._api_call(command, [b'pcap/\r\n'])
        elif 'call_simulation' in command:
            return self._api_call(command, [b'\<\/Call Simulation Test progress\>\r\n'])
        else:
            return self._api_call(command)

    def login(self):
        for i in range(0, 10):
            if self._login():
                log.debug('DnlSwitch login SUCCESS')
                return True
            sleep(0.01)
            self.connect()
            log.debug('DnlSwitch reconnect retry {} after delay'.format(i))
        log.error('DnlSwitch Can\'t login. after 5 retries')
        raise CallApiException('DnlSwitch Can\'t login. after 5 retries')

    def _login(self):
        try:
            self._api_write('login')
            r = self.conn.read_until(b'switch', timeout=READ_TIMEOUT)
            result = r + self.conn.read_eager()
            if b'Welcome' not in result:
                log.debug('DnlSwitch not logged in: {}'.format(result))
                return False
            self.logged_in = True
            return True
        except EOFError as e:
            return False

    def logout(self):
        # self.conn.read_until(b'login')
        self._api_write('logout')
        r = self.conn.read_until(b'!')
        result = r + self.conn.read_eager()
        if b'Goodbye' not in result:
            log.error('DnlSwitch: {}'.format(result))
            raise CallApiException('Bad logout.')
        self.logged_in = False
        return result

    def call_simulation(self, caller_ip, caller_port, ani, dnis, include_blocked_route):
        command = (
            'call_simulation '
            '{},{},{},{},{}'
        ).format(
            caller_ip, str(caller_port), str(ani), str(dnis), str(include_blocked_route)
        )

        response = self._api_call(command, [b'\<\/Call Simulation Test progress\>\r\n'])

        resp = '<payload>' + response.decode('utf-8') + '</payload>'
        resp = resp.replace('\r\n', '')
        resp = resp.replace('Call Simulation Test progress', 'Simulation-progress')
        print(resp)
        log.debug('DnlSwitch response: {}'.format(resp))
        x2j = X2J(xml_fromstring=False, dict_type=dict)
        parser = etree.XMLParser(recover=True)
        try:
            data = x2j.data(fromstring(resp, parser=parser), preserve_root=True)
        except:
            data = {'error': 'cannot parse switch telnet output', 'result': response.decode('utf-8')}
            raise Exception(data)
        return fix_ident(data)

    def sip_profile_start(self):
        response = self._api_call('sip_profile_start')
        resp = '<payload>' + response.decode('utf-8') + '</payload>'
        resp = resp.replace('\r\n', '')
        # resp = resp.replace('sip_profile_start', 'Simulation-progress')
        print(resp)
        log.debug('DnlSwitch response: {}'.format(resp))
        x2j = X2J(xml_fromstring=False, dict_type=dict)
        data = x2j.data(fromstring(resp), preserve_root=True)
        data[""]
        return data

    @classmethod
    def get_license_limit_from_cache(cls, ip):
        if _license_limit and ip in _license_limit and _license_limit[ip]['time'] > datetime.now():
            log.debug('return from cache')
            return _license_limit[ip]['value']
        return None

    def get_license_limit(self):
        response = self._api_call('get_license_limit', [b'\<\/expire\>\r\n'])
        resp = '<payload>' + response.decode('utf-8') + '</payload>'
        resp = resp.replace('\r\n', '')
        # resp = resp.replace('sip_profile_start', 'Simulation-progress')
        print(resp)
        log.debug('DnlSwitch response: {}'.format(resp))
        x2j = X2J(xml_fromstring=False, dict_type=dict)
        data = x2j.data(fromstring(resp), preserve_root=True)
        data['payload']['result'] = 'success'
        return data['payload']

    @classmethod
    def get_system_call_stat_from_cache(cls, ip):
        if _system_call_stat and ip in _system_call_stat and _system_call_stat[ip]['time'] > datetime.now():
            log.debug('return from cache')
            return _system_call_stat[ip]['value']
        return None

    def get_system_call_stat(self):
        response = self._api_call('get_system_call_statistics')
        response = response.decode().replace('\r', '')
        log.debug('DnlSwitch response: {}'.format(response))
        ret = dict(
            [(i.split('=')[0], int(i.split('=')[1])) for i in response.split('\n') if i and len(i.split('=')) == 2])
        return ret

    @classmethod
    def get_system_peak_stat_from_cache(cls, ip):
        if _system_peak_stat and ip in _system_peak_stat and _system_peak_stat[ip]['time'] > datetime.now():
            log.debug('return from cache')
            return _system_peak_stat[ip]['value']
        return None

    def get_system_peak_stat(self):
        response = self._api_call('get_system_peak_statistics')
        response = response.decode().replace('\r', '')
        log.debug('DnlSwitch response: {}'.format(response))
        try:
            ret = dict([(i.split('\t')[0], int(i.split('\t')[1])) for i in response.split('\n') if
                        i and len(i.split('=')) == 2])
        except:
            ret = dict(
                [(i.split('=')[0], int(i.split('=')[1])) for i in response.split('\n') if i and len(i.split('=')) == 2])

        return ret

    @classmethod
    def get_disk_info_from_cache(cls, ip):
        if _disk_info and ip in _disk_info and _disk_info[ip]['time'] > datetime.now():
            log.debug('return from cache')
            return _disk_info[ip]['value']
        return None

    def get_disk_info(self):
        response = self._api_call('get_disk_info', [b'pcap/\r\n'])
        response = response.decode().replace('\r', '')
        log.debug('DnlSwitch response: {}'.format(response))
        lines = [l for l in response.split('\n') if '/' in l]
        items = [dict(size=l.split()[0], used=l.split()[1], avail=l.split()[2], use=l.split()[3], path=l.split()[4]) for
                 l in lines]
        ret = {'items': items}
        return ret

    @classmethod
    def get_sip_profile_cache(cls, ip):
        if _sip_profile and ip in _sip_profile and _sip_profile[ip]['time'] > datetime.now():
            log.debug('return from cache')
            return _sip_profile[ip]['value']
        return None

    def sip_profile_show(self):
        response = self._api_call('sip_profile_show', [b'rows\)\r\n'])
        response = response.decode().replace('\r', '')
        log.debug('DnlSwitch response: {}'.format(response))
        lines = [l.replace(' ', '') for l in response.split('\n') if '|' in l][1:]
        items = [dict(name=l.split('|')[0], status=l.split('|')[1], sip_ip=l.split('|')[2], sip_port=l.split('|')[3])
                 for
                 l in lines]
        ret = {'items': items}
        _sip_profile[self.ip] = {'time': datetime.now() + timedelta(minutes=STATS_CACHE_DELAY_MINUTES), 'value': ret}
        return ret

    def sip_channel_dump(self):
        response = self._api_call('sip_channel_dump', [b'rows\)\r\n'])
        response = response.decode().replace('\r', '')
        log.debug('DnlSwitch response: {}'.format(response))
        lines = [l.replace(' ', '') for l in response.split('\n') if '|' in l][1:]
        items = []
        for l in lines:
            item = l.split('|')
            if 'UUID' in item[0] or '---' in item[0]:
                continue
            row = dict(uuid=item[0], status=item[1], type=item[2], start_time=item[3], ip=item[4], port=item[5],
                       parner_uuid=item[6])
            items.append(row)
        ret = {'items': items}
        return ret

    def kill_channel(self, uuid):
        response = self._api_call('kill_channel {}'.format(uuid), [b'\r\n'])
        response = response.decode().replace('\r\n', '')
        return response

    def kill_all_channels(self):
        channels = self.sip_channel_dump()
        ret = []
        for c in channels['items']:
            if 'uuid' in c:
                try:
                    r = self.kill_channel(c['uuid'])
                    ret.append('uuid:{},error:{}'.format(c['uuid'], r))
                except Exception as e:
                    ret.append('uuid:{},error:{}'.format(c['uuid'], e))
            else:
                ret.append('no uuid in sip channels dump'.format(c))
        res = {'items': ret}
        return res


def fix_ident(d):
    ret = {}
    for k in list(d.keys()):
        v = d[k]
        if type(v) == type({}):
            v = fix_ident(v)
        if '-' in k:
            k1 = k.replace('-', '_')
            ret[k1] = v
        else:
            ret[k] = v
    return ret


_login_tested = {}


class CallApi(object):
    """
    """

    @staticmethod
    def test_call(ip, port, caller_ip, caller_port, ani, dnis, include_blocked_route):
        try:
            fs = get_dnl_switch_session(ip, port)
            data = fs.call_simulation(caller_ip, caller_port, ani, dnis, include_blocked_route)
            print(data)
            return data
        except Exception as e:
            import traceback
            log.debug('CallApiException:{} traceback: {}'.format(e, traceback.format_exc()))
            raise CallApiException(e)

    @staticmethod
    def sip_profile_start(ip, port):
        try:
            fs = get_dnl_switch_session(ip, port)
            data = fs.sip_profile_start()
            print(data)
            return data
        except Exception as e:
            raise CallApiException(e)


# switch_session = DnlSwitchSession('127.0.0.1', '4320')
def get_dnl_switch_session(ip, port):
    try:
        if (ip, port) in _sessions:
            sess = _sessions[(ip, port)]
            if hasattr(sess, 'conn') and sess.conn:
                if hasattr(sess, 'logged_in') and sess.logged_in:
                    try:
                        sess._api_read()
                        return sess
                    except Exception as e:
                        log.debug('DnlSwitch session cannot read: {}'.format(e))
            sess.connect()
            sess.login()
            return sess
        else:
            sess = DnlSwitchSession(ip, port)
            sess.login()
            return sess
    except Exception as e:
        raise CallApiException('Switch telnet interface error. Server IP {} Port {}. Detailed error {}'.format(
            ip,
            port,
            e
        ))


if __name__ == '__main__':
    # CallApi.test_call('127.0.0.1','4320','1.2.3.4',5555,1111,2222,0)
    ret = CallApi.sip_profile_start('127.0.0.1', '4320')
    print(ret)
