/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 *  Copyright (c) 2017 by Contributors
 * \file compile_engine.h
 * \brief Internal engine to compile a subgraph fragment and cache compilation.
 */
#ifndef NNVM_COMPILER_COMPILE_ENGINE_H_
#define NNVM_COMPILER_COMPILE_ENGINE_H_

#include <nnvm/graph.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/graph_attr_types.h>
#include <nnvm/tuple.h>
#include <nnvm/pass.h>
#include <nnvm/compiler/op_attr_types.h>
#include <nnvm/compiler/packed_func_ext.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/operation.h>
#include <tvm/lowered_func.h>
#include <string>
#include <utility>
#include "graph_hash.h"

namespace nnvm {
namespace compiler {

/*! \brief A TVM Node to represent compiled graph function */
struct GraphFuncNode : public tvm::Node {
  /* \brief compiled target */
  std::string target;
  /*! \brief Function name */
  std::string func_name;
  /* \brief The inputs to the function */
  tvm::Array<Tensor> inputs;
  /* \brief The outputs to the function */
  tvm::Array<Tensor> outputs;
  /*! \brief The lowered functions */
  tvm::Array<tvm::LoweredFunc> funcs;

  void VisitAttrs(tvm::AttrVisitor* v) final {
    v->Visit("target", &target);
    v->Visit("func_name", &func_name);
    v->Visit("inputs", &inputs);
    v->Visit("outputs", &outputs);
    v->Visit("funcs", &funcs);
  }

  static constexpr const char* _type_key = "GraphFunc";
  TVM_DECLARE_NODE_TYPE_INFO(GraphFuncNode, tvm::Node);
};

TVM_DEFINE_NODE_REF(GraphFunc, GraphFuncNode);

/*! \brief Cache Entry in the graph */
struct GraphCacheEntryNode : public tvm::Node {
  /*! \brief The graph function */
  GraphFunc graph_func;
  /*! \brief Usage statistics */
  int use_count{0};
  /*! \brief Index of the master node for calling schedule*/
  int master_idx;

  void VisitAttrs(tvm::AttrVisitor* v) final {
    v->Visit("graph_func", &graph_func);
    v->Visit("use_count", &use_count);
    v->Visit("master_idx", &master_idx);
  }
  static constexpr const char* _type_key = "GraphCacheEntry";
  TVM_DECLARE_NODE_TYPE_INFO(GraphCacheEntryNode, tvm::Node);
};

class GraphCacheEntry : public ::tvm::NodeRef {
 public:
  GraphCacheEntry() {}
  explicit GraphCacheEntry(::tvm::NodePtr<::tvm::Node> n) : NodeRef(n) {}
  GraphCacheEntryNode* operator->() {
    return static_cast<GraphCacheEntryNode*>(node_.get());
  }
  using ContainerType = GraphCacheEntryNode;
};

/*!
 * \brief Call compile engine to lower a graph with given inputs.
 *
 * \param graph The graph to be compiled
 * \param inputs The input specification.
 * \param target The build target
 * \param master_idx The index of master node for calling schedule
 *
 * \return func A lowered tvm function.
 */
GraphFunc GraphLower(Graph graph,
                     const Array<tvm::Tensor>& inputs,
                     const std::string& target,
                     int master_idx);

/*!
 * \brief Get type flag from TVM Type
 *
 * \param type the tvm type
 * \return corresponding DLDataType
 */
int GetTypeFlag(tvm::Type type);

/*!
 * \brief Get TVM Type from type flag
 *
 * \param type_flag the type flag
 * \return corresponding TVM type
 */
tvm::Type GetTVMType(int type_flag);

}  // namespace compiler
}  // namespace nnvm

#endif  // NNVM_COMPILER_COMPILE_ENGINE_H_