/*!
 *  Copyright (c) 2016 by Contributors
 *  Implementation of DSL API
 * \file dsl_api.cc
 */
#include <dmlc/base.h>
#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <tvm/api_registry.h>
#include <tvm/attrs.h>
#include <vector>
#include <string>
#include <exception>
#include "../runtime/dsl_api.h"

namespace tvm {
namespace runtime {
/*! \brief entry to to easily hold returning information */
struct TVMAPIThreadLocalEntry {
  /*! \brief result holder for returning strings */
  std::vector<std::string> ret_vec_str;
  /*! \brief result holder for returning string pointers */
  std::vector<const char *> ret_vec_charp;
  /*! \brief result holder for retruning string */
  std::string ret_str;
};

/*! \brief Thread local store that can be used to hold return values. */
typedef dmlc::ThreadLocalStore<TVMAPIThreadLocalEntry> TVMAPIThreadLocalStore;

using TVMAPINode = NodePtr<Node>;

struct APIAttrGetter : public AttrVisitor {
  std::string skey;
  TVMRetValue* ret;
  bool found_ref_object{false};

  void Visit(const char* key, double* value) final {
    if (skey == key) *ret = value[0];
  }
  void Visit(const char* key, int64_t* value) final {
    if (skey == key) *ret = value[0];
  }
  void Visit(const char* key, uint64_t* value) final {
    CHECK_LE(value[0], static_cast<uint64_t>(std::numeric_limits<int64_t>::max()))
        << "cannot return too big constant";
    if (skey == key) *ret = static_cast<int64_t>(value[0]);
  }
  void Visit(const char* key, int* value) final {
    if (skey == key) *ret = static_cast<int64_t>(value[0]);
  }
  void Visit(const char* key, bool* value) final {
    if (skey == key) *ret = static_cast<int64_t>(value[0]);
  }
  void Visit(const char* key, void** value) final {
    if (skey == key) *ret = static_cast<void*>(value[0]);
  }
  void Visit(const char* key, Type* value) final {
    if (skey == key) *ret = value[0];
  }
  void Visit(const char* key, std::string* value) final {
    if (skey == key) *ret = value[0];
  }
  void Visit(const char* key, NodeRef* value) final {
    if (skey == key) {
      *ret = value[0];
      found_ref_object = true;
    }
  }
  void Visit(const char* key, runtime::NDArray* value) final {
    if (skey == key) {
      *ret = value[0];
      found_ref_object = true;
    }
  }
};

struct APIAttrDir : public AttrVisitor {
  std::vector<std::string>* names;

  void Visit(const char* key, double* value) final {
    names->push_back(key);
  }
  void Visit(const char* key, int64_t* value) final {
    names->push_back(key);
  }
  void Visit(const char* key, uint64_t* value) final {
    names->push_back(key);
  }
  void Visit(const char* key, bool* value) final {
    names->push_back(key);
  }
  void Visit(const char* key, int* value) final {
    names->push_back(key);
  }
  void Visit(const char* key, void** value) final {
    names->push_back(key);
  }
  void Visit(const char* key, Type* value) final {
    names->push_back(key);
  }
  void Visit(const char* key, std::string* value) final {
    names->push_back(key);
  }
  void Visit(const char* key, NodeRef* value) final {
    names->push_back(key);
  }
  void Visit(const char* key, runtime::NDArray* value) final {
    names->push_back(key);
  }
};

class DSLAPIImpl : public DSLAPI {
 public:
  void NodeFree(NodeHandle handle) const final {
    delete static_cast<TVMAPINode*>(handle);
  }
  void NodeTypeKey2Index(const char* type_key,
                        int* out_index) const final {
    *out_index = static_cast<int>(Node::TypeKey2Index(type_key));
  }
  void NodeGetTypeIndex(NodeHandle handle,
                        int* out_index) const final {
    *out_index = static_cast<int>(
        (*static_cast<TVMAPINode*>(handle))->type_index());
  }
  void NodeGetAttr(NodeHandle handle,
                   const char* key,
                   TVMValue* ret_val,
                   int* ret_type_code,
                   int* ret_success) const final {
    TVMRetValue rv;
    APIAttrGetter getter;
    TVMAPINode* tnode = static_cast<TVMAPINode*>(handle);
    getter.skey = key;
    getter.ret = &rv;
    if (getter.skey == "type_key") {
      ret_val->v_str = (*tnode)->type_key();
      *ret_type_code = kStr;
      *ret_success = 1;
      return;
    } else if (!(*tnode)->is_type<DictAttrsNode>()) {
      (*tnode)->VisitAttrs(&getter);
      *ret_success = getter.found_ref_object || rv.type_code() != kNull;
    } else {
      // specially handle dict attr
      DictAttrsNode* dnode = static_cast<DictAttrsNode*>(tnode->get());
      auto it = dnode->dict.find(key);
      if (it != dnode->dict.end()) {
        *ret_success = 1;
        rv = (*it).second;
      } else {
        *ret_success = 0;
      }
    }
    if (*ret_success) {
      if (rv.type_code() == kStr ||
          rv.type_code() == kTVMType) {
        TVMAPIThreadLocalEntry *e = TVMAPIThreadLocalStore::Get();
        e->ret_str = rv.operator std::string();
        *ret_type_code = kStr;
        ret_val->v_str = e->ret_str.c_str();
      } else {
        rv.MoveToCHost(ret_val, ret_type_code);
      }
    }
  }
  void NodeListAttrNames(NodeHandle handle,
                        int *out_size,
                        const char*** out_array) const final {
    TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get();
    ret->ret_vec_str.clear();
    TVMAPINode* tnode = static_cast<TVMAPINode*>(handle);
    APIAttrDir dir;
    dir.names = &(ret->ret_vec_str);

    if (!(*tnode)->is_type<DictAttrsNode>()) {
      (*tnode)->VisitAttrs(&dir);
    } else {
      // specially handle dict attr
      DictAttrsNode* dnode = static_cast<DictAttrsNode*>(tnode->get());
      for (const auto& kv : dnode->dict) {
        ret->ret_vec_str.push_back(kv.first);
      }
    }
    ret->ret_vec_charp.clear();
    for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
      ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
    }
    *out_array = dmlc::BeginPtr(ret->ret_vec_charp);
    *out_size = static_cast<int>(ret->ret_vec_str.size());
  }
};

TVM_REGISTER_GLOBAL("dsl_api.singleton")
.set_body([](TVMArgs args, TVMRetValue* rv) {
    static DSLAPIImpl impl;
    void* ptr = &impl;
    *rv = ptr;
  });
}  // namespace runtime
}  // namespace tvm