Unverified Commit ea063888 by Zhi Committed by GitHub

[BUGFIX][IR] Fix String SEqual (#5275)

* fix String SEqual

* retrigger ci
parent d430d528
...@@ -43,8 +43,8 @@ struct StringObjTrait { ...@@ -43,8 +43,8 @@ struct StringObjTrait {
SEqualReducer equal) { SEqualReducer equal) {
if (lhs == rhs) return true; if (lhs == rhs) return true;
if (lhs->size != rhs->size) return false; if (lhs->size != rhs->size) return false;
if (lhs->data != rhs->data) return true; if (lhs->data == rhs->data) return true;
return std::memcmp(lhs->data, rhs->data, lhs->size) != 0; return std::memcmp(lhs->data, rhs->data, lhs->size) == 0;
} }
}; };
......
...@@ -356,7 +356,7 @@ def test_function_attr(): ...@@ -356,7 +356,7 @@ def test_function_attr():
p00 = relay.subtract(z00, w01) p00 = relay.subtract(z00, w01)
q00 = relay.multiply(p00, w02) q00 = relay.multiply(p00, w02)
func0 = relay.Function([x0, w00, w01, w02], q00) func0 = relay.Function([x0, w00, w01, w02], q00)
func0 = func0.with_attr("FuncName", tvm.tir.StringImm("a")) func0 = func0.with_attr("FuncName", tvm.runtime.container.String("a"))
x1 = relay.var('x1', shape=(10, 10)) x1 = relay.var('x1', shape=(10, 10))
w10 = relay.var('w10', shape=(10, 10)) w10 = relay.var('w10', shape=(10, 10))
...@@ -366,7 +366,7 @@ def test_function_attr(): ...@@ -366,7 +366,7 @@ def test_function_attr():
p10 = relay.subtract(z10, w11) p10 = relay.subtract(z10, w11)
q10 = relay.multiply(p10, w12) q10 = relay.multiply(p10, w12)
func1 = relay.Function([x1, w10, w11, w12], q10) func1 = relay.Function([x1, w10, w11, w12], q10)
func1 = func1.with_attr("FuncName", tvm.tir.StringImm("b")) func1 = func1.with_attr("FuncName", tvm.runtime.container.String("b"))
assert not consistent_equal(func0, func1) assert not consistent_equal(func0, func1)
...@@ -698,7 +698,7 @@ def test_fn_attribute(): ...@@ -698,7 +698,7 @@ def test_fn_attribute():
d = relay.var('d', shape=(10, 10)) d = relay.var('d', shape=(10, 10))
add_1 = relay.add(c, d) add_1 = relay.add(c, d)
add_1_fn = relay.Function([c, d], add_1) add_1_fn = relay.Function([c, d], add_1)
add_1_fn = add_1_fn.with_attr("TestAttribute", tvm.tir.StringImm("test")) add_1_fn = add_1_fn.with_attr("TestAttribute", tvm.runtime.container.String("test"))
add_1_fn = run_opt_pass(add_1_fn, relay.transform.InferType()) add_1_fn = run_opt_pass(add_1_fn, relay.transform.InferType())
assert not consistent_equal(add_1_fn, add_fn) assert not consistent_equal(add_1_fn, add_fn)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment