/*! * Copyright (c) 2016 by Contributors * Implementation of C API * \file c_api.cc */ #include <tvm/c_api.h> #include "./c_api_common.h" #include "./c_api_registry.h" /*! \brief entry to to easily hold returning information */ struct TVMAPIThreadLocalEntry { /*! \brief hold last error */ std::string last_error; /*! \brief result holder for returning strings */ std::vector<std::string> ret_vec_str; /*! \brief result holder for returning string pointers */ std::vector<const char *> ret_vec_charp; /*! \brief argument stack */ std::vector<tvm::APIVariantValue> arg_stack; /*! \brief return value */ tvm::APIVariantValue ret_value; // clear calling stack inline void Clear() { arg_stack.clear(); ret_value.sptr.reset(); } inline void SetReturn(ArgVariant* ret_val, int* ret_typeid); }; using namespace tvm; /*! \brief Thread local store that can be used to hold return values. */ typedef dmlc::ThreadLocalStore<TVMAPIThreadLocalEntry> TVMAPIThreadLocalStore; using TVMAPINode = std::shared_ptr<Node>; struct APIAttrGetter : public AttrVisitor { std::string skey; APIVariantValue* ret; void Visit(const char* key, double* value) final { if (skey == key) *ret = value[0]; } void Visit(const char* key, int64_t* value) final { if (skey == key) *ret = value[0]; } void Visit(const char* key, uint64_t* value) final { CHECK_LE(value[0], std::numeric_limits<int64_t>::max()) << "cannot return too big constant"; if (skey == key) *ret = static_cast<int64_t>(value[0]); } void Visit(const char* key, int* value) final { if (skey == key) *ret = static_cast<int64_t>(value[0]); } void Visit(const char* key, bool* value) final { if (skey == key) *ret = static_cast<int64_t>(value[0]); } void Visit(const char* key, Type* value) final { if (skey == key) *ret = value[0]; } void Visit(const char* key, std::string* value) final { if (skey == key) *ret = value[0]; } void Visit(const char* key, NodeRef* value) final { if (skey == key) *ret = value[0]; } }; struct APIAttrDir : public AttrVisitor { std::vector<std::string>* names; void Visit(const char* key, double* value) final { names->push_back(key); } void Visit(const char* key, int64_t* value) final { names->push_back(key); } void Visit(const char* key, uint64_t* value) final { names->push_back(key); } void Visit(const char* key, bool* value) final { names->push_back(key); } void Visit(const char* key, int* value) final { names->push_back(key); } void Visit(const char* key, Type* value) final { names->push_back(key); } void Visit(const char* key, std::string* value) final { names->push_back(key); } void Visit(const char* key, NodeRef* value) final { names->push_back(key); } }; const char *TVMGetLastError() { return TVMAPIThreadLocalStore::Get()->last_error.c_str(); } void TVMAPISetLastError(const char* msg) { TVMAPIThreadLocalStore::Get()->last_error = msg; } int TVMListFunctionNames(int *out_size, const char*** out_array) { API_BEGIN(); TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get(); ret->ret_vec_str = dmlc::Registry<APIFunctionReg>::ListAllNames(); ret->ret_vec_charp.clear(); for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) { ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str()); } *out_array = dmlc::BeginPtr(ret->ret_vec_charp); *out_size = static_cast<int>(ret->ret_vec_str.size()); API_END(); } int TVMGetFunctionHandle(const char* fname, FunctionHandle* out) { API_BEGIN(); const APIFunctionReg* reg = dmlc::Registry<APIFunctionReg>::Find(fname); CHECK(reg != nullptr) << "cannot find function " << fname; *out = (FunctionHandle)reg; API_END(); } int TVMGetFunctionInfo(FunctionHandle handle, const char **real_name, const char **description, int *num_doc_args, const char ***arg_names, const char ***arg_type_infos, const char ***arg_descriptions, const char **return_type) { const auto *op = static_cast<const APIFunctionReg *>(handle); TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get(); API_BEGIN(); *real_name = op->name.c_str(); *description = op->description.c_str(); *num_doc_args = static_cast<int>(op->arguments.size()); if (return_type) *return_type = nullptr; ret->ret_vec_charp.clear(); for (size_t i = 0; i < op->arguments.size(); ++i) { ret->ret_vec_charp.push_back(op->arguments[i].name.c_str()); } for (size_t i = 0; i < op->arguments.size(); ++i) { ret->ret_vec_charp.push_back(op->arguments[i].type_info_str.c_str()); } for (size_t i = 0; i < op->arguments.size(); ++i) { ret->ret_vec_charp.push_back(op->arguments[i].description.c_str()); } *arg_names = dmlc::BeginPtr(ret->ret_vec_charp); *arg_type_infos = dmlc::BeginPtr(ret->ret_vec_charp) + op->arguments.size(); *arg_descriptions = dmlc::BeginPtr(ret->ret_vec_charp) + (op->arguments.size() * 2); API_END(); } int TVMPushStack(ArgVariant arg, int type_id) { TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get(); API_BEGIN(); ret->arg_stack.resize(ret->arg_stack.size() + 1); APIVariantValue& v = ret->arg_stack.back(); v.type_id = static_cast<ArgVariantID>(type_id); if (type_id == kStr) { v.str = arg.v_str; } else if (type_id == kNodeHandle) { v.sptr = *static_cast<TVMAPINode*>(arg.v_handle); } else { v.v_union = arg; } API_END_HANDLE_ERROR(ret->Clear()); } int TVMFunctionCall(FunctionHandle handle, ArgVariant* ret_val, int* ret_typeid) { TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get(); API_BEGIN(); const auto *op = static_cast<const APIFunctionReg *>(handle); op->body(ret->arg_stack, &(ret->ret_value)); ret->SetReturn(ret_val, ret_typeid); ret->arg_stack.clear(); API_END_HANDLE_ERROR(ret->Clear()); } int TVMNodeFree(NodeHandle handle) { API_BEGIN(); delete static_cast<TVMAPINode*>(handle); API_END(); } int TVMNodeGetAttr(NodeHandle handle, const char* key, ArgVariant* ret_val, int* ret_typeid) { TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get(); API_BEGIN(); ret->ret_value.type_id = kNull; APIAttrGetter getter; getter.skey = key; getter.ret = &(ret->ret_value); TVMAPINode* tnode = static_cast<TVMAPINode*>(handle); if (getter.skey == "type_key") { ret_val->v_str = (*tnode)->type_key(); *ret_typeid = kStr; } else { (*tnode)->VisitAttrs(&getter); if (ret->ret_value.type_id != kNull) { ret->SetReturn(ret_val, ret_typeid); } else { *ret_typeid = kNull; } } API_END_HANDLE_ERROR(ret->Clear()); } int TVMNodeListAttrNames(NodeHandle handle, int *out_size, const char*** out_array) { TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get(); API_BEGIN(); ret->ret_vec_str.clear(); TVMAPINode* tnode = static_cast<TVMAPINode*>(handle); APIAttrDir dir; dir.names = &(ret->ret_vec_str); (*tnode)->VisitAttrs(&dir); ret->ret_vec_charp.clear(); for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) { ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str()); } *out_array = dmlc::BeginPtr(ret->ret_vec_charp); *out_size = static_cast<int>(ret->ret_vec_str.size()); API_END(); } inline void TVMAPIThreadLocalEntry::SetReturn(ArgVariant* ret_val, int* ret_typeid) { APIVariantValue& rv = ret_value; *ret_typeid = rv.type_id; if (rv.type_id == kNodeHandle) { if (rv.sptr.get() != nullptr) { ret_val->v_handle = new TVMAPINode(std::move(rv.sptr)); } else { ret_val->v_handle = nullptr; } } else { *ret_val = rv.v_union; } }