/*!
 * Copyright (c) 2018 by Contributors
 * \file graph_fuse.h
 * \brief Definition of structs used by graph fusion
*/
#ifndef NNVM_COMPILER_GRAPH_FUSE_H_
#define NNVM_COMPILER_GRAPH_FUSE_H_

#include <nnvm/graph.h>
#include <vector>

#include "compile_engine.h"

namespace nnvm {
namespace compiler {

// The single fuse rule.
enum class FuseRule {
  kUknown,
  kFuseToMaster,
  kRealize
};

/*!
 * \brief Get DLDataType from dtype flag.
 *
 * \param type_flag The data type flag
 * \return corresponding DLDataType
 */
inline DLDataType GetDLType(int type_flag) {
  return tvm::Type2TVMType(GetTVMType(type_flag));
}

struct INodeEntryHash {
  size_t operator()(const IndexedGraph::NodeEntry& e) const {
    return e.node_id;
  }
};

struct INodeEntryEqual {
  size_t operator()(const IndexedGraph::NodeEntry &a,
                    const IndexedGraph::NodeEntry &b) const {
    return a.node_id == b.node_id && a.index == b.index;
  }
};

// Auxiliary data structure for representing fused op.
struct FuseEntry {
  // Subgraph of the fragment
  Graph subgraph;
  // The input map
  std::unordered_map<IndexedGraph::NodeEntry, nnvm::NodeEntry, INodeEntryHash,
                     INodeEntryEqual>
      imap;
  // Reverse map to the old input entry
  std::unordered_map<const Node *, IndexedGraph::NodeEntry> reverse_imap;
  // TVM Placeholder for inputs
  std::unordered_map<const Node *, Tensor> input_info;
  // Whether we can flatten data
  bool flatten_data;
  // The corresponding function.
  GraphFunc compiled_func;
};

// GroupVec stores the root node ids of the fused nodes.
using GroupVec = std::vector<int>;

// MasterVec stores master node ids of fused groups.
using MasterVec = std::vector<int>;

// FuseVec stores fused entries.
using FuseEntryVec = std::vector<FuseEntry>;

// PatternVec stores operator patterns.
using PatternVec = std::vector<TOpPattern>;

}  // namespace compiler
}  // namespace nnvm

#endif  // NNVM_COMPILER_GRAPH_FUSE_H_