Commit 7e6cba48 by Jessica Davies Committed by Tianqi Chen

Tighten buffer bound for TensorComputeOp by improving EvalSet on ranges (#2565)

parent ec3a4251
......@@ -573,12 +573,15 @@ IntSet EvalSet(Expr e,
IntSet EvalSet(Range r,
const std::unordered_map<const Variable*, IntSet>& dom_map) {
IntSetEvaluator m(dom_map);
IntSet min_set = m.Eval(r->min);
IntSet ext_set = m.Eval(r->extent).cover_interval();
const Interval& ei = ext_set.as<IntervalSet>()->i;
if (!ei.has_upper_bound()) return IntSet::everything();
ext_set = IntervalSet::make(make_zero(ei.max.type()), ComputeExpr<Sub>(ei.max, 1));
return Combine<Add>(min_set, ext_set);
IntSet min_set = m.Eval(r->min).cover_interval();
// Simplifying first can give tighter bounds if r->min and r->extent share variables
Expr sum = ComputeExpr<Sub>(ComputeExpr<Add>(r->min, r->extent), 1);
IntSet max_set = m.Eval(Simplify(sum)).cover_interval();
const Interval& ni = min_set.as<IntervalSet>()->i;
const Interval& xi = max_set.as<IntervalSet>()->i;
if (!ni.has_lower_bound()) return IntSet::everything();
if (!xi.has_upper_bound()) return IntSet::everything();
return IntervalSet::make(ni.min, xi.max);
}
IntSet EvalSet(IntSet s,
......
......@@ -260,6 +260,36 @@ def test_gemm_bound():
assert(bounds[CC.op.axis[1]].extent.value == 8)
def test_bound_tensor_compute_op():
def intrin_test():
m1 = tvm.var("m1")
n1 = tvm.var("n1")
a = tvm.placeholder((m1, n1), name='a')
c = tvm.compute((1, n1), lambda i, j : a[0, j] + a[1, j] + a[2, j], name='c')
Ab = tvm.decl_buffer(a.shape, name="Abuf", offset_factor=1)
Cb = tvm.decl_buffer(c.shape, name="Cbuf", offset_factor=1)
def intrin_func(ins, outs):
aa = ins[0]
cc = outs[0]
def _body():
ib = tvm.ir_builder.create()
ib.emit(tvm.call_extern("int32", "test", cc.access_ptr("w"), aa.access_ptr("r")))
return ib.get()
return _body()
with tvm.build_config(offset_factor=1):
return tvm.decl_tensor_intrin(c.op, intrin_func, binds={a : Ab, c : Cb})
test_func = intrin_test()
A = tvm.placeholder((20,20), name='A')
B = tvm.compute(A.shape, lambda i,j : A[i,j], name='B')
C = tvm.compute((10, 20), lambda i : test_func(B[i:10, 0:20]), name='C')
s = tvm.create_schedule(C.op)
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
assert(bounds[B.op.axis[0]].extent.value == 10)
if __name__ == "__main__":
test_bound_nest_thread()
test_bound1()
......@@ -273,3 +303,4 @@ if __name__ == "__main__":
test_bound2()
test_gemm_bound()
test_bound_warp()
test_bound_tensor_compute_op()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment