Unverified Commit 07fbe5c8 by Tianqi Chen Committed by GitHub

[RELAY][PASS] Enable decorating python class as Pass (#3364)

parent 133bb250
Subproject commit 3943914eed66470bd010df581e29e4dca4f7df6f Subproject commit fbe142b267a8edd1f1188fa2140d88f7ae308661
...@@ -101,6 +101,7 @@ const = expr.const ...@@ -101,6 +101,7 @@ const = expr.const
bind = expr.bind bind = expr.bind
module_pass = transform.module_pass module_pass = transform.module_pass
function_pass = transform.function_pass function_pass = transform.function_pass
alpha_equal = ir_pass.alpha_equal
# ExprFunctor # ExprFunctor
ExprFunctor = expr_functor.ExprFunctor ExprFunctor = expr_functor.ExprFunctor
......
...@@ -19,6 +19,8 @@ ...@@ -19,6 +19,8 @@
Relay pass transformation infrastructure. Relay pass transformation infrastructure.
""" """
import types import types
import inspect
import functools
from tvm._ffi.runtime_ctypes import TVMContext from tvm._ffi.runtime_ctypes import TVMContext
from . import _transform from . import _transform
...@@ -444,16 +446,47 @@ def PartialEvaluate(): ...@@ -444,16 +446,47 @@ def PartialEvaluate():
return _transform.PartialEvaluate() return _transform.PartialEvaluate()
def _wrap_class_module_pass(pass_cls, pass_info):
"""Wrap a python class as function pass"""
class PyModulePass(ModulePass):
"""Internal wrapper class to create a class instance."""
def __init__(self, *args, **kwargs):
# initialize handle in cass pass_cls creation failed.fg
self.handle = None
inst = pass_cls(*args, **kwargs)
# it is important not to capture self to
# avoid a cyclic dependency
def _pass_func(mod, ctx):
return inst.transform_module(mod, ctx)
self.__init_handle_by_constructor__(
_transform.MakeModulePass, _pass_func, pass_info)
self._inst = inst
def __getattr__(self, name):
# fall back to instance attribute if there is not any
return self._inst.__getattribute__(name)
functools.update_wrapper(PyModulePass.__init__, pass_cls.__init__)
PyModulePass.__name__ = pass_cls.__name__
PyModulePass.__doc__ = pass_cls.__doc__
PyModulePass.__module__ = pass_cls.__module__
return PyModulePass
def module_pass(pass_func=None, opt_level=None, name=None, required=None): def module_pass(pass_func=None, opt_level=None, name=None, required=None):
"""Create a module pass. This function returns a callback when pass_func """Decorate a module pass.
is provided. Otherwise, it returns the created module level pass using the
given optimization function. This function returns a callback when pass_func is provided.
Otherwise, it serves a decorator function.
pass_func can also be a class type with a method transform_module.
This function will create a decorated ModulePass using transform_module
as the pass function.
Parameters Parameters
---------- ----------
pass_func : Optional[Callable[(Module/Function, PassContext) -> pass_func : Optional[Callable[(Module, PassContext) ->Module]]
Module/Function]] The transformation function or class.
The implemented optimization pass.
opt_level : int opt_level : int
The optimization level of this module pass. The optimization level of this module pass.
...@@ -468,14 +501,39 @@ def module_pass(pass_func=None, opt_level=None, name=None, required=None): ...@@ -468,14 +501,39 @@ def module_pass(pass_func=None, opt_level=None, name=None, required=None):
Returns Returns
------- -------
create_module_pass : Union[Callable, ModulePass] create_module_pass : Union[Callable, ModulePass]
The callable that will create a module pass is returned when A decorator will be returned if pass_func is not provided,
pass_func is not passed in. Otherwise, a ModulePass object will be otherwise return the decorated result.
directly created. The returned decorator has two behaviors depending on the input:
A new ModulePass will be returned when we decorate a pass function.
A new ModulePass class will be returned when we decorate a class type.
Examples Examples
-------- --------
The following code creates a module level pass and adds an abs function to The following code block decorates a module pass class.
the module.
.. code-block:: python
@relay.transform.module_pass
class CustomPipeline:
def __init__(self, enable_fold):
self.enable_fold = enable_fold
self.cse = relay.transform.EliminateCommonSubexpr()
self.const_fold = relay.transform.FoldConstant()
def transform_module(self, mod, ctx):
mod = self.cse(mod, ctx)
if self.enable_fold:
mod = self.const_fold(mod, ctx)
return mod
# create an instance of customized pipeline
pipeline = CustomPipeline(enable_fold=False)
assert isinstance(pipeline, transform.ModulePass)
# run the pipeline.
output_module = pipeline(input_module)
The following code creates a module pass by decorating
a user defined transform function.
.. code-block:: python .. code-block:: python
...@@ -497,7 +555,6 @@ def module_pass(pass_func=None, opt_level=None, name=None, required=None): ...@@ -497,7 +555,6 @@ def module_pass(pass_func=None, opt_level=None, name=None, required=None):
updated_mod = module_pass(m) updated_mod = module_pass(m)
# Now a function abs should be added to the module m. # Now a function abs should be added to the module m.
""" """
if opt_level is None: if opt_level is None:
raise ValueError("Please provide opt_level for the module pass.") raise ValueError("Please provide opt_level for the module pass.")
...@@ -506,30 +563,59 @@ def module_pass(pass_func=None, opt_level=None, name=None, required=None): ...@@ -506,30 +563,59 @@ def module_pass(pass_func=None, opt_level=None, name=None, required=None):
raise TypeError("Required is expected to be the type of " + raise TypeError("Required is expected to be the type of " +
"list/tuple.") "list/tuple.")
def create_module_pass(pass_func): def create_module_pass(pass_arg):
"""Internal function that creates a module pass""" """Internal function that creates a module pass"""
if not isinstance(pass_func, (types.FunctionType, types.LambdaType)): fname = name if name else pass_arg.__name__
raise TypeError("pass_func must be a callable for Module pass")
fname = name if name else pass_func.__name__
info = PassInfo(opt_level, fname, required) info = PassInfo(opt_level, fname, required)
return _transform.MakeModulePass(pass_func, info) if inspect.isclass(pass_arg):
return _wrap_class_module_pass(pass_arg, info)
if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)):
raise TypeError("pass_func must be a callable for Module pass")
return _transform.MakeModulePass(pass_arg, info)
if pass_func: if pass_func:
return create_module_pass(pass_func) return create_module_pass(pass_func)
return create_module_pass return create_module_pass
def _wrap_class_function_pass(pass_cls, pass_info):
"""Wrap a python class as function pass"""
class PyFunctionPass(FunctionPass):
"""Internal wrapper class to create a class instance."""
def __init__(self, *args, **kwargs):
# initialize handle in cass pass_cls creation failed.fg
self.handle = None
inst = pass_cls(*args, **kwargs)
# it is important not to capture self to
# avoid a cyclic dependency
def _pass_func(func, mod, ctx):
return inst.transform_function(func, mod, ctx)
self.__init_handle_by_constructor__(
_transform.MakeFunctionPass, _pass_func, pass_info)
self._inst = inst
def __getattr__(self, name):
# fall back to instance attribute if there is not any
return self._inst.__getattribute__(name)
functools.update_wrapper(PyFunctionPass.__init__, pass_cls.__init__)
PyFunctionPass.__name__ = pass_cls.__name__
PyFunctionPass.__doc__ = pass_cls.__doc__
PyFunctionPass.__module__ = pass_cls.__module__
return PyFunctionPass
def function_pass(pass_func=None, opt_level=None, name=None, required=None): def function_pass(pass_func=None, opt_level=None, name=None, required=None):
"""Create a function pass. This function returns a callback when pass_func """Decorate a function pass.
This function returns a callback when pass_func
is provided. Otherwise, it returns the created function pass using the is provided. Otherwise, it returns the created function pass using the
given optimization function. given optimization function.
Parameters Parameters
---------- ----------
pass_func : Optional[Callable[(Module/Function, PassContext) -> pass_func : Optional[Callable[(Function, Module, PassContext) -> Function]]
Module/Function]] The transformation function or class.
The implemented optimization pass.
opt_level : int opt_level : int
The optimization level of this module pass. The optimization level of this module pass.
...@@ -544,20 +630,48 @@ def function_pass(pass_func=None, opt_level=None, name=None, required=None): ...@@ -544,20 +630,48 @@ def function_pass(pass_func=None, opt_level=None, name=None, required=None):
Returns Returns
------- -------
create_function_pass : Union[Callable, FunctionPass] create_function_pass : Union[Callable, FunctionPass]
The callable that will create a function pass is returned when
pass_func is not passed in. Otherwise, a FunctionPass object will be A decorator will be returned if pass_func is not provided,
created. otherwise return the decorated result.
The returned decorator has two behaviors depending on the input:
A new FunctionPass will be returned when we decorate a pass function.
A new FunctionPass class will be returned when we decorate a class type.
Examples Examples
-------- --------
The following code creates a function level pass that performs constant The following code block decorates a function pass class.
folding.
.. code-block:: python
@relay.transform.function_pass(opt_level=1)
class TestReplaceFunc:
def __init__(self, new_func):
self.new_func = new_func
def transform_function(self, func, mod, ctx):
# just for demo purposes
# transform func to new_func
return self.new_func
x = relay.var("x", shape=(10, 20))
f1 = relay.Function([x], x)
f2 = relay.Function([x], relay.log(x))
# fpass is now a special pass that replaces every
# function to f1
fpass = TestReplaceFunc(f1)
# now every function in input_mod is replaced by f1
res_mod = fpass(input_mod)
The following code creates a function pass by decorating
a user defined transform function.
.. code-block:: python .. code-block:: python
@relay.transform.function_pass(opt_level=2) @relay.transform.function_pass(opt_level=2)
def transform(func, ctx): def transform(func, mod, ctx):
return ir_pass.fold_constant(func) # my transformations here.
return func
function_pass = transform function_pass = transform
assert isinstance(function_pass, transform.FunctionPass) assert isinstance(function_pass, transform.FunctionPass)
...@@ -577,14 +691,15 @@ def function_pass(pass_func=None, opt_level=None, name=None, required=None): ...@@ -577,14 +691,15 @@ def function_pass(pass_func=None, opt_level=None, name=None, required=None):
raise TypeError("Required is expected to be the type of " + raise TypeError("Required is expected to be the type of " +
"list/tuple.") "list/tuple.")
def create_function_pass(pass_func): def create_function_pass(pass_arg):
"""Internal function that creates a function pass""" """Internal function that creates a function pass"""
if not isinstance(pass_func, (types.FunctionType, types.LambdaType)): fname = name if name else pass_arg.__name__
raise TypeError("pass_func must be a callable for Module pass")
fname = name if name else pass_func.__name__
info = PassInfo(opt_level, fname, required) info = PassInfo(opt_level, fname, required)
return _transform.MakeFunctionPass(pass_func, info) if inspect.isclass(pass_arg):
return _wrap_class_function_pass(pass_arg, info)
if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)):
raise TypeError("pass_func must be a callable for Module pass")
return _transform.MakeFunctionPass(pass_arg, info)
if pass_func: if pass_func:
return create_function_pass(pass_func) return create_function_pass(pass_func)
......
...@@ -189,6 +189,29 @@ def test_module_pass(): ...@@ -189,6 +189,29 @@ def test_module_pass():
test_pass_run() test_pass_run()
def test_function_class_pass():
@relay.transform.function_pass(opt_level=1)
class TestReplaceFunc:
"""Simple test function to replace one argument to another."""
def __init__(self, new_func):
self.new_func = new_func
def transform_function(self, func, mod, ctx):
return self.new_func
x = relay.var("x", shape=(10, 20))
f1 = relay.Function([x], x)
f2 = relay.Function([x], relay.log(x))
fpass = TestReplaceFunc(f1)
assert fpass.info.opt_level == 1
assert fpass.info.name == "TestReplaceFunc"
mod = relay.Module.from_expr(f2)
mod = fpass(mod)
# wrap in expr
mod2 = relay.Module.from_expr(f1)
assert relay.alpha_equal(mod["main"], mod2["main"])
def test_function_pass(): def test_function_pass():
shape = (10, ) shape = (10, )
dtype = 'float32' dtype = 'float32'
...@@ -259,6 +282,30 @@ def test_function_pass(): ...@@ -259,6 +282,30 @@ def test_function_pass():
test_pass_run() test_pass_run()
def test_module_class_pass():
@relay.transform.module_pass(opt_level=1)
class TestPipeline:
"""Simple test function to replace one argument to another."""
def __init__(self, new_mod, replace):
self.new_mod = new_mod
self.replace = replace
def transform_module(self, mod, ctx):
if self.replace:
return self.new_mod
return mod
x = relay.var("x", shape=(10, 20))
m1 = relay.Module.from_expr(relay.Function([x], x))
m2 = relay.Module.from_expr(relay.Function([x], relay.log(x)))
fpass = TestPipeline(m2, replace=True)
assert fpass.info.name == "TestPipeline"
mod3 = fpass(m1)
assert mod3.same_as(m2)
mod4 = TestPipeline(m2, replace=False)(m1)
assert mod4.same_as(m1)
def test_pass_info(): def test_pass_info():
info = relay.transform.PassInfo(opt_level=1, name="xyz") info = relay.transform.PassInfo(opt_level=1, name="xyz")
assert info.opt_level == 1 assert info.opt_level == 1
...@@ -451,6 +498,8 @@ def test_sequential_with_scoping(): ...@@ -451,6 +498,8 @@ def test_sequential_with_scoping():
if __name__ == "__main__": if __name__ == "__main__":
test_function_class_pass()
test_module_class_pass()
test_module_pass() test_module_pass()
test_function_pass() test_function_pass()
test_sequential_pass() test_sequential_pass()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment