Unverified Commit 07fbe5c8 by Tianqi Chen Committed by GitHub

[RELAY][PASS] Enable decorating python class as Pass (#3364)

parent 133bb250
Subproject commit 3943914eed66470bd010df581e29e4dca4f7df6f
Subproject commit fbe142b267a8edd1f1188fa2140d88f7ae308661
......@@ -101,6 +101,7 @@ const = expr.const
bind = expr.bind
module_pass = transform.module_pass
function_pass = transform.function_pass
alpha_equal = ir_pass.alpha_equal
# ExprFunctor
ExprFunctor = expr_functor.ExprFunctor
......
......@@ -189,6 +189,29 @@ def test_module_pass():
test_pass_run()
def test_function_class_pass():
@relay.transform.function_pass(opt_level=1)
class TestReplaceFunc:
"""Simple test function to replace one argument to another."""
def __init__(self, new_func):
self.new_func = new_func
def transform_function(self, func, mod, ctx):
return self.new_func
x = relay.var("x", shape=(10, 20))
f1 = relay.Function([x], x)
f2 = relay.Function([x], relay.log(x))
fpass = TestReplaceFunc(f1)
assert fpass.info.opt_level == 1
assert fpass.info.name == "TestReplaceFunc"
mod = relay.Module.from_expr(f2)
mod = fpass(mod)
# wrap in expr
mod2 = relay.Module.from_expr(f1)
assert relay.alpha_equal(mod["main"], mod2["main"])
def test_function_pass():
shape = (10, )
dtype = 'float32'
......@@ -259,6 +282,30 @@ def test_function_pass():
test_pass_run()
def test_module_class_pass():
@relay.transform.module_pass(opt_level=1)
class TestPipeline:
"""Simple test function to replace one argument to another."""
def __init__(self, new_mod, replace):
self.new_mod = new_mod
self.replace = replace
def transform_module(self, mod, ctx):
if self.replace:
return self.new_mod
return mod
x = relay.var("x", shape=(10, 20))
m1 = relay.Module.from_expr(relay.Function([x], x))
m2 = relay.Module.from_expr(relay.Function([x], relay.log(x)))
fpass = TestPipeline(m2, replace=True)
assert fpass.info.name == "TestPipeline"
mod3 = fpass(m1)
assert mod3.same_as(m2)
mod4 = TestPipeline(m2, replace=False)(m1)
assert mod4.same_as(m1)
def test_pass_info():
info = relay.transform.PassInfo(opt_level=1, name="xyz")
assert info.opt_level == 1
......@@ -451,6 +498,8 @@ def test_sequential_with_scoping():
if __name__ == "__main__":
test_function_class_pass()
test_module_class_pass()
test_module_pass()
test_function_pass()
test_sequential_pass()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment