/*!
 *  Copyright (c) 2017 by Contributors
 * \file stack_vm_module.cc
 */
#include <tvm/runtime/registry.h>
#include <tvm/runtime/module.h>
#include <tvm/codegen.h>
#include "./codegen_stack_vm.h"

namespace tvm {
namespace codegen {

class StackVMModuleNode : public runtime::ModuleNode {
 public:
  const char* type_key() const {
    return "stackvm";
  }

  PackedFunc GetFunction(
      const std::string& name,
      const std::shared_ptr<ModuleNode>& sptr_to_self) final {
    if (name == runtime::symbol::tvm_module_main) {
      return GetFunction(entry_func_, sptr_to_self);
    }
    auto it = fmap_.find(name);
    if (it == fmap_.end()) return PackedFunc();
    const StackVM& vm = it->second;
    // capture sptr_to_self to keep module node alive.
    return PackedFunc([vm, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
        vm(args);
      });
  }

  std::string GetSource(const std::string& format) final {
    std::ostringstream os;
    for (const auto& kv : fmap_) {
      os << "Function: " << kv.first << '\n';
      os << kv.second;
    }
    return os.str();
  }

  static runtime::Module Build(const Array<LoweredFunc>& funcs) {
    CHECK_NE(funcs.size(), 0U);
    std::shared_ptr<StackVMModuleNode> n =
        std::make_shared<StackVMModuleNode>();
    for (LoweredFunc f : funcs) {
      StackVM vm = codegen::CodeGenStackVM().Compile(f);
      CHECK(!n->fmap_.count(f->name))
          << "Function name " << f->name << "already exist in list";
      vm.mod_ctx = n.get();
      n->fmap_[f->name] = std::move(vm);
    }
    n->entry_func_ = funcs[0]->name;
    return runtime::Module(n);
  }

 private:
  // entry function.
  std::string entry_func_;
  // internal function map
  std::unordered_map<std::string, StackVM> fmap_;
};

TVM_REGISTER_API("codegen.build_stackvm")
.set_body([](TVMArgs args, TVMRetValue* rv) {
    *rv = StackVMModuleNode::Build(args[0]);
  });

}  // namespace codegen
}  // namespace tvm