Commit 43cc89bf by Haichen Shen Committed by Zhi

[Bugfix] Fix the issue that function pass modifies original module (#3712)

* fix

* fix interpreter
parent 3b287c4d
......@@ -269,7 +269,6 @@ class Interpreter(Executor):
self.mod = mod
self.ctx = ctx
self.target = target
self._intrp = _backend.CreateInterpreter(mod, ctx, target)
def optimize(self):
"""Optimize functions in a module.
......@@ -313,5 +312,6 @@ class Interpreter(Executor):
mod = self.optimize()
opt_expr = Call(mod["main"], relay_args)
return self._intrp(opt_expr)
_intrp = _backend.CreateInterpreter(mod, self.ctx, self.target)
return _intrp(opt_expr)
return _interp_wrapper
......@@ -314,11 +314,10 @@ Module FunctionPassNode::operator()(const Module& mod,
<< " with opt level: "
<< pass_info->opt_level;
Module updated_mod = mod;
// Execute the pass function and return a new module.
Module updated_mod = ModuleNode::make(mod->functions, mod->type_definitions);
std::vector<std::pair<GlobalVar, Function> > updates;
auto original = mod->functions;
for (const auto& it : original) {
for (const auto& it : updated_mod->functions) {
auto updated_func = SkipFunction(it.second)
? it.second
: pass_func(it.second, updated_mod, pass_ctx);
......
......@@ -512,6 +512,35 @@ def test_fuse_parallel_injective():
assert relay.analysis.alpha_equal(zz, after)
def test_immutable():
"""Verify the fusion pass won't change original module."""
def before():
x = relay.var("x", shape=(10, 20))
y = relay.add(x, relay.const(1, "float32"))
z = relay.exp(y)
w = relay.squeeze(z)
mod = relay.module.Module()
mod["main"] = relay.Function([x], w)
return mod
def expected():
x = relay.var("p", shape=(10, 20))
y = relay.add(x, relay.const(1, "float32"))
z = relay.exp(y)
w = relay.squeeze(z)
f1 = relay.Function([x], w)
x = relay.var("x", shape=(10, 20))
y = relay.Call(f1, [x])
mod = relay.module.Module()
mod["main"] = relay.Function([x], y)
return mod
mod = before()
new_mod = transform.FuseOps(fuse_opt_level=2)(mod)
assert relay.analysis.alpha_equal(mod, before())
assert relay.analysis.alpha_equal(new_mod, expected())
if __name__ == "__main__":
test_fuse_simple()
test_conv2d_fuse()
......@@ -525,3 +554,4 @@ if __name__ == "__main__":
test_tuple_consecutive()
test_inception_like()
test_fuse_parallel_injective()
test_immutable()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment