###############################################################################
#
# The MIT License (MIT)
#
# Copyright (c) Crossbar.io Technologies GmbH
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
###############################################################################


from __future__ import absolute_import, print_function

import six
import ssl  # XXX what Python version is this always available at?
import signal
from functools import partial, wraps

try:
    import asyncio
except ImportError:
    # Trollius >= 0.3 was renamed to asyncio
    # noinspection PyUnresolvedReferences
    import trollius as asyncio

import txaio
txaio.use_asyncio()  # noqa

from autobahn.asyncio.websocket import WampWebSocketClientFactory
from autobahn.asyncio.rawsocket import WampRawSocketClientFactory

from autobahn.wamp import component
from autobahn.wamp.exception import TransportLost

from autobahn.asyncio.wamp import Session
from autobahn.wamp.serializer import create_transport_serializers, create_transport_serializer


__all__ = ('Component',)


def _unique_list(seq):
    """
    Return a list with unique elements from sequence, preserving order.
    """
    seen = set()
    return [x for x in seq if x not in seen and not seen.add(x)]


def _camel_case_from_snake_case(snake):
    parts = snake.split('_')
    return parts[0] + ''.join([s.capitalize() for s in parts[1:]])


def _create_transport_factory(loop, transport, session_factory):
    """
    Create a WAMP-over-XXX transport factory.
    """
    if transport.type == u'websocket':
        serializers = create_transport_serializers(transport)
        factory = WampWebSocketClientFactory(
            session_factory,
            url=transport.url,
            serializers=serializers,
            proxy=transport.proxy,  # either None or a dict with host, port
        )

    elif transport.type == u'rawsocket':
        serializer = create_transport_serializer(transport.serializers[0])
        factory = WampRawSocketClientFactory(session_factory, serializer=serializer)

    else:
        assert(False), 'should not arrive here'

    # set the options one at a time so we can give user better feedback
    for k, v in transport.options.items():
        try:
            factory.setProtocolOptions(**{k: v})
        except (TypeError, KeyError):
            # this allows us to document options as snake_case
            # until everything internally is upgraded from
            # camelCase
            try:
                factory.setProtocolOptions(
                    **{_camel_case_from_snake_case(k): v}
                )
            except (TypeError, KeyError):
                raise ValueError(
                    "Unknown {} transport option: {}={}".format(transport.type, k, v)
                )
    return factory


class Component(component.Component):
    """
    A component establishes a transport and attached a session
    to a realm using the transport for communication.

    The transports a component tries to use can be configured,
    as well as the auto-reconnect strategy.
    """

    log = txaio.make_logger()

    session_factory = Session
    """
    The factory of the session we will instantiate.
    """

    def _is_ssl_error(self, e):
        """
        Internal helper.
        """
        return isinstance(e, ssl.SSLError)

    def _check_native_endpoint(self, endpoint):
        if isinstance(endpoint, dict):
            if u'tls' in endpoint:
                tls = endpoint[u'tls']
                if isinstance(tls, (dict, bool)):
                    pass
                elif isinstance(tls, ssl.SSLContext):
                    pass
                else:
                    raise ValueError(
                        "'tls' configuration must be a dict, bool or "
                        "SSLContext instance"
                    )
        else:
            raise ValueError(
                "'endpoint' configuration must be a dict or IStreamClientEndpoint"
                " provider"
            )

    # async function
    def _connect_transport(self, loop, transport, session_factory, done):
        """
        Create and connect a WAMP-over-XXX transport.
        """
        factory = _create_transport_factory(loop, transport, session_factory)

        # XXX the rest of this should probably be factored into its
        # own method (or three!)...

        if transport.proxy:
            timeout = transport.endpoint.get(u'timeout', 10)  # in seconds
            if type(timeout) not in six.integer_types:
                raise ValueError('invalid type {} for timeout in client endpoint configuration'.format(type(timeout)))
            # do we support HTTPS proxies?

            f = loop.create_connection(
                protocol_factory=factory,
                host=transport.proxy['host'],
                port=transport.proxy['port'],
            )
            time_f = asyncio.ensure_future(asyncio.wait_for(f, timeout=timeout))
            return self._wrap_connection_future(transport, done, time_f)

        elif transport.endpoint[u'type'] == u'tcp':

            version = transport.endpoint.get(u'version', 4)
            if version not in [4, 6]:
                raise ValueError('invalid IP version {} in client endpoint configuration'.format(version))

            host = transport.endpoint[u'host']
            if type(host) != six.text_type:
                raise ValueError('invalid type {} for host in client endpoint configuration'.format(type(host)))

            port = transport.endpoint[u'port']
            if type(port) not in six.integer_types:
                raise ValueError('invalid type {} for port in client endpoint configuration'.format(type(port)))

            timeout = transport.endpoint.get(u'timeout', 10)  # in seconds
            if type(timeout) not in six.integer_types:
                raise ValueError('invalid type {} for timeout in client endpoint configuration'.format(type(timeout)))

            tls = transport.endpoint.get(u'tls', None)
            tls_hostname = None

            # create a TLS enabled connecting TCP socket
            if tls:
                if isinstance(tls, dict):
                    for k in tls.keys():
                        if k not in [u"hostname", u"trust_root"]:
                            raise ValueError("Invalid key '{}' in 'tls' config".format(k))
                    hostname = tls.get(u'hostname', host)
                    if type(hostname) != six.text_type:
                        raise ValueError('invalid type {} for hostname in TLS client endpoint configuration'.format(hostname))
                    cert_fname = tls.get(u'trust_root', None)

                    tls_hostname = hostname
                    tls = True
                    if cert_fname is not None:
                        tls = ssl.create_default_context(
                            purpose=ssl.Purpose.SERVER_AUTH,
                            cafile=cert_fname,
                        )

                elif isinstance(tls, ssl.SSLContext):
                    # tls=<an SSLContext> is valid
                    tls_hostname = host

                elif tls in [False, True]:
                    if tls:
                        tls_hostname = host

                else:
                    raise RuntimeError('unknown type {} for "tls" configuration in transport'.format(type(tls)))

            f = loop.create_connection(
                protocol_factory=factory,
                host=host,
                port=port,
                ssl=tls,
                server_hostname=tls_hostname,
            )
            time_f = asyncio.ensure_future(asyncio.wait_for(f, timeout=timeout))
            return self._wrap_connection_future(transport, done, time_f)

        elif transport.endpoint[u'type'] == u'unix':
            path = transport.endpoint[u'path']
            timeout = int(transport.endpoint.get(u'timeout', 10))  # in seconds

            f = loop.create_unix_connection(
                protocol_factory=factory,
                path=path,
            )
            time_f = asyncio.ensure_future(asyncio.wait_for(f, timeout=timeout))
            return self._wrap_connection_future(transport, done, time_f)

        else:
            assert(False), 'should not arrive here'

    def _wrap_connection_future(self, transport, done, conn_f):

        def on_connect_success(result):
            # async connect call returns a 2-tuple
            transport, proto = result

            # if e.g. an SSL handshake fails, we will have
            # successfully connected (i.e. get here) but need to
            # 'listen' for the "connection_lost" from the underlying
            # protocol in case of handshake failure .. so we wrap
            # it. Also, we don't increment transport.success_count
            # here on purpose (because we might not succeed).

            # XXX double-check that asyncio behavior on TLS handshake
            # failures is in fact as described above
            orig = proto.connection_lost

            @wraps(orig)
            def lost(fail):
                rtn = orig(fail)
                if not txaio.is_called(done):
                    # asyncio will call connection_lost(None) in case of
                    # a transport failure, in which case we create an
                    # appropriate exception
                    if fail is None:
                        fail = TransportLost("failed to complete connection")
                    txaio.reject(done, fail)
                return rtn
            proto.connection_lost = lost

        def on_connect_failure(err):
            transport.connect_failures += 1
            # failed to establish a connection in the first place
            txaio.reject(done, err)

        txaio.add_callbacks(conn_f, on_connect_success, None)
        # the errback is added as a second step so it gets called if
        # there as an error in on_connect_success itself.
        txaio.add_callbacks(conn_f, None, on_connect_failure)
        return conn_f

    # async function
    def start(self, loop=None):
        """
        This starts the Component, which means it will start connecting
        (and re-connecting) to its configured transports. A Component
        runs until it is "done", which means one of:
        - There was a "main" function defined, and it completed successfully;
        - Something called ``.leave()`` on our session, and we left successfully;
        - ``.stop()`` was called, and completed successfully;
        - none of our transports were able to connect successfully (failure);
        :returns: a Future which will resolve (to ``None``) when we are
            "done" or with an error if something went wrong.
        """

        if loop is None:
            self.log.warn("Using default loop")
            loop = asyncio.get_event_loop()

        return self._start(loop=loop)


def run(components, log_level='info'):
    """
    High-level API to run a series of components.

    This will only return once all the components have stopped
    (including, possibly, after all re-connections have failed if you
    have re-connections enabled). Under the hood, this calls

    XXX fixme for asyncio

    -- if you wish to manage the loop loop yourself, use the
    :meth:`autobahn.asyncio.component.Component.start` method to start
    each component yourself.

    :param components: the Component(s) you wish to run
    :type components: Component or list of Components

    :param log_level: a valid log-level (or None to avoid calling start_logging)
    :type log_level: string
    """

    # actually, should we even let people "not start" the logging? I'm
    # not sure that's wise... (double-check: if they already called
    # txaio.start_logging() what happens if we call it again?)
    if log_level is not None:
        txaio.start_logging(level=log_level)
    loop = asyncio.get_event_loop()
    if loop.is_closed():
        asyncio.set_event_loop(asyncio.new_event_loop())
        loop = asyncio.get_event_loop()
        txaio.config.loop = loop
    log = txaio.make_logger()

    # see https://github.com/python/asyncio/issues/341 asyncio has
    # "odd" handling of KeyboardInterrupt when using Tasks (as
    # run_until_complete does). Another option is to just resture
    # default SIGINT handling, which is to exit:
    #   import signal
    #   signal.signal(signal.SIGINT, signal.SIG_DFL)

    @asyncio.coroutine
    def exit():
        return loop.stop()

    def nicely_exit(signal):
        log.info("Shutting down due to {signal}", signal=signal)
        for task in asyncio.Task.all_tasks():
            task.cancel()
        asyncio.ensure_future(exit())

    try:
        loop.add_signal_handler(signal.SIGINT, partial(nicely_exit, 'SIGINT'))
        loop.add_signal_handler(signal.SIGTERM, partial(nicely_exit, 'SIGTERM'))
    except NotImplementedError:
        # signals are not available on Windows
        pass

    def done_callback(loop, arg):
        loop.stop()

    # returns a future; could run_until_complete() but see below
    component._run(loop, components, done_callback)

    try:
        loop.run_forever()
        # this is probably more-correct, but then you always get
        # "Event loop stopped before Future completed":
        # loop.run_until_complete(f)
    except asyncio.CancelledError:
        pass
    # finally:
    #     signal.signal(signal.SIGINT, signal.SIG_DFL)
    #     signal.signal(signal.SIGTERM, signal.SIG_DFL)

    # Close the event loop at the end, otherwise an exception is
    # thrown. https://bugs.python.org/issue23548
    loop.close()
