Commit 9595a9c1 by tqchen

Expose array to python

parent de2be97e
......@@ -4,3 +4,4 @@ from __future__ import absolute_import as _abs
from .function import *
from ._ctypes._api import register_node
from . import expr
from . import domain
from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, register_node
from . import _function_internal
@register_node("RangeNode")
class Range(NodeBase):
pass
@register_node("ArrayNode")
class Array(NodeBase):
def __getitem__(self, i):
return _function_internal._ArrayGetItem(self, i)
def __len__(self):
return _function_internal._ArraySize(self)
from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, register_node
from .function import binary_op
......@@ -40,6 +41,26 @@ class Expr(NodeBase):
class Var(Expr):
pass
@register_node("IntNode")
class IntExpr(Expr):
pass
@register_node("FloatNode")
class FloatExpr(Expr):
pass
@register_node("UnaryOpNode")
class UnaryOpExpr(Expr):
pass
@register_node("BinaryOpNode")
class BinaryOpExpr(Expr):
pass
@register_node("ReduceNode")
class ReduceExpr(Expr):
pass
@register_node("TensorReadNode")
class TensorReadExpr(Expr):
pass
......@@ -24,6 +24,9 @@ def _symbol(value):
"""Convert a value to expression."""
if isinstance(value, _Number):
return constant(value)
elif isinstance(value, list):
value = [_symbol(x) for x in value]
return _function_internal._Array(*value)
else:
return value
......
......@@ -61,6 +61,41 @@ TVM_REGISTER_API(Range)
.add_argument("begin", "Expr", "beginning of the range.")
.add_argument("end", "Expr", "end of the range");
TVM_REGISTER_API(_Array)
.set_body([](const ArgStack& args, RetValue *ret) {
std::vector<std::shared_ptr<Node> > data;
for (size_t i = 0; i < args.size(); ++i) {
CHECK(args.at(i).type_id == kNodeHandle);
data.push_back(args.at(i).sptr);
}
auto node = std::make_shared<ArrayNode>();
node->data = std::move(data);
ret->type_id = kNodeHandle;
ret->sptr = node;
});
TVM_REGISTER_API(_ArrayGetItem)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle);
int64_t i = args.at(1);
auto& sptr = args.at(0).sptr;
CHECK(sptr->is_type<ArrayNode>());
auto* n = static_cast<const ArrayNode*>(sptr.get());
CHECK_LT(static_cast<size_t>(i), n->data.size())
<< "out of bound of array";
ret->sptr = n->data[i];
ret->type_id = kNodeHandle;
});
TVM_REGISTER_API(_ArraySize)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle);
auto& sptr = args.at(0).sptr;
CHECK(sptr->is_type<ArrayNode>());
*ret = static_cast<int64_t>(
static_cast<const ArrayNode*>(sptr.get())->data.size());
});
TVM_REGISTER_API(_TensorInput)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = Tensor(
......
......@@ -57,7 +57,7 @@ struct APIVariantValue {
return *this;
}
template<typename T,
typename = typename std::enable_if<std::is_base_of<NodeRef, T>::value>::type >
typename = typename std::enable_if<std::is_base_of<NodeRef, T>::value>::type>
inline operator T() const {
if (type_id == kNull) return T();
CHECK_EQ(type_id, kNodeHandle);
......
......@@ -9,5 +9,15 @@ def test_basic():
assert c.dtype == tvm.int32
assert tvm.format_str(c) == '(%s + %s)' % (a.name, b.name)
def test_array():
a = tvm.Var('a')
x = tvm.function._symbol([1,2,a])
print type(x)
print len(x)
print x[4]
if __name__ == "__main__":
test_basic()
test_array()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment