Commit 43cc89bf by Haichen Shen Committed by Zhi

[Bugfix] Fix the issue that function pass modifies original module (#3712)

* fix

* fix interpreter
parent 3b287c4d
...@@ -269,7 +269,6 @@ class Interpreter(Executor): ...@@ -269,7 +269,6 @@ class Interpreter(Executor):
self.mod = mod self.mod = mod
self.ctx = ctx self.ctx = ctx
self.target = target self.target = target
self._intrp = _backend.CreateInterpreter(mod, ctx, target)
def optimize(self): def optimize(self):
"""Optimize functions in a module. """Optimize functions in a module.
...@@ -313,5 +312,6 @@ class Interpreter(Executor): ...@@ -313,5 +312,6 @@ class Interpreter(Executor):
mod = self.optimize() mod = self.optimize()
opt_expr = Call(mod["main"], relay_args) opt_expr = Call(mod["main"], relay_args)
return self._intrp(opt_expr) _intrp = _backend.CreateInterpreter(mod, self.ctx, self.target)
return _intrp(opt_expr)
return _interp_wrapper return _interp_wrapper
...@@ -314,11 +314,10 @@ Module FunctionPassNode::operator()(const Module& mod, ...@@ -314,11 +314,10 @@ Module FunctionPassNode::operator()(const Module& mod,
<< " with opt level: " << " with opt level: "
<< pass_info->opt_level; << pass_info->opt_level;
Module updated_mod = mod;
// Execute the pass function and return a new module. // Execute the pass function and return a new module.
Module updated_mod = ModuleNode::make(mod->functions, mod->type_definitions);
std::vector<std::pair<GlobalVar, Function> > updates; std::vector<std::pair<GlobalVar, Function> > updates;
auto original = mod->functions; for (const auto& it : updated_mod->functions) {
for (const auto& it : original) {
auto updated_func = SkipFunction(it.second) auto updated_func = SkipFunction(it.second)
? it.second ? it.second
: pass_func(it.second, updated_mod, pass_ctx); : pass_func(it.second, updated_mod, pass_ctx);
......
...@@ -512,6 +512,35 @@ def test_fuse_parallel_injective(): ...@@ -512,6 +512,35 @@ def test_fuse_parallel_injective():
assert relay.analysis.alpha_equal(zz, after) assert relay.analysis.alpha_equal(zz, after)
def test_immutable():
"""Verify the fusion pass won't change original module."""
def before():
x = relay.var("x", shape=(10, 20))
y = relay.add(x, relay.const(1, "float32"))
z = relay.exp(y)
w = relay.squeeze(z)
mod = relay.module.Module()
mod["main"] = relay.Function([x], w)
return mod
def expected():
x = relay.var("p", shape=(10, 20))
y = relay.add(x, relay.const(1, "float32"))
z = relay.exp(y)
w = relay.squeeze(z)
f1 = relay.Function([x], w)
x = relay.var("x", shape=(10, 20))
y = relay.Call(f1, [x])
mod = relay.module.Module()
mod["main"] = relay.Function([x], y)
return mod
mod = before()
new_mod = transform.FuseOps(fuse_opt_level=2)(mod)
assert relay.analysis.alpha_equal(mod, before())
assert relay.analysis.alpha_equal(new_mod, expected())
if __name__ == "__main__": if __name__ == "__main__":
test_fuse_simple() test_fuse_simple()
test_conv2d_fuse() test_conv2d_fuse()
...@@ -525,3 +554,4 @@ if __name__ == "__main__": ...@@ -525,3 +554,4 @@ if __name__ == "__main__":
test_tuple_consecutive() test_tuple_consecutive()
test_inception_like() test_inception_like()
test_fuse_parallel_injective() test_fuse_parallel_injective()
test_immutable()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment