Commit bca7914a by Haichen Shen Committed by Yao Wang

[Relay][Fix] Fix alter op layout when calling a global var (#4454)

* [Relay][Fix] Fix alter op layout when calling a global var

* add test case
parent 0531a3e4
...@@ -161,6 +161,9 @@ std::tuple<Array<Layout>, Array<Layout>, bool> CallInfer( ...@@ -161,6 +161,9 @@ std::tuple<Array<Layout>, Array<Layout>, bool> CallInfer(
const Array<Layout>& old_in_layouts, const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr> > &old_in_shapes) { const Array<Array<IndexExpr> > &old_in_shapes) {
static auto finfer_layout = Op::GetAttr<FInferCorrectLayout>("FInferCorrectLayout"); static auto finfer_layout = Op::GetAttr<FInferCorrectLayout>("FInferCorrectLayout");
if (!call->op.as<OpNode>()) {
return std::make_tuple<>(Array<Layout>(nullptr), Array<Layout>(nullptr), false);
}
Op op = Downcast<Op>(call->op); Op op = Downcast<Op>(call->op);
if (finfer_layout.count(op)) { if (finfer_layout.count(op)) {
......
...@@ -931,6 +931,47 @@ def test_alter_layout_nhwc_nchw_arm(): ...@@ -931,6 +931,47 @@ def test_alter_layout_nhwc_nchw_arm():
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
def test_alter_op_with_global_var():
"""Test directly replacing an operator with a new one"""
def before():
x = relay.var("x", shape=(1, 64, 56, 56))
weight = relay.var('weight', shape=(64, 64, 3, 3))
y = relay.nn.conv2d(x, weight,
channels=64,
kernel_size=(3, 3),
padding=(1, 1))
y = relay.nn.relu(y)
mod = relay.Module()
foo = relay.GlobalVar('foo')
mod[foo] = relay.Function([x, weight], y)
mod["main"] = relay.Function([x, weight], foo(x, weight))
return mod
def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs
weight = relay.multiply(weight, relay.const(2.0, "float32"))
return relay.nn.conv2d(data, weight, **attrs)
def expected():
x = relay.var("x", shape=(1, 64, 56, 56))
weight = relay.var('weight', shape=(64, 64, 3, 3))
y = relay.nn.conv2d(x, relay.multiply(weight, relay.const(2.0, "float32")),
channels=64,
kernel_size=(3, 3),
padding=(1, 1))
y = relay.nn.relu(y)
mod = relay.Module()
foo = relay.GlobalVar('foo')
mod[foo] = relay.Function([x, weight], y)
mod["main"] = relay.Function([x, weight], foo(x, weight))
return mod
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
a = before()
a = transform.AlterOpLayout()(a)
b = transform.InferType()(expected())
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
if __name__ == "__main__": if __name__ == "__main__":
test_alter_op() test_alter_op()
...@@ -949,3 +990,4 @@ if __name__ == "__main__": ...@@ -949,3 +990,4 @@ if __name__ == "__main__":
test_alter_layout_pool() test_alter_layout_pool()
test_alter_layout_sum() test_alter_layout_sum()
test_alter_layout_nhwc_nchw_arm() test_alter_layout_nhwc_nchw_arm()
test_alter_op_with_global_var()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment