# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""TF: Tensorflow parser"""
from __future__ import absolute_import as _abs
from __future__ import print_function
import os
from tvm.contrib import util


class TFParser(object):
    """
    A Wrapper to handle tensorflow models parsing, TensorFlow is needed

    Parameters
    ----------
    model_dir : tensorflow frozen pb file or a directory that contains saved
    model or checkpoints.

    Examples
    --------
    .. code-block:: python

        parser = TFParser(model_dir)
        graphdef = parser.parse()
    """

    def __init__(self, model_dir):
        from tensorflow.core.framework import graph_pb2
        self._tmp_dir = util.tempdir()
        self._model_dir = model_dir
        self._graph = graph_pb2.GraphDef()

    def _set_graph(self, graph):
        """Set Graph"""
        self._graph = graph

    def _get_graph(self):
        """Get Graph"""
        return self._graph

    def _load_pb_file(self):
        """Load single pb file"""
        graph = self._get_graph()
        with open(self._model_dir, "rb") as f:
            graph.ParseFromString(f.read())
        return graph

    def _get_tag_set(self):
        """Return the tag set of saved model, multiple metagraphs are not supported"""
        try:
            from tensorflow.contrib.saved_model.python.saved_model import reader
        except ImportError:
            raise ImportError(
                "InputConfiguration: Unable to import saved_model.reader which is "
                "required to get tag set from saved model.")
        tag_sets = reader.get_saved_model_tag_sets(self._model_dir)
        return tag_sets[0]

    def _get_output_names(self):
        """Return the concatenated output names"""
        try:
            import tensorflow as tf
        except ImportError:
            raise ImportError(
                "InputConfiguration: Unable to import tensorflow which is "
                "required to restore from saved model.")
        tags = self._get_tag_set()
        output_names = set()
        with tf.Session() as sess:
            meta_graph_def = tf.saved_model.loader.load(sess,
                                                        tags,
                                                        self._model_dir)
            for sig_def in meta_graph_def.signature_def.values():
                for output_tensor in sig_def.outputs.values():
                    output_names.add(output_tensor.name.replace(":0", ""))
        tf.reset_default_graph()
        return ",".join(output_names)

    def _load_saved_model(self):
        """Load the tensorflow saved model."""
        try:
            from tensorflow.python.tools import freeze_graph
            from tensorflow.python.framework import ops
            from tensorflow.python.framework import graph_util
            from tensorflow.core.framework import graph_pb2
        except ImportError:
            raise ImportError(
                "InputConfiguration: Unable to import tensorflow which is "
                "required to restore from saved model.")

        saved_model_dir = self._model_dir
        output_graph_filename = self._tmp_dir.relpath("tf_frozen_model.pb")
        input_saved_model_dir = saved_model_dir
        output_node_names = self._get_output_names()

        input_binary = False
        input_saver_def_path = False
        restore_op_name = None
        filename_tensor_name = None
        clear_devices = True
        input_meta_graph = False
        checkpoint_path = None
        input_graph_filename = None
        saved_model_tags = ",".join(self._get_tag_set())

        freeze_graph.freeze_graph(input_graph_filename, input_saver_def_path,
                                  input_binary, checkpoint_path, output_node_names,
                                  restore_op_name, filename_tensor_name,
                                  output_graph_filename, clear_devices, "", "", "",
                                  input_meta_graph, input_saved_model_dir,
                                  saved_model_tags)

        with ops.Graph().as_default():
            output_graph_def = graph_pb2.GraphDef()
            with open(output_graph_filename, "rb") as f:
                output_graph_def.ParseFromString(f.read())
            output_graph_def = graph_util.remove_training_nodes(output_graph_def)
            return output_graph_def

    def _load_ckpt(self):
        """TODO: Load checkpoint model."""
        raise RuntimeError("InputConfiguration: Loading tf checkpoint model is "
                           "not supported yet.")

    def parse(self):
        """
        Parse tensorflow models: checkpoints, saved models, and single frozen pb file.

        Returns
        -------
        GraphDef of the passed model
        """

        graph = None

        if os.path.isdir(self._model_dir):
            ckpt = os.path.join(self._model_dir, "checkpoint")
            if not os.path.isfile(ckpt):
                if not os.path.isdir(os.path.join(self._model_dir, "variables")):
                    raise RuntimeError("InputConfiguration: Invalid model path.")
                graph = self._load_saved_model()
            else:
                graph = self._load_ckpt()
        elif os.path.isfile(self._model_dir):
            # Only .pb or .pbtxt is a valid suffix name.
            if self._model_dir.endswith(".pb") or \
               self._model_dir.endswith(".pbtxt"):
                cur_dir = os.path.dirname(self._model_dir)
            else:
                raise RuntimeError("InputConfiguration: Invalid model format.")

            # It is a saved model if `variables` directory is present at the
            # same directory with the pb or pbtxt file.
            if os.path.isdir(os.path.join(cur_dir, "variables")):
                self._model_dir = cur_dir
                graph = self._load_saved_model()
            else:
                graph = self._load_pb_file()
        else:
            raise RuntimeError("InputConfiguration: Unrecognized model "
                               "file or path.")

        self._set_graph(graph)
        return graph