Commit adf4bfef by tqchen

Enable attribute key in LetStmt

parent 38f03f1f
Subproject commit 24a7c0357a6a8db5db782d320aad7f706ebe8507
Subproject commit 7f1d811972bccc26f651ea2289d88bcadea9fe9f
export CXX=g++
export LDFLAGS = -pthread -lm
export CFLAGS = -std=c++11 -Wall -O2 -Wno-unknown-pragmas -funroll-loops\
-Iinclude -Idmlc-core/include -IHalideIR/src -fPIC
......
from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, register_node
from . import function as _func
from . import make as _make
class Stmt(NodeBase):
def __repr__(self):
return _func.format_str(self)
pass
@register_node
class LetStmt(Stmt):
......
......@@ -56,6 +56,23 @@ TVM_REGISTER_API(_make_Allocate)
args.at(4));
});
TVM_REGISTER_API(_make_LetStmt)
.set_body([](const ArgStack& args, RetValue *ret) {
if (args.size() == 3) {
*ret = LetStmt::make(args.at(0),
args.at(1),
args.at(2));
} else {
CHECK_EQ(args.size(), 5);
*ret = LetStmt::make(args.at(0),
args.at(1),
args.at(2),
args.at(3),
args.at(4));
}
});
// make from two arguments
#define REGISTER_MAKE1(Node) \
TVM_REGISTER_API(_make_## Node) \
......@@ -109,7 +126,6 @@ REGISTER_MAKE3(Select);
REGISTER_MAKE3(Ramp);
REGISTER_MAKE2(Broadcast);
REGISTER_MAKE3(Let);
REGISTER_MAKE3(LetStmt);
REGISTER_MAKE2(AssertStmt);
REGISTER_MAKE3(ProducerConsumer);
REGISTER_MAKE3(Store);
......
......@@ -97,6 +97,9 @@ class APIVariantValue {
inline operator T() const {
if (type_id == kNull) return T();
CHECK_EQ(type_id, kNodeHandle);
// use dynamic RTTI for safety
CHECK(dynamic_cast<typename T::ContainerType*>(sptr.get()))
<< "wrong type specified";
return T(sptr);
}
inline operator Expr() const {
......
......@@ -18,6 +18,15 @@ def test_ir():
stmt = tvm.make.Evaluate(z)
assert isinstance(stmt, tvm.stmt.Evaluate)
def test_let():
x = tvm.Var('x')
y = tvm.Var('y')
stmt = tvm.make.LetStmt(
x, 10, tvm.make.Evaluate(x + 1), y, "stride")
assert stmt.attr_of_node == y
print(stmt)
def test_basic():
a = tvm.Var('a')
b = tvm.Var('b')
......@@ -28,10 +37,10 @@ def test_array():
a = tvm.convert([1,2,3])
def test_stmt():
x = tvm.make.Evaluate(0)
tvm.make.For(tvm.Var('i'), 0, 1,
tvm.stmt.For.Serial, 0,
tvm.make.Evaluate(0))
x)
if __name__ == "__main__":
......@@ -40,3 +49,4 @@ if __name__ == "__main__":
test_ir()
test_basic()
test_stmt()
test_let()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment