Unverified Commit 95de08ba by Zhi Committed by GitHub

Fix alpha_equal bug (#4897)

parent e7be8bf4
...@@ -92,7 +92,7 @@ class AlphaEqualHandler: ...@@ -92,7 +92,7 @@ class AlphaEqualHandler:
auto compute = [&]() { auto compute = [&]() {
if (&lhs == &rhs) return true; if (&lhs == &rhs) return true;
if (auto lhsd = lhs.as<DictAttrsNode>()) { if (auto lhsd = lhs.as<DictAttrsNode>()) {
auto rhsd = lhs.as<DictAttrsNode>(); auto rhsd = rhs.as<DictAttrsNode>();
if (!rhsd) return false; if (!rhsd) return false;
if (lhsd->dict.size() != rhsd->dict.size()) return false; if (lhsd->dict.size() != rhsd->dict.size()) return false;
for (const auto& k : lhsd->dict) { for (const auto& k : lhsd->dict) {
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
""" test ir""" """ test ir"""
import pytest
import tvm import tvm
from tvm import relay from tvm import relay
from tvm.tir.expr import * from tvm.tir.expr import *
...@@ -174,6 +175,7 @@ def test_function(): ...@@ -174,6 +175,7 @@ def test_function():
str(fn) str(fn)
check_json_roundtrip(fn) check_json_roundtrip(fn)
@pytest.mark.skip(reason="AttrsEqualHandler doesn't handle Map so far.")
def test_function_attrs(): def test_function_attrs():
param_names = ['a', 'b', 'c', 'd'] param_names = ['a', 'b', 'c', 'd']
params = tvm.convert([relay.var(n, shape=(5, 2)) for n in param_names]) params = tvm.convert([relay.var(n, shape=(5, 2)) for n in param_names])
......
...@@ -18,6 +18,7 @@ import numpy as np ...@@ -18,6 +18,7 @@ import numpy as np
import tvm import tvm
from tvm import relay from tvm import relay
from tvm.relay import analysis from tvm.relay import analysis
from tvm.relay.testing import run_opt_pass
def alpha_equal(x, y): def alpha_equal(x, y):
""" """
...@@ -313,7 +314,7 @@ def test_tuple_get_item_alpha_equal(): ...@@ -313,7 +314,7 @@ def test_tuple_get_item_alpha_equal():
assert alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 1)) assert alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 1))
def test_multi_node_subgraph(): def test_function_attr():
x0 = relay.var('x0', shape=(10, 10)) x0 = relay.var('x0', shape=(10, 10))
w00 = relay.var('w00', shape=(10, 10)) w00 = relay.var('w00', shape=(10, 10))
w01 = relay.var('w01', shape=(10, 10)) w01 = relay.var('w01', shape=(10, 10))
...@@ -608,6 +609,7 @@ def test_graph_equal(): ...@@ -608,6 +609,7 @@ def test_graph_equal():
z3 = relay.add(relay.add(x, x), relay.add(x, x)) z3 = relay.add(relay.add(x, x), relay.add(x, x))
assert alpha_equal(z0, z1) assert alpha_equal(z0, z1)
assert alpha_equal(z0, z1)
# z3's dataflow format is different from z0 # z3's dataflow format is different from z0
# z0 is computed from a common y0 node # z0 is computed from a common y0 node
...@@ -649,6 +651,26 @@ def test_tuple_match(): ...@@ -649,6 +651,26 @@ def test_tuple_match():
assert analysis.structural_hash(x) == analysis.structural_hash(y) assert analysis.structural_hash(x) == analysis.structural_hash(y)
def test_fn_attribute():
# create function that performs add
a = relay.var('a', shape=(10, 10))
b = relay.var('b', shape=(10, 10))
add = relay.add(a, b)
add_fn = relay.Function([a, b], add)
add_fn = run_opt_pass(add_fn, relay.transform.InferType())
# create function that performs add with test attribute
c = relay.var('c', shape=(10, 10))
d = relay.var('d', shape=(10, 10))
add_1 = relay.add(c, d)
add_1_fn = relay.Function([c, d], add_1)
add_1_fn = add_1_fn.set_attribute("TestAttribute", tvm.tir.StringImm("test"))
add_1_fn = run_opt_pass(add_1_fn, relay.transform.InferType())
assert not relay.analysis.alpha_equal(add_1_fn, add_fn)
assert not relay.analysis.alpha_equal(add_fn, add_1_fn)
if __name__ == "__main__": if __name__ == "__main__":
test_tensor_type_alpha_equal() test_tensor_type_alpha_equal()
test_incomplete_type_alpha_equal() test_incomplete_type_alpha_equal()
...@@ -672,3 +694,4 @@ if __name__ == "__main__": ...@@ -672,3 +694,4 @@ if __name__ == "__main__":
test_var_alpha_equal() test_var_alpha_equal()
test_graph_equal() test_graph_equal()
test_hash_unequal() test_hash_unequal()
test_fn_attribute()
...@@ -35,6 +35,7 @@ def test_fuse_simple(): ...@@ -35,6 +35,7 @@ def test_fuse_simple():
z = relay.exp(y) z = relay.exp(y)
w = relay.squeeze(z) w = relay.squeeze(z)
f1 = relay.Function([x], w) f1 = relay.Function([x], w)
f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
x = relay.var("x", shape=(10, 20)) x = relay.var("x", shape=(10, 20))
y = relay.Call(f1, [x]) y = relay.Call(f1, [x])
return relay.Function([x], y) return relay.Function([x], y)
...@@ -76,6 +77,8 @@ def test_conv2d_fuse(): ...@@ -76,6 +77,8 @@ def test_conv2d_fuse():
x = relay.var("p0", shape=dshape) x = relay.var("p0", shape=dshape)
y = relay.add(x, relay.const(1, "float32")) y = relay.add(x, relay.const(1, "float32"))
f0 = relay.Function([x], y) f0 = relay.Function([x], y)
f0 = f0.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
# segment 1 # segment 1
x = relay.var("p0", shape=dshape) x = relay.var("p0", shape=dshape)
w = relay.var("p1") w = relay.var("p1")
...@@ -86,6 +89,8 @@ def test_conv2d_fuse(): ...@@ -86,6 +89,8 @@ def test_conv2d_fuse():
y1 = relay.add(relay.const(1, "float32"), y) y1 = relay.add(relay.const(1, "float32"), y)
y = relay.add(y, y1) y = relay.add(y, y1)
f1 = relay.Function([x, w], y) f1 = relay.Function([x, w], y)
f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
# segment 2 # segment 2
x = relay.var("p0", shape=dshape) x = relay.var("p0", shape=dshape)
w = relay.var("p1") w = relay.var("p1")
...@@ -94,6 +99,8 @@ def test_conv2d_fuse(): ...@@ -94,6 +99,8 @@ def test_conv2d_fuse():
padding=(1,1), padding=(1,1),
channels=16) channels=16)
f2 = relay.Function([x, w], z2) f2 = relay.Function([x, w], z2)
f2 = f2.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
# segment 3 # segment 3
x = relay.var("p0", shape=dshape) x = relay.var("p0", shape=dshape)
w = relay.var("p1") w = relay.var("p1")
...@@ -104,6 +111,8 @@ def test_conv2d_fuse(): ...@@ -104,6 +111,8 @@ def test_conv2d_fuse():
channels=16) channels=16)
z3 = relay.add(z3, offset) z3 = relay.add(z3, offset)
f3 = relay.Function([x, w, offset], z3) f3 = relay.Function([x, w, offset], z3)
f3 = f3.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
# compose # compose
x = relay.var("x", shape=dshape) x = relay.var("x", shape=dshape)
y = relay.Call(f0, [x]) y = relay.Call(f0, [x])
...@@ -135,6 +144,7 @@ def test_concatenate(): ...@@ -135,6 +144,7 @@ def test_concatenate():
x = relay.var("x", shape=dshape) x = relay.var("x", shape=dshape)
pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
f0 = relay.Function([x], pooled) f0 = relay.Function([x], pooled)
f0 = f0.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2]//2, dshape[3]//2)) p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2]//2, dshape[3]//2))
p1 = relay.var("p1", shape=dshape) p1 = relay.var("p1", shape=dshape)
...@@ -142,6 +152,7 @@ def test_concatenate(): ...@@ -142,6 +152,7 @@ def test_concatenate():
concat = relay.concatenate((upsampled, p1), axis=1) concat = relay.concatenate((upsampled, p1), axis=1)
out = relay.add(concat, relay.const(1, "float32")) out = relay.add(concat, relay.const(1, "float32"))
f1 = relay.Function([p0, p1], out) f1 = relay.Function([p0, p1], out)
f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
x = relay.var("x", shape=dshape) x = relay.var("x", shape=dshape)
y = relay.Call(f0, [x]) y = relay.Call(f0, [x])
...@@ -172,10 +183,12 @@ def test_tuple_root(): ...@@ -172,10 +183,12 @@ def test_tuple_root():
x = relay.var("x", shape=dshape) x = relay.var("x", shape=dshape)
pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
f0 = relay.Function([x], pooled) f0 = relay.Function([x], pooled)
f0 = f0.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2]//2, dshape[3]//2)) p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2]//2, dshape[3]//2))
upsampled = relay.nn.upsampling(p0, scale_h=2, scale_w=2, layout="NCHW") upsampled = relay.nn.upsampling(p0, scale_h=2, scale_w=2, layout="NCHW")
f1 = relay.Function([p0], upsampled) f1 = relay.Function([p0], upsampled)
f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
x = relay.var("x", shape=dshape) x = relay.var("x", shape=dshape)
y = relay.Call(f0, [x]) y = relay.Call(f0, [x])
...@@ -205,10 +218,12 @@ def test_stop_fusion(): ...@@ -205,10 +218,12 @@ def test_stop_fusion():
x = relay.var("p0", shape=dshape) x = relay.var("p0", shape=dshape)
y = relay.add(x, relay.const(1, "float32")) y = relay.add(x, relay.const(1, "float32"))
f1 = relay.Function([x], y) f1 = relay.Function([x], y)
f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
x = relay.var("p01", shape=dshape) x = relay.var("p01", shape=dshape)
y = relay.exp(x) y = relay.exp(x)
f2 = relay.Function([x], y) f2 = relay.Function([x], y)
f2 = f2.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
x = relay.var("x", shape=dshape) x = relay.var("x", shape=dshape)
y = relay.Call(f1, [x]) y = relay.Call(f1, [x])
...@@ -242,6 +257,7 @@ def test_fuse_myia_regression(): ...@@ -242,6 +257,7 @@ def test_fuse_myia_regression():
p2 = relay.var('p2', shape=dshape, dtype=dtype) p2 = relay.var('p2', shape=dshape, dtype=dtype)
fused_gt = relay.Function([p1, p2], fused_gt = relay.Function([p1, p2],
relay.op.greater(p1, p2)) relay.op.greater(p1, p2))
fused_gt = fused_gt.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
with sb.if_scope(fused_gt(x, y)): with sb.if_scope(fused_gt(x, y)):
sb.ret(relay.Function([], x)) sb.ret(relay.Function([], x))
with sb.else_scope(): with sb.else_scope():
...@@ -271,11 +287,13 @@ def test_fuse_tuple_get_elemwise(): ...@@ -271,11 +287,13 @@ def test_fuse_tuple_get_elemwise():
p1 = relay.var("p1", shape=(3 * dim, dim)) p1 = relay.var("p1", shape=(3 * dim, dim))
matmul = relay.nn.dense(p0, p1) matmul = relay.nn.dense(p0, p1)
f0 = relay.Function([p0, p1], matmul) f0 = relay.Function([p0, p1], matmul)
f0 = f0.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
p01 = relay.var("p01", shape=(1, 3 * dim)) p01 = relay.var("p01", shape=(1, 3 * dim))
splitted = relay.split(p01, indices_or_sections=3, axis=1) splitted = relay.split(p01, indices_or_sections=3, axis=1)
out = relay.sigmoid(splitted[0]) + relay.tanh(splitted[1]) * relay.exp(splitted[2]) out = relay.sigmoid(splitted[0]) + relay.tanh(splitted[1]) * relay.exp(splitted[2])
f1 = relay.Function([p01], out) f1 = relay.Function([p01], out)
f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
X = relay.var("X", shape=(1, dim)) X = relay.var("X", shape=(1, dim))
W = relay.var("W", shape=(3 * dim, dim)) W = relay.var("W", shape=(3 * dim, dim))
...@@ -306,11 +324,13 @@ def test_tuple_get_root(): ...@@ -306,11 +324,13 @@ def test_tuple_get_root():
splitted = relay.split(p0, indices_or_sections=3, axis=1) splitted = relay.split(p0, indices_or_sections=3, axis=1)
out = splitted[0] out = splitted[0]
f0 = relay.Function([p0], out) f0 = relay.Function([p0], out)
f0 = f0.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
p01 = relay.var("p01", shape=(1, dim)) p01 = relay.var("p01", shape=(1, dim))
p1 = relay.var("p1", shape=(dim, dim)) p1 = relay.var("p1", shape=(dim, dim))
out = relay.nn.dense(p01, p1) out = relay.nn.dense(p01, p1)
f1 = relay.Function([p01, p1], out) f1 = relay.Function([p01, p1], out)
f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
X = relay.var("X", shape=(1, 3 * dim)) X = relay.var("X", shape=(1, 3 * dim))
W = relay.var("W", shape=(dim, dim)) W = relay.var("W", shape=(dim, dim))
...@@ -346,8 +366,9 @@ def test_tuple_intermediate(): ...@@ -346,8 +366,9 @@ def test_tuple_intermediate():
def expected(p0): def expected(p0):
f0 = before(p0) f0 = before(p0)
f1 = f0.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
x = relay.var("x", shape=dshape) x = relay.var("x", shape=dshape)
y = relay.Call(f0, [x]) y = relay.Call(f1, [x])
return relay.Function([x], y) return relay.Function([x], y)
dshape = (1, 16, 64, 64) dshape = (1, 16, 64, 64)
...@@ -388,15 +409,18 @@ def test_tuple_consecutive(): ...@@ -388,15 +409,18 @@ def test_tuple_consecutive():
p0 = relay.var("p0", shape=dshape) p0 = relay.var("p0", shape=dshape)
concat = gen_consecutive_tuple(p0) concat = gen_consecutive_tuple(p0)
f0 = relay.Function([p0], concat) f0 = relay.Function([p0], concat)
f0 = f0.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
p01 = relay.var("p01", shape=(1, dshape[1]*9, dshape[2], dshape[3])) p01 = relay.var("p01", shape=(1, dshape[1]*9, dshape[2], dshape[3]))
pooled = relay.nn.max_pool2d(p01, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) pooled = relay.nn.max_pool2d(p01, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
out = relay.add(pooled, relay.const(1, "float32")) out = relay.add(pooled, relay.const(1, "float32"))
f1 = relay.Function([p01], out) f1 = relay.Function([p01], out)
f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
p02 = relay.var("p02", shape=(1, dshape[1]*9, dshape[2]//2, dshape[3]//2)) p02 = relay.var("p02", shape=(1, dshape[1]*9, dshape[2]//2, dshape[3]//2))
out = relay.add(p02, relay.const(1, "float32")) out = relay.add(p02, relay.const(1, "float32"))
f2 = relay.Function([p02], out) f2 = relay.Function([p02], out)
f2 = f2.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
x = relay.var("x", shape=dshape) x = relay.var("x", shape=dshape)
y = relay.Call(f0, [x]) y = relay.Call(f0, [x])
...@@ -438,30 +462,36 @@ def test_inception_like(): ...@@ -438,30 +462,36 @@ def test_inception_like():
p0 = relay.var("p0", shape=dshape) p0 = relay.var("p0", shape=dshape)
c = conv(p0) c = conv(p0)
f0 = relay.Function(relay.analysis.free_vars(c), c) f0 = relay.Function(relay.analysis.free_vars(c), c)
f0 = f0.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
p01 = relay.var("p01", shape=dshape) p01 = relay.var("p01", shape=dshape)
c = conv(p01) c = conv(p01)
f1 = relay.Function(relay.analysis.free_vars(c), c) f1 = relay.Function(relay.analysis.free_vars(c), c)
f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
p02 = relay.var("p02", shape=dshape) p02 = relay.var("p02", shape=dshape)
p12 = relay.var("p12", shape=dshape) p12 = relay.var("p12", shape=dshape)
concat1 = relay.concatenate((p02, p12), axis=1) concat1 = relay.concatenate((p02, p12), axis=1)
f_concat1 = relay.Function([p02, p12], concat1) f_concat1 = relay.Function([p02, p12], concat1)
f_concat1 = f_concat1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
dshape2 = (dshape[0], dshape[1]*2, dshape[2], dshape[3]) dshape2 = (dshape[0], dshape[1]*2, dshape[2], dshape[3])
p03 = relay.var("p03", shape=dshape2) p03 = relay.var("p03", shape=dshape2)
c = conv(p03) c = conv(p03)
f2 = relay.Function(relay.analysis.free_vars(c), c) f2 = relay.Function(relay.analysis.free_vars(c), c)
f2 = f2.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
p04 = relay.var("p04", shape=dshape2) p04 = relay.var("p04", shape=dshape2)
c = conv(p04) c = conv(p04)
f3 = relay.Function(relay.analysis.free_vars(c), c) f3 = relay.Function(relay.analysis.free_vars(c), c)
f3 = f3.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
p05 = relay.var("p05", shape=dshape) p05 = relay.var("p05", shape=dshape)
p15 = relay.var("p15", shape=dshape) p15 = relay.var("p15", shape=dshape)
concat2 = relay.concatenate((p05, p15), axis=1) concat2 = relay.concatenate((p05, p15), axis=1)
f_concat2 = relay.Function([p05, p15], concat2) f_concat2 = relay.Function([p05, p15], concat2)
f_concat2 = f_concat2.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
x = relay.var("x", shape=dshape) x = relay.var("x", shape=dshape)
c1 = relay.Call(f0, [x, relay.var("w1")]) c1 = relay.Call(f0, [x, relay.var("w1")])
...@@ -499,6 +529,7 @@ def test_fuse_parallel_injective(): ...@@ -499,6 +529,7 @@ def test_fuse_parallel_injective():
u = relay.transpose(y, axes=[0, 1]) u = relay.transpose(y, axes=[0, 1])
w = relay.left_shift(z, u) w = relay.left_shift(z, u)
f1 = relay.Function([x], w) f1 = relay.Function([x], w)
f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
x = relay.var("x", shape=(10, 20)) x = relay.var("x", shape=(10, 20))
y = relay.Call(f1, [x]) y = relay.Call(f1, [x])
return relay.Function([x], y) return relay.Function([x], y)
...@@ -529,6 +560,7 @@ def test_immutable(): ...@@ -529,6 +560,7 @@ def test_immutable():
z = relay.exp(y) z = relay.exp(y)
w = relay.squeeze(z) w = relay.squeeze(z)
f1 = relay.Function([x], w) f1 = relay.Function([x], w)
f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
x = relay.var("x", shape=(10, 20)) x = relay.var("x", shape=(10, 20))
y = relay.Call(f1, [x]) y = relay.Call(f1, [x])
mod = tvm.IRModule() mod = tvm.IRModule()
...@@ -570,6 +602,7 @@ def test_fuse_max(): ...@@ -570,6 +602,7 @@ def test_fuse_max():
for i in range(max_fused_ops): for i in range(max_fused_ops):
y = relay.exp(y) y = relay.exp(y)
f1 = relay.Function([x], y) f1 = relay.Function([x], y)
f1 = f1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
x = relay.var("x", shape=(10, 20)) x = relay.var("x", shape=(10, 20))
z = relay.Call(f1, [x]) z = relay.Call(f1, [x])
xx = relay.var("pp", shape=(10, 20)) xx = relay.var("pp", shape=(10, 20))
...@@ -577,6 +610,7 @@ def test_fuse_max(): ...@@ -577,6 +610,7 @@ def test_fuse_max():
for i in range(n-max_fused_ops): for i in range(n-max_fused_ops):
yy = relay.exp(yy) yy = relay.exp(yy)
f2 = relay.Function([xx], yy) f2 = relay.Function([xx], yy)
f2 = f2.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
zz = relay.Call(f2, [z]) zz = relay.Call(f2, [z])
return relay.Function([x], zz) return relay.Function([x], zz)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment