Commit 5fe5ceee by Zhi Committed by Tianqi Chen

Check function attr for alpha equal (#4479)

parent 03a59bc9
......@@ -267,6 +267,8 @@ class RelayHashHandler:
hash = Combine(hash, TypeHash(func->ret_type));
hash = Combine(hash, ExprHash(func->body));
hash = Combine(hash, AttrHash(func->attrs));
return hash;
}
......
......@@ -313,6 +313,29 @@ def test_tuple_get_item_alpha_equal():
assert alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 1))
def test_multi_node_subgraph():
x0 = relay.var('x0', shape=(10, 10))
w00 = relay.var('w00', shape=(10, 10))
w01 = relay.var('w01', shape=(10, 10))
w02 = relay.var('w02', shape=(10, 10))
z00 = relay.add(x0, w00)
p00 = relay.subtract(z00, w01)
q00 = relay.multiply(p00, w02)
func0 = relay.Function([x0, w00, w01, w02], q00)
func0 = func0.set_attribute("FuncName", tvm.expr.StringImm("a"))
x1 = relay.var('x1', shape=(10, 10))
w10 = relay.var('w10', shape=(10, 10))
w11 = relay.var('w11', shape=(10, 10))
w12 = relay.var('w12', shape=(10, 10))
z10 = relay.add(x1, w10)
p10 = relay.subtract(z10, w11)
q10 = relay.multiply(p10, w12)
func1 = relay.Function([x1, w10, w11, w12], q10)
func1 = func1.set_attribute("FuncName", tvm.expr.StringImm("b"))
assert not alpha_equal(func0, func1)
def test_function_alpha_equal():
tt1 = relay.TensorType((1, 2, 3), "float32")
tt2 = relay.TensorType((4, 5, 6), "int8")
......@@ -639,6 +662,7 @@ if __name__ == "__main__":
test_tuple_alpha_equal()
test_tuple_get_item_alpha_equal()
test_function_alpha_equal()
test_function_attr()
test_call_alpha_equal()
test_let_alpha_equal()
test_if_alpha_equal()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment