Commit 062bb853 by tqchen

Add in Array, fix most of IR

parent 622cee7a
Subproject commit 872099363b9f16a6cd4a4e8e46b9bd8dd1b861e9
Subproject commit 9070ac3697931ef5aeb8c373c23b2e8a2fec4627
......@@ -6,3 +6,4 @@ from ._ctypes._api import register_node
from . import expr
from . import stmt
from . import make
from . import domain
......@@ -5,7 +5,7 @@ from __future__ import absolute_import as _abs
import ctypes
import sys
from numbers import Number as Number
from numbers import Number, Integral
from .._base import _LIB
from .._base import c_str, py_str, string_types
......@@ -93,6 +93,27 @@ class NodeBase(object):
names.append(py_str(plist[i]))
return names
def const(value, dtype=None):
"""construct a constant"""
if dtype is None:
if isinstance(value, Integral):
dtype = 'int32'
else:
dtype = 'float32'
return _function_internal._const(value, dtype)
def convert(value):
"""Convert a value to expression."""
if isinstance(value, Number):
return const(value)
elif isinstance(value, list):
value = [convert(x) for x in value]
return _function_internal._Array(*value)
else:
if not isinstance(value, NodeBase):
raise ValueError("don't know how to handle type %s" % type(value))
def _push_arg(arg):
a = ArgVariant()
......@@ -147,9 +168,16 @@ def _make_function(handle, name):
doc_str = doc_str % (desc, param_str)
arg_names = [py_str(arg_names[i]) for i in range(num_args.value)]
def func(*args, **kwargs):
def func(*args):
"""TVM function"""
for arg in args:
cargs = []
for x in args:
if isinstance(x, list):
cargs.append(convert(x))
else:
cargs.append(x)
for arg in cargs:
_push_arg(arg)
ret_val = ArgVariant()
ret_typeid = ctypes.c_int()
......
from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, register_node
from . import _function_internal
@register_node
class Array(NodeBase):
def __getitem__(self, i):
if i >= len(self):
raise IndexError("array index out ot range")
return _function_internal._ArrayGetItem(self, i)
def __len__(self):
return _function_internal._ArraySize(self)
def __repr__(self):
return '[' + (','.join(str(x) for x in self)) + ']'
......@@ -52,6 +52,10 @@ class CmpExpr(Expr):
class LogicalExpr(Expr):
pass
@register_node("Variable")
class Var(Expr):
pass
@register_node
class FloatImm(ConstExpr):
pass
......
......@@ -8,6 +8,7 @@ int32 = "int32"
float32 = "float32"
def const(value, dtype=None):
"""construct a constant"""
if dtype is None:
if isinstance(value, _Integral):
dtype = 'int32'
......@@ -16,12 +17,26 @@ def const(value, dtype=None):
return _function_internal._const(value, dtype)
def _symbol(value):
def Var(name="tindex", dtype=int32):
"""Create a new variable with specified name and dtype
Parameters
----------
name : str
The name
dtype : int
The data type
"""
return _function_internal._Var(name, dtype)
def convert(value):
"""Convert a value to expression."""
if isinstance(value, _Number):
return const(value)
elif isinstance(value, list):
value = [_symbol(x) for x in value]
value = [convert(x) for x in value]
return _function_internal._Array(*value)
else:
return value
......
......@@ -21,6 +21,10 @@ class ProducerConsumer(Stmt):
@register_node
class For(Stmt):
Serial = 0
Parallel = 1
Vectorized = 2
Unrolled = 3
pass
@register_node
......
......@@ -40,9 +40,46 @@ TVM_REGISTER_API(format_str)
os << args.at(0).operator Expr();
} else if (dynamic_cast<const BaseStmtNode*>(sptr.get())) {
os << args.at(0).operator Stmt();
} else {
LOG(FATAL) << "don't know how to print input NodeBaseType";
}
*ret = os.str();
})
.add_argument("expr", "Node", "expression to be printed");
TVM_REGISTER_API(_Array)
.set_body([](const ArgStack& args, RetValue *ret) {
std::vector<std::shared_ptr<Node> > data;
for (size_t i = 0; i < args.size(); ++i) {
CHECK(args.at(i).type_id == kNodeHandle);
data.push_back(args.at(i).sptr);
}
auto node = std::make_shared<ArrayNode>();
node->data = std::move(data);
ret->type_id = kNodeHandle;
ret->sptr = node;
});
TVM_REGISTER_API(_ArrayGetItem)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle);
int64_t i = args.at(1);
auto& sptr = args.at(0).sptr;
CHECK(sptr->is_type<ArrayNode>());
auto* n = static_cast<const ArrayNode*>(sptr.get());
CHECK_LT(static_cast<size_t>(i), n->data.size())
<< "out of bound of array";
ret->sptr = n->data[i];
ret->type_id = kNodeHandle;
});
TVM_REGISTER_API(_ArraySize)
.set_body([](const ArgStack& args, RetValue *ret) {
CHECK(args.at(0).type_id == kNodeHandle);
auto& sptr = args.at(0).sptr;
CHECK(sptr->is_type<ArrayNode>());
*ret = static_cast<int64_t>(
static_cast<const ArrayNode*>(sptr.get())->data.size());
});
} // namespace tvm
......@@ -14,6 +14,30 @@ using namespace Halide::Internal;
using ArgStack = const std::vector<APIVariantValue>;
using RetValue = APIVariantValue;
TVM_REGISTER_API(_Var)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = Variable::make(args.at(1), args.at(0));
});
TVM_REGISTER_API(_make_For)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = For::make(args.at(0),
args.at(1),
args.at(2),
static_cast<ForType>(args.at(3).operator int()),
static_cast<Halide::DeviceAPI>(args.at(4).operator int()),
args.at(5));
});
TVM_REGISTER_API(_make_Allocate)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = Allocate::make(args.at(0),
args.at(1),
args.at(2),
args.at(3),
args.at(4));
});
// make from two arguments
#define REGISTER_MAKE1(Node) \
TVM_REGISTER_API(_make_## Node) \
......@@ -67,13 +91,12 @@ REGISTER_MAKE3(Select);
REGISTER_MAKE3(Ramp);
REGISTER_MAKE2(Broadcast);
REGISTER_MAKE3(Let);
// TODO(tqchen) Call;
REGISTER_MAKE3(LetStmt);
REGISTER_MAKE2(AssertStmt);
REGISTER_MAKE3(ProducerConsumer);
// TODO(tqchen) For;
REGISTER_MAKE3(Store);
// TODO(tqchen) Provide;
// TODO(tqchen) Allocate;
REGISTER_MAKE3(Provide);
REGISTER_MAKE1(Free);
// TODO(tqchen) Realize;
REGISTER_MAKE2(Block);
......
......@@ -96,8 +96,10 @@ struct APIVariantValue {
}
inline operator Expr() const {
if (type_id == kNull) return Expr();
if (type_id == kLong) return Expr(operator int64_t());
if (type_id == kDouble) return Expr(operator double());
if (type_id == kLong) return Expr(operator int());
if (type_id == kDouble) {
return Expr(static_cast<float>(operator double()));
}
CHECK_EQ(type_id, kNodeHandle);
return Expr(sptr);
}
......
......@@ -19,7 +19,26 @@ def test_ir():
assert isinstance(stmt, tvm.stmt.Evaluate)
print tvm.format_str(stmt)
def test_basic():
a = tvm.Var('a')
b = tvm.Var('b')
c = a + b
assert tvm.format_str(c) == '(%s + %s)' % (a.name, b.name)
def test_array():
a = tvm.convert([1,2,3])
def test_stmt():
print tvm.make.Provide('a', [1,2,3], [1,2,3])
print tvm.make.For('a', 0, 1,
tvm.stmt.For.Serial, 0,
tvm.make.Evaluate(0))
if __name__ == "__main__":
test_const()
test_make()
test_ir()
test_basic()
test_stmt()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment