/*! * Copyright (c) 2016 by Contributors * Implementation of DSL API * \file dsl_api.cc */ #include <dmlc/base.h> #include <dmlc/logging.h> #include <dmlc/thread_local.h> #include <tvm/api_registry.h> #include <tvm/attrs.h> #include <vector> #include <string> #include <exception> #include "../runtime/dsl_api.h" namespace tvm { namespace runtime { /*! \brief entry to to easily hold returning information */ struct TVMAPIThreadLocalEntry { /*! \brief result holder for returning strings */ std::vector<std::string> ret_vec_str; /*! \brief result holder for returning string pointers */ std::vector<const char *> ret_vec_charp; /*! \brief result holder for retruning string */ std::string ret_str; }; /*! \brief Thread local store that can be used to hold return values. */ typedef dmlc::ThreadLocalStore<TVMAPIThreadLocalEntry> TVMAPIThreadLocalStore; using TVMAPINode = NodePtr<Node>; struct APIAttrGetter : public AttrVisitor { std::string skey; TVMRetValue* ret; bool found_ref_object{false}; void Visit(const char* key, double* value) final { if (skey == key) *ret = value[0]; } void Visit(const char* key, int64_t* value) final { if (skey == key) *ret = value[0]; } void Visit(const char* key, uint64_t* value) final { CHECK_LE(value[0], static_cast<uint64_t>(std::numeric_limits<int64_t>::max())) << "cannot return too big constant"; if (skey == key) *ret = static_cast<int64_t>(value[0]); } void Visit(const char* key, int* value) final { if (skey == key) *ret = static_cast<int64_t>(value[0]); } void Visit(const char* key, bool* value) final { if (skey == key) *ret = static_cast<int64_t>(value[0]); } void Visit(const char* key, void** value) final { if (skey == key) *ret = static_cast<void*>(value[0]); } void Visit(const char* key, Type* value) final { if (skey == key) *ret = value[0]; } void Visit(const char* key, std::string* value) final { if (skey == key) *ret = value[0]; } void Visit(const char* key, NodeRef* value) final { if (skey == key) { *ret = value[0]; found_ref_object = true; } } void Visit(const char* key, runtime::NDArray* value) final { if (skey == key) { *ret = value[0]; found_ref_object = true; } } }; struct APIAttrDir : public AttrVisitor { std::vector<std::string>* names; void Visit(const char* key, double* value) final { names->push_back(key); } void Visit(const char* key, int64_t* value) final { names->push_back(key); } void Visit(const char* key, uint64_t* value) final { names->push_back(key); } void Visit(const char* key, bool* value) final { names->push_back(key); } void Visit(const char* key, int* value) final { names->push_back(key); } void Visit(const char* key, void** value) final { names->push_back(key); } void Visit(const char* key, Type* value) final { names->push_back(key); } void Visit(const char* key, std::string* value) final { names->push_back(key); } void Visit(const char* key, NodeRef* value) final { names->push_back(key); } void Visit(const char* key, runtime::NDArray* value) final { names->push_back(key); } }; class DSLAPIImpl : public DSLAPI { public: void NodeFree(NodeHandle handle) const final { delete static_cast<TVMAPINode*>(handle); } void NodeTypeKey2Index(const char* type_key, int* out_index) const final { *out_index = static_cast<int>(Node::TypeKey2Index(type_key)); } void NodeGetTypeIndex(NodeHandle handle, int* out_index) const final { *out_index = static_cast<int>( (*static_cast<TVMAPINode*>(handle))->type_index()); } void NodeGetAttr(NodeHandle handle, const char* key, TVMValue* ret_val, int* ret_type_code, int* ret_success) const final { TVMRetValue rv; APIAttrGetter getter; TVMAPINode* tnode = static_cast<TVMAPINode*>(handle); getter.skey = key; getter.ret = &rv; if (getter.skey == "type_key") { ret_val->v_str = (*tnode)->type_key(); *ret_type_code = kStr; *ret_success = 1; return; } else if (!(*tnode)->is_type<DictAttrsNode>()) { (*tnode)->VisitAttrs(&getter); *ret_success = getter.found_ref_object || rv.type_code() != kNull; } else { // specially handle dict attr DictAttrsNode* dnode = static_cast<DictAttrsNode*>(tnode->get()); auto it = dnode->dict.find(key); if (it != dnode->dict.end()) { *ret_success = 1; rv = (*it).second; } else { *ret_success = 0; } } if (*ret_success) { if (rv.type_code() == kStr || rv.type_code() == kTVMType) { TVMAPIThreadLocalEntry *e = TVMAPIThreadLocalStore::Get(); e->ret_str = rv.operator std::string(); *ret_type_code = kStr; ret_val->v_str = e->ret_str.c_str(); } else { rv.MoveToCHost(ret_val, ret_type_code); } } } void NodeListAttrNames(NodeHandle handle, int *out_size, const char*** out_array) const final { TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get(); ret->ret_vec_str.clear(); TVMAPINode* tnode = static_cast<TVMAPINode*>(handle); APIAttrDir dir; dir.names = &(ret->ret_vec_str); if (!(*tnode)->is_type<DictAttrsNode>()) { (*tnode)->VisitAttrs(&dir); } else { // specially handle dict attr DictAttrsNode* dnode = static_cast<DictAttrsNode*>(tnode->get()); for (const auto& kv : dnode->dict) { ret->ret_vec_str.push_back(kv.first); } } ret->ret_vec_charp.clear(); for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) { ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str()); } *out_array = dmlc::BeginPtr(ret->ret_vec_charp); *out_size = static_cast<int>(ret->ret_vec_str.size()); } }; TVM_REGISTER_GLOBAL("dsl_api.singleton") .set_body([](TVMArgs args, TVMRetValue* rv) { static DSLAPIImpl impl; void* ptr = &impl; *rv = ptr; }); } // namespace runtime } // namespace tvm