/*! * Copyright (c) 2016 by Contributors * Implementation of C API * \file c_api.cc */ #include <dmlc/base.h> #include <dmlc/logging.h> #include <dmlc/thread_local.h> #include <tvm/c_api.h> #include <tvm/api_registry.h> #include <vector> #include <string> #include <exception> #include "../runtime/runtime_base.h" /*! \brief entry to to easily hold returning information */ struct TVMAPIThreadLocalEntry { /*! \brief result holder for returning strings */ std::vector<std::string> ret_vec_str; /*! \brief result holder for returning string pointers */ std::vector<const char *> ret_vec_charp; /*! \brief result holder for retruning string */ std::string ret_str; }; using namespace tvm; /*! \brief Thread local store that can be used to hold return values. */ typedef dmlc::ThreadLocalStore<TVMAPIThreadLocalEntry> TVMAPIThreadLocalStore; using TVMAPINode = std::shared_ptr<Node>; struct APIAttrGetter : public AttrVisitor { std::string skey; TVMRetValue* ret; bool found_node_ref{false}; void Visit(const char* key, double* value) final { if (skey == key) *ret = value[0]; } void Visit(const char* key, int64_t* value) final { if (skey == key) *ret = value[0]; } void Visit(const char* key, uint64_t* value) final { CHECK_LE(value[0], static_cast<uint64_t>(std::numeric_limits<int64_t>::max())) << "cannot return too big constant"; if (skey == key) *ret = static_cast<int64_t>(value[0]); } void Visit(const char* key, int* value) final { if (skey == key) *ret = static_cast<int64_t>(value[0]); } void Visit(const char* key, bool* value) final { if (skey == key) *ret = static_cast<int64_t>(value[0]); } void Visit(const char* key, Type* value) final { if (skey == key) *ret = value[0]; } void Visit(const char* key, std::string* value) final { if (skey == key) *ret = value[0]; } void Visit(const char* key, NodeRef* value) final { if (skey == key) { *ret = value[0]; found_node_ref = true; } } }; struct APIAttrDir : public AttrVisitor { std::vector<std::string>* names; void Visit(const char* key, double* value) final { names->push_back(key); } void Visit(const char* key, int64_t* value) final { names->push_back(key); } void Visit(const char* key, uint64_t* value) final { names->push_back(key); } void Visit(const char* key, bool* value) final { names->push_back(key); } void Visit(const char* key, int* value) final { names->push_back(key); } void Visit(const char* key, Type* value) final { names->push_back(key); } void Visit(const char* key, std::string* value) final { names->push_back(key); } void Visit(const char* key, NodeRef* value) final { names->push_back(key); } }; int TVMNodeFree(NodeHandle handle) { API_BEGIN(); delete static_cast<TVMAPINode*>(handle); API_END(); } int TVMCbArgToReturn(TVMValue* value, int code) { API_BEGIN(); tvm::runtime::TVMRetValue rv; rv = tvm::runtime::TVMArgValue(*value, code); int tcode; rv.MoveToCHost(value, &tcode); CHECK_EQ(tcode, code); API_END(); } int TVMNodeTypeKey2Index(const char* type_key, int* out_index) { API_BEGIN(); *out_index = static_cast<int>(Node::TypeKey2Index(type_key)); API_END(); } int TVMNodeGetTypeIndex(NodeHandle handle, int* out_index) { API_BEGIN(); *out_index = static_cast<int>( (*static_cast<TVMAPINode*>(handle))->type_index()); API_END(); } int TVMNodeGetAttr(NodeHandle handle, const char* key, TVMValue* ret_val, int* ret_type_code, int* ret_success) { API_BEGIN(); TVMRetValue rv; APIAttrGetter getter; getter.skey = key; getter.ret = &rv; TVMAPINode* tnode = static_cast<TVMAPINode*>(handle); if (getter.skey == "type_key") { ret_val->v_str = (*tnode)->type_key(); *ret_type_code = kStr; *ret_success = 1; } else { (*tnode)->VisitAttrs(&getter); *ret_success = getter.found_node_ref || rv.type_code() != kNull; if (rv.type_code() == kStr || rv.type_code() == kTVMType) { TVMAPIThreadLocalEntry *e = TVMAPIThreadLocalStore::Get(); e->ret_str = rv.operator std::string(); *ret_type_code = kStr; ret_val->v_str = e->ret_str.c_str(); } else { rv.MoveToCHost(ret_val, ret_type_code); } } API_END(); } int TVMNodeListAttrNames(NodeHandle handle, int *out_size, const char*** out_array) { TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get(); API_BEGIN(); ret->ret_vec_str.clear(); TVMAPINode* tnode = static_cast<TVMAPINode*>(handle); APIAttrDir dir; dir.names = &(ret->ret_vec_str); (*tnode)->VisitAttrs(&dir); ret->ret_vec_charp.clear(); for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) { ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str()); } *out_array = dmlc::BeginPtr(ret->ret_vec_charp); *out_size = static_cast<int>(ret->ret_vec_str.size()); API_END(); }