Commit 1a7fb9f9 by tqchen

fix c api

parent 35277c2f
...@@ -24,9 +24,9 @@ class BinaryOp; ...@@ -24,9 +24,9 @@ class BinaryOp;
/*! \brief list of all supported data types */ /*! \brief list of all supported data types */
enum DataType { enum DataType {
kUnknown, kUnknown = 0,
kInt32, kInt32 = 1,
kFloat32 kFloat32 = 2
}; };
/*! /*!
......
...@@ -27,6 +27,7 @@ class VarNode : public ExprNode { ...@@ -27,6 +27,7 @@ class VarNode : public ExprNode {
} }
void VisitAttrs(AttrVisitor* visitor) override { void VisitAttrs(AttrVisitor* visitor) override {
visitor->Visit("name", &name); visitor->Visit("name", &name);
visitor->Visit("dtype", &dtype_);
} }
}; };
......
...@@ -25,7 +25,26 @@ kDouble = 2 ...@@ -25,7 +25,26 @@ kDouble = 2
kStr = 3 kStr = 3
kNodeHandle = 4 kNodeHandle = 4
RET_SWITCH = None
def _type_key(handle):
ret_val = ArgVariant()
ret_typeid = ctypes.c_int()
check_call(_LIB.TVMNodeGetAttr(
handle, c_str("type_key"),
ctypes.byref(ret_val), ctypes.byref(ret_typeid)))
return py_str(ret_val.v_str)
NODE_TYPE = {
}
RET_SWITCH = {
kNull: lambda x: None,
kLong: lambda x: x.v_long,
kDouble: lambda x: x.v_double,
kStr: lambda x: py_str(x.v_str),
kNodeHandle: lambda x: NODE_TYPE.get(_type_key(x), NodeBase)(x.v_handle)
}
class NodeBase(object): class NodeBase(object):
"""Symbol is symbolic graph.""" """Symbol is symbolic graph."""
...@@ -50,27 +69,8 @@ class NodeBase(object): ...@@ -50,27 +69,8 @@ class NodeBase(object):
check_call(_LIB.TVMNodeGetAttr( check_call(_LIB.TVMNodeGetAttr(
self.handle, c_str(name), self.handle, c_str(name),
ctypes.byref(ret_val), ctypes.byref(ret_typeid))) ctypes.byref(ret_val), ctypes.byref(ret_typeid)))
ret = RET_SWITCH[ret_typeid.value](ret_val) return RET_SWITCH[ret_typeid.value](ret_val)
def _type_key(handle):
ret_val = ArgVariant()
ret_typeid = ctypes.c_int()
check_call(_LIB.TVMNodeGetAttr(
handle, c_str("type_key"),
ctypes.byref(ret_val), ctypes.byref(ret_typeid)))
return py_str(ret_val.v_str)
NODE_TYPE = {
}
RET_SWITCH = {
kNull: lambda x: None,
kLong: lambda x: x.v_long.value,
kDouble: lambda x: x.v_double.value,
kStr: lambda x: py_str(x.v_str),
kNodeHandle: lambda x: NODE_TYPE.get(_type_key(x), NodeBase)(x.v_handle)
}
def _push_arg(arg): def _push_arg(arg):
a = ArgVariant() a = ArgVariant()
......
from ._ctypes._api import _init_function_module from ._ctypes._api import _init_function_module
import _function_internal import _function_internal
int32 = 1
float32 = 2
def Var(name="tindex", dtype=int32):
"""Create a new variable with specified name and dtype
Parameters
----------
name : str
The name
dtype : int
The data type
"""
return _function_internal._Var(name, dtype)
_init_function_module("tvm.cpp") _init_function_module("tvm.cpp")
...@@ -16,9 +16,10 @@ namespace tvm { ...@@ -16,9 +16,10 @@ namespace tvm {
using ArgStack = const std::vector<APIVariantValue>; using ArgStack = const std::vector<APIVariantValue>;
using RetValue = APIVariantValue; using RetValue = APIVariantValue;
TVM_REGISTER_API(Var) TVM_REGISTER_API(_Var)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](const ArgStack& args, RetValue *ret) {
*ret = Var(args.at(0), static_cast<DataType>(static_cast<int>(args.at(1)))); *ret = Var(args.at(0),
static_cast<DataType>(static_cast<int>(args.at(1))));
}) })
.add_argument("name", "str", "name of the var") .add_argument("name", "str", "name of the var")
.add_argument("dtype", "int", "data type of var"); .add_argument("dtype", "int", "data type of var");
......
from tvm import cpp as tvm from tvm import cpp as tvm
def test_basic(): def test_basic():
a = tvm.Var('a', 0) a = tvm.Var('a')
b = tvm.Var('b', 0) b = tvm.Var('b')
z = tvm.max(a, b) z = tvm.max(a, b)
assert tvm.format_str(z) == 'max(%s, %s)' % (a.name, b.name) assert tvm.format_str(z) == 'max(%s, %s)' % (a.name, b.name)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment