Commit 9e01367d by Cody Hao Yu Committed by Tianqi Chen

Fix dependency problem of reducer condition (#712) (#721)

* Make duplicated function name checker working

* Fix dependency checking problem for reducer condition (#712); add test

* Fix dependency checking problem for reducer condition (#712); add test

* Specify R to be computed inlined
parent aa55b1a9
...@@ -138,6 +138,7 @@ xcuserdata/ ...@@ -138,6 +138,7 @@ xcuserdata/
*.xcscmblueprint *.xcscmblueprint
.DS_Store .DS_Store
tags tags
cscope*
# vim temporary files # vim temporary files
*.swp *.swp
......
...@@ -134,6 +134,7 @@ DEFINE_BINOP_VISIT_(Or) ...@@ -134,6 +134,7 @@ DEFINE_BINOP_VISIT_(Or)
void IRVisitor::Visit_(const Reduce* op) { void IRVisitor::Visit_(const Reduce* op) {
VisitRDom(op->axis, this); VisitRDom(op->axis, this);
VisitArray(op->source, this); VisitArray(op->source, this);
this->Visit(op->condition);
} }
void IRVisitor::Visit_(const Cast* op) { void IRVisitor::Visit_(const Cast* op) {
......
...@@ -7,8 +7,9 @@ def test_reduce_prims(): ...@@ -7,8 +7,9 @@ def test_reduce_prims():
n = tvm.var('n') n = tvm.var('n')
m = tvm.var('m') m = tvm.var('m')
A = tvm.placeholder((n, m), name='A') A = tvm.placeholder((n, m), name='A')
R = tvm.compute((n, ), lambda i: tvm.select((i > 1), 1, 0), name='R')
k = tvm.reduce_axis((0, m)) k = tvm.reduce_axis((0, m))
B = tvm.compute((n,), lambda i: reducer(A[i, k], axis=k, where=(i>1)), name='B') B = tvm.compute((n,), lambda i: reducer(A[i, k], axis=k, where=(R[i]==1)), name='B')
# schedule # schedule
s = tvm.create_schedule(B.op) s = tvm.create_schedule(B.op)
# create iter var and assign them tags. # create iter var and assign them tags.
...@@ -16,6 +17,7 @@ def test_reduce_prims(): ...@@ -16,6 +17,7 @@ def test_reduce_prims():
xo, xi = s[B].split(B.op.axis[0], factor=num_thread) xo, xi = s[B].split(B.op.axis[0], factor=num_thread)
s[B].bind(xo, tvm.thread_axis("blockIdx.x")) s[B].bind(xo, tvm.thread_axis("blockIdx.x"))
s[B].bind(xi, tvm.thread_axis("threadIdx.x")) s[B].bind(xi, tvm.thread_axis("threadIdx.x"))
s[R].compute_inline()
# one line to build the function. # one line to build the function.
def check_device(device, host="stackvm"): def check_device(device, host="stackvm"):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment