# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""Methods and data structures to support dumping HalideIR to Hybrid Script.
This allows users to do quick hack to generated HalideIR and cast it back to
TVM modules.

To enable this feature, you need to build with -DUSE_HYBRID_DUMP=ON.
"""

import ast

from tvm.contrib import util
from .util import _internal_assert
from .util import _is_tvm_arg_types
from .parser import source_to_op


class HybridModule(object):
    """The usage of Hybrid Module is very similar to conventional TVM module,
    but conventional TVM module requires a function body which is already fully
    lowered. This contradicts to the fact that Hybrid Module is originally a text
    format for Phase 0 HalideIR. Thus, a totally separated module is defined."""


    def __init__(self, src=None, name=None):
        """The constructor of this a hybrid module

        Parameters
        ----------
        src : str
            The source code of this module

        name : str
            The name of this module
        """
        self.src_ = self.name = self.func_ = self.root_ = None
        if src is not None:
            temp = util.tempdir()
            dst = temp.relpath("script.py")
            with open(dst, 'w') as f:
                f.write("import tvm\n@tvm.te.hybrid.script\n%s" % src)

            if name is not None:
                self.name = name
            self.load(dst)


    def __call__(self, *args):
        if _is_tvm_arg_types(args):
            return source_to_op(self.root_, args, globals(), {})
        return self.func_(*args)


    def get_source(self):
        return self.src_


    def save(self, path):
        if not path.endswith('.py'):
            path = path + '.py'
        with open(path, 'w') as f:
            f.write(self.src_)


    def load(self, path):
        """Load the module from a python file

        Parameters
        ----------
        path : str
            Path to the given python file
        """
        with open(path, 'r') as f:
            self.src_ = f.read()

        src = self.src_

        class FindFunc(ast.NodeVisitor):
            """ Find the function in module to be loaded module. """
            #pylint: disable=invalid-name
            def __init__(self):
                self.name = None
                self.root = None


            def visit_FunctionDef(self, node):
                _internal_assert(self.name is None, "For now, only one function supported!")
                self.name = node.name
                _internal_assert(self.root is None, "For now, only one function supported!")
                self.root = node

        root = ast.parse(src)
        finder = FindFunc()
        finder.visit(root)
        _internal_assert(finder.name is not None and finder.root is not None, \
                         "No function found!")
        if self.name is None:
            self.name = finder.name
        self.root_ = finder.root

        _, local_ = {}, {}
        exec(self.src_, _, local_) #pylint: disable=exec-used
        local_.pop('tvm')
        assert len(local_) == 1
        self.func_ = list(local_.values())[0]