/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 *  Copyright (c) 2017 by Contributors
 * \file compile_engine.cc
 * \brief The compile engine.
 */
#include <dmlc/common.h>
#include <tvm/ir.h>
#include <tvm/operation.h>
#include <nnvm/graph.h>
#include <nnvm/node.h>
#include <nnvm/pass_functions.h>
#include <nnvm/compiler/op_attr_types.h>
#include <mutex>
#include <tuple>
#include <vector>
#include <limits>
#include <unordered_map>
#include "graph_hash.h"
#include "compile_engine.h"

namespace nnvm {
namespace compiler {

using namespace tvm;

/*!
 * \brief Get type flag from TVM Type
 *
 * \param type the tvm type.
 * \return corresponding DLDataType
 */
int GetTypeFlag(tvm::Type type) {
  if (type == tvm::Float(32)) return 0;
  if (type == tvm::Float(64)) return 1;
  if (type == tvm::Float(16)) return 2;
  if (type == tvm::UInt(8)) return 3;
  if (type == tvm::Int(32)) return 4;
  if (type == tvm::Int(8)) return 5;
  if (type == tvm::Int(64)) return 6;
  if (type == tvm::Int(16)) return 7;
  if (type == tvm::UInt(16)) return 8;
  if (type == tvm::UInt(32)) return 9;
  if (type == tvm::UInt(64)) return 10;
  if (type == tvm::UInt(1)) return 11;
  LOG(FATAL) << "cannot convert " << type;
  return 0;
}
// convert from type flag to tvm type.
Type GetTVMType(int type_flag) {
  switch (type_flag) {
    case 0:
      return tvm::Float(32);
    case 1:
      return tvm::Float(64);
    case 2:
      return tvm::Float(16);
    case 3:
      return tvm::UInt(8);
    case 4:
      return tvm::Int(32);
    case 5:
      return tvm::Int(8);
    case 6:
      return tvm::Int(64);
    case 7:
      return tvm::Int(16);
    case 8:
      return tvm::UInt(16);
    case 9:
      return tvm::UInt(32);
    case 10:
      return tvm::UInt(64);
    case 11:
      return tvm::UInt(1);
    default:
      LOG(FATAL) << "unknown type_flag=" << type_flag;
      return Float(32);
  }
}

// internal compile engine
class CompileEngine {
 public:
  static CompileEngine* Global() {
    static CompileEngine inst;
    return &inst;
  }
  // lower graph possible get back an cached op.
  GraphFunc Lower(Graph graph,
                  const Array<tvm::Tensor>& inputs,
                  const std::string& target,
                  int master_idx) {
    GraphKey key = GraphKeyNode::make(graph, inputs, target);
    std::lock_guard<std::mutex> lock(mutex_);
    auto it = cache_.find(key);
    if (it != cache_.end()) {
      ++(it->second->use_count);
      return it->second->graph_func;
    }
    GraphFunc f = DoLower(key->graph, key->inputs, key->target, master_idx);
    auto n = tvm::make_node<GraphCacheEntryNode>();
    n->graph_func = f;
    n->use_count = 1;
    n->master_idx = master_idx;
    cache_[key] = GraphCacheEntry(n);
    return f;
  }
  // List all items in the cache.
  Array<NodeRef> ListCacheItems() {
    std::lock_guard<std::mutex> lock(mutex_);
    Array<NodeRef> items;
    for (auto& kv : cache_) {
      items.push_back(kv.first);
      auto n = tvm::make_node<GraphCacheEntryNode>(*(kv.second.operator->()));
      items.push_back(GraphCacheEntry(n));
    }
    return items;
  }
  // Find the function given graph key.
  GraphCacheEntry Find(const GraphKey& key) {
    std::lock_guard<std::mutex> lock(mutex_);
    auto it = cache_.find(key);
    if (it != cache_.end()) {
      return it->second;
    } else {
      return GraphCacheEntry();
    }
  }
  // Set the given function on given graph key.
  void Set(const GraphKey& key, GraphFunc func) {
    std::lock_guard<std::mutex> lock(mutex_);
    auto n = tvm::make_node<GraphCacheEntryNode>();
    n->graph_func = func;
    n->use_count = 1;
    cache_[key] = GraphCacheEntry(n);
  }
    // Clear the function cache.
  void Clear() {
    std::lock_guard<std::mutex> lock(mutex_);
    cache_.clear();
  }

  // get schedule and its args
  std::tuple<Schedule, Array<tvm::Tensor>, Graph>
  GetScheduleArgs(Graph graph,
                  const Array<tvm::Tensor> &inputs,
                  const std::string &target,
                  int master_idx,
                  std::string *readable_name,
                  Array<tvm::Tensor> *outputs) {
    // shape, type
    static auto& fcompute =
        nnvm::Op::GetAttr<FTVMCompute>("FTVMCompute");
    static auto& fschedule =
        nnvm::Op::GetAttr<FTVMSchedule>("FTVMSchedule");

    std::vector<TShape> ishape;
    std::vector<int> idtype;

    for (const tvm::Tensor t : inputs) {
      std::vector<dim_t> shape;
      for (Expr v : t->shape) {
        CHECK(v.as<tvm::ir::IntImm>());
        shape.push_back(v.as<tvm::ir::IntImm>()->value);
      }
      ishape.emplace_back(TShape(shape.begin(), shape.end()));
      idtype.emplace_back(GetTypeFlag(t->dtype));
    }
    graph = pass::InferShape(graph, ishape);
    graph = pass::InferType(graph, idtype);

    const ShapeVector& shape_vec = graph.GetAttr<ShapeVector>("shape");
    const DTypeVector& dtype_vec = graph.GetAttr<DTypeVector>("dtype");
    const IndexedGraph& idx = graph.indexed_graph();
    CHECK_EQ(inputs.size(), idx.input_nodes().size());

    std::vector<tvm::Tensor> tensor_vec(idx.num_node_entries());
    for (size_t i = 0; i < idx.input_nodes().size(); ++i) {
      uint32_t nid = idx.input_nodes()[i];
      tensor_vec[idx.entry_id(nid, 0)] = inputs[i];
    }

    std::ostringstream readable_name_os;
    readable_name_os << "fuse";
    for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
      const auto& inode = idx[nid];
      if (inode.source->is_variable()) continue;
      Array<Tensor> op_inputs, out_info;
      readable_name_os << "_" << inode.source->op()->name;
      // input array
      for (const IndexedGraph::NodeEntry& e : inode.inputs) {
        const tvm::Tensor& t = tensor_vec[idx.entry_id(e)];
        CHECK(t.defined());
        op_inputs.push_back(t);
      }
      // output hint
      for (uint32_t i = 0; i < inode.source->num_outputs(); ++i) {
        Array<Expr> shape;
        for (int64_t x : shape_vec[idx.entry_id(nid, i)]) {
          CHECK_LE(x, static_cast<int64_t>(std::numeric_limits<int>::max()));
          shape.push_back(make_const(Int(32), x));
        }
        out_info.push_back(
            placeholder(shape,
                        GetTVMType(dtype_vec[idx.entry_id(nid, i)])));
      }
      // get default
      Array<Tensor> out = fcompute[inode.source->op()](
          inode.source->attrs, op_inputs, out_info);
      CHECK_EQ(out.size(), inode.source->num_outputs());

      // check output dimentions also match
      // This check is to make sure the NNVM operator Infer match with Compute result.
      // Missing this check may pass the build but leads to runtime errors.
      for (uint32_t i = 0; i < out.size(); ++i) {
        CHECK_EQ(out[i].ndim(), out_info[i].ndim()) << inode.source->op()->name;
        tvm::Tensor inferred_tensor = out[i];
        tvm::Tensor computed_tensor = out_info[i];
        for (uint32_t j = 0; j < inferred_tensor->shape.size(); ++j) {
          if ((as_const_int(inferred_tensor->shape[j])) &&
              (as_const_int(computed_tensor->shape[j])))
            CHECK_EQ((*as_const_int(inferred_tensor->shape[j])),
                     (*as_const_int(computed_tensor->shape[j]))) << inode.source->op()->name;
        }
      }

      // schedule on root node, and use master's schedule
      for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) {
        uint32_t eid = idx.entry_id(nid, index);
        tensor_vec[eid] = out[index];
      }
    }
    // Schedule on final output.
    Array<Tensor> all_args = inputs;
    Array<Tensor> outs;
    for (const IndexedGraph::NodeEntry& e : idx.outputs()) {
      const tvm::Tensor& t = tensor_vec[idx.entry_id(e)];
      CHECK(t.defined());
      outs.push_back(t);
      all_args.push_back(t);
    }

    Schedule sch = fschedule[idx[master_idx].source->op()](
        idx[master_idx].source->attrs, outs, target);

    // store extra return values
    if (readable_name != nullptr) {
      *readable_name = readable_name_os.str();
    }
    if (outputs != nullptr) {
      *outputs = outs;
    }

    return std::make_tuple(sch, all_args, graph);
  }

  // run the actual lowering process
  GraphFunc DoLower(Graph graph,
                    const Array<tvm::Tensor>& inputs,
                    const std::string& target,
                    int master_idx) {
    std::string readable_name;
    Array<tvm::Tensor> all_args;
    Array<tvm::Tensor> outputs;
    Schedule sch;

    std::tie(sch, all_args, graph) = GetScheduleArgs(
        graph, inputs, target, master_idx,
        &readable_name, &outputs);

    auto gf = tvm::make_node<GraphFuncNode>();
    gf->target = target;
    gf->func_name = GetUniqeName(readable_name);
    gf->inputs = inputs;
    gf->outputs = outputs;
    static const PackedFunc& flower = GetPackedFunc("nnvm.compiler.lower");
    gf->funcs = flower(sch, all_args, gf->func_name, graph);
    return GraphFunc(gf);
  }

 private:
  // Get unique name
  std::string GetUniqeName(std::string name) {
    while (true) {
      auto it = name_map_.find(name);
      if (it == name_map_.end()) {
        name_map_[name] = 1;
        return name;
      } else {
        std::ostringstream os;
        os << name << "_" << it->second;
        ++(it->second);
        name = os.str();
      }
    }
    return name;
  }

  // global mutex
  std::mutex mutex_;
  // the name map
  std::unordered_map<std::string, int> name_map_;
  // the compiler cache
  std::unordered_map<GraphKey, GraphCacheEntry,
                     GraphKeyHash, GraphKeyEqual> cache_;
};

GraphFunc GraphLower(Graph graph,
                     const Array<tvm::Tensor>& inputs,
                     const std::string& target,
                     int master_idx) {
  return CompileEngine::Global()->Lower(
      graph, inputs, target, master_idx);
}

// Expose cache to front end
TVM_REGISTER_GLOBAL("nnvm.compiler.ListCacheItems")
.set_body([](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue *rv) {
    *rv = CompileEngine::Global()->ListCacheItems();
  });

TVM_REGISTER_GLOBAL("nnvm.compiler.ClearCache")
.set_body([](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue *rv) {
    CompileEngine::Global()->Clear();
  });

// NOTE: this involves graph lookup and can be slow
TVM_REGISTER_GLOBAL("nnvm.compiler.GetCacheItem")
.set_body([](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue *rv) {
    *rv = CompileEngine::Global()->Find(args[0]);
  });

TVM_REGISTER_GLOBAL("nnvm.compiler.SetCacheItem")
.set_body([](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue *rv) {
    CompileEngine::Global()->Set(args[0], args[1]);
  });

TVM_REGISTER_GLOBAL("nnvm.compiler.GraphKeyGetGraph")
.set_body([](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue *rv) {
    *rv = args[0].operator GraphKey()->graph;
  });

TVM_REGISTER_GLOBAL("nnvm.compiler.MakeGraphKey")
.set_body_typed(GraphKeyNode::make);

// This can be used to extract workloads from nnvm compiler
TVM_REGISTER_GLOBAL("nnvm.compiler.CacheItem2ScheduleArgs")
.set_body([](TVMArgs args, TVMRetValue *rv) {
    Array<tvm::NodeRef> item = args[0];

    const GraphKeyNode *key = reinterpret_cast<const GraphKeyNode *>(item[0].get());
    const GraphCacheEntryNode *value = reinterpret_cast<const GraphCacheEntryNode *>(item[1].get());

    // extract arguments from cached item
    Graph graph = key->graph;
    const Array<tvm::Tensor> &inputs = key->inputs;
    std::string target = args[1];
    int master_idx = value->master_idx;

    Schedule sch;
    Array<tvm::Tensor> all_args;
    std::tie(sch, all_args, graph) =
        CompileEngine::Global()->GetScheduleArgs(
        graph, inputs, target, master_idx, nullptr, nullptr);

    Array<tvm::NodeRef> ret;
    ret.push_back(sch);
    ret.push_back(all_args);
    *rv = ret;
  });

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<GraphFuncNode>([](const GraphFuncNode *op, IRPrinter *p) {
    p->stream << "GraphFunc(name=" << op->func_name
              << ", addr=" << op << ")";
});

}  // namespace compiler
}  // namespace nnvm