Commit b00aabc5 by masahi Committed by Tianqi Chen

Add missing check when deciding conv op and injective op are in the same group (#1622)

parent 1c66012a
......@@ -146,6 +146,7 @@ nnvm::Graph GraphFindFusibleGroups(nnvm::Graph g) {
bool parent_out_ewise = false;
bool parent_injective = false;
for (const auto& e : inode.inputs) {
if (fuse_vec[e.node_id] != FuseRule::kFuseToMaster) continue;
TOpPattern pt = pattern_vec[e.node_id];
if (pt == kOutEWiseFusable) {
parent_out_ewise = true;
......
......@@ -110,6 +110,39 @@ def test_injective_conv2d():
np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5)
def test_concatenate_conv2d():
ch = 3
size = 8
data = sym.Variable(name="data")
concat = sym.concatenate(data, data, axis=1)
conv = sym.conv2d(data=concat, kernel_size=(1,1), channels=ch*2, use_bias=False, name="conv")
net = sym.elemwise_add(concat, conv)
dtype="float32"
dshape = (1, ch, size, size)
kshape = (ch*2, ch*2, 1, 1)
oshape = (1, ch*2, size, size)
shape_dict = {"data": dshape}
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(net, target, shape_dict)
# data, conv weight, conv op, concat
assert graph.index.num_nodes == 4
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
kernel = tvm.nd.array(np.random.uniform(size=kshape).astype(dtype))
m = graph_runtime.create(graph, lib, ctx)
m.run(data=data, conv_weight=kernel)
# get output
out = m.get_output(0, tvm.nd.empty(oshape, dtype))
concat = np.concatenate((data.asnumpy(), data.asnumpy()), axis=1)
conv = topi.testing.conv2d_nchw_python(
concat, kernel.asnumpy(), (1,1), 'SAME')
ref = concat + conv
np.testing.assert_allclose(out.asnumpy(), ref, rtol=1e-5)
def build_and_run(sym, params, data, out_shape, target, ctx, opt_level=2):
with nnvm.compiler.build_config(opt_level=opt_level):
graph, lib, params = nnvm.compiler.build(sym, target, shape={"data":data.shape}, params=params)
......@@ -157,3 +190,4 @@ if __name__ == "__main__":
test_conv_ewise_injective()
test_fuse_conv2d_elu()
test_injective_conv2d()
test_concatenate_conv2d()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment