Commit 8577c81b by Zhi Committed by Haichen Shen

[tvm][codegen] Make buffer auto broadcast independent to the order of input args (#3956)

* [tvm][codegen] Make buffer auto broadcast independent to the order of the input arg

* fix indent
parent ab1853c2
...@@ -582,7 +582,7 @@ def decl_buffer(shape, ...@@ -582,7 +582,7 @@ def decl_buffer(shape,
buffer_type: str, optional, {"", "auto_broadcast"} buffer_type: str, optional, {"", "auto_broadcast"}
auto_broadcast buffer allows one to implement broadcast computation auto_broadcast buffer allows one to implement broadcast computation
without considering whether dimension size equals to one. without considering whether dimension size equals to one.
TVM maps buffer[i][j][k] -> buffer[i][0][k] if dimension i's shape equals 1. TVM maps buffer[i][j][k] -> buffer[i][0][k] if dimension j's shape equals 1.
Returns Returns
------- -------
...@@ -601,8 +601,8 @@ def decl_buffer(shape, ...@@ -601,8 +601,8 @@ def decl_buffer(shape,
A = tvm.placeholder((m0, m1, m2), name='A') A = tvm.placeholder((m0, m1, m2), name='A')
B = tvm.placeholder((n0, n1, n2), name='B') B = tvm.placeholder((n0, n1, n2), name='B')
C = tvm.compute((o0, o1, o2), lambda i, j, k: A[i, j, k] + B[i, j, k], name='C') C = tvm.compute((o0, o1, o2), lambda i, j, k: A[i, j, k] + B[i, j, k], name='C')
Ab = tvm.decl_buffer(A.shape, A.dtype, name="Ab", buffer_type="broadcast") Ab = tvm.decl_buffer(A.shape, A.dtype, name="Ab", buffer_type="auto_broadcast")
Bb = tvm.decl_buffer(B.shape, B.dtype, name="Bb", buffer_type="broadcast") Bb = tvm.decl_buffer(B.shape, B.dtype, name="Bb", buffer_type="auto_broadcast")
s = tvm.create_schedule(C.op) s = tvm.create_schedule(C.op)
fadd = tvm.build(s, [A, B, C], target='llvm', name='bcast_add', binds={A:Ab, B:Bb}) fadd = tvm.build(s, [A, B, C], target='llvm', name='bcast_add', binds={A:Ab, B:Bb})
ctx = tvm.cpu(0) ctx = tvm.cpu(0)
......
...@@ -102,6 +102,10 @@ LoweredFunc MakeAPI(Stmt body, ...@@ -102,6 +102,10 @@ LoweredFunc MakeAPI(Stmt body,
seq_init.emplace_back( seq_init.emplace_back(
MakeAssertEQ(v_num_packed_args, num_packed_args, os.str())); MakeAssertEQ(v_num_packed_args, num_packed_args, os.str()));
} }
// Save the input variables and buffers that will be bound later.
std::vector<std::pair<Var, Var> > var_defs;
std::vector<std::pair<Buffer, Var> > buf_defs;
for (int i = 0; i < static_cast<int>(api_args.size()); ++i) { for (int i = 0; i < static_cast<int>(api_args.size()); ++i) {
Var v_arg = f_arg_decl(i); Var v_arg = f_arg_decl(i);
if (i < num_packed_args) { if (i < num_packed_args) {
...@@ -139,15 +143,29 @@ LoweredFunc MakeAPI(Stmt body, ...@@ -139,15 +143,29 @@ LoweredFunc MakeAPI(Stmt body,
} }
// add checks for functions. // add checks for functions.
if (api_args[i].as<Variable>()) { if (api_args[i].as<Variable>()) {
binder.Bind(Var(api_args[i].node_), v_arg, v_arg->name_hint, true); var_defs.emplace_back(std::make_pair(Var(api_args[i].node_), v_arg));
} else { } else {
// Buffer checks // Buffer checks
CHECK(api_args[i].as<BufferNode>()) CHECK(api_args[i].as<BufferNode>())
<< "api_args can only be Buffer or Var"; << "api_args can only be Buffer or Var";
Buffer buf(api_args[i].node_); buf_defs.emplace_back(std::make_pair(Buffer(api_args[i].node_), v_arg));
binder.BindDLTensor( }
buf, device_type, device_id, v_arg, v_arg->name_hint);
} }
// Arg definitions are defined before buffer binding to avoid the use before
// def errors.
//
// For example, for auto broadcasting, checks are required to guarantee that
// either 0 or the original stride will be correctly used. Checks here have
// to use the args that may have no let bining yet. Therefore, hoisting let
// binding for args before buffer declaration is needed.
for (const auto& arg : var_defs) {
binder.Bind(arg.first, arg.second, arg.second->name_hint, true);
}
for (const auto& buf_arg : buf_defs) {
binder.BindDLTensor(buf_arg.first, device_type, device_id,
buf_arg.second, buf_arg.second->name_hint);
} }
NodePtr<LoweredFuncNode> n = make_node<LoweredFuncNode>(); NodePtr<LoweredFuncNode> n = make_node<LoweredFuncNode>();
......
...@@ -29,6 +29,7 @@ def test_buffer(): ...@@ -29,6 +29,7 @@ def test_buffer():
assert Ab.dtype == tvm.float32 assert Ab.dtype == tvm.float32
assert tuple(Ab.shape) == (m, n) assert tuple(Ab.shape) == (m, n)
def test_buffer_access_ptr(): def test_buffer_access_ptr():
m = tvm.var('m') m = tvm.var('m')
n = tvm.var('n') n = tvm.var('n')
...@@ -40,6 +41,7 @@ def test_buffer_access_ptr(): ...@@ -40,6 +41,7 @@ def test_buffer_access_ptr():
aptr = Ab.access_ptr("w") aptr = Ab.access_ptr("w")
assert aptr.args[4].value == Buffer.WRITE assert aptr.args[4].value == Buffer.WRITE
def test_buffer_access_ptr_offset(): def test_buffer_access_ptr_offset():
m = tvm.var('m') m = tvm.var('m')
n = tvm.var('n') n = tvm.var('n')
...@@ -58,6 +60,7 @@ def test_buffer_access_ptr_offset(): ...@@ -58,6 +60,7 @@ def test_buffer_access_ptr_offset():
assert tvm.ir_pass.Equal(offset, tvm.call_extern('int32', "test_call", 200 + v)) assert tvm.ir_pass.Equal(offset, tvm.call_extern('int32', "test_call", 200 + v))
assert aptr.args[4].value == Buffer.READ | Buffer.WRITE assert aptr.args[4].value == Buffer.READ | Buffer.WRITE
def test_buffer_access_ptr_extent(): def test_buffer_access_ptr_extent():
m = tvm.var('m') m = tvm.var('m')
n = tvm.var('n') n = tvm.var('n')
...@@ -70,6 +73,7 @@ def test_buffer_access_ptr_extent(): ...@@ -70,6 +73,7 @@ def test_buffer_access_ptr_extent():
aptr = Ab.access_ptr("rw", offset=100) aptr = Ab.access_ptr("rw", offset=100)
assert tvm.ir_pass.Equal(aptr.args[3], Ab.strides[0] * m - 100) assert tvm.ir_pass.Equal(aptr.args[3], Ab.strides[0] * m - 100)
def test_buffer_vload(): def test_buffer_vload():
m = tvm.var('m') m = tvm.var('m')
n = tvm.var('n') n = tvm.var('n')
...@@ -78,6 +82,7 @@ def test_buffer_vload(): ...@@ -78,6 +82,7 @@ def test_buffer_vload():
offset = tvm.ir_pass.Simplify(load.index) offset = tvm.ir_pass.Simplify(load.index)
assert tvm.ir_pass.Equal(offset, n * 2 + 103) assert tvm.ir_pass.Equal(offset, n * 2 + 103)
def test_buffer_index_merge_mult_mod(): def test_buffer_index_merge_mult_mod():
m = tvm.var('m') m = tvm.var('m')
n = tvm.var('n') n = tvm.var('n')
...@@ -109,6 +114,7 @@ def test_buffer_index_merge_mult_mod(): ...@@ -109,6 +114,7 @@ def test_buffer_index_merge_mult_mod():
index_direct = A.vload((0, ((k0 % (k1 / s)) / n) * n + ((k0 % (k1 / n)) % n + (k0 % k1)))) index_direct = A.vload((0, ((k0 % (k1 / s)) / n) * n + ((k0 % (k1 / n)) % n + (k0 % k1))))
assert_simplified_equal(index_simplified, index_direct) assert_simplified_equal(index_simplified, index_direct)
def test_buffer_broadcast(): def test_buffer_broadcast():
m0, m1, m2 = tvm.var("m0"), tvm.var("m1"), tvm.var("m2") m0, m1, m2 = tvm.var("m0"), tvm.var("m1"), tvm.var("m2")
n0, n1, n2 = tvm.var("n0"), tvm.var("n1"), tvm.var("n2") n0, n1, n2 = tvm.var("n0"), tvm.var("n1"), tvm.var("n2")
...@@ -137,6 +143,48 @@ def test_buffer_broadcast(): ...@@ -137,6 +143,48 @@ def test_buffer_broadcast():
check() check()
def test_bbuffer_roadcast_expr():
n0, m0, x = tvm.var('n0'), tvm.var('m0'), tvm.var('x')
n1, m1 = tvm.var('n1'), tvm.var('m1')
o0, o1 = tvm.var('o0'), tvm.var('o1')
A = tvm.placeholder((m0, n0), name='A')
B = tvm.placeholder((m1, n1), name='B')
C = tvm.compute((o0, o1/x), lambda i, j: A[i, j] + B[i, j], name='C')
Ab = tvm.decl_buffer(A.shape, A.dtype, name="Ab", buffer_type="auto_broadcast")
Bb = tvm.decl_buffer(B.shape, B.dtype, name="Bb", buffer_type="auto_broadcast")
Cc = tvm.decl_buffer(C.shape, C.dtype, name="Cc", buffer_type="auto_broadcast")
s = tvm.create_schedule(C.op)
def check_stride():
if not tvm.module.enabled("llvm"):
return
fadd = tvm.build(s, [A, B, C, o1, x], target='llvm', name='bcast_add',
binds={A:Ab, B:Bb, C:Cc})
ctx = tvm.cpu(0)
a = tvm.nd.array(np.random.uniform(size=(2, 4)).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=(2, 4)).astype(B.dtype), ctx)
c = tvm.nd.array(np.zeros((2, 4), dtype=C.dtype), ctx)
fadd(a, b, c, 4, 1)
tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
def check_no_stride():
if not tvm.module.enabled("llvm"):
return
fadd = tvm.build(s, [A, B, C, o1, x], target='llvm', name='bcast_add',
binds={A: Ab, B: Bb, C: Cc})
ctx = tvm.cpu(0)
a = tvm.nd.array(np.random.uniform(size=(1, 4)).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=(2, 4)).astype(B.dtype), ctx)
c = tvm.nd.array(np.zeros((2, 4), dtype=C.dtype), ctx)
fadd(a, b, c, 4, 1)
tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
check_stride()
check_no_stride()
if __name__ == "__main__": if __name__ == "__main__":
test_buffer() test_buffer()
test_buffer_access_ptr() test_buffer_access_ptr()
...@@ -145,3 +193,4 @@ if __name__ == "__main__": ...@@ -145,3 +193,4 @@ if __name__ == "__main__":
test_buffer_vload() test_buffer_vload()
test_buffer_index_merge_mult_mod() test_buffer_index_merge_mult_mod()
test_buffer_broadcast() test_buffer_broadcast()
test_buffer_broadcast_expr()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment