Commit 062bb853 by tqchen

Add in Array, fix most of IR

parent 622cee7a
Subproject commit 872099363b9f16a6cd4a4e8e46b9bd8dd1b861e9 Subproject commit 9070ac3697931ef5aeb8c373c23b2e8a2fec4627
...@@ -6,3 +6,4 @@ from ._ctypes._api import register_node ...@@ -6,3 +6,4 @@ from ._ctypes._api import register_node
from . import expr from . import expr
from . import stmt from . import stmt
from . import make from . import make
from . import domain
...@@ -5,7 +5,7 @@ from __future__ import absolute_import as _abs ...@@ -5,7 +5,7 @@ from __future__ import absolute_import as _abs
import ctypes import ctypes
import sys import sys
from numbers import Number as Number from numbers import Number, Integral
from .._base import _LIB from .._base import _LIB
from .._base import c_str, py_str, string_types from .._base import c_str, py_str, string_types
...@@ -93,6 +93,27 @@ class NodeBase(object): ...@@ -93,6 +93,27 @@ class NodeBase(object):
names.append(py_str(plist[i])) names.append(py_str(plist[i]))
return names return names
def const(value, dtype=None):
"""construct a constant"""
if dtype is None:
if isinstance(value, Integral):
dtype = 'int32'
else:
dtype = 'float32'
return _function_internal._const(value, dtype)
def convert(value):
"""Convert a value to expression."""
if isinstance(value, Number):
return const(value)
elif isinstance(value, list):
value = [convert(x) for x in value]
return _function_internal._Array(*value)
else:
if not isinstance(value, NodeBase):
raise ValueError("don't know how to handle type %s" % type(value))
def _push_arg(arg): def _push_arg(arg):
a = ArgVariant() a = ArgVariant()
...@@ -147,9 +168,16 @@ def _make_function(handle, name): ...@@ -147,9 +168,16 @@ def _make_function(handle, name):
doc_str = doc_str % (desc, param_str) doc_str = doc_str % (desc, param_str)
arg_names = [py_str(arg_names[i]) for i in range(num_args.value)] arg_names = [py_str(arg_names[i]) for i in range(num_args.value)]
def func(*args, **kwargs): def func(*args):
"""TVM function""" """TVM function"""
for arg in args: cargs = []
for x in args:
if isinstance(x, list):
cargs.append(convert(x))
else:
cargs.append(x)
for arg in cargs:
_push_arg(arg) _push_arg(arg)
ret_val = ArgVariant() ret_val = ArgVariant()
ret_typeid = ctypes.c_int() ret_typeid = ctypes.c_int()
......
from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, register_node
from . import _function_internal
@register_node
class Array(NodeBase):
def __getitem__(self, i):
if i >= len(self):
raise IndexError("array index out ot range")
return _function_internal._ArrayGetItem(self, i)
def __len__(self):
return _function_internal._ArraySize(self)
def __repr__(self):
return '[' + (','.join(str(x) for x in self)) + ']'
...@@ -52,6 +52,10 @@ class CmpExpr(Expr): ...@@ -52,6 +52,10 @@ class CmpExpr(Expr):
class LogicalExpr(Expr): class LogicalExpr(Expr):
pass pass
@register_node("Variable")
class Var(Expr):
pass
@register_node @register_node
class FloatImm(ConstExpr): class FloatImm(ConstExpr):
pass pass
......
...@@ -8,6 +8,7 @@ int32 = "int32" ...@@ -8,6 +8,7 @@ int32 = "int32"
float32 = "float32" float32 = "float32"
def const(value, dtype=None): def const(value, dtype=None):
"""construct a constant"""
if dtype is None: if dtype is None:
if isinstance(value, _Integral): if isinstance(value, _Integral):
dtype = 'int32' dtype = 'int32'
...@@ -16,12 +17,26 @@ def const(value, dtype=None): ...@@ -16,12 +17,26 @@ def const(value, dtype=None):
return _function_internal._const(value, dtype) return _function_internal._const(value, dtype)
def _symbol(value): def Var(name="tindex", dtype=int32):
"""Create a new variable with specified name and dtype
Parameters
----------
name : str
The name
dtype : int
The data type
"""
return _function_internal._Var(name, dtype)
def convert(value):
"""Convert a value to expression.""" """Convert a value to expression."""
if isinstance(value, _Number): if isinstance(value, _Number):
return const(value) return const(value)
elif isinstance(value, list): elif isinstance(value, list):
value = [_symbol(x) for x in value] value = [convert(x) for x in value]
return _function_internal._Array(*value) return _function_internal._Array(*value)
else: else:
return value return value
......
...@@ -21,6 +21,10 @@ class ProducerConsumer(Stmt): ...@@ -21,6 +21,10 @@ class ProducerConsumer(Stmt):
@register_node @register_node
class For(Stmt): class For(Stmt):
Serial = 0
Parallel = 1
Vectorized = 2
Unrolled = 3
pass pass
@register_node @register_node
......
...@@ -40,9 +40,46 @@ TVM_REGISTER_API(format_str) ...@@ -40,9 +40,46 @@ TVM_REGISTER_API(format_str)
os << args.at(0).operator Expr(); os << args.at(0).operator Expr();
} else if (dynamic_cast<const BaseStmtNode*>(sptr.get())) { } else if (dynamic_cast<const BaseStmtNode*>(sptr.get())) {
os << args.at(0).operator Stmt(); os << args.at(0).operator Stmt();
} else {
LOG(FATAL) << "don't know how to print input NodeBaseType";
} }
*ret = os.str(); *ret = os.str();
}) })
.add_argument("expr", "Node", "expression to be printed"); .add_argument("expr", "Node", "expression to be printed");
TVM_REGISTER_API(_Array)
.set_body([](const ArgStack& args, RetValue *ret) {
std::vector<std::shared_ptr<Node> > data;
for (size_t i = 0; i < args.size(); ++i) {
CHECK(args.at(i).type_id == kNodeHandle);
data.push_back(args.at(i).sptr);
}
auto node = std::make_shared<ArrayNode>();
node->data = std::move(data);
ret->type_id = kNodeHandle;
ret->sptr = node;
});
TVM_REGISTER_API(_ArrayGetItem)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle);
int64_t i = args.at(1);
auto& sptr = args.at(0).sptr;
CHECK(sptr->is_type<ArrayNode>());
auto* n = static_cast<const ArrayNode*>(sptr.get());
CHECK_LT(static_cast<size_t>(i), n->data.size())
<< "out of bound of array";
ret->sptr = n->data[i];
ret->type_id = kNodeHandle;
});
TVM_REGISTER_API(_ArraySize)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle);
auto& sptr = args.at(0).sptr;
CHECK(sptr->is_type<ArrayNode>());
*ret = static_cast<int64_t>(
static_cast<const ArrayNode*>(sptr.get())->data.size());
});
} // namespace tvm } // namespace tvm
...@@ -14,6 +14,30 @@ using namespace Halide::Internal; ...@@ -14,6 +14,30 @@ using namespace Halide::Internal;
using ArgStack = const std::vector<APIVariantValue>; using ArgStack = const std::vector<APIVariantValue>;
using RetValue = APIVariantValue; using RetValue = APIVariantValue;
TVM_REGISTER_API(_Var)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = Variable::make(args.at(1), args.at(0));
});
TVM_REGISTER_API(_make_For)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = For::make(args.at(0),
args.at(1),
args.at(2),
static_cast<ForType>(args.at(3).operator int()),
static_cast<Halide::DeviceAPI>(args.at(4).operator int()),
args.at(5));
});
TVM_REGISTER_API(_make_Allocate)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = Allocate::make(args.at(0),
args.at(1),
args.at(2),
args.at(3),
args.at(4));
});
// make from two arguments // make from two arguments
#define REGISTER_MAKE1(Node) \ #define REGISTER_MAKE1(Node) \
TVM_REGISTER_API(_make_## Node) \ TVM_REGISTER_API(_make_## Node) \
...@@ -67,13 +91,12 @@ REGISTER_MAKE3(Select); ...@@ -67,13 +91,12 @@ REGISTER_MAKE3(Select);
REGISTER_MAKE3(Ramp); REGISTER_MAKE3(Ramp);
REGISTER_MAKE2(Broadcast); REGISTER_MAKE2(Broadcast);
REGISTER_MAKE3(Let); REGISTER_MAKE3(Let);
// TODO(tqchen) Call;
REGISTER_MAKE3(LetStmt); REGISTER_MAKE3(LetStmt);
REGISTER_MAKE2(AssertStmt); REGISTER_MAKE2(AssertStmt);
REGISTER_MAKE3(ProducerConsumer); REGISTER_MAKE3(ProducerConsumer);
// TODO(tqchen) For;
REGISTER_MAKE3(Store); REGISTER_MAKE3(Store);
// TODO(tqchen) Provide; REGISTER_MAKE3(Provide);
// TODO(tqchen) Allocate;
REGISTER_MAKE1(Free); REGISTER_MAKE1(Free);
// TODO(tqchen) Realize; // TODO(tqchen) Realize;
REGISTER_MAKE2(Block); REGISTER_MAKE2(Block);
......
...@@ -96,8 +96,10 @@ struct APIVariantValue { ...@@ -96,8 +96,10 @@ struct APIVariantValue {
} }
inline operator Expr() const { inline operator Expr() const {
if (type_id == kNull) return Expr(); if (type_id == kNull) return Expr();
if (type_id == kLong) return Expr(operator int64_t()); if (type_id == kLong) return Expr(operator int());
if (type_id == kDouble) return Expr(operator double()); if (type_id == kDouble) {
return Expr(static_cast<float>(operator double()));
}
CHECK_EQ(type_id, kNodeHandle); CHECK_EQ(type_id, kNodeHandle);
return Expr(sptr); return Expr(sptr);
} }
......
...@@ -19,7 +19,26 @@ def test_ir(): ...@@ -19,7 +19,26 @@ def test_ir():
assert isinstance(stmt, tvm.stmt.Evaluate) assert isinstance(stmt, tvm.stmt.Evaluate)
print tvm.format_str(stmt) print tvm.format_str(stmt)
def test_basic():
a = tvm.Var('a')
b = tvm.Var('b')
c = a + b
assert tvm.format_str(c) == '(%s + %s)' % (a.name, b.name)
def test_array():
a = tvm.convert([1,2,3])
def test_stmt():
print tvm.make.Provide('a', [1,2,3], [1,2,3])
print tvm.make.For('a', 0, 1,
tvm.stmt.For.Serial, 0,
tvm.make.Evaluate(0))
if __name__ == "__main__": if __name__ == "__main__":
test_const() test_const()
test_make() test_make()
test_ir() test_ir()
test_basic()
test_stmt()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment