Commit 9e01367d by Cody Hao Yu Committed by Tianqi Chen

Fix dependency problem of reducer condition (#712) (#721)

* Make duplicated function name checker working

* Fix dependency checking problem for reducer condition (#712); add test

* Fix dependency checking problem for reducer condition (#712); add test

* Specify R to be computed inlined
parent aa55b1a9
......@@ -138,6 +138,7 @@ xcuserdata/
*.xcscmblueprint
.DS_Store
tags
cscope*
# vim temporary files
*.swp
......
......@@ -134,6 +134,7 @@ DEFINE_BINOP_VISIT_(Or)
void IRVisitor::Visit_(const Reduce* op) {
VisitRDom(op->axis, this);
VisitArray(op->source, this);
this->Visit(op->condition);
}
void IRVisitor::Visit_(const Cast* op) {
......
......@@ -7,8 +7,9 @@ def test_reduce_prims():
n = tvm.var('n')
m = tvm.var('m')
A = tvm.placeholder((n, m), name='A')
R = tvm.compute((n, ), lambda i: tvm.select((i > 1), 1, 0), name='R')
k = tvm.reduce_axis((0, m))
B = tvm.compute((n,), lambda i: reducer(A[i, k], axis=k, where=(i>1)), name='B')
B = tvm.compute((n,), lambda i: reducer(A[i, k], axis=k, where=(R[i]==1)), name='B')
# schedule
s = tvm.create_schedule(B.op)
# create iter var and assign them tags.
......@@ -16,6 +17,7 @@ def test_reduce_prims():
xo, xi = s[B].split(B.op.axis[0], factor=num_thread)
s[B].bind(xo, tvm.thread_axis("blockIdx.x"))
s[B].bind(xi, tvm.thread_axis("threadIdx.x"))
s[R].compute_inline()
# one line to build the function.
def check_device(device, host="stackvm"):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment