Commit 9d5cba20 by Zhi Committed by Haichen Shen

[tvm][any] broadcast with values other than one (#3967)

* [tvm][any] broadcast with values other than 1

* Add test for incompatible runtime values

* Remove hybrid script compact buffer binding

* retrigger ci
parent 15ae9780
...@@ -206,6 +206,14 @@ Stmt StorageFlatten(Stmt stmt, ...@@ -206,6 +206,14 @@ Stmt StorageFlatten(Stmt stmt,
Map<Tensor, Buffer> extern_buffer, Map<Tensor, Buffer> extern_buffer,
int cache_line_size, int cache_line_size,
bool create_bound_attribute = false); bool create_bound_attribute = false);
/*!
* \brief Verify if there is any argument bound to compact buffer.
*
* \param stmt The stmt to be verified.
* \return true if there is any buffer_bind_scope attribute found,
* otherwise, false.
*/
bool VerifyCompactBuffer(Stmt stmt);
/*! /*!
* \brief Remove No Op from the Stmt. * \brief Remove No Op from the Stmt.
......
...@@ -264,7 +264,7 @@ def build_config(**kwargs): ...@@ -264,7 +264,7 @@ def build_config(**kwargs):
return config return config
def get_binds(args, binds=None): def get_binds(args, compact=False, binds=None):
"""Internal function to get binds and arg_list given arguments. """Internal function to get binds and arg_list given arguments.
Parameters Parameters
...@@ -272,6 +272,9 @@ def get_binds(args, binds=None): ...@@ -272,6 +272,9 @@ def get_binds(args, binds=None):
args : list of Buffer or Tensor or Var args : list of Buffer or Tensor or Var
The argument lists to the function. The argument lists to the function.
compact : bool
If the statement has already bound to a compact buffer.
binds : dict of :any:`Tensor` to :any:`Buffer`, optional binds : dict of :any:`Tensor` to :any:`Buffer`, optional
Dictionary that maps the Tensor to Buffer which specified the data layout Dictionary that maps the Tensor to Buffer which specified the data layout
requirement of the function. By default, a new compact buffer is created requirement of the function. By default, a new compact buffer is created
...@@ -290,12 +293,15 @@ def get_binds(args, binds=None): ...@@ -290,12 +293,15 @@ def get_binds(args, binds=None):
arg_list = [] arg_list = []
for x in args: for x in args:
if isinstance(x, tensor.Tensor): if isinstance(x, tensor.Tensor):
any_dim = any(isinstance(i, expr.Var) for i in x.shape)
buffer_type = "auto_broadcast" if any_dim and not compact else ""
if x not in binds: if x not in binds:
buf = api.decl_buffer(x.shape, buf = api.decl_buffer(x.shape,
dtype=x.dtype, dtype=x.dtype,
name=x.name, name=x.name,
data_alignment=cfg.data_alignment, data_alignment=cfg.data_alignment,
offset_factor=cfg.offset_factor) offset_factor=cfg.offset_factor,
buffer_type=buffer_type)
binds[x] = buf binds[x] = buf
arg_list.append(buf) arg_list.append(buf)
else: else:
...@@ -361,7 +367,6 @@ def lower(sch, ...@@ -361,7 +367,6 @@ def lower(sch,
The result function, if with_api_wrapper=False The result function, if with_api_wrapper=False
Then the Stmt before make api is returned. Then the Stmt before make api is returned.
""" """
binds, arg_list = get_binds(args, binds)
cfg = current_build_config() cfg = current_build_config()
add_lower_pass = cfg.add_lower_pass if cfg.add_lower_pass else [] add_lower_pass = cfg.add_lower_pass if cfg.add_lower_pass else []
if cfg.dump_pass_ir: if cfg.dump_pass_ir:
...@@ -377,11 +382,16 @@ def lower(sch, ...@@ -377,11 +382,16 @@ def lower(sch,
for f in lower_phase0: for f in lower_phase0:
stmt = f(stmt) stmt = f(stmt)
compact = ir_pass.VerifyCompactBuffer(stmt)
binds, arg_list = get_binds(args, compact, binds)
# Phase 1 # Phase 1
stmt = ir_pass.StorageFlatten(stmt, binds, 64, cfg.instrument_bound_checkers) stmt = ir_pass.StorageFlatten(stmt, binds, 64, cfg.instrument_bound_checkers)
stmt = ir_pass.CanonicalSimplify(stmt) stmt = ir_pass.CanonicalSimplify(stmt)
for f in lower_phase1: for f in lower_phase1:
stmt = f(stmt) stmt = f(stmt)
# Phase 2 # Phase 2
if not simple_mode: if not simple_mode:
stmt = ir_pass.LoopPartition(stmt, cfg.partition_const_loop) stmt = ir_pass.LoopPartition(stmt, cfg.partition_const_loop)
...@@ -400,6 +410,7 @@ def lower(sch, ...@@ -400,6 +410,7 @@ def lower(sch,
cfg.unroll_explicit) cfg.unroll_explicit)
for f in lower_phase2: for f in lower_phase2:
stmt = f(stmt) stmt = f(stmt)
# Phase 3 # Phase 3
stmt = ir_pass.Simplify(stmt) stmt = ir_pass.Simplify(stmt)
stmt = ir_pass.LowerStorageAccessInfo(stmt) stmt = ir_pass.LowerStorageAccessInfo(stmt)
...@@ -413,6 +424,7 @@ def lower(sch, ...@@ -413,6 +424,7 @@ def lower(sch,
stmt = ir_pass.InstrumentBoundCheckers(stmt) stmt = ir_pass.InstrumentBoundCheckers(stmt)
if simple_mode: if simple_mode:
return stmt return stmt
return ir_pass.MakeAPI(stmt, name, arg_list, 0, cfg.restricted_func) return ir_pass.MakeAPI(stmt, name, arg_list, 0, cfg.restricted_func)
......
...@@ -159,5 +159,6 @@ REGISTER_PASS(VerifyMemory); ...@@ -159,5 +159,6 @@ REGISTER_PASS(VerifyMemory);
REGISTER_PASS(VerifyGPUCode); REGISTER_PASS(VerifyGPUCode);
REGISTER_PASS(DecorateDeviceScope); REGISTER_PASS(DecorateDeviceScope);
REGISTER_PASS(InstrumentBoundCheckers); REGISTER_PASS(InstrumentBoundCheckers);
REGISTER_PASS(VerifyCompactBuffer);
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
...@@ -331,8 +331,19 @@ Buffer BufferWithOffsetAlignment(Array<Expr> shape, ...@@ -331,8 +331,19 @@ Buffer BufferWithOffsetAlignment(Array<Expr> shape,
Type dtype, Type dtype,
std::string name, std::string name,
int data_alignment, int data_alignment,
int offset_factor) { int offset_factor,
bool compact) {
auto data = Var(name, Handle()); auto data = Var(name, Handle());
bool has_any = false;
if (!compact) {
for (const auto& it : shape) {
if (it.as<Variable>()) {
has_any = true;
break;
}
}
}
BufferType buffer_type = has_any ? kAutoBroadcast : kDefault;
Expr elem_offset; Expr elem_offset;
if (offset_factor != 0) { if (offset_factor != 0) {
...@@ -342,10 +353,11 @@ Buffer BufferWithOffsetAlignment(Array<Expr> shape, ...@@ -342,10 +353,11 @@ Buffer BufferWithOffsetAlignment(Array<Expr> shape,
} }
return BufferNode::make(data, dtype, shape, Array<Expr>(), elem_offset, name, "", return BufferNode::make(data, dtype, shape, Array<Expr>(), elem_offset, name, "",
data_alignment, offset_factor, kDefault); data_alignment, offset_factor, buffer_type);
} }
void GetBinds(const Array<Tensor>& args, void GetBinds(const Array<Tensor>& args,
bool compact,
const std::unordered_map<Tensor, Buffer>& binds, const std::unordered_map<Tensor, Buffer>& binds,
Map<Tensor, Buffer>* out_binds, Map<Tensor, Buffer>* out_binds,
Array<NodeRef>* out_arg_list, Array<NodeRef>* out_arg_list,
...@@ -355,7 +367,7 @@ void GetBinds(const Array<Tensor>& args, ...@@ -355,7 +367,7 @@ void GetBinds(const Array<Tensor>& args,
for (const auto &x : args) { for (const auto &x : args) {
if (out_binds->find(x) == out_binds->end()) { if (out_binds->find(x) == out_binds->end()) {
auto buf = BufferWithOffsetAlignment(x->shape, x->dtype, x->op->name, auto buf = BufferWithOffsetAlignment(x->shape, x->dtype, x->op->name,
config->data_alignment, config->offset_factor); config->data_alignment, config->offset_factor, compact);
out_binds->Set(x, buf); out_binds->Set(x, buf);
out_arg_list->push_back(buf); out_arg_list->push_back(buf);
} else { } else {
...@@ -380,9 +392,6 @@ Stmt BuildStmt(Schedule sch, ...@@ -380,9 +392,6 @@ Stmt BuildStmt(Schedule sch,
bool loop_partition, bool loop_partition,
Array<NodeRef> *out_arg_list, Array<NodeRef> *out_arg_list,
const BuildConfig& config) { const BuildConfig& config) {
Map<Tensor, Buffer> out_binds;
GetBinds(args, binds, &out_binds, out_arg_list, config);
sch = sch.normalize(); sch = sch.normalize();
// Phase 0 // Phase 0
...@@ -390,6 +399,10 @@ Stmt BuildStmt(Schedule sch, ...@@ -390,6 +399,10 @@ Stmt BuildStmt(Schedule sch,
auto stmt = schedule::ScheduleOps(sch, bounds, false); auto stmt = schedule::ScheduleOps(sch, bounds, false);
stmt = ir::InjectPrefetch(stmt); stmt = ir::InjectPrefetch(stmt);
bool compact = ir::VerifyCompactBuffer(stmt);
Map<Tensor, Buffer> out_binds;
GetBinds(args, compact, binds, &out_binds, out_arg_list, config);
// Phase 1 // Phase 1
stmt = ir::StorageFlatten(stmt, out_binds, 64, stmt = ir::StorageFlatten(stmt, out_binds, 64,
config->instrument_bound_checkers); config->instrument_bound_checkers);
......
...@@ -180,31 +180,6 @@ Stmt HybridOpNode::BuildProvide( ...@@ -180,31 +180,6 @@ Stmt HybridOpNode::BuildProvide(
bool debug_keep_trivial_loop) const { bool debug_keep_trivial_loop) const {
CHECK_EQ(stage->op.operator->(), this); CHECK_EQ(stage->op.operator->(), this);
Stmt ret = AttrStmt::make(make_zero(Int(32)), attr::extern_scope, 0, this->body); Stmt ret = AttrStmt::make(make_zero(Int(32)), attr::extern_scope, 0, this->body);
auto f_push_bind = [&ret](Buffer buffer, Tensor tensor) {
Array<NodeRef> bind_spec;
Array<Expr> tuple;
bind_spec.push_back(buffer);
bind_spec.push_back(tensor);
for (size_t k = 0; k < buffer->shape.size(); ++k) {
tuple.push_back(make_const(buffer->shape[k].type(), 0));
tuple.push_back(buffer->shape[k]);
}
ret = AttrStmt::make(
bind_spec, attr::buffer_bind_scope,
Call::make(Handle(), intrinsic::tvm_tuple, tuple, Call::Intrinsic), ret);
};
for (int i = static_cast<int>(outputs.size()) - 1; i >= 0; --i) {
Buffer buffer = decl_buffer(
outputs[i]->shape,
outputs[i]->dtype);
f_push_bind(buffer, stage->op.output(i));
}
auto curr_inputs = InputTensors();
for (int i = static_cast<int>(curr_inputs.size()) - 1; i >= 0; --i) {
Buffer buffer = decl_buffer(curr_inputs[i]->shape, curr_inputs[i]->dtype);
f_push_bind(buffer, curr_inputs[i]);
}
std::unordered_map<Tensor, Tensor> rmap; std::unordered_map<Tensor, Tensor> rmap;
for (int i = 0; i < this->num_outputs(); ++i) { for (int i = 0; i < this->num_outputs(); ++i) {
rmap[outputs[i]] = stage->op.output(i); rmap[outputs[i]] = stage->op.output(i);
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file verify_compact_buffer.cc
* \brief Verify if there was any compact buffer bound to a statement.
*/
#include <tvm/buffer.h>
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/tensor.h>
#include <unordered_map>
namespace tvm {
namespace ir {
class VerifyBuffer : public IRVisitor {
public:
bool Verify(const Stmt& stmt) {
this->Visit(stmt);
return is_compact_;
}
void Visit_(const AttrStmt* op) final {
IRVisitor::Visit_(op);
if (op->attr_key == attr::buffer_bind_scope) {
is_compact_ = true;
}
}
private:
bool is_compact_{false};
};
bool VerifyCompactBuffer(Stmt stmt) {
VerifyBuffer verifier;
return verifier.Verify(stmt);
}
} // namespace ir
} // namespace tvm
...@@ -37,7 +37,7 @@ def test_reduce_prims(): ...@@ -37,7 +37,7 @@ def test_reduce_prims():
s[R].compute_inline() s[R].compute_inline()
# one line to build the function. # one line to build the function.
def check_device(device, host="stackvm"): def check_device(device, host="llvm"):
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
if not tvm.module.enabled(host): if not tvm.module.enabled(host):
return return
......
...@@ -47,17 +47,35 @@ def verify_any_broadcast(x_shape, y_shape, x_np_shape, y_np_shape, op, np_op): ...@@ -47,17 +47,35 @@ def verify_any_broadcast(x_shape, y_shape, x_np_shape, y_np_shape, op, np_op):
tvm.testing.assert_allclose(result.asnumpy(), np_op(x_np, y_np)) tvm.testing.assert_allclose(result.asnumpy(), np_op(x_np, y_np))
def test_any_broadcast(): def test_any_broadcast():
# Test broadcast with 1s
verify_any_broadcast((relay.Any(),), (3, 2), (1,), (3, 2), relay.add, np.add) verify_any_broadcast((relay.Any(),), (3, 2), (1,), (3, 2), relay.add, np.add)
verify_any_broadcast((relay.Any(), 2), (1, 2), (1, 2), (1, 2), relay.add, np.add) verify_any_broadcast((relay.Any(), 2), (1, 2), (1, 2), (1, 2), relay.add, np.add)
verify_any_broadcast((relay.Any(), 2), (1, 2), (3, 2), (1, 2), relay.add, np.add) verify_any_broadcast((relay.Any(), 2), (1, 2), (3, 2), (1, 2), relay.add, np.add)
verify_any_broadcast((relay.Any(), 2), (3, 2), (1, 2), (3, 2), relay.add, np.add) verify_any_broadcast((relay.Any(), 2), (3, 2), (1, 2), (3, 2), relay.add, np.add)
verify_any_broadcast((relay.Any(), 2), (3, relay.Any()), (1, 2), (3, 1), relay.add, np.add) verify_any_broadcast((relay.Any(), 2), (3, relay.Any()), (1, 2), (3, 1), relay.add, np.add)
# The following currently fail because topi compute treats Any as 1 # Test broadcast with values other than 1
# will requires auto_broadcast buffer to solve the problem verify_any_broadcast((relay.Any(),), (3, 2), (2,), (3, 2), relay.add, np.add)
# TODO(@zhiics): Fix this verify_any_broadcast((relay.Any(), 2), (3, 2), (3, 2), (3, 2), relay.add, np.add)
# verify_any_broadcast((relay.Any(),), (3, 2), (2,), (3, 2), relay.add, np.add)
# verify_any_broadcast((relay.Any(), 2), (3, 2), (3, 2), (3, 2), relay.add, np.add)
def test_any_broadcast_fail():
# Test broadcast with incompatible values at runtime
def check_fail(x_shape, y_shape, x_np_shape, y_np_shape, op, np_op):
try:
verify_any_broadcast(
x_shape, y_shape, x_np_shape, y_np_shape, op, np_op)
except tvm._ffi.base.TVMError:
pass
else:
assert False
check_fail((relay.Any(),), (3, 2), (1,), (4, 2), relay.add, np.add)
check_fail((relay.Any(), 2), (3, 2), (4, 2), (4, 2), relay.add, np.add)
check_fail((relay.Any(), 2), (3, relay.Any()), (1, 2), (4, 1), relay.add, np.add)
check_fail((relay.Any(), 2), (3, 3), (1, 3), (3, 3), relay.add, np.add)
check_fail((relay.Any(),), (3, 2), (2), (4, 2), relay.add, np.add)
def test_any_concat(): def test_any_concat():
x = relay.var('x', shape=(relay.Any(), 2), dtype="float32") x = relay.var('x', shape=(relay.Any(), 2), dtype="float32")
...@@ -285,6 +303,7 @@ def test_recursive_concat_with_wrong_annotation(): ...@@ -285,6 +303,7 @@ def test_recursive_concat_with_wrong_annotation():
if __name__ == "__main__": if __name__ == "__main__":
test_any_broadcast() test_any_broadcast()
test_any_broadcast_fail()
test_any_concat() test_any_concat()
test_any_reshape() test_any_reshape()
test_any_take() test_any_take()
......
...@@ -42,13 +42,14 @@ def test_popcount(): ...@@ -42,13 +42,14 @@ def test_popcount():
check_correct_assembly('uint32', 2, 2) check_correct_assembly('uint32', 2, 2)
check_correct_assembly('uint64', 2, 3) check_correct_assembly('uint64', 2, 3)
def test_vmlal_s16(): def test_vmlal_s16():
target = 'llvm -target=armv7l-none-linux-gnueabihf -mcpu=cortex-a53 -mattr=+neon' target = 'llvm -target=armv7l-none-linux-gnueabihf -mcpu=cortex-a53 -mattr=+neon'
def check_correct_assembly(N): def check_correct_assembly(N):
K = tvm.var("K") K = tvm.var("K")
A = tvm.placeholder((K, N), dtype="int8", name='A') A = tvm.placeholder((K, N), dtype="int8", name='A')
B = tvm.placeholder((K, N), dtype="int8", name='A') B = tvm.placeholder((K, N), dtype="int8", name='B')
k = tvm.reduce_axis((0, K)) k = tvm.reduce_axis((0, K))
C = tvm.compute((N, ), lambda n: tvm.sum( C = tvm.compute((N, ), lambda n: tvm.sum(
A[k, n].astype("int32") * B[k, n].astype("int32"), axis=[k]), name='C') A[k, n].astype("int32") * B[k, n].astype("int32"), axis=[k]), name='C')
...@@ -60,14 +61,15 @@ def test_vmlal_s16(): ...@@ -60,14 +61,15 @@ def test_vmlal_s16():
assembly = f.get_source('asm') assembly = f.get_source('asm')
matches = re.findall("vmlal.s16", assembly) matches = re.findall("vmlal.s16", assembly)
assert (len(matches) == N // 4) assert (len(matches) == N // 4)
check_correct_assembly(4)
check_correct_assembly(8) check_correct_assembly(8)
check_correct_assembly(16) check_correct_assembly(16)
check_correct_assembly(32)
check_correct_assembly(64)
def check_broadcast_correct_assembly(N): def check_broadcast_correct_assembly(N):
K = tvm.var("K") K = tvm.var("K")
A = tvm.placeholder((K, N), dtype="int8", name='A') A = tvm.placeholder((K, N), dtype="int8", name='A')
B = tvm.placeholder((K,), dtype="int8", name='A') B = tvm.placeholder((K,), dtype="int8", name='B')
k = tvm.reduce_axis((0, K)) k = tvm.reduce_axis((0, K))
C = tvm.compute((N, ), lambda n: tvm.sum( C = tvm.compute((N, ), lambda n: tvm.sum(
A[k, n].astype("int32") * B[k].astype("int32"), A[k, n].astype("int32") * B[k].astype("int32"),
...@@ -85,6 +87,7 @@ def test_vmlal_s16(): ...@@ -85,6 +87,7 @@ def test_vmlal_s16():
check_broadcast_correct_assembly(32) check_broadcast_correct_assembly(32)
check_broadcast_correct_assembly(64) check_broadcast_correct_assembly(64)
if __name__ == "__main__": if __name__ == "__main__":
test_popcount() test_popcount()
test_vmlal_s16() test_vmlal_s16()
...@@ -188,8 +188,21 @@ def test_buffer_broadcast_expr(): ...@@ -188,8 +188,21 @@ def test_buffer_broadcast_expr():
fadd(a, b, c, 4, 1) fadd(a, b, c, 4, 1)
tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy()) tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
def check_auto_bind():
if not tvm.module.enabled("llvm"):
return
# Let build bind buffers
fadd = tvm.build(s, [A, B, C, o1, x], target='llvm', name='bcast_add')
ctx = tvm.cpu(0)
a = tvm.nd.array(np.random.uniform(size=(1, 4)).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=(2, 4)).astype(B.dtype), ctx)
c = tvm.nd.array(np.zeros((2, 4), dtype=C.dtype), ctx)
fadd(a, b, c, 4, 1)
tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
check_stride() check_stride()
check_no_stride() check_no_stride()
check_auto_bind()
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -74,9 +74,11 @@ inline BroadcastHelper BroadcastShape(const tvm::Array<tvm::Expr>& shape1, ...@@ -74,9 +74,11 @@ inline BroadcastHelper BroadcastShape(const tvm::Array<tvm::Expr>& shape1,
} else if (var1) { } else if (var1) {
bh.common_shape.push_front(shape2[s2_size - i]); bh.common_shape.push_front(shape2[s2_size - i]);
bh.vars2.push_front(bh.all_vars[0]); bh.vars2.push_front(bh.all_vars[0]);
bh.vars1.push_front(bh.all_vars[0]);
} else if (var2) { } else if (var2) {
bh.common_shape.push_front(shape1[s1_size - i]); bh.common_shape.push_front(shape1[s1_size - i]);
bh.vars1.push_front(bh.all_vars[0]); bh.vars1.push_front(bh.all_vars[0]);
bh.vars2.push_front(bh.all_vars[0]);
} else { } else {
CHECK(false) << "Incompatible broadcast dims: " << shape1[s1_size - i] CHECK(false) << "Incompatible broadcast dims: " << shape1[s1_size - i]
<< " and " << shape2[s2_size - i] << " in: " << " and " << shape2[s2_size - i] << " in: "
...@@ -98,16 +100,18 @@ inline BroadcastHelper BroadcastShape(const tvm::Array<tvm::Expr>& shape1, ...@@ -98,16 +100,18 @@ inline BroadcastHelper BroadcastShape(const tvm::Array<tvm::Expr>& shape1,
} }
inline tvm::Array<tvm::Expr> InputIndexFromBroadcast( inline tvm::Array<tvm::Expr> InputIndexFromBroadcast(
const tvm::Array<tvm::Var>& ovars, const tvm::Tensor& T, const tvm::Array<tvm::Var>& ovars,
const std::deque<tvm::Var>& my_vars, const std::deque<tvm::Var>& all_vars) { const tvm::Tensor& T,
const std::deque<tvm::Var>& my_vars,
const std::deque<tvm::Var>& all_vars) {
tvm::Array<tvm::Expr> ivars; tvm::Array<tvm::Expr> ivars;
CHECK_EQ(ovars.size(), all_vars.size()); CHECK_EQ(ovars.size(), all_vars.size());
// N^2, could use a map but NBD.. // N^2, could use a map but NBD.
size_t expected_dims = T->shape.size(); size_t expected_dims = T->shape.size();
for (size_t i = 0; i < ovars.size(); ++i) { for (size_t i = 0; i < ovars.size(); ++i) {
bool found = false; bool found = false;
for (size_t j = 0; j < my_vars.size(); ++j) { for (size_t j = 0; j < my_vars.size(); ++j) {
if (all_vars[i].same_as(my_vars[j])) { if (all_vars[i].same_as(my_vars[j])) {
ivars.push_back(ovars[i]); ivars.push_back(ovars[i]);
found = true; found = true;
break; break;
...@@ -123,13 +127,12 @@ inline tvm::Array<tvm::Expr> InputIndexFromBroadcast( ...@@ -123,13 +127,12 @@ inline tvm::Array<tvm::Expr> InputIndexFromBroadcast(
return ivars; return ivars;
} }
template <typename FBinaryExpr> template <typename FBinaryExpr>
inline tvm::Tensor WithBroadcast(FBinaryExpr op, inline tvm::Tensor WithBroadcast(FBinaryExpr op,
const tvm::Tensor& A, const tvm::Tensor& A,
const tvm::Tensor& B, const tvm::Tensor& B,
std::string name = "tensor", const std::string& name = "tensor",
std::string tag = "") { const std::string& tag = "") {
auto bh = BroadcastShape(A->shape, B->shape); auto bh = BroadcastShape(A->shape, B->shape);
auto l = [&](tvm::Array<tvm::Var> ovars) { auto l = [&](tvm::Array<tvm::Var> ovars) {
return op(A(InputIndexFromBroadcast(ovars, A, bh.vars1, bh.all_vars)), return op(A(InputIndexFromBroadcast(ovars, A, bh.vars1, bh.all_vars)),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment