Commit 9d5cba20 by Zhi Committed by Haichen Shen

[tvm][any] broadcast with values other than one (#3967)

* [tvm][any] broadcast with values other than 1

* Add test for incompatible runtime values

* Remove hybrid script compact buffer binding

* retrigger ci
parent 15ae9780
......@@ -206,6 +206,14 @@ Stmt StorageFlatten(Stmt stmt,
Map<Tensor, Buffer> extern_buffer,
int cache_line_size,
bool create_bound_attribute = false);
/*!
* \brief Verify if there is any argument bound to compact buffer.
*
* \param stmt The stmt to be verified.
* \return true if there is any buffer_bind_scope attribute found,
* otherwise, false.
*/
bool VerifyCompactBuffer(Stmt stmt);
/*!
* \brief Remove No Op from the Stmt.
......
......@@ -264,7 +264,7 @@ def build_config(**kwargs):
return config
def get_binds(args, binds=None):
def get_binds(args, compact=False, binds=None):
"""Internal function to get binds and arg_list given arguments.
Parameters
......@@ -272,6 +272,9 @@ def get_binds(args, binds=None):
args : list of Buffer or Tensor or Var
The argument lists to the function.
compact : bool
If the statement has already bound to a compact buffer.
binds : dict of :any:`Tensor` to :any:`Buffer`, optional
Dictionary that maps the Tensor to Buffer which specified the data layout
requirement of the function. By default, a new compact buffer is created
......@@ -290,12 +293,15 @@ def get_binds(args, binds=None):
arg_list = []
for x in args:
if isinstance(x, tensor.Tensor):
any_dim = any(isinstance(i, expr.Var) for i in x.shape)
buffer_type = "auto_broadcast" if any_dim and not compact else ""
if x not in binds:
buf = api.decl_buffer(x.shape,
dtype=x.dtype,
name=x.name,
data_alignment=cfg.data_alignment,
offset_factor=cfg.offset_factor)
offset_factor=cfg.offset_factor,
buffer_type=buffer_type)
binds[x] = buf
arg_list.append(buf)
else:
......@@ -361,7 +367,6 @@ def lower(sch,
The result function, if with_api_wrapper=False
Then the Stmt before make api is returned.
"""
binds, arg_list = get_binds(args, binds)
cfg = current_build_config()
add_lower_pass = cfg.add_lower_pass if cfg.add_lower_pass else []
if cfg.dump_pass_ir:
......@@ -377,11 +382,16 @@ def lower(sch,
for f in lower_phase0:
stmt = f(stmt)
compact = ir_pass.VerifyCompactBuffer(stmt)
binds, arg_list = get_binds(args, compact, binds)
# Phase 1
stmt = ir_pass.StorageFlatten(stmt, binds, 64, cfg.instrument_bound_checkers)
stmt = ir_pass.CanonicalSimplify(stmt)
for f in lower_phase1:
stmt = f(stmt)
# Phase 2
if not simple_mode:
stmt = ir_pass.LoopPartition(stmt, cfg.partition_const_loop)
......@@ -400,6 +410,7 @@ def lower(sch,
cfg.unroll_explicit)
for f in lower_phase2:
stmt = f(stmt)
# Phase 3
stmt = ir_pass.Simplify(stmt)
stmt = ir_pass.LowerStorageAccessInfo(stmt)
......@@ -413,6 +424,7 @@ def lower(sch,
stmt = ir_pass.InstrumentBoundCheckers(stmt)
if simple_mode:
return stmt
return ir_pass.MakeAPI(stmt, name, arg_list, 0, cfg.restricted_func)
......
......@@ -159,5 +159,6 @@ REGISTER_PASS(VerifyMemory);
REGISTER_PASS(VerifyGPUCode);
REGISTER_PASS(DecorateDeviceScope);
REGISTER_PASS(InstrumentBoundCheckers);
REGISTER_PASS(VerifyCompactBuffer);
} // namespace ir
} // namespace tvm
......@@ -331,8 +331,19 @@ Buffer BufferWithOffsetAlignment(Array<Expr> shape,
Type dtype,
std::string name,
int data_alignment,
int offset_factor) {
int offset_factor,
bool compact) {
auto data = Var(name, Handle());
bool has_any = false;
if (!compact) {
for (const auto& it : shape) {
if (it.as<Variable>()) {
has_any = true;
break;
}
}
}
BufferType buffer_type = has_any ? kAutoBroadcast : kDefault;
Expr elem_offset;
if (offset_factor != 0) {
......@@ -342,10 +353,11 @@ Buffer BufferWithOffsetAlignment(Array<Expr> shape,
}
return BufferNode::make(data, dtype, shape, Array<Expr>(), elem_offset, name, "",
data_alignment, offset_factor, kDefault);
data_alignment, offset_factor, buffer_type);
}
void GetBinds(const Array<Tensor>& args,
bool compact,
const std::unordered_map<Tensor, Buffer>& binds,
Map<Tensor, Buffer>* out_binds,
Array<NodeRef>* out_arg_list,
......@@ -355,7 +367,7 @@ void GetBinds(const Array<Tensor>& args,
for (const auto &x : args) {
if (out_binds->find(x) == out_binds->end()) {
auto buf = BufferWithOffsetAlignment(x->shape, x->dtype, x->op->name,
config->data_alignment, config->offset_factor);
config->data_alignment, config->offset_factor, compact);
out_binds->Set(x, buf);
out_arg_list->push_back(buf);
} else {
......@@ -380,9 +392,6 @@ Stmt BuildStmt(Schedule sch,
bool loop_partition,
Array<NodeRef> *out_arg_list,
const BuildConfig& config) {
Map<Tensor, Buffer> out_binds;
GetBinds(args, binds, &out_binds, out_arg_list, config);
sch = sch.normalize();
// Phase 0
......@@ -390,6 +399,10 @@ Stmt BuildStmt(Schedule sch,
auto stmt = schedule::ScheduleOps(sch, bounds, false);
stmt = ir::InjectPrefetch(stmt);
bool compact = ir::VerifyCompactBuffer(stmt);
Map<Tensor, Buffer> out_binds;
GetBinds(args, compact, binds, &out_binds, out_arg_list, config);
// Phase 1
stmt = ir::StorageFlatten(stmt, out_binds, 64,
config->instrument_bound_checkers);
......
......@@ -180,31 +180,6 @@ Stmt HybridOpNode::BuildProvide(
bool debug_keep_trivial_loop) const {
CHECK_EQ(stage->op.operator->(), this);
Stmt ret = AttrStmt::make(make_zero(Int(32)), attr::extern_scope, 0, this->body);
auto f_push_bind = [&ret](Buffer buffer, Tensor tensor) {
Array<NodeRef> bind_spec;
Array<Expr> tuple;
bind_spec.push_back(buffer);
bind_spec.push_back(tensor);
for (size_t k = 0; k < buffer->shape.size(); ++k) {
tuple.push_back(make_const(buffer->shape[k].type(), 0));
tuple.push_back(buffer->shape[k]);
}
ret = AttrStmt::make(
bind_spec, attr::buffer_bind_scope,
Call::make(Handle(), intrinsic::tvm_tuple, tuple, Call::Intrinsic), ret);
};
for (int i = static_cast<int>(outputs.size()) - 1; i >= 0; --i) {
Buffer buffer = decl_buffer(
outputs[i]->shape,
outputs[i]->dtype);
f_push_bind(buffer, stage->op.output(i));
}
auto curr_inputs = InputTensors();
for (int i = static_cast<int>(curr_inputs.size()) - 1; i >= 0; --i) {
Buffer buffer = decl_buffer(curr_inputs[i]->shape, curr_inputs[i]->dtype);
f_push_bind(buffer, curr_inputs[i]);
}
std::unordered_map<Tensor, Tensor> rmap;
for (int i = 0; i < this->num_outputs(); ++i) {
rmap[outputs[i]] = stage->op.output(i);
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file verify_compact_buffer.cc
* \brief Verify if there was any compact buffer bound to a statement.
*/
#include <tvm/buffer.h>
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/tensor.h>
#include <unordered_map>
namespace tvm {
namespace ir {
class VerifyBuffer : public IRVisitor {
public:
bool Verify(const Stmt& stmt) {
this->Visit(stmt);
return is_compact_;
}
void Visit_(const AttrStmt* op) final {
IRVisitor::Visit_(op);
if (op->attr_key == attr::buffer_bind_scope) {
is_compact_ = true;
}
}
private:
bool is_compact_{false};
};
bool VerifyCompactBuffer(Stmt stmt) {
VerifyBuffer verifier;
return verifier.Verify(stmt);
}
} // namespace ir
} // namespace tvm
......@@ -37,7 +37,7 @@ def test_reduce_prims():
s[R].compute_inline()
# one line to build the function.
def check_device(device, host="stackvm"):
def check_device(device, host="llvm"):
ctx = tvm.context(device, 0)
if not tvm.module.enabled(host):
return
......
......@@ -47,17 +47,35 @@ def verify_any_broadcast(x_shape, y_shape, x_np_shape, y_np_shape, op, np_op):
tvm.testing.assert_allclose(result.asnumpy(), np_op(x_np, y_np))
def test_any_broadcast():
# Test broadcast with 1s
verify_any_broadcast((relay.Any(),), (3, 2), (1,), (3, 2), relay.add, np.add)
verify_any_broadcast((relay.Any(), 2), (1, 2), (1, 2), (1, 2), relay.add, np.add)
verify_any_broadcast((relay.Any(), 2), (1, 2), (3, 2), (1, 2), relay.add, np.add)
verify_any_broadcast((relay.Any(), 2), (3, 2), (1, 2), (3, 2), relay.add, np.add)
verify_any_broadcast((relay.Any(), 2), (3, relay.Any()), (1, 2), (3, 1), relay.add, np.add)
# The following currently fail because topi compute treats Any as 1
# will requires auto_broadcast buffer to solve the problem
# TODO(@zhiics): Fix this
# verify_any_broadcast((relay.Any(),), (3, 2), (2,), (3, 2), relay.add, np.add)
# verify_any_broadcast((relay.Any(), 2), (3, 2), (3, 2), (3, 2), relay.add, np.add)
# Test broadcast with values other than 1
verify_any_broadcast((relay.Any(),), (3, 2), (2,), (3, 2), relay.add, np.add)
verify_any_broadcast((relay.Any(), 2), (3, 2), (3, 2), (3, 2), relay.add, np.add)
def test_any_broadcast_fail():
# Test broadcast with incompatible values at runtime
def check_fail(x_shape, y_shape, x_np_shape, y_np_shape, op, np_op):
try:
verify_any_broadcast(
x_shape, y_shape, x_np_shape, y_np_shape, op, np_op)
except tvm._ffi.base.TVMError:
pass
else:
assert False
check_fail((relay.Any(),), (3, 2), (1,), (4, 2), relay.add, np.add)
check_fail((relay.Any(), 2), (3, 2), (4, 2), (4, 2), relay.add, np.add)
check_fail((relay.Any(), 2), (3, relay.Any()), (1, 2), (4, 1), relay.add, np.add)
check_fail((relay.Any(), 2), (3, 3), (1, 3), (3, 3), relay.add, np.add)
check_fail((relay.Any(),), (3, 2), (2), (4, 2), relay.add, np.add)
def test_any_concat():
x = relay.var('x', shape=(relay.Any(), 2), dtype="float32")
......@@ -285,6 +303,7 @@ def test_recursive_concat_with_wrong_annotation():
if __name__ == "__main__":
test_any_broadcast()
test_any_broadcast_fail()
test_any_concat()
test_any_reshape()
test_any_take()
......
......@@ -42,13 +42,14 @@ def test_popcount():
check_correct_assembly('uint32', 2, 2)
check_correct_assembly('uint64', 2, 3)
def test_vmlal_s16():
target = 'llvm -target=armv7l-none-linux-gnueabihf -mcpu=cortex-a53 -mattr=+neon'
def check_correct_assembly(N):
K = tvm.var("K")
A = tvm.placeholder((K, N), dtype="int8", name='A')
B = tvm.placeholder((K, N), dtype="int8", name='A')
B = tvm.placeholder((K, N), dtype="int8", name='B')
k = tvm.reduce_axis((0, K))
C = tvm.compute((N, ), lambda n: tvm.sum(
A[k, n].astype("int32") * B[k, n].astype("int32"), axis=[k]), name='C')
......@@ -60,14 +61,15 @@ def test_vmlal_s16():
assembly = f.get_source('asm')
matches = re.findall("vmlal.s16", assembly)
assert (len(matches) == N // 4)
check_correct_assembly(4)
check_correct_assembly(8)
check_correct_assembly(16)
check_correct_assembly(32)
check_correct_assembly(64)
def check_broadcast_correct_assembly(N):
K = tvm.var("K")
A = tvm.placeholder((K, N), dtype="int8", name='A')
B = tvm.placeholder((K,), dtype="int8", name='A')
B = tvm.placeholder((K,), dtype="int8", name='B')
k = tvm.reduce_axis((0, K))
C = tvm.compute((N, ), lambda n: tvm.sum(
A[k, n].astype("int32") * B[k].astype("int32"),
......@@ -85,6 +87,7 @@ def test_vmlal_s16():
check_broadcast_correct_assembly(32)
check_broadcast_correct_assembly(64)
if __name__ == "__main__":
test_popcount()
test_vmlal_s16()
......@@ -188,8 +188,21 @@ def test_buffer_broadcast_expr():
fadd(a, b, c, 4, 1)
tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
def check_auto_bind():
if not tvm.module.enabled("llvm"):
return
# Let build bind buffers
fadd = tvm.build(s, [A, B, C, o1, x], target='llvm', name='bcast_add')
ctx = tvm.cpu(0)
a = tvm.nd.array(np.random.uniform(size=(1, 4)).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=(2, 4)).astype(B.dtype), ctx)
c = tvm.nd.array(np.zeros((2, 4), dtype=C.dtype), ctx)
fadd(a, b, c, 4, 1)
tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
check_stride()
check_no_stride()
check_auto_bind()
if __name__ == "__main__":
......
......@@ -74,9 +74,11 @@ inline BroadcastHelper BroadcastShape(const tvm::Array<tvm::Expr>& shape1,
} else if (var1) {
bh.common_shape.push_front(shape2[s2_size - i]);
bh.vars2.push_front(bh.all_vars[0]);
bh.vars1.push_front(bh.all_vars[0]);
} else if (var2) {
bh.common_shape.push_front(shape1[s1_size - i]);
bh.vars1.push_front(bh.all_vars[0]);
bh.vars2.push_front(bh.all_vars[0]);
} else {
CHECK(false) << "Incompatible broadcast dims: " << shape1[s1_size - i]
<< " and " << shape2[s2_size - i] << " in: "
......@@ -98,11 +100,13 @@ inline BroadcastHelper BroadcastShape(const tvm::Array<tvm::Expr>& shape1,
}
inline tvm::Array<tvm::Expr> InputIndexFromBroadcast(
const tvm::Array<tvm::Var>& ovars, const tvm::Tensor& T,
const std::deque<tvm::Var>& my_vars, const std::deque<tvm::Var>& all_vars) {
const tvm::Array<tvm::Var>& ovars,
const tvm::Tensor& T,
const std::deque<tvm::Var>& my_vars,
const std::deque<tvm::Var>& all_vars) {
tvm::Array<tvm::Expr> ivars;
CHECK_EQ(ovars.size(), all_vars.size());
// N^2, could use a map but NBD..
// N^2, could use a map but NBD.
size_t expected_dims = T->shape.size();
for (size_t i = 0; i < ovars.size(); ++i) {
bool found = false;
......@@ -123,13 +127,12 @@ inline tvm::Array<tvm::Expr> InputIndexFromBroadcast(
return ivars;
}
template <typename FBinaryExpr>
inline tvm::Tensor WithBroadcast(FBinaryExpr op,
const tvm::Tensor& A,
const tvm::Tensor& B,
std::string name = "tensor",
std::string tag = "") {
const std::string& name = "tensor",
const std::string& tag = "") {
auto bh = BroadcastShape(A->shape, B->shape);
auto l = [&](tvm::Array<tvm::Var> ovars) {
return op(A(InputIndexFromBroadcast(ovars, A, bh.vars1, bh.all_vars)),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment