# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-variable,invalid-name,unused-argument
"""
Decorators for registering tunable templates to TOPI.

These decorators can make your simple implementation be able to use different configurations
for different workloads.
Here we directly use all arguments to the TOPI call as "workload", so make sure all the arguments
(except tvm.Tensor) in you calls are hashable. For tvm.Tensor, we will serialize it to a hashable
tuple.

See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage.
"""

from ... import _api_internal, tensor, placeholder, create_schedule

from .task import args_to_workload, dispatcher, register
from ..util import get_const_tuple

# A table that records all registered dispatcher for all targets
_REGISTERED_DISPATCHER = {
}


def serialize_args(args):
    """serialize arguments of a topi function to a hashable tuple.

    Parameters
    ----------
    args: list of hashable or Tensor
    """
    ret = []
    for t in args:
        if isinstance(t, tensor.Tensor):
            ret.append(('TENSOR', get_const_tuple(t.shape), t.dtype))
        else:
            ret.append(t)
    return tuple(ret)


def deserialize_args(args):
    """The inverse function of :code:`serialize_args`.

    Parameters
    ----------
    args: list of hashable or Tensor
    """
    ret = []
    for t in args:
        if isinstance(t, tuple) and t[0] == 'TENSOR':
            ret.append(placeholder(shape=t[1], dtype=t[2]))
        else:
            ret.append(t)
    return ret


# Task extractor for nnvm graph, relay program
class TaskExtractEnv:
    """Global environment for extracting tuning tasks from nnvm graph"""
    current = None

    def __init__(self):
        import topi

        # topi compute -> autotvm task name
        self.topi_to_task = {
            topi.nn.conv2d: "topi_nn_conv2d",
            topi.nn.depthwise_conv2d_nchw: "topi_nn_depthwise_conv2d_nchw",
            topi.nn.group_conv2d_nchw: "topi_nn_group_conv2d_nchw",
            topi.nn.conv2d_transpose_nchw: "topi_nn_conv2d_transpose_nchw",
            topi.nn.conv2d_NCHWc: "topi_x86_conv2d_NCHWc",
            topi.nn.dense: "topi_nn_dense",
            topi.nn.bitserial_conv2d_nchw: "topi_nn_bitserial_conv2d_nchw",
            topi.nn.bitserial_conv2d_nhwc: "topi_nn_bitserial_conv2d_nhwc",
            topi.nn.deformable_conv2d_nchw: "topi_nn_deformable_conv2d_nchw",
        }

        self.topi_to_schedule = {
            topi.nn.conv2d: [topi.generic.schedule_conv2d_nchw,
                             topi.generic.schedule_conv2d_nhwc],
            topi.nn.depthwise_conv2d_nchw: [topi.generic.schedule_depthwise_conv2d_nchw,
                                            topi.generic.schedule_depthwise_conv2d_nhwc],
            topi.nn.group_conv2d_nchw: [topi.generic.schedule_group_conv2d_nchw],
            topi.nn.conv2d_transpose_nchw: [topi.generic.schedule_conv2d_transpose_nchw],
            topi.nn.conv2d_NCHWc: [topi.generic.schedule_conv2d_NCHWc],
            topi.nn.dense: [topi.generic.schedule_dense],
            topi.nn.bitserial_conv2d_nchw: [topi.generic.schedule_bitserial_conv2d_nchw],
            topi.nn.bitserial_conv2d_nhwc: [topi.generic.schedule_bitserial_conv2d_nhwc],
            topi.nn.deformable_conv2d_nchw: [topi.generic.schedule_deformable_conv2d_nchw],
        }

        self._register_tracing()
        self._register_topi_task()
        self.task_collection = []
        self.wanted_topi_funcs = list(self.topi_to_task.keys())

    def _register_tracing(self):
        """Register tracing function to track the topi function call"""
        # register topi compute for "tracing" target
        for topi_compute in self.topi_to_task:
            def _local_scope(compute_func):
                """start a scope to hold the local function in for loop"""

                @compute_func.register("tracing", )
                def _tracing_topi_compute(*args, **kwargs):
                    assert not kwargs, "Do not support extracting tuning tasks when" \
                                       "kwargs is used in TOPI function call." \
                                       "Please modify it to use only positional args."

                    if compute_func in self.wanted_topi_funcs:  # record this call
                        key = (self.topi_to_task[compute_func], serialize_args(args))
                        if key not in self.task_collection:
                            self.task_collection.append(key)
                    return compute_func.fdefault(*args)
            _local_scope(topi_compute)

        # register topi schedule for "tracing" target
        for topi_compute in self.topi_to_task:
            for topi_schedule in self.topi_to_schedule[topi_compute]:
                def _local_scope_(schedule_func):
                    """start a scope to hold the local function in for loop"""

                    @schedule_func.register("tracing", )
                    def _tracing_topi_compute(outs):
                        outs = [outs] if isinstance(outs, tensor.Tensor) else outs
                        return create_schedule([x.op for x in outs])
                _local_scope_(topi_schedule)

    def _register_topi_task(self):
        """register tuning wrapper for topi function"""
        import topi

        # Tuning wrapper for topi functions
        @register("topi_nn_conv2d")
        def _topi_nn_conv2d(*args, **kwargs):
            assert not kwargs, "Do not support kwargs in template function call"
            args = deserialize_args(args)
            A, W = args[:2]
            layout = args[-2]
            assert layout == 'NCHW', "only support NCHW currently"
            C = topi.nn.conv2d(*args, **kwargs)
            s = topi.generic.schedule_conv2d_nchw([C])
            return s, [A, W, C]

        @register("topi_nn_depthwise_conv2d_nchw")
        def _topi_nn_depthwise_conv2d_nchw(*args, **kwargs):
            assert not kwargs, "Do not support kwargs in template function call"
            args = deserialize_args(args)
            A, W = args[:2]
            C = topi.nn.depthwise_conv2d_nchw(*args, **kwargs)
            s = topi.generic.schedule_depthwise_conv2d_nchw([C])
            return s, [A, W, C]

        @register("topi_nn_group_conv2d_nchw")
        def _topi_nn_group_conv2d_nchw(*args, **kwargs):
            assert not kwargs, "Do not support kwargs in template function call"
            args = deserialize_args(args)
            A, W = args[:2]
            C = topi.nn.group_conv2d_nchw(*args, **kwargs)
            s = topi.generic.schedule_group_conv2d_nchw([C])
            return s, [A, W, C]

        @register("topi_nn_conv2d_transpose_nchw")
        def _topi_nn_conv2d_transpose_nchw(*args, **kwargs):
            assert not kwargs, "Do not support kwargs in template function call"
            args = deserialize_args(args)
            A, W = args[:2]
            C = topi.nn.conv2d_transpose_nchw(*args, **kwargs)
            s = topi.generic.schedule_conv2d_transpose_nchw([C])
            return s, [A, W, C]

        @register("topi_nn_dense")
        def _topi_nn_dense(*args, **kwargs):
            assert not kwargs, "Do not support kwargs in template function call"
            args = deserialize_args(args)
            data, weight, bias = args
            C = topi.nn.dense(*args, **kwargs)
            s = topi.generic.schedule_dense([C])
            if bias is not None:
                return s, [data, weight, bias, C]
            return s, [data, weight, C]

        @register("topi_nn_bitserial_conv2d_nhwc")
        def _topi_bitserial_conv2d_nhwc(*args, **kwargs):
            args = deserialize_args(args)
            C = topi.nn.bitserial_conv2d_nhwc(*args, **kwargs)
            s = topi.generic.nn.schedule_bitserial_conv2d_nhwc([C])
            data = args[0]
            kernel = args[1]
            return s, [data, kernel, C]

        @register("topi_nn_bitserial_conv2d_nchw")
        def _topi_bitserial_conv2d_nchw(*args, **kwargs):
            args = deserialize_args(args)
            C = topi.nn.bitserial_conv2d_nchw(*args, **kwargs)
            s = topi.generic.nn.schedule_bitserial_conv2d_nchw([C])
            data = args[0]
            kernel = args[1]
            return s, [data, kernel, C]

        @register("topi_nn_deformable_conv2d_nchw")
        def _topi_nn_deformable_conv2d_nchw(*args, **kwargs):
            assert not kwargs, "Do not support kwargs in template function call"
            args = deserialize_args(args)
            A, Offset, W = args[:3]
            C = topi.nn.deformable_conv2d_nchw(*args, **kwargs)
            s = topi.generic.schedule_deformable_conv2d_nchw([C])
            return s, [A, Offset, W, C]

        @register("topi_nn_conv2d_NCHWc")
        def _topi_nn_conv2d_NCHWc(*args, **kwargs):
            assert not kwargs, "Do not support kwargs in template function call"
            args = deserialize_args(args)
            A, W = args[:2]
            C = topi.nn.conv2d_NCHWc(*args, **kwargs)
            s = topi.generic.schedule_conv2d_NCHWc([C])
            return s, [A, W, C]

    def reset(self, wanted_topi_funcs):
        """Reset task collections

        Parameters
        ----------
        wanted_topi_funcs: List of function
            The topi function to be extracted
        """
        self.task_collection = []
        self.wanted_topi_funcs = wanted_topi_funcs

    def get_tasks(self):
        """Get collected tasks

        Returns
        -------
        tasks: List of tuple(name, args)
            A list of tasks extracted from the nnvm graph
        """
        return self.task_collection

    @staticmethod
    def get():
        """Get the single instance of TaskExtractEnv

        Returns
        -------
        env: TaskExtractEnv
            The single instance of TaskExtractEnv
        """
        if not TaskExtractEnv.current:
            TaskExtractEnv.current = TaskExtractEnv()
        return TaskExtractEnv.current


def register_topi_compute(topi_compute, target_keys, template_keys, func=None):
    """Register a tunable template for a topi compute function.

    After the registration, this topi compute will become a configuration dispatcher. It uses
    all its argument as workload and dispatches configurations according to the input workload.

    It also stores this "workload" to its final ComputeOp, which can be used to reconstruct
    "workload" in the following topi_schedule call.

    Parameters
    ----------
    topi_compute: GenericFunc
        The topi compute function that will be overloaded
    target_keys: str or list of str
        The compilation target. The same as the argument of GenericFunc.register.
    template_keys: str or list of str
        The template key.
        We might have several strategies for a single operator (e.g. direct, im2col, winograd).
        The template key is used to identity the algorithm strategy.
        Every operator must have a "direct" template, which is used by default.
    func: None or callable
        If it is None, return a decorator.
        If is callable, decorate this function.

    Returns
    -------
    decorator: callable
        A decorator

    Examples
    --------
    See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage.
    """
    def _decorator(f):
        targets = [target_keys] if isinstance(target_keys, str) else target_keys
        for target_key in targets:
            if target_key not in _REGISTERED_DISPATCHER:
                _REGISTERED_DISPATCHER[target_key] = {}
            if topi_compute not in _REGISTERED_DISPATCHER[target_key]:
                @topi_compute.register(target_key)
                @dispatcher
                def config_dispatcher(*args, **kwargs):
                    """override topi call as a config dispatcher"""
                    assert not kwargs, "Do not support kwargs in template function call"
                    return args_to_workload(args, topi_compute)
                _REGISTERED_DISPATCHER[target_key][topi_compute] = config_dispatcher

            config_dispatcher = _REGISTERED_DISPATCHER[target_key][topi_compute]

            @config_dispatcher.register(template_keys)
            def template_call(cfg, *args, **kwargs):
                """call the topi func and attach workload to compute node"""
                assert not kwargs, "Do not support kwargs in template function call"

                if f == topi_compute.fdefault:
                    node = f(*args, **kwargs)
                else:
                    node = f(cfg, *args, **kwargs)

                # attach workload to return op
                op = node.op
                attrs = {}
                for k, v in node.op.attrs.items():
                    attrs[k] = v
                attrs['workload'] = args_to_workload(args, topi_compute)
                if isinstance(op, tensor.ComputeOp):
                    op = _api_internal._ComputeOp(
                        op.name, op.tag, attrs, op.axis, op.body)
                elif isinstance(op, tensor.ExternOp):
                    op = _api_internal._ExternOp(
                        op.name, op.tag, attrs,
                        op.inputs, op.input_placeholders,
                        op.output_placeholders, op.body)
                else:
                    raise RuntimeError("Unsupported op type: " + str(type(op)))

                if isinstance(node, tensor.Tensor):
                    return op.output(0)
                return [op.output(i) for i in range(len(node))]

        return f

    if func:
        _decorator(func)

    return _decorator


def register_topi_schedule(topi_schedule, target_keys, template_keys, func=None):
    """Register a tunable template for a topi schedule function.

    After the registration. This topi schedule will become a configuration dispatcher. It dispatches
    configurations according to the input workload.

    Note that this function will try to find "workload" from all the ComputeOp in the input.
    You can attach "workload" to your compute op by using :any:`register_topi_compute`.

    Parameters
    ----------
    topi_schedule: GenericFunc
        The topi schedule function that will be overloaded
    target_keys: str or list of str
        The compilation target
    template_keys: str or list of str
        The template key.
        We might have several strategies for a single operator (e.g. direct, im2col, winograd).
        The template key is used to identity the algorithm strategy.
        Every operator must have a "direct" template, which is used by default.
    func: None or callable
        If it is None, return a decorator.
        If is callable, decorate this function.

    Returns
    -------
    decorator: callable
        A decorator

    Examples
    --------
    See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage.
    """
    def _decorator(f):
        targets = [target_keys] if isinstance(target_keys, str) else target_keys
        for target_key in targets:
            if target_key not in _REGISTERED_DISPATCHER:
                _REGISTERED_DISPATCHER[target_key] = {}
            if topi_schedule not in _REGISTERED_DISPATCHER[target_key]:
                @topi_schedule.register(target_key)
                @dispatcher
                def config_dispatcher(outs, *args, **kwargs):
                    """override topi call as a workload dispatcher"""
                    def traverse(tensors):
                        """traverse all ops to find attached workload"""
                        for t in tensors:
                            op = t.op
                            if 'workload' in op.attrs:
                                return op.attrs['workload']
                            wkl = traverse(op.input_tensors)
                            if wkl:
                                return wkl
                        return None

                    outs = [outs] if isinstance(outs, tensor.Tensor) else outs
                    workload = traverse(outs)

                    if workload is None:
                        raise RuntimeError("Cannot find workload in attribute of this schedule")

                    return args_to_workload(workload)

                _REGISTERED_DISPATCHER[target_key][topi_schedule] = config_dispatcher

            config_dispatcher = _REGISTERED_DISPATCHER[target_key][topi_schedule]

            @config_dispatcher.register(template_keys)
            def template_call(cfg, outs, *args, **kwargs):
                """call the schedule func"""
                if f == topi_schedule.fdefault:
                    return f(outs, *args, **kwargs)
                return f(cfg, outs, *args, **kwargs)

        return f

    if func:
        _decorator(func)

    return _decorator