Commit 3e693f53 by tqchen

C++ API work with python

parent 5079987e
...@@ -5,7 +5,7 @@ export CFLAGS = -std=c++11 -Wall -O2 -Wno-unknown-pragmas -funroll-loops\ ...@@ -5,7 +5,7 @@ export CFLAGS = -std=c++11 -Wall -O2 -Wno-unknown-pragmas -funroll-loops\
# specify tensor path # specify tensor path
.PHONY: clean all test doc .PHONY: clean all test doc
all: lib/libtvm.a all: lib/libtvm.a lib/libtvm.so
SRC = $(wildcard src/*.cc src/*/*.cc) SRC = $(wildcard src/*.cc src/*/*.cc)
ALL_OBJ = $(patsubst src/%.cc, build/%.o, $(SRC)) ALL_OBJ = $(patsubst src/%.cc, build/%.o, $(SRC))
ALL_DEP = $(ALL_OBJ) ALL_DEP = $(ALL_OBJ)
...@@ -24,6 +24,10 @@ lib/libtvm.a: $(ALL_DEP) ...@@ -24,6 +24,10 @@ lib/libtvm.a: $(ALL_DEP)
@mkdir -p $(@D) @mkdir -p $(@D)
ar crv $@ $(filter %.o, $?) ar crv $@ $(filter %.o, $?)
lib/libtvm.so: $(ALL_DEP)
@mkdir -p $(@D)
$(CXX) $(CFLAGS) -shared -o $@ $(filter %.o %.a, $^) $(LDFLAGS)
lint: lint:
python2 dmlc-core/scripts/lint.py tvm cpp include src python2 dmlc-core/scripts/lint.py tvm cpp include src
......
...@@ -126,6 +126,7 @@ class NodeRef { ...@@ -126,6 +126,7 @@ class NodeRef {
protected: protected:
template<typename T, typename> template<typename T, typename>
friend class Array; friend class Array;
friend class APIVariantValue;
NodeRef() = default; NodeRef() = default;
explicit NodeRef(std::shared_ptr<Node>&& node) : node_(std::move(node)) {} explicit NodeRef(std::shared_ptr<Node>&& node) : node_(std::move(node)) {}
/*! \brief the internal node */ /*! \brief the internal node */
...@@ -136,7 +137,7 @@ class NodeRef { ...@@ -136,7 +137,7 @@ class NodeRef {
using NodeFactory = std::function<std::shared_ptr<Node> ()>; using NodeFactory = std::function<std::shared_ptr<Node> ()>;
/*! /*!
* \brief Registry entry for DataIterator factory functions. * \brief Registry entry for NodeFactory
*/ */
struct NodeFactoryReg struct NodeFactoryReg
: public dmlc::FunctionRegEntryBase<NodeFactoryReg, : public dmlc::FunctionRegEntryBase<NodeFactoryReg,
......
...@@ -23,27 +23,117 @@ ...@@ -23,27 +23,117 @@
#define TVM_DLL TVM_EXTERN_C #define TVM_DLL TVM_EXTERN_C
#endif #endif
/*! \brief handle to node creator */ /*! \brief handle to functions */
typedef void* NodeCreatorHandle; typedef void* FunctionHandle;
/*! \brief handle to node */ /*! \brief handle to node */
typedef void* NodeHandle; typedef void* NodeHandle;
TVM_DLL int TVMNodeCreatorGet(const char* node_type, /*!
NodeCreatorHandle *handle); * \brief union type for returning value of attributes
* Attribute type can be identified by id
*/
typedef union {
long v_long; // NOLINT(*)
double v_double;
const char* v_str;
NodeHandle v_handle;
} ArgVariant;
/*! \brief attribute types */
typedef enum {
kNull = 0,
kLong = 1,
kDouble = 2,
kStr = 3,
kNodeHandle = 4
} ArgVariantID;
/*!
* \brief return str message of the last error
* all function in this file will return 0 when success
* and -1 when an error occured,
* NNGetLastError can be called to retrieve the error
*
* this function is threadsafe and can be called by different thread
* \return error info
*/
TVM_DLL const char *TVMGetLastError(void);
TVM_DLL int TVMNodeCreate(NodeCreatorHandle handle, /*!
int num_child_ref, * \brief List all the node function name
const char* child_ref_keys, * \param out_size The number of functions
NodeHandle* child_node_refs, * \param out_array The array of function names.
int num_attrs, */
const char* attr_keys, TVM_DLL int TVMListFunctionNames(int *out_size,
const char* attr_vals, const char*** out_array);
NodeHandle* handle); /*!
* \brief get function handle by name
* \param name The name of function
* \param handle The returning function handle
*/
TVM_DLL int TVMGetFunctionHandle(const char* name,
FunctionHandle *handle);
TVM_DLL int TVMNodeGetAttr(const char* key, /*!
const char** value); * \brief Get the detailed information about function.
* \param handle The operator handle.
* \param real_name The returned name of the function.
* This name is not the alias name of the atomic symbol.
* \param description The returned description of the symbol.
* \param num_doc_args Number of arguments that contain documents.
* \param arg_names Name of the arguments of doc args
* \param arg_type_infos Type informations about the arguments.
* \param arg_descriptions Description information about the arguments.
* \param return_type Return type of the function, if any.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMGetFunctionInfo(FunctionHandle handle,
const char **real_name,
const char **description,
int *num_doc_args,
const char ***arg_names,
const char ***arg_type_infos,
const char ***arg_descriptions,
const char **return_type);
TVM_DLL int TVMNodeGetChildNodeRef(const char* key, /*!
NodeHandle* out); * \brief Push an argument to the function calling stack.
* If push fails, the stack will be reset to empty
*
* \param arg number of attributes
* \param type_id The typeid of attributes.
*/
TVM_DLL int TVMPushStack(ArgVariant arg,
int type_id);
/*!
* \brief call a function by using arguments in the stack.
* The stack will be cleanup to empty after this call, whether the call is successful.
*
* \param handle The function handle
* \param ret_val The return value.
* \param ret_typeid the type id of return value.
*/
TVM_DLL int TVMFunctionCall(FunctionHandle handle,
ArgVariant* ret_val,
int* ret_typeid);
/*!
* \brief free the node handle
* \param handle The node handle to be freed.
*/
TVM_DLL int TVMNodeFree(NodeHandle handle);
/*!
* \brief get attributes given key
* \param handle The node handle
* \param key The attribute name
* \param out_value The attribute value
* \param out_typeid The typeif of the attribute.
*/
TVM_DLL int TVMNodeGetAttr(NodeHandle handle,
const char* key,
ArgVariant* out_value,
int* out_typeid);
#endif // TVM_C_API_H_ #endif // TVM_C_API_H_
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#ifndef TVM_TENSOR_H_ #ifndef TVM_TENSOR_H_
#define TVM_TENSOR_H_ #define TVM_TENSOR_H_
#include <string>
#include "./expr.h" #include "./expr.h"
#include "./array.h" #include "./array.h"
...@@ -43,9 +44,7 @@ class TensorNode : public Node { ...@@ -43,9 +44,7 @@ class TensorNode : public Node {
class Tensor : public NodeRef { class Tensor : public NodeRef {
public: public:
Tensor(Array<Expr> shape); explicit Tensor(Array<Expr> shape);
Tensor(Array<Expr> shape, std::function<Expr (Var, Var, Var)> f3) {
}
inline size_t ndim() const; inline size_t ndim() const;
template<typename... Args> template<typename... Args>
......
"""C++ backend related python scripts"""
from __future__ import absolute_import as _abs
from .function import *
from ._ctypes._api import register_node
from . import expr
# coding: utf-8
# pylint: disable=invalid-name
""" ctypes library of nnvm and helper functions """
from __future__ import absolute_import
import sys
import os
import ctypes
import numpy as np
from . import libinfo
__all__ = ['TVMError']
#----------------------------
# library loading
#----------------------------
if sys.version_info[0] == 3:
string_types = str,
numeric_types = (float, int, np.float32, np.int32)
# this function is needed for python3
# to convert ctypes.char_p .value back to python str
py_str = lambda x: x.decode('utf-8')
else:
string_types = basestring,
numeric_types = (float, int, long, np.float32, np.int32)
py_str = lambda x: x
class TVMError(Exception):
"""Error that will be throwed by all functions"""
pass
def _load_lib():
"""Load libary by searching possible path."""
lib_path = libinfo.find_lib_path()
lib = ctypes.CDLL(lib_path[0], ctypes.RTLD_GLOBAL)
# DMatrix functions
lib.TVMGetLastError.restype = ctypes.c_char_p
return lib
# version number
__version__ = libinfo.__version__
# library instance of nnvm
_LIB = _load_lib()
# type definitions
FunctionHandle = ctypes.c_void_p
NodeHandle = ctypes.c_void_p
#----------------------------
# helper function definition
#----------------------------
def check_call(ret):
"""Check the return value of C API call
This function will raise exception when error occurs.
Wrap every API call with this function
Parameters
----------
ret : int
return value from API calls
"""
if ret != 0:
raise TVMError(py_str(_LIB.TVMGetLastError()))
def c_str(string):
"""Create ctypes char * from a python string
Parameters
----------
string : string type
python string
Returns
-------
str : c_char_p
A char pointer that can be passed to C API
"""
return ctypes.c_char_p(string.encode('utf-8'))
def ctypes2docstring(num_args, arg_names, arg_types, arg_descs, remove_dup=True):
"""Convert ctypes returned doc string information into parameters docstring.
num_args : nn_uint
Number of arguments.
arg_names : ctypes.POINTER(ctypes.c_char_p)
Argument names.
arg_types : ctypes.POINTER(ctypes.c_char_p)
Argument type information.
arg_descs : ctypes.POINTER(ctypes.c_char_p)
Argument description information.
remove_dup : boolean, optional
Whether remove duplication or not.
Returns
-------
docstr : str
Python docstring of parameter sections.
"""
param_keys = set()
param_str = []
for i in range(num_args.value):
key = py_str(arg_names[i])
if key in param_keys and remove_dup:
continue
param_keys.add(key)
type_info = py_str(arg_types[i])
ret = '%s : %s' % (key, type_info)
if len(arg_descs[i]) != 0:
ret += '\n ' + py_str(arg_descs[i])
param_str.append(ret)
doc_str = ('Parameters\n' +
'----------\n' +
'%s\n')
doc_str = doc_str % ('\n'.join(param_str))
return doc_str
# coding: utf-8
# pylint: disable=invalid-name, protected-access, too-many-arguments, too-many-lines
"""Symbolic configuration API."""
from __future__ import absolute_import as _abs
import ctypes
import sys
from numbers import Number as Number
from .._base import _LIB
from .._base import c_str, py_str, string_types
from .._base import FunctionHandle, NodeHandle
from .._base import check_call, ctypes2docstring
class ArgVariant(ctypes.Union):
_fields_ = [("v_long", ctypes.c_long),
("v_double", ctypes.c_double),
("v_str", ctypes.c_char_p),
("v_handle", ctypes.c_void_p)]
kNull = 0
kLong = 1
kDouble = 2
kStr = 3
kNodeHandle = 4
RET_SWITCH = None
class NodeBase(object):
"""Symbol is symbolic graph."""
__slots__ = ["handle"]
# pylint: disable=no-member
def __init__(self, handle):
"""Initialize the function with handle
Parameters
----------
handle : SymbolHandle
the handle to the underlying C++ Symbol
"""
self.handle = handle
def __del__(self):
check_call(_LIB.TVMNodeFree(self.handle))
def __getattr__(self, name):
ret_val = ArgVariant()
ret_typeid = ctypes.c_int()
check_call(_LIB.TVMNodeGetAttr(
self.handle, c_str(name),
ctypes.byref(ret_val), ctypes.byref(ret_typeid)))
return RET_SWITCH[ret_typeid.value](ret_val)
def _type_key(handle):
ret_val = ArgVariant()
ret_typeid = ctypes.c_int()
check_call(_LIB.TVMNodeGetAttr(
handle, c_str("type_key"),
ctypes.byref(ret_val), ctypes.byref(ret_typeid)))
return py_str(ret_val.v_str)
NODE_TYPE = {
}
RET_SWITCH = {
kNull: lambda x: None,
kLong: lambda x: x.v_long.value,
kDouble: lambda x: x.v_double.value,
kStr: lambda x: py_str(x.v_str),
kNodeHandle: lambda x: NODE_TYPE.get(_type_key(x), NodeBase)(x.v_handle)
}
def _push_arg(arg):
a = ArgVariant()
if arg is None:
_LIB.TVMPushStack(a, ctypes.c_int(kNull))
elif isinstance(arg, NodeBase):
a.v_handle = arg.handle
_LIB.TVMPushStack(a, ctypes.c_int(kNodeHandle))
elif isinstance(arg, int):
a.v_long = ctypes.c_long(arg)
_LIB.TVMPushStack(a, ctypes.c_int(kLong))
elif isinstance(arg, Number):
a.v_double = ctypes.c_double(arg)
_LIB.TVMPushStack(a, ctypes.c_int(kDouble))
elif isinstance(arg, string_types):
a.v_str = c_str(arg)
_LIB.TVMPushStack(a, ctypes.c_int(kStr))
else:
raise TypeError("Don't know how to handle type %s" % type(arg))
def _make_function(handle, name):
"""Create an atomic symbol function by handle and funciton name."""
real_name = ctypes.c_char_p()
desc = ctypes.c_char_p()
num_args = ctypes.c_int()
arg_names = ctypes.POINTER(ctypes.c_char_p)()
arg_types = ctypes.POINTER(ctypes.c_char_p)()
arg_descs = ctypes.POINTER(ctypes.c_char_p)()
ret_type = ctypes.c_char_p()
check_call(_LIB.TVMGetFunctionInfo(
handle, ctypes.byref(real_name), ctypes.byref(desc),
ctypes.byref(num_args),
ctypes.byref(arg_names),
ctypes.byref(arg_types),
ctypes.byref(arg_descs),
ctypes.byref(ret_type)))
param_str = ctypes2docstring(num_args, arg_names, arg_types, arg_descs)
func_name = name
desc = py_str(desc.value)
doc_str = ('%s\n\n' +
'%s\n' +
'name : string, optional.\n' +
' Name of the resulting symbol.\n\n' +
'Returns\n' +
'-------\n' +
'symbol: Symbol\n' +
' The result symbol.')
doc_str = doc_str % (desc, param_str)
arg_names = [py_str(arg_names[i]) for i in range(num_args.value)]
def func(*args, **kwargs):
"""TVM function"""
for arg in args:
_push_arg(arg)
ret_val = ArgVariant()
ret_typeid = ctypes.c_int()
check_call(_LIB.TVMFunctionCall(
handle, ctypes.byref(ret_val), ctypes.byref(ret_typeid)))
return RET_SWITCH[ret_typeid.value](ret_val)
func.__name__ = func_name
func.__doc__ = doc_str
return func
def register_node(type_key):
"""register node type
Parameters
----------
type_key : str
The type key of the node
"""
def register(cls):
NODE_TYPE[type_key] = cls
return register
def _init_function_module(root_namespace):
"""List and add all the functions to current module."""
plist = ctypes.POINTER(ctypes.c_char_p)()
size = ctypes.c_uint()
check_call(_LIB.TVMListFunctionNames(ctypes.byref(size),
ctypes.byref(plist)))
op_names = []
for i in range(size.value):
op_names.append(py_str(plist[i]))
module_obj = sys.modules["%s.function" % root_namespace]
module_internal = sys.modules["%s._function_internal" % root_namespace]
for name in op_names:
hdl = FunctionHandle()
check_call(_LIB.TVMGetFunctionHandle(c_str(name), ctypes.byref(hdl)))
function = _make_function(hdl, name)
if function.__name__.startswith('_'):
setattr(module_internal, function.__name__, function)
else:
setattr(module_obj, function.__name__, function)
"""namespace of internal function"""
from ._ctypes._api import NodeBase, register_node
class Expr(NodeBase):
pass
@register_node("VarNode")
class Var(Expr):
pass
@register_node("BinaryOpNode")
class BinaryOpNode(Expr):
pass
from ._ctypes._api import _init_function_module
import _function_internal
_init_function_module("tvm.cpp")
# coding: utf-8
"""Information about nnvm."""
from __future__ import absolute_import
import os
import platform
def find_lib_path():
"""Find dynamic library files.
Returns
-------
lib_path : list(string)
List of all found path to the libraries
"""
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
api_path = os.path.join(curr_path, '../../../lib/')
cmake_build_path = os.path.join(curr_path, '../../build/Release/')
dll_path = [curr_path, api_path, cmake_build_path]
if os.name == 'nt':
vs_configuration = 'Release'
if platform.architecture()[0] == '64bit':
dll_path.append(os.path.join(curr_path, '../../build', vs_configuration))
dll_path.append(os.path.join(curr_path, '../../windows/x64', vs_configuration))
else:
dll_path.append(os.path.join(curr_path, '../../build', vs_configuration))
dll_path.append(os.path.join(curr_path, '../../windows', vs_configuration))
elif os.name == "posix" and os.environ.get('LD_LIBRARY_PATH', None):
dll_path.extend([p.strip() for p in os.environ['LD_LIBRARY_PATH'].split(":")])
if os.name == 'nt':
dll_path = [os.path.join(p, 'libtvm.dll') for p in dll_path]
else:
dll_path = [os.path.join(p, 'libtvm.so') for p in dll_path]
lib_path = [p for p in dll_path if os.path.exists(p) and os.path.isfile(p)]
if len(lib_path) == 0:
raise RuntimeError('Cannot find the files.\n' +
'List of candidates:\n' + str('\n'.join(dll_path)))
return lib_path
# current version
__version__ = "0.1.0"
/*!
* Copyright (c) 2016 by Contributors
* Implementation of C API
* \file c_api.cc
*/
#include <tvm/c_api.h>
#include <tvm/op.h>
#include "./c_api_common.h"
#include "./c_api_registry.h"
/*! \brief entry to to easily hold returning information */
struct TVMAPIThreadLocalEntry {
/*! \brief hold last error */
std::string last_error;
/*! \brief result holder for returning strings */
std::vector<std::string> ret_vec_str;
/*! \brief result holder for returning string pointers */
std::vector<const char *> ret_vec_charp;
/*! \brief argument stack */
std::vector<tvm::APIVariantValue> arg_stack;
/*! \brief return value */
tvm::APIVariantValue ret_value;
// clear calling stack
inline void Clear() {
arg_stack.clear();
ret_value.sptr.reset();
}
inline void SetReturn(ArgVariant* ret_val, int* ret_typeid);
};
using namespace tvm;
/*! \brief Thread local store that can be used to hold return values. */
typedef dmlc::ThreadLocalStore<TVMAPIThreadLocalEntry> TVMAPIThreadLocalStore;
using TVMAPINode = std::shared_ptr<Node>;
struct APIAttrGetter : public AttrVisitor {
std::string skey;
APIVariantValue* ret;
void Visit(const char* key, double* value) override {
if (skey == key) *ret = value[0];
}
void Visit(const char* key, int64_t* value) override {
if (skey == key) *ret = value[0];
}
void Visit(const char* key, DataType* value) override {
if (skey == key) *ret = static_cast<int64_t>(value[0]);
}
void Visit(const char* key, std::string* value) override {
if (skey == key) *ret = value[0];
}
void Visit(const char* key, const UnaryOp** value) override {
if (skey == key) *ret = value[0]->FunctionName();
}
void Visit(const char* key, const BinaryOp** value) override {
if (skey == key) *ret = value[0]->FunctionName();
}
};
const char *TVMGetLastError() {
return TVMAPIThreadLocalStore::Get()->last_error.c_str();
}
void TVMAPISetLastError(const char* msg) {
TVMAPIThreadLocalStore::Get()->last_error = msg;
}
int TVMListFunctionNames(int *out_size,
const char*** out_array) {
API_BEGIN();
TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get();
ret->ret_vec_str = dmlc::Registry<APIFunctionReg>::ListAllNames();
ret->ret_vec_charp.clear();
for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
}
*out_array = dmlc::BeginPtr(ret->ret_vec_charp);
*out_size = static_cast<int>(ret->ret_vec_str.size());
API_END();
}
int TVMGetFunctionHandle(const char* fname,
FunctionHandle* out) {
API_BEGIN();
const APIFunctionReg* reg = dmlc::Registry<APIFunctionReg>::Find(fname);
CHECK(reg != nullptr) << "cannot find function " << fname;
*out = (FunctionHandle)reg;
API_END();
}
int TVMGetFunctionInfo(FunctionHandle handle,
const char **real_name,
const char **description,
int *num_doc_args,
const char ***arg_names,
const char ***arg_type_infos,
const char ***arg_descriptions,
const char **return_type) {
const auto *op = static_cast<const APIFunctionReg *>(handle);
TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get();
API_BEGIN();
*real_name = op->name.c_str();
*description = op->description.c_str();
*num_doc_args = static_cast<int>(op->arguments.size());
if (return_type) *return_type = nullptr;
ret->ret_vec_charp.clear();
for (size_t i = 0; i < op->arguments.size(); ++i) {
ret->ret_vec_charp.push_back(op->arguments[i].name.c_str());
}
for (size_t i = 0; i < op->arguments.size(); ++i) {
ret->ret_vec_charp.push_back(op->arguments[i].type_info_str.c_str());
}
for (size_t i = 0; i < op->arguments.size(); ++i) {
ret->ret_vec_charp.push_back(op->arguments[i].description.c_str());
}
*arg_names = dmlc::BeginPtr(ret->ret_vec_charp);
*arg_type_infos = dmlc::BeginPtr(ret->ret_vec_charp) + op->arguments.size();
*arg_descriptions = dmlc::BeginPtr(ret->ret_vec_charp) + (op->arguments.size() * 2);
API_END();
}
int TVMPushStack(ArgVariant arg,
int type_id) {
TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get();
API_BEGIN();
ret->arg_stack.resize(ret->arg_stack.size() + 1);
APIVariantValue& v = ret->arg_stack.back();
v.type_id = static_cast<ArgVariantID>(type_id);
if (type_id == kStr) {
v = arg.v_str;
} else if (type_id == kNodeHandle) {
v.sptr = *static_cast<TVMAPINode*>(arg.v_handle);
} else {
v.v_union = arg;
}
API_END_HANDLE_ERROR(ret->Clear());
}
int TVMFunctionCall(FunctionHandle handle,
ArgVariant* ret_val,
int* ret_typeid) {
TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get();
API_BEGIN();
const auto *op = static_cast<const APIFunctionReg *>(handle);
op->body(ret->arg_stack, &(ret->ret_value));
ret->SetReturn(ret_val, ret_typeid);
ret->arg_stack.clear();
API_END_HANDLE_ERROR(ret->Clear());
}
int TVMNodeFree(NodeHandle handle) {
API_BEGIN();
delete static_cast<TVMAPINode*>(handle);
API_END();
}
int TVMNodeGetAttr(NodeHandle handle,
const char* key,
ArgVariant* ret_val,
int* ret_typeid) {
TVMAPIThreadLocalEntry *ret = TVMAPIThreadLocalStore::Get();
API_BEGIN();
ret->ret_value.type_id = kNull;
APIAttrGetter getter;
getter.skey = key;
getter.ret = &(ret->ret_value);
TVMAPINode* tnode = static_cast<TVMAPINode*>(handle);
if (getter.skey == "type_key") {
ret_val->v_str = (*tnode)->type_key();
*ret_typeid = kStr;
} else {
(*tnode)->VisitAttrs(&getter);
if (ret->ret_value.type_id != kNull) {
ret->SetReturn(ret_val, ret_typeid);
} else {
const std::string& skey = getter.skey;
(*tnode)->VisitNodeRefFields([&skey, ret](const char* key, NodeRef* ref) {
if (key == skey) ret->ret_value = *ref;
});
if (ret->ret_value.type_id != kNull) {
ret->SetReturn(ret_val, ret_typeid);
} else {
*ret_typeid = kNull;
}
}
}
API_END_HANDLE_ERROR(ret->Clear());
}
inline void TVMAPIThreadLocalEntry::SetReturn(ArgVariant* ret_val,
int* ret_typeid) {
APIVariantValue& rv = ret_value;
*ret_typeid = rv.type_id;
if (rv.type_id == kNodeHandle) {
if (rv.sptr.get() != nullptr) {
ret_val->v_handle = new TVMAPINode(std::move(rv.sptr));
} else {
ret_val->v_handle = nullptr;
}
} else {
*ret_val = rv.v_union;
}
}
/*!
* Copyright (c) 2016 by Contributors
* \file c_api_common.h
* \brief Common fields of all C APIs
*/
#ifndef TVM_C_API_C_API_COMMON_H_
#define TVM_C_API_C_API_COMMON_H_
#include <dmlc/base.h>
#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <tvm/c_api.h>
#include <vector>
#include <string>
#include "./c_api_registry.h"
/*! \brief macro to guard beginning and end section of all functions */
#define API_BEGIN() try {
/*! \brief every function starts with API_BEGIN();
and finishes with API_END() or API_END_HANDLE_ERROR */
#define API_END() } catch(dmlc::Error &_except_) { return TVMAPIHandleException(_except_); } return 0; // NOLINT(*)
/*!
* \brief every function starts with API_BEGIN();
* and finishes with API_END() or API_END_HANDLE_ERROR
* The finally clause contains procedure to cleanup states when an error happens.
*/
#define API_END_HANDLE_ERROR(Finalize) } catch(dmlc::Error &_except_) { Finalize; return TVMAPIHandleException(_except_); } return 0; // NOLINT(*)
void TVMAPISetLastError(const char* msg);
/*!
* \brief handle exception throwed out
* \param e the exception
* \return the return value of API after exception is handled
*/
inline int TVMAPIHandleException(const dmlc::Error &e) {
TVMAPISetLastError(e.what());
return -1;
}
#endif // TVM_C_API_C_API_COMMON_H_
/*!
* Copyright (c) 2016 by Contributors
* Implementation of API functions
* \file c_api_impl.cc
*/
#include <tvm/expr.h>
#include <tvm/op.h>
#include "./c_api_registry.h"
namespace dmlc {
DMLC_REGISTRY_ENABLE(::tvm::APIFunctionReg);
} // namespace dmlc
namespace tvm {
using ArgStack = const std::vector<APIVariantValue>;
using RetValue = APIVariantValue;
TVM_REGISTER_API(Var)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = Var(args.at(0), static_cast<DataType>(static_cast<int>(args.at(1))));
})
.add_argument("name", "str", "name of the var")
.add_argument("dtype", "int", "data type of var");
TVM_REGISTER_API(max)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = max(args.at(0), args.at(1));
})
.add_argument("lhs", "Expr", "left operand")
.add_argument("rhs", "Expr", "right operand");
TVM_REGISTER_API(min)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = min(args.at(0), args.at(1));
})
.add_argument("lhs", "Expr", "left operand")
.add_argument("rhs", "Expr", "right operand");
TVM_REGISTER_API(format_str)
.set_body([](const ArgStack& args, RetValue *ret) {
std::ostringstream os;
os << Expr(args.at(0));
*ret = os.str();
})
.add_argument("expr", "Expr", "expression to be printed");
} // namespace tvm
/*!
* Copyright (c) 2016 by Contributors
* \file c_api_registry.h
* \brief Quick registry for C API.
*/
#ifndef TVM_C_API_C_API_REGISTRY_H_
#define TVM_C_API_C_API_REGISTRY_H_
#include <tvm/base.h>
#include <tvm/expr.h>
#include <tvm/c_api.h>
#include <memory>
#include <string>
#include <vector>
namespace tvm {
/*! \brief Variant container for API calls */
struct APIVariantValue {
/*! \brief the type id */
ArgVariantID type_id{kNull};
/*! \brief shared pointer container */
std::shared_ptr<Node> sptr;
/*! \brief string container */
std::string str;
/*! \brief the variant holder */
ArgVariant v_union;
// constructor
APIVariantValue() {}
// clear value
inline void Clear() {
}
// assign op
inline APIVariantValue& operator=(double value) {
type_id = kDouble;
v_union.v_double = value;
return *this;
}
inline APIVariantValue& operator=(std::nullptr_t value) {
type_id = kNull;
return *this;
}
inline APIVariantValue& operator=(int64_t value) {
type_id = kLong;
v_union.v_long = value;
return *this;
}
inline APIVariantValue& operator=(std::string value) {
type_id = kStr;
str = std::move(value);
v_union.v_str = str.c_str();
return *this;
}
inline APIVariantValue& operator=(const NodeRef& ref) {
type_id = kNodeHandle;
this->sptr = ref.node_;
return *this;
}
template<typename T,
typename = typename std::enable_if<std::is_base_of<NodeRef, T>::value>::type >
inline operator T() const {
if (type_id == kNull) return T();
CHECK_EQ(type_id, kNodeHandle);
std::shared_ptr<Node> x = sptr;
return T(std::move(x));
}
inline operator double() const {
CHECK_EQ(type_id, kDouble);
return v_union.v_double;
}
inline operator int64_t() const {
CHECK_EQ(type_id, kLong);
return v_union.v_long;
}
inline operator int() const {
CHECK_EQ(type_id, kLong);
return v_union.v_long;
}
inline operator std::string() const {
CHECK_EQ(type_id, kStr);
return str;
}
};
// common defintiion of API function.
using APIFunction = std::function<
void(const std::vector<APIVariantValue> &args, APIVariantValue* ret)>;
/*!
* \brief Registry entry for DataIterator factory functions.
*/
struct APIFunctionReg
: public dmlc::FunctionRegEntryBase<APIFunctionReg,
APIFunction> {
};
#define TVM_REGISTER_API(TypeName) \
DMLC_REGISTRY_REGISTER(::tvm::APIFunctionReg, APIFunctionReg, TypeName) \
} // namespace tvm
#endif // TVM_C_API_C_API_REGISTRY_H_
...@@ -5,6 +5,10 @@ ...@@ -5,6 +5,10 @@
#include <tvm/expr_node.h> #include <tvm/expr_node.h>
#include <memory> #include <memory>
namespace dmlc {
DMLC_REGISTRY_ENABLE(::tvm::NodeFactoryReg);
} // namespace dmlc
namespace tvm { namespace tvm {
TVM_REGISTER_NODE_TYPE(VarNode); TVM_REGISTER_NODE_TYPE(VarNode);
......
from tvm import cpp as tvm
def test_basic():
a = tvm.Var('a', 0)
b = tvm.Var('b', 0)
z = tvm.max(a, b)
assert tvm.format_str(z) == 'max(%s, %s)' % (a.name, b.name)
if __name__ == "__main__":
test_basic()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment