/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 *  Copyright (c) 2018 by Contributors
 * \file  module.cc
 * \brief The global module in Relay.
 */
#include <tvm/relay/module.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/transform.h>
#include <sstream>

namespace tvm {
namespace relay {

using tvm::IRPrinter;
using namespace runtime;

Module ModuleNode::make(tvm::Map<GlobalVar, Function> global_funcs,
                        tvm::Map<GlobalTypeVar, TypeData> global_type_defs) {
  auto n = make_node<ModuleNode>();
  n->functions = std::move(global_funcs);
  n->type_definitions = std::move(global_type_defs);

  for (const auto& kv : n->functions) {
    // set global var map
    CHECK(!n->global_var_map_.count(kv.first->name_hint))
      << "Duplicate global function name " << kv.first->name_hint;
    n->global_var_map_.Set(kv.first->name_hint, kv.first);
  }

  for (const auto& kv : n->type_definitions) {
    // set global typevar map
    CHECK(!n->global_type_var_map_.count(kv.first->var->name_hint))
      << "Duplicate global type definition name " << kv.first->var->name_hint;
    n->global_type_var_map_.Set(kv.first->var->name_hint, kv.first);
    n->RegisterConstructors(kv.first, kv.second);
  }

  return Module(n);
}

bool ModuleNode::ContainGlobalVar(const std::string& name) const {
  return global_var_map_.find(name) != global_var_map_.end();
}

GlobalVar ModuleNode::GetGlobalVar(const std::string& name) const {
  auto it = global_var_map_.find(name);
  CHECK(it != global_var_map_.end())
    << "Cannot find global var " << name << " in the Module";
  return (*it).second;
}

void ModuleNode::AddUnchecked(const GlobalVar& var,
                              const Function& func) {
  auto mod = GetRef<Module>(this);
  this->functions.Set(var, func);

  auto it = global_var_map_.find(var->name_hint);
  if (it != global_var_map_.end()) {
    CHECK_EQ((*it).second, var);
  } else {
    CHECK(!global_var_map_.count(var->name_hint))
        << "Duplicate global function name " << var->name_hint;
  }

  global_var_map_.Set(var->name_hint, var);
}

GlobalTypeVar ModuleNode::GetGlobalTypeVar(const std::string& name) const {
  auto it = global_type_var_map_.find(name);
  CHECK(it != global_type_var_map_.end())
    << "Cannot find global type var " << name << " in the Module";
  return (*it).second;
}

template<typename T>
tvm::Array<T> concat(const tvm::Array<T>& l, const tvm::Array<T>& r) {
  tvm::Array<T> ret(l);
  for (const T& t : r) {
    ret.push_back(t);
  }
  return ret;
}

void ModuleNode::Add(const GlobalVar& var,
                     const Function& f,
                     bool update) {
  Function func = Downcast<Function>(DeDup(f));
  // Type check the item before we add it to the module.
  auto mod = GetRef<Module>(this);
  auto fv = FreeVars(func);
  auto ftv = FreeTypeVars(func, mod);
  if (fv.size() != 0) {
    LOG(WARNING)
      << "There are free variables: "
      << fv
      << " in function: "
      << AsText(func, false)
      << std::endl;
  }
  if (ftv.size() != 0) {
    LOG(WARNING)
      << "There are free type variables: "
      << ftv
      << " in function: "
      << AsText(func, false)
      << std::endl;
  }
  func =
    FunctionNode::make(concat(func->params, fv),
                       func->body,
                       func->ret_type,
                       concat(func->type_params, ftv),
                       func->attrs);
  // Type check the item before we add it to the module.
  Function checked_func = InferType(func, mod, var);
  auto type = checked_func->checked_type();
  CHECK(type.as<IncompleteTypeNode>() == nullptr);
  if (functions.find(var) != functions.end()) {
    CHECK(update)
        << "Already have definition for " << var->name_hint;
    auto old_type = functions[var].as<FunctionNode>()->checked_type();
    CHECK(AlphaEqual(type, old_type))
        << "Module#update changes type, not possible in this mode.";
  }
  var->checked_type_ = type;
  AddUnchecked(var, checked_func);
}

void ModuleNode::RegisterConstructors(const GlobalTypeVar& var, const TypeData& type) {
  // We hash the global type var name to use as a globally unique prefix for tags.
  // The hash will be used as the most significant byte of the tag, with the index of
  // the constructor in the less significant bytes
  size_t hash = std::hash<std::string>()(var->var->name_hint);
  int32_t prefix = static_cast<int32_t>(hash & 0xff) << 24;
  for (size_t i = 0; i < type->constructors.size(); ++i) {
    type->constructors[i]->tag = prefix | static_cast<int32_t>(i);
    constructor_tag_map_[type->constructors[i]->tag] = type->constructors[i];
  }
}

void ModuleNode::AddDef(const GlobalTypeVar& var, const TypeData& type) {
  this->type_definitions.Set(var, type);
  // set global type var map
  CHECK(!global_type_var_map_.count(var->var->name_hint))
    << "Duplicate global type definition name " << var->var->name_hint;
  global_type_var_map_.Set(var->var->name_hint, var);
  RegisterConstructors(var, type);

  // need to kind check at the end because the check can look up
  // a definition potentially
  CHECK(KindCheck(type, GetRef<Module>(this)) == Kind::kTypeData)
    << "Invalid or malformed typedata given to module: " << type;
}

void ModuleNode::Update(const GlobalVar& var, const Function& func) {
  this->Add(var, func, true);
}

void ModuleNode::Remove(const GlobalVar& var) {
  auto functions_node = this->functions.CopyOnWrite();
  functions_node->data.erase(var.node_);
  auto gvar_node = global_var_map_.CopyOnWrite();
  gvar_node->data.erase(var->name_hint);
}

Function ModuleNode::Lookup(const GlobalVar& var) const {
  auto it = functions.find(var);
  CHECK(it != functions.end())
      << "There is no definition of " << var->name_hint;
  return (*it).second;
}

Function ModuleNode::Lookup(const std::string& name) const {
  GlobalVar id = this->GetGlobalVar(name);
  return this->Lookup(id);
}

TypeData ModuleNode::LookupDef(const GlobalTypeVar& var) const {
  auto it = type_definitions.find(var);
  CHECK(it != type_definitions.end())
    << "There is no definition of " << var->var->name_hint;
  return (*it).second;
}

TypeData ModuleNode::LookupDef(const std::string& name) const {
  GlobalTypeVar id = this->GetGlobalTypeVar(name);
  return this->LookupDef(id);
}

Constructor ModuleNode::LookupTag(const int32_t tag) {
  auto it = constructor_tag_map_.find(tag);
  CHECK(it != constructor_tag_map_.end())
    << "There is no constructor with the tag " << tag;
  return (*it).second;
}

void ModuleNode::Update(const Module& mod) {
  for (auto pair : mod->functions) {
    this->Update(pair.first, pair.second);
  }
}

Module ModuleNode::FromExpr(
  const Expr& expr,
  const tvm::Map<GlobalVar, Function>& global_funcs,
  const tvm::Map<GlobalTypeVar, TypeData>& type_definitions) {
  auto mod = ModuleNode::make(global_funcs, type_definitions);
  auto func_node = expr.as<FunctionNode>();
  Function func;
  if (func_node) {
    func = GetRef<Function>(func_node);
  } else {
    func = FunctionNode::make(FreeVars(expr), expr, Type(), FreeTypeVars(expr, mod), {});
  }
  auto main_gv = GlobalVarNode::make("main");
  mod->Add(main_gv, func);
  return mod;
}

TVM_REGISTER_NODE_TYPE(ModuleNode);

TVM_REGISTER_API("relay._make.Module")
.set_body_typed(ModuleNode::make);

TVM_REGISTER_API("relay._module.Module_Add")
.set_body([](TVMArgs args, TVMRetValue* ret) {
  Module mod = args[0];
  GlobalVar var = args[1];
  NodeRef val = args[2];
  bool update = args[3];
  CHECK(val->derived_from<ExprNode>());
  if (val->derived_from<FunctionNode>()) {
    mod->Add(var, Downcast<Function>(val), update);
  } else if (val->derived_from<GlobalVarNode>()) {
    GlobalVar gv = Downcast<GlobalVar>(val);
    auto mod_copy = Module(make_node<ModuleNode>(*mod.operator->()));
    mod_copy = transform::EtaExpand()(mod_copy);
    auto func = mod_copy->Lookup(gv->name_hint);
    mod->Add(var, Downcast<Function>(func), update);
  } else {
    auto func = FunctionNode::make({}, Downcast<Expr>(val), Type(nullptr), {});
    mod->Add(var, func, update);
  }
  *ret = mod;
});

TVM_REGISTER_API("relay._module.Module_AddDef")
.set_body_method<Module>(&ModuleNode::AddDef);

TVM_REGISTER_API("relay._module.Module_GetGlobalVar")
.set_body_method<Module>(&ModuleNode::GetGlobalVar);

TVM_REGISTER_API("relay._module.Module_ContainGlobalVar")
.set_body_method<Module>(&ModuleNode::ContainGlobalVar);

TVM_REGISTER_API("relay._module.Module_GetGlobalTypeVar")
.set_body_method<Module>(&ModuleNode::GetGlobalTypeVar);

TVM_REGISTER_API("relay._module.Module_Lookup")
.set_body_typed<Function(Module, GlobalVar)>([](Module mod, GlobalVar var) {
  return mod->Lookup(var);
});

TVM_REGISTER_API("relay._module.Module_Lookup_str")
.set_body_typed<Function(Module, std::string)>([](Module mod, std::string var) {
  return mod->Lookup(var);
});

TVM_REGISTER_API("relay._module.Module_LookupDef")
.set_body_typed<TypeData(Module, GlobalTypeVar)>([](Module mod, GlobalTypeVar var) {
  return mod->LookupDef(var);
});

TVM_REGISTER_API("relay._module.Module_LookupDef_str")
.set_body_typed<TypeData(Module, std::string)>([](Module mod, std::string var) {
  return mod->LookupDef(var);
});

TVM_REGISTER_API("relay._module.Module_LookupTag")
.set_body_typed<Constructor(Module, int32_t)>([](Module mod, int32_t tag) {
    return mod->LookupTag(tag);
  });

TVM_REGISTER_API("relay._module.Module_FromExpr")
.set_body_typed<
  Module(Expr,
         tvm::Map<GlobalVar, Function>,
         tvm::Map<GlobalTypeVar, TypeData>)>([](Expr e,
                                                tvm::Map<GlobalVar, Function> funcs,
                                                tvm::Map<GlobalTypeVar, TypeData> type_defs) {
                                               return ModuleNode::FromExpr(e, funcs, type_defs);
                                             });

TVM_REGISTER_API("relay._module.Module_Update")
.set_body_typed<void(Module, Module)>([](Module mod, Module from) {
  mod->Update(from);
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<ModuleNode>(
  [](const ModuleNode *node, tvm::IRPrinter *p) {
    p->stream << "ModuleNode( " << node->functions << ")";
});

}  // namespace relay
}  // namespace tvm