Unverified Commit 95de08ba by Zhi Committed by GitHub

Fix alpha_equal bug (#4897)

parent e7be8bf4
...@@ -92,7 +92,7 @@ class AlphaEqualHandler: ...@@ -92,7 +92,7 @@ class AlphaEqualHandler:
auto compute = [&]() { auto compute = [&]() {
if (&lhs == &rhs) return true; if (&lhs == &rhs) return true;
if (auto lhsd = lhs.as<DictAttrsNode>()) { if (auto lhsd = lhs.as<DictAttrsNode>()) {
auto rhsd = lhs.as<DictAttrsNode>(); auto rhsd = rhs.as<DictAttrsNode>();
if (!rhsd) return false; if (!rhsd) return false;
if (lhsd->dict.size() != rhsd->dict.size()) return false; if (lhsd->dict.size() != rhsd->dict.size()) return false;
for (const auto& k : lhsd->dict) { for (const auto& k : lhsd->dict) {
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
""" test ir""" """ test ir"""
import pytest
import tvm import tvm
from tvm import relay from tvm import relay
from tvm.tir.expr import * from tvm.tir.expr import *
...@@ -174,6 +175,7 @@ def test_function(): ...@@ -174,6 +175,7 @@ def test_function():
str(fn) str(fn)
check_json_roundtrip(fn) check_json_roundtrip(fn)
@pytest.mark.skip(reason="AttrsEqualHandler doesn't handle Map so far.")
def test_function_attrs(): def test_function_attrs():
param_names = ['a', 'b', 'c', 'd'] param_names = ['a', 'b', 'c', 'd']
params = tvm.convert([relay.var(n, shape=(5, 2)) for n in param_names]) params = tvm.convert([relay.var(n, shape=(5, 2)) for n in param_names])
......
...@@ -18,6 +18,7 @@ import numpy as np ...@@ -18,6 +18,7 @@ import numpy as np
import tvm import tvm
from tvm import relay from tvm import relay
from tvm.relay import analysis from tvm.relay import analysis
from tvm.relay.testing import run_opt_pass
def alpha_equal(x, y): def alpha_equal(x, y):
""" """
...@@ -313,7 +314,7 @@ def test_tuple_get_item_alpha_equal(): ...@@ -313,7 +314,7 @@ def test_tuple_get_item_alpha_equal():
assert alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 1)) assert alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 1))
def test_multi_node_subgraph(): def test_function_attr():
x0 = relay.var('x0', shape=(10, 10)) x0 = relay.var('x0', shape=(10, 10))
w00 = relay.var('w00', shape=(10, 10)) w00 = relay.var('w00', shape=(10, 10))
w01 = relay.var('w01', shape=(10, 10)) w01 = relay.var('w01', shape=(10, 10))
...@@ -608,6 +609,7 @@ def test_graph_equal(): ...@@ -608,6 +609,7 @@ def test_graph_equal():
z3 = relay.add(relay.add(x, x), relay.add(x, x)) z3 = relay.add(relay.add(x, x), relay.add(x, x))
assert alpha_equal(z0, z1) assert alpha_equal(z0, z1)
assert alpha_equal(z0, z1)
# z3's dataflow format is different from z0 # z3's dataflow format is different from z0
# z0 is computed from a common y0 node # z0 is computed from a common y0 node
...@@ -649,6 +651,26 @@ def test_tuple_match(): ...@@ -649,6 +651,26 @@ def test_tuple_match():
assert analysis.structural_hash(x) == analysis.structural_hash(y) assert analysis.structural_hash(x) == analysis.structural_hash(y)
def test_fn_attribute():
# create function that performs add
a = relay.var('a', shape=(10, 10))
b = relay.var('b', shape=(10, 10))
add = relay.add(a, b)
add_fn = relay.Function([a, b], add)
add_fn = run_opt_pass(add_fn, relay.transform.InferType())
# create function that performs add with test attribute
c = relay.var('c', shape=(10, 10))
d = relay.var('d', shape=(10, 10))
add_1 = relay.add(c, d)
add_1_fn = relay.Function([c, d], add_1)
add_1_fn = add_1_fn.set_attribute("TestAttribute", tvm.tir.StringImm("test"))
add_1_fn = run_opt_pass(add_1_fn, relay.transform.InferType())
assert not relay.analysis.alpha_equal(add_1_fn, add_fn)
assert not relay.analysis.alpha_equal(add_fn, add_1_fn)
if __name__ == "__main__": if __name__ == "__main__":
test_tensor_type_alpha_equal() test_tensor_type_alpha_equal()
test_incomplete_type_alpha_equal() test_incomplete_type_alpha_equal()
...@@ -672,3 +694,4 @@ if __name__ == "__main__": ...@@ -672,3 +694,4 @@ if __name__ == "__main__":
test_var_alpha_equal() test_var_alpha_equal()
test_graph_equal() test_graph_equal()
test_hash_unequal() test_hash_unequal()
test_fn_attribute()
...@@ -35,6 +35,7 @@ def test_fuse_simple(): ...@@ -35,6 +35,7 @@ def test_fuse_simple():
z = relay.exp(y) z = relay.exp(y)
w = relay.squeeze(z) w = relay.squeeze(z)
f1 = relay.Function([x], w) f1 = relay.Function([x], w)
f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
x = relay.var("x", shape=(10, 20)) x = relay.var("x", shape=(10, 20))
y = relay.Call(f1, [x]) y = relay.Call(f1, [x])
return relay.Function([x], y) return relay.Function([x], y)
...@@ -76,6 +77,8 @@ def test_conv2d_fuse(): ...@@ -76,6 +77,8 @@ def test_conv2d_fuse():
x = relay.var("p0", shape=dshape) x = relay.var("p0", shape=dshape)
y = relay.add(x, relay.const(1, "float32")) y = relay.add(x, relay.const(1, "float32"))
f0 = relay.Function([x], y) f0 = relay.Function([x], y)
f0 = f0.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
# segment 1 # segment 1
x = relay.var("p0", shape=dshape) x = relay.var("p0", shape=dshape)
w = relay.var("p1") w = relay.var("p1")
...@@ -86,6 +89,8 @@ def test_conv2d_fuse(): ...@@ -86,6 +89,8 @@ def test_conv2d_fuse():
y1 = relay.add(relay.const(1, "float32"), y) y1 = relay.add(relay.const(1, "float32"), y)
y = relay.add(y, y1) y = relay.add(y, y1)
f1 = relay.Function([x, w], y) f1 = relay.Function([x, w], y)
f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
# segment 2 # segment 2
x = relay.var("p0", shape=dshape) x = relay.var("p0", shape=dshape)
w = relay.var("p1") w = relay.var("p1")
...@@ -94,6 +99,8 @@ def test_conv2d_fuse(): ...@@ -94,6 +99,8 @@ def test_conv2d_fuse():
padding=(1,1), padding=(1,1),
channels=16) channels=16)
f2 = relay.Function([x, w], z2) f2 = relay.Function([x, w], z2)
f2 = f2.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
# segment 3 # segment 3
x = relay.var("p0", shape=dshape) x = relay.var("p0", shape=dshape)
w = relay.var("p1") w = relay.var("p1")
...@@ -104,6 +111,8 @@ def test_conv2d_fuse(): ...@@ -104,6 +111,8 @@ def test_conv2d_fuse():
channels=16) channels=16)
z3 = relay.add(z3, offset) z3 = relay.add(z3, offset)
f3 = relay.Function([x, w, offset], z3) f3 = relay.Function([x, w, offset], z3)
f3 = f3.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
# compose # compose
x = relay.var("x", shape=dshape) x = relay.var("x", shape=dshape)
y = relay.Call(f0, [x]) y = relay.Call(f0, [x])
...@@ -135,6 +144,7 @@ def test_concatenate(): ...@@ -135,6 +144,7 @@ def test_concatenate():
x = relay.var("x", shape=dshape) x = relay.var("x", shape=dshape)
pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
f0 = relay.Function([x], pooled) f0 = relay.Function([x], pooled)
f0 = f0.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2]//2, dshape[3]//2)) p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2]//2, dshape[3]//2))
p1 = relay.var("p1", shape=dshape) p1 = relay.var("p1", shape=dshape)
...@@ -142,6 +152,7 @@ def test_concatenate(): ...@@ -142,6 +152,7 @@ def test_concatenate():
concat = relay.concatenate((upsampled, p1), axis=1) concat = relay.concatenate((upsampled, p1), axis=1)
out = relay.add(concat, relay.const(1, "float32")) out = relay.add(concat, relay.const(1, "float32"))
f1 = relay.Function([p0, p1], out) f1 = relay.Function([p0, p1], out)
f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
x = relay.var("x", shape=dshape) x = relay.var("x", shape=dshape)
y = relay.Call(f0, [x]) y = relay.Call(f0, [x])
...@@ -172,10 +183,12 @@ def test_tuple_root(): ...@@ -172,10 +183,12 @@ def test_tuple_root():
x = relay.var("x", shape=dshape) x = relay.var("x", shape=dshape)
pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
f0 = relay.Function([x], pooled) f0 = relay.Function([x], pooled)
f0 = f0.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2]//2, dshape[3]//2)) p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2]//2, dshape[3]//2))
upsampled = relay.nn.upsampling(p0, scale_h=2, scale_w=2, layout="NCHW") upsampled = relay.nn.upsampling(p0, scale_h=2, scale_w=2, layout="NCHW")
f1 = relay.Function([p0], upsampled) f1 = relay.Function([p0], upsampled)
f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
x = relay.var("x", shape=dshape) x = relay.var("x", shape=dshape)
y = relay.Call(f0, [x]) y = relay.Call(f0, [x])
...@@ -205,10 +218,12 @@ def test_stop_fusion(): ...@@ -205,10 +218,12 @@ def test_stop_fusion():
x = relay.var("p0", shape=dshape) x = relay.var("p0", shape=dshape)
y = relay.add(x, relay.const(1, "float32")) y = relay.add(x, relay.const(1, "float32"))
f1 = relay.Function([x], y) f1 = relay.Function([x], y)
f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
x = relay.var("p01", shape=dshape) x = relay.var("p01", shape=dshape)
y = relay.exp(x) y = relay.exp(x)
f2 = relay.Function([x], y) f2 = relay.Function([x], y)
f2 = f2.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
x = relay.var("x", shape=dshape) x = relay.var("x", shape=dshape)
y = relay.Call(f1, [x]) y = relay.Call(f1, [x])
...@@ -242,6 +257,7 @@ def test_fuse_myia_regression(): ...@@ -242,6 +257,7 @@ def test_fuse_myia_regression():
p2 = relay.var('p2', shape=dshape, dtype=dtype) p2 = relay.var('p2', shape=dshape, dtype=dtype)
fused_gt = relay.Function([p1, p2], fused_gt = relay.Function([p1, p2],
relay.op.greater(p1, p2)) relay.op.greater(p1, p2))
fused_gt = fused_gt.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
with sb.if_scope(fused_gt(x, y)): with sb.if_scope(fused_gt(x, y)):
sb.ret(relay.Function([], x)) sb.ret(relay.Function([], x))
with sb.else_scope(): with sb.else_scope():
...@@ -271,11 +287,13 @@ def test_fuse_tuple_get_elemwise(): ...@@ -271,11 +287,13 @@ def test_fuse_tuple_get_elemwise():
p1 = relay.var("p1", shape=(3 * dim, dim)) p1 = relay.var("p1", shape=(3 * dim, dim))
matmul = relay.nn.dense(p0, p1) matmul = relay.nn.dense(p0, p1)
f0 = relay.Function([p0, p1], matmul) f0 = relay.Function([p0, p1], matmul)
f0 = f0.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
p01 = relay.var("p01", shape=(1, 3 * dim)) p01 = relay.var("p01", shape=(1, 3 * dim))
splitted = relay.split(p01, indices_or_sections=3, axis=1) splitted = relay.split(p01, indices_or_sections=3, axis=1)
out = relay.sigmoid(splitted[0]) + relay.tanh(splitted[1]) * relay.exp(splitted[2]) out = relay.sigmoid(splitted[0]) + relay.tanh(splitted[1]) * relay.exp(splitted[2])
f1 = relay.Function([p01], out) f1 = relay.Function([p01], out)
f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
X = relay.var("X", shape=(1, dim)) X = relay.var("X", shape=(1, dim))
W = relay.var("W", shape=(3 * dim, dim)) W = relay.var("W", shape=(3 * dim, dim))
...@@ -306,11 +324,13 @@ def test_tuple_get_root(): ...@@ -306,11 +324,13 @@ def test_tuple_get_root():
splitted = relay.split(p0, indices_or_sections=3, axis=1) splitted = relay.split(p0, indices_or_sections=3, axis=1)
out = splitted[0] out = splitted[0]
f0 = relay.Function([p0], out) f0 = relay.Function([p0], out)
f0 = f0.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
p01 = relay.var("p01", shape=(1, dim)) p01 = relay.var("p01", shape=(1, dim))
p1 = relay.var("p1", shape=(dim, dim)) p1 = relay.var("p1", shape=(dim, dim))
out = relay.nn.dense(p01, p1) out = relay.nn.dense(p01, p1)
f1 = relay.Function([p01, p1], out) f1 = relay.Function([p01, p1], out)
f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
X = relay.var("X", shape=(1, 3 * dim)) X = relay.var("X", shape=(1, 3 * dim))
W = relay.var("W", shape=(dim, dim)) W = relay.var("W", shape=(dim, dim))
...@@ -346,8 +366,9 @@ def test_tuple_intermediate(): ...@@ -346,8 +366,9 @@ def test_tuple_intermediate():
def expected(p0): def expected(p0):
f0 = before(p0) f0 = before(p0)
f1 = f0.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
x = relay.var("x", shape=dshape) x = relay.var("x", shape=dshape)
y = relay.Call(f0, [x]) y = relay.Call(f1, [x])
return relay.Function([x], y) return relay.Function([x], y)
dshape = (1, 16, 64, 64) dshape = (1, 16, 64, 64)
...@@ -388,15 +409,18 @@ def test_tuple_consecutive(): ...@@ -388,15 +409,18 @@ def test_tuple_consecutive():
p0 = relay.var("p0", shape=dshape) p0 = relay.var("p0", shape=dshape)
concat = gen_consecutive_tuple(p0) concat = gen_consecutive_tuple(p0)
f0 = relay.Function([p0], concat) f0 = relay.Function([p0], concat)
f0 = f0.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
p01 = relay.var("p01", shape=(1, dshape[1]*9, dshape[2], dshape[3])) p01 = relay.var("p01", shape=(1, dshape[1]*9, dshape[2], dshape[3]))
pooled = relay.nn.max_pool2d(p01, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) pooled = relay.nn.max_pool2d(p01, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
out = relay.add(pooled, relay.const(1, "float32")) out = relay.add(pooled, relay.const(1, "float32"))
f1 = relay.Function([p01], out) f1 = relay.Function([p01], out)
f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
p02 = relay.var("p02", shape=(1, dshape[1]*9, dshape[2]//2, dshape[3]//2)) p02 = relay.var("p02", shape=(1, dshape[1]*9, dshape[2]//2, dshape[3]//2))
out = relay.add(p02, relay.const(1, "float32")) out = relay.add(p02, relay.const(1, "float32"))
f2 = relay.Function([p02], out) f2 = relay.Function([p02], out)
f2 = f2.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
x = relay.var("x", shape=dshape) x = relay.var("x", shape=dshape)
y = relay.Call(f0, [x]) y = relay.Call(f0, [x])
...@@ -438,30 +462,36 @@ def test_inception_like(): ...@@ -438,30 +462,36 @@ def test_inception_like():
p0 = relay.var("p0", shape=dshape) p0 = relay.var("p0", shape=dshape)
c = conv(p0) c = conv(p0)
f0 = relay.Function(relay.analysis.free_vars(c), c) f0 = relay.Function(relay.analysis.free_vars(c), c)
f0 = f0.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
p01 = relay.var("p01", shape=dshape) p01 = relay.var("p01", shape=dshape)
c = conv(p01) c = conv(p01)
f1 = relay.Function(relay.analysis.free_vars(c), c) f1 = relay.Function(relay.analysis.free_vars(c), c)
f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
p02 = relay.var("p02", shape=dshape) p02 = relay.var("p02", shape=dshape)
p12 = relay.var("p12", shape=dshape) p12 = relay.var("p12", shape=dshape)
concat1 = relay.concatenate((p02, p12), axis=1) concat1 = relay.concatenate((p02, p12), axis=1)
f_concat1 = relay.Function([p02, p12], concat1) f_concat1 = relay.Function([p02, p12], concat1)
f_concat1 = f_concat1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
dshape2 = (dshape[0], dshape[1]*2, dshape[2], dshape[3]) dshape2 = (dshape[0], dshape[1]*2, dshape[2], dshape[3])
p03 = relay.var("p03", shape=dshape2) p03 = relay.var("p03", shape=dshape2)
c = conv(p03) c = conv(p03)
f2 = relay.Function(relay.analysis.free_vars(c), c) f2 = relay.Function(relay.analysis.free_vars(c), c)
f2 = f2.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
p04 = relay.var("p04", shape=dshape2) p04 = relay.var("p04", shape=dshape2)
c = conv(p04) c = conv(p04)
f3 = relay.Function(relay.analysis.free_vars(c), c) f3 = relay.Function(relay.analysis.free_vars(c), c)
f3 = f3.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
p05 = relay.var("p05", shape=dshape) p05 = relay.var("p05", shape=dshape)
p15 = relay.var("p15", shape=dshape) p15 = relay.var("p15", shape=dshape)
concat2 = relay.concatenate((p05, p15), axis=1) concat2 = relay.concatenate((p05, p15), axis=1)
f_concat2 = relay.Function([p05, p15], concat2) f_concat2 = relay.Function([p05, p15], concat2)
f_concat2 = f_concat2.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
x = relay.var("x", shape=dshape) x = relay.var("x", shape=dshape)
c1 = relay.Call(f0, [x, relay.var("w1")]) c1 = relay.Call(f0, [x, relay.var("w1")])
...@@ -499,6 +529,7 @@ def test_fuse_parallel_injective(): ...@@ -499,6 +529,7 @@ def test_fuse_parallel_injective():
u = relay.transpose(y, axes=[0, 1]) u = relay.transpose(y, axes=[0, 1])
w = relay.left_shift(z, u) w = relay.left_shift(z, u)
f1 = relay.Function([x], w) f1 = relay.Function([x], w)
f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
x = relay.var("x", shape=(10, 20)) x = relay.var("x", shape=(10, 20))
y = relay.Call(f1, [x]) y = relay.Call(f1, [x])
return relay.Function([x], y) return relay.Function([x], y)
...@@ -529,6 +560,7 @@ def test_immutable(): ...@@ -529,6 +560,7 @@ def test_immutable():
z = relay.exp(y) z = relay.exp(y)
w = relay.squeeze(z) w = relay.squeeze(z)
f1 = relay.Function([x], w) f1 = relay.Function([x], w)
f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
x = relay.var("x", shape=(10, 20)) x = relay.var("x", shape=(10, 20))
y = relay.Call(f1, [x]) y = relay.Call(f1, [x])
mod = tvm.IRModule() mod = tvm.IRModule()
...@@ -570,6 +602,7 @@ def test_fuse_max(): ...@@ -570,6 +602,7 @@ def test_fuse_max():
for i in range(max_fused_ops): for i in range(max_fused_ops):
y = relay.exp(y) y = relay.exp(y)
f1 = relay.Function([x], y) f1 = relay.Function([x], y)
f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
x = relay.var("x", shape=(10, 20)) x = relay.var("x", shape=(10, 20))
z = relay.Call(f1, [x]) z = relay.Call(f1, [x])
xx = relay.var("pp", shape=(10, 20)) xx = relay.var("pp", shape=(10, 20))
...@@ -577,6 +610,7 @@ def test_fuse_max(): ...@@ -577,6 +610,7 @@ def test_fuse_max():
for i in range(n-max_fused_ops): for i in range(n-max_fused_ops):
yy = relay.exp(yy) yy = relay.exp(yy)
f2 = relay.Function([xx], yy) f2 = relay.Function([xx], yy)
f2 = f2.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
zz = relay.Call(f2, [z]) zz = relay.Call(f2, [z])
return relay.Function([x], zz) return relay.Function([x], zz)
......
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""Unit tests for merge composite.""" """Unit tests for merge composite."""
from tvm import expr
from tvm import relay from tvm import relay
from tvm import tir
from tvm.relay.testing import run_opt_pass from tvm.relay.testing import run_opt_pass
""" """
...@@ -144,6 +144,8 @@ def test_simple_merge(): ...@@ -144,6 +144,8 @@ def test_simple_merge():
add_node = relay.add(in_1, in_2) add_node = relay.add(in_1, in_2)
relu_node = relay.nn.relu(add_node) relu_node = relay.nn.relu(add_node)
add_relu = relay.Function([in_1, in_2], relu_node) add_relu = relay.Function([in_1, in_2], relu_node)
add_relu = add_relu.set_attribute("Primitive", tir.IntImm("int32", 1))
add_relu = add_relu.set_attribute("Composite", tir.StringImm("add_relu"))
# merged function # merged function
r = relay.Call(add_relu, [a, b]) r = relay.Call(add_relu, [a, b])
...@@ -208,11 +210,27 @@ def test_branch_merge(): ...@@ -208,11 +210,27 @@ def test_branch_merge():
sub_node = relay.subtract(in_1, in_2) sub_node = relay.subtract(in_1, in_2)
mul_node = relay.multiply(add_node, sub_node) mul_node = relay.multiply(add_node, sub_node)
add_sub_mul = relay.Function([in_1, in_2], mul_node) add_sub_mul = relay.Function([in_1, in_2], mul_node)
add_sub_mul = add_sub_mul.set_attribute("Primitive",
tir.IntImm("int32", 1))
add_sub_mul = add_sub_mul.set_attribute("Composite",
tir.StringImm("add_sub_mul"))
# add_sub_mul1 function
in_3 = relay.var('in_3', shape=(10, 10))
in_4 = relay.var('in_4', shape=(10, 10))
add_node_1 = relay.add(in_3, in_4)
sub_node_1 = relay.subtract(in_3, in_4)
mul_node_1 = relay.multiply(add_node_1, sub_node_1)
add_sub_mul_1 = relay.Function([in_3, in_4], mul_node_1)
add_sub_mul_1 = add_sub_mul_1.set_attribute("Primitive",
tir.IntImm("int32", 1))
add_sub_mul_1 = add_sub_mul_1.set_attribute("Composite",
tir.StringImm("add_sub_mul"))
# merged function # merged function
add_sub_mul_1 = relay.Call(add_sub_mul, [a, b]) m_add_sub_mul_1 = relay.Call(add_sub_mul, [a, b])
add_sub_mul_2 = relay.Call(add_sub_mul, [c, add_sub_mul_1]) m_add_sub_mul_2 = relay.Call(add_sub_mul_1, [c, m_add_sub_mul_1])
r = relay.nn.relu(add_sub_mul_2) r = relay.nn.relu(m_add_sub_mul_2)
return relay.Function([a, b, c], r) return relay.Function([a, b, c], r)
result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table)) result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
...@@ -291,6 +309,9 @@ def test_multiple_patterns(): ...@@ -291,6 +309,9 @@ def test_multiple_patterns():
bias_node = relay.nn.bias_add(conv_node, in_3) bias_node = relay.nn.bias_add(conv_node, in_3)
r = relay.nn.relu(bias_node) r = relay.nn.relu(bias_node)
conv_bias_add_relu = relay.Function([in_1, in_2, in_3], r) conv_bias_add_relu = relay.Function([in_1, in_2, in_3], r)
conv_bias_add_relu = conv_bias_add_relu.set_attribute("Primitive", tir.IntImm("int32", 1))
conv_bias_add_relu = conv_bias_add_relu.set_attribute("Composite",
tir.StringImm("conv2d_bias_relu"))
# add_relu function # add_relu function
in_4 = relay.var('in_4', shape=(1, 256, 28, 28)) in_4 = relay.var('in_4', shape=(1, 256, 28, 28))
...@@ -298,6 +319,8 @@ def test_multiple_patterns(): ...@@ -298,6 +319,8 @@ def test_multiple_patterns():
add_node = relay.add(in_4, in_5) add_node = relay.add(in_4, in_5)
r = relay.nn.relu(add_node) r = relay.nn.relu(add_node)
add_relu = relay.Function([in_4, in_5], r) add_relu = relay.Function([in_4, in_5], r)
add_relu = add_relu.set_attribute("Primitive", tir.IntImm("int32", 1))
add_relu = add_relu.set_attribute("Composite", tir.StringImm("add_relu"))
# merged function # merged function
conv_bias_add_relu_1 = relay.Call(conv_bias_add_relu, [data, kernel, bias]) conv_bias_add_relu_1 = relay.Call(conv_bias_add_relu, [data, kernel, bias])
...@@ -357,7 +380,7 @@ def test_merge_order(): ...@@ -357,7 +380,7 @@ def test_merge_order():
out = relay.nn.relu(out) out = relay.nn.relu(out)
return relay.Function([input_1, input_2], out) return relay.Function([input_1, input_2], out)
def after_A_priority(): def after_A_priority(composite_name):
input_1 = relay.var('input_1', shape=(10, 10)) input_1 = relay.var('input_1', shape=(10, 10))
input_2 = relay.var('input_2', shape=(10, 10)) input_2 = relay.var('input_2', shape=(10, 10))
x = relay.var('x') x = relay.var('x')
...@@ -366,38 +389,12 @@ def test_merge_order(): ...@@ -366,38 +389,12 @@ def test_merge_order():
out = relay.abs(out) out = relay.abs(out)
out = relay.nn.relu(out) out = relay.nn.relu(out)
merged_func = relay.Function([x, y], out) merged_func = relay.Function([x, y], out)
merged_func = merged_func.set_attribute('Primitive', expr.IntImm('int32', 1)) merged_func = merged_func.set_attribute('Primitive', tir.IntImm('int32', 1))
merged_func = merged_func.set_attribute('Composite', expr.StringImm('A')) merged_func = merged_func.set_attribute('Composite',
tir.StringImm(composite_name))
ret = relay.Call(merged_func, [input_1, input_2]) ret = relay.Call(merged_func, [input_1, input_2])
return relay.Function([input_1, input_2], ret) return relay.Function([input_1, input_2], ret)
def after_B_priority():
input_1 = relay.var('input_1', shape=(10, 10))
input_2 = relay.var('input_2', shape=(10, 10))
x = relay.var('x')
y = relay.var('y')
out = relay.add(x, y)
out = relay.abs(out)
merged_func = relay.Function([x, y], out)
merged_func = merged_func.set_attribute('Primitive', expr.IntImm('int32', 1))
merged_func = merged_func.set_attribute('Composite', expr.StringImm('B'))
merged_call = relay.Call(merged_func, [input_1, input_2])
ret = relay.nn.relu(merged_call)
return relay.Function([input_1, input_2], ret)
def after_C_priority():
input_1 = relay.var('input_1', shape=(10, 10))
input_2 = relay.var('input_2', shape=(10, 10))
add = relay.add(input_1, input_2)
x = relay.var('x')
out = relay.abs(x)
out = relay.nn.relu(out)
merged_func = relay.Function([x], out)
merged_func = merged_func.set_attribute('Primitive', expr.IntImm('int32', 1))
merged_func = merged_func.set_attribute('Composite', expr.StringImm('C'))
ret = relay.Call(merged_func, [add])
return relay.Function([input_1, input_2], ret)
# check A highest priority # check A highest priority
pattern_table = [ pattern_table = [
("A", pattern_A()), ("A", pattern_A()),
...@@ -406,7 +403,7 @@ def test_merge_order(): ...@@ -406,7 +403,7 @@ def test_merge_order():
] ]
result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table)) result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
assert not relay.analysis.free_vars(result) assert not relay.analysis.free_vars(result)
expected = run_opt_pass(after_A_priority(), relay.transform.InferType()) expected = run_opt_pass(after_A_priority("A"), relay.transform.InferType())
assert relay.analysis.alpha_equal(result, expected) assert relay.analysis.alpha_equal(result, expected)
# check B highest priority # check B highest priority
...@@ -417,7 +414,7 @@ def test_merge_order(): ...@@ -417,7 +414,7 @@ def test_merge_order():
] ]
result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table)) result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
assert not relay.analysis.free_vars(result) assert not relay.analysis.free_vars(result)
expected = run_opt_pass(after_A_priority(), relay.transform.InferType()) expected = run_opt_pass(after_A_priority("B"), relay.transform.InferType())
assert relay.analysis.alpha_equal(result, expected) assert relay.analysis.alpha_equal(result, expected)
# check C highest priority # check C highest priority
...@@ -428,7 +425,7 @@ def test_merge_order(): ...@@ -428,7 +425,7 @@ def test_merge_order():
] ]
result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table)) result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table))
assert not relay.analysis.free_vars(result) assert not relay.analysis.free_vars(result)
expected = run_opt_pass(after_A_priority(), relay.transform.InferType()) expected = run_opt_pass(after_A_priority("C"), relay.transform.InferType())
assert relay.analysis.alpha_equal(result, expected) assert relay.analysis.alpha_equal(result, expected)
...@@ -459,11 +456,15 @@ def test_parallel_merge(): ...@@ -459,11 +456,15 @@ def test_parallel_merge():
y = relay.var('y') y = relay.var('y')
branch_1 = relay.multiply(relay.add(x, y), relay.subtract(x, y)) branch_1 = relay.multiply(relay.add(x, y), relay.subtract(x, y))
func_1 = relay.Function([x, y], branch_1) func_1 = relay.Function([x, y], branch_1)
func_1 = func_1.set_attribute('Primitive', tir.IntImm('int32', 1))
func_1 = func_1.set_attribute('Composite', tir.StringImm("add_sub_mul"))
call_1 = relay.Call(func_1, [input_1, input_2]) call_1 = relay.Call(func_1, [input_1, input_2])
x1 = relay.var('x1') x1 = relay.var('x1')
y1 = relay.var('y1') y1 = relay.var('y1')
branch_2 = relay.multiply(relay.add(x1, y1), relay.subtract(x1, y1)) branch_2 = relay.multiply(relay.add(x1, y1), relay.subtract(x1, y1))
func_2 = relay.Function([x1, y1], branch_2) func_2 = relay.Function([x1, y1], branch_2)
func_2 = func_2.set_attribute('Primitive', tir.IntImm('int32', 1))
func_2 = func_2.set_attribute('Composite', tir.StringImm("add_sub_mul"))
call_2 = relay.Call(func_2, [input_1, input_2]) call_2 = relay.Call(func_2, [input_1, input_2])
out = relay.multiply(call_1, call_2) out = relay.multiply(call_1, call_2)
return relay.Function([input_1, input_2], out) return relay.Function([input_1, input_2], out)
...@@ -542,16 +543,16 @@ def test_multiple_input_subgraphs(): ...@@ -542,16 +543,16 @@ def test_multiple_input_subgraphs():
add_relu_1 = relay.add(x, y) add_relu_1 = relay.add(x, y)
add_relu_1 = relay.nn.relu(add_relu_1) add_relu_1 = relay.nn.relu(add_relu_1)
add_relu_1 = relay.Function([x, y], add_relu_1) add_relu_1 = relay.Function([x, y], add_relu_1)
add_relu_1 = add_relu_1.set_attribute('Primitive', expr.IntImm('int32', 1)) add_relu_1 = add_relu_1.set_attribute('Primitive', tir.IntImm('int32', 1))
add_relu_1 = add_relu_1.set_attribute('Composite', expr.StringImm('add_relu')) add_relu_1 = add_relu_1.set_attribute('Composite', tir.StringImm('add_relu'))
add_relu_call_1 = relay.Call(add_relu_1, [inputs[0], inputs[1]]) add_relu_call_1 = relay.Call(add_relu_1, [inputs[0], inputs[1]])
x1 = relay.var('x1') x1 = relay.var('x1')
y1 = relay.var('y1') y1 = relay.var('y1')
add_relu_2 = relay.add(x1, y1) add_relu_2 = relay.add(x1, y1)
add_relu_2 = relay.nn.relu(add_relu_2) add_relu_2 = relay.nn.relu(add_relu_2)
add_relu_2 = relay.Function([x1, y1], add_relu_2) add_relu_2 = relay.Function([x1, y1], add_relu_2)
add_relu_2 = add_relu_2.set_attribute('Primitive', expr.IntImm('int32', 1)) add_relu_2 = add_relu_2.set_attribute('Primitive', tir.IntImm('int32', 1))
add_relu_2 = add_relu_2.set_attribute('Composite', expr.StringImm('add_relu')) add_relu_2 = add_relu_2.set_attribute('Composite', tir.StringImm('add_relu'))
add_relu_call_2 = relay.Call(add_relu_2, [inputs[2], inputs[3]]) add_relu_call_2 = relay.Call(add_relu_2, [inputs[2], inputs[3]])
x2 = relay.var('x2') x2 = relay.var('x2')
y2 = relay.var('y2') y2 = relay.var('y2')
...@@ -559,8 +560,8 @@ def test_multiple_input_subgraphs(): ...@@ -559,8 +560,8 @@ def test_multiple_input_subgraphs():
sub = relay.subtract(x2, y2) sub = relay.subtract(x2, y2)
add_sub_mul = relay.multiply(add, sub) add_sub_mul = relay.multiply(add, sub)
add_sub_mul = relay.Function([x2, y2], add_sub_mul) add_sub_mul = relay.Function([x2, y2], add_sub_mul)
add_sub_mul = add_sub_mul.set_attribute('Primitive', expr.IntImm('int32', 1)) add_sub_mul = add_sub_mul.set_attribute('Primitive', tir.IntImm('int32', 1))
add_sub_mul = add_sub_mul.set_attribute('Composite', expr.StringImm('add_sub_mul')) add_sub_mul = add_sub_mul.set_attribute('Composite', tir.StringImm('add_sub_mul'))
add_sub_mul_call = relay.Call(add_sub_mul, [add_relu_call_1, add_relu_call_2]) add_sub_mul_call = relay.Call(add_sub_mul, [add_relu_call_1, add_relu_call_2])
return relay.Function(inputs, add_sub_mul_call) return relay.Function(inputs, add_sub_mul_call)
...@@ -573,8 +574,8 @@ def test_multiple_input_subgraphs(): ...@@ -573,8 +574,8 @@ def test_multiple_input_subgraphs():
add_relu = relay.add(x, y) add_relu = relay.add(x, y)
add_relu = relay.nn.relu(add_relu) add_relu = relay.nn.relu(add_relu)
add_relu = relay.Function([x, y], add_relu) add_relu = relay.Function([x, y], add_relu)
add_relu = add_relu.set_attribute('Primitive', expr.IntImm('int32', 1)) add_relu = add_relu.set_attribute('Primitive', tir.IntImm('int32', 1))
add_relu = add_relu.set_attribute('Composite', expr.StringImm('add_relu')) add_relu = add_relu.set_attribute('Composite', tir.StringImm('add_relu'))
add_relu_call = relay.Call(add_relu, [inputs[i*2], inputs[i*2+1]]) add_relu_call = relay.Call(add_relu, [inputs[i*2], inputs[i*2+1]])
add_relu_calls.append(add_relu_call) add_relu_calls.append(add_relu_call)
...@@ -606,4 +607,4 @@ if __name__ == "__main__": ...@@ -606,4 +607,4 @@ if __name__ == "__main__":
test_multiple_patterns() test_multiple_patterns()
test_merge_order() test_merge_order()
test_parallel_merge() test_parallel_merge()
test_multiple_input_subgraphs() test_multiple_input_subgraphs()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment