Commit 1a7fb9f9 by tqchen

fix c api

parent 35277c2f
......@@ -24,9 +24,9 @@ class BinaryOp;
/*! \brief list of all supported data types */
enum DataType {
kUnknown,
kInt32,
kFloat32
kUnknown = 0,
kInt32 = 1,
kFloat32 = 2
};
/*!
......
......@@ -27,6 +27,7 @@ class VarNode : public ExprNode {
}
void VisitAttrs(AttrVisitor* visitor) override {
visitor->Visit("name", &name);
visitor->Visit("dtype", &dtype_);
}
};
......
......@@ -25,7 +25,26 @@ kDouble = 2
kStr = 3
kNodeHandle = 4
RET_SWITCH = None
def _type_key(handle):
ret_val = ArgVariant()
ret_typeid = ctypes.c_int()
check_call(_LIB.TVMNodeGetAttr(
handle, c_str("type_key"),
ctypes.byref(ret_val), ctypes.byref(ret_typeid)))
return py_str(ret_val.v_str)
NODE_TYPE = {
}
RET_SWITCH = {
kNull: lambda x: None,
kLong: lambda x: x.v_long,
kDouble: lambda x: x.v_double,
kStr: lambda x: py_str(x.v_str),
kNodeHandle: lambda x: NODE_TYPE.get(_type_key(x), NodeBase)(x.v_handle)
}
class NodeBase(object):
"""Symbol is symbolic graph."""
......@@ -50,27 +69,8 @@ class NodeBase(object):
check_call(_LIB.TVMNodeGetAttr(
self.handle, c_str(name),
ctypes.byref(ret_val), ctypes.byref(ret_typeid)))
ret = RET_SWITCH[ret_typeid.value](ret_val)
def _type_key(handle):
ret_val = ArgVariant()
ret_typeid = ctypes.c_int()
check_call(_LIB.TVMNodeGetAttr(
handle, c_str("type_key"),
ctypes.byref(ret_val), ctypes.byref(ret_typeid)))
return py_str(ret_val.v_str)
NODE_TYPE = {
}
return RET_SWITCH[ret_typeid.value](ret_val)
RET_SWITCH = {
kNull: lambda x: None,
kLong: lambda x: x.v_long.value,
kDouble: lambda x: x.v_double.value,
kStr: lambda x: py_str(x.v_str),
kNodeHandle: lambda x: NODE_TYPE.get(_type_key(x), NodeBase)(x.v_handle)
}
def _push_arg(arg):
a = ArgVariant()
......
from ._ctypes._api import _init_function_module
import _function_internal
int32 = 1
float32 = 2
def Var(name="tindex", dtype=int32):
"""Create a new variable with specified name and dtype
Parameters
----------
name : str
The name
dtype : int
The data type
"""
return _function_internal._Var(name, dtype)
_init_function_module("tvm.cpp")
......@@ -16,9 +16,10 @@ namespace tvm {
using ArgStack = const std::vector<APIVariantValue>;
using RetValue = APIVariantValue;
TVM_REGISTER_API(Var)
TVM_REGISTER_API(_Var)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = Var(args.at(0), static_cast<DataType>(static_cast<int>(args.at(1))));
*ret = Var(args.at(0),
static_cast<DataType>(static_cast<int>(args.at(1))));
})
.add_argument("name", "str", "name of the var")
.add_argument("dtype", "int", "data type of var");
......
from tvm import cpp as tvm
def test_basic():
a = tvm.Var('a', 0)
b = tvm.Var('b', 0)
a = tvm.Var('a')
b = tvm.Var('b')
z = tvm.max(a, b)
assert tvm.format_str(z) == 'max(%s, %s)' % (a.name, b.name)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment