Unverified Commit c9a2f3da by Tianqi Chen Committed by GitHub

[RELAY] Pass infra cleanup (#3336)

parent d6c4aba8
...@@ -202,7 +202,8 @@ class PassInfoNode : public RelayNode { ...@@ -202,7 +202,8 @@ class PassInfoNode : public RelayNode {
v->Visit("required", &required); v->Visit("required", &required);
} }
TVM_DLL static PassInfo make(int opt_level, std::string name, TVM_DLL static PassInfo make(int opt_level,
std::string name,
tvm::Array<tvm::Expr> required); tvm::Array<tvm::Expr> required);
static constexpr const char* _type_key = "relay.PassInfo"; static constexpr const char* _type_key = "relay.PassInfo";
...@@ -467,7 +468,7 @@ TVM_DLL Pass SimplifyInference(); ...@@ -467,7 +468,7 @@ TVM_DLL Pass SimplifyInference();
* type information filled in, as well as it's checked type field * type information filled in, as well as it's checked type field
* populated with the result type. * populated with the result type.
* *
* \return The pass. * \return The pass.
*/ */
TVM_DLL Pass InferType(); TVM_DLL Pass InferType();
......
...@@ -14,13 +14,9 @@ ...@@ -14,13 +14,9 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=no-else-return
# pylint: disable=unidiomatic-typecheck
# pylint: disable=invalid-name # pylint: disable=invalid-name
""" """
This file contains the pass manager for Relay which exposes different Relay pass transformation infrastructure.
granularity of interfaces for users to implement and use passes more
conveniently.
""" """
import types import types
...@@ -39,19 +35,19 @@ class PassInfo(RelayNode): ...@@ -39,19 +35,19 @@ class PassInfo(RelayNode):
Parameters Parameters
---------- ----------
name : str
The pass name.
opt_level : int opt_level : int
The optimization level of this pass. The optimization level of this pass.
name : str
The pass name.
required : List[str] required : List[str]
The list of passes that are required by a certain pass. The list of passes that are required by a certain pass.
""" """
def __init__(self, name, opt_level, required=None): def __init__(self, opt_level, name, required=None):
self.__init_handle_by_constructor__(_transform.PassInfo, name, opt_level, self.__init_handle_by_constructor__(
required) _transform.PassInfo, opt_level, name, required)
@register_relay_node @register_relay_node
...@@ -194,7 +190,7 @@ class ModulePass(Pass): ...@@ -194,7 +190,7 @@ class ModulePass(Pass):
`module_pass`, because the design of the `module_pass` API is flexible `module_pass`, because the design of the `module_pass` API is flexible
enough to handle the creation of a module pass in different manners. In enough to handle the creation of a module pass in different manners. In
addition, all members of a module pass can be accessed from the base class. addition, all members of a module pass can be accessed from the base class.
The same rule applies to FunctionPass and Sequential as well. The same rule applies to FunctionPass as well.
""" """
...@@ -250,153 +246,6 @@ class Sequential(Pass): ...@@ -250,153 +246,6 @@ class Sequential(Pass):
passes, opt_level, name, required) passes, opt_level, name, required)
def module_pass(pass_func=None, opt_level=None, name=None, required=None):
"""Create a module pass. This function returns a callback when pass_func
is provided. Otherwise, it returns the created module level pass using the
given optimization function.
Parameters
----------
pass_func : Optional[Callable[(Module/Function, PassContext) ->
Module/Function]]
The implemented optimization pass.
opt_level : int
The optimization level of this module pass.
name : Optional[str]
The name of the module pass. The name could be empty. In this case, the
name of the optimization function will be used as the pass name.
required : Optional[List[str]]
The list of passes that the module pass is dependent on.
Returns
-------
create_module_pass : Union[Callable, ModulePass]
The callable that will create a module pass is returned when
pass_func is not passed in. Otherwise, a ModulePass object will be
directly created.
Examples
--------
The following code creates a module level pass and adds an abs function to
the module.
.. code-block:: python
@relay.transform.module_pass(opt_level=2)
def transform(mod, ctx):
tp = relay.TensorType((10,), "float32")
x = relay.var("x", tp)
gv = relay.GlobalVar("var")
func = relay.Function([x], relay.abs(x))
new_mod = relay.Module({gv: func})
new_mod.update(mod)
return new_mod
module_pass = transform
assert isinstance(module_pass, transform.ModulePass)
assert module_pass.info.opt_level == 2
# Given a module m, the optimization could be invoked as the follwoing:
updated_mod = module_pass(m)
# Now a function abs should be added to the module m.
"""
if opt_level is None:
raise ValueError("Please provide opt_level for the module pass.")
required = required if required else []
if not isinstance(required, (list, tuple)):
raise TypeError("Required is expected to be the type of " +
"list/tuple.")
def create_module_pass(pass_func):
"""Internal function that creates a module pass"""
if not isinstance(pass_func, (types.FunctionType, types.LambdaType)):
raise TypeError("pass_func must be a callable for Module pass")
return _transform.CreateModulePass(
pass_func, opt_level, name if name else pass_func.__name__,
required)
if pass_func:
return create_module_pass(pass_func)
return create_module_pass
def function_pass(pass_func=None, opt_level=None, name=None, required=None):
"""Create a function pass. This function returns a callback when pass_func
is provided. Otherwise, it returns the created function pass using the
given optimization function.
Parameters
----------
pass_func : Optional[Callable[(Module/Function, PassContext) ->
Module/Function]]
The implemented optimization pass.
opt_level : int
The optimization level of this module pass.
name : Optional[str]
The name of the function pass. The name could be empty. In this case, the
name of the optimization function will be used as the pass name.
required : Optional[List[str]]
The list of passes that the module pass is dependent on.
Returns
-------
create_function_pass : Union[Callable, FunctionPass]
The callable that will create a function pass is returned when
pass_func is not passed in. Otherwise, a FunctionPass object will be
created.
Examples
--------
The following code creates a function level pass that performs constant
folding.
.. code-block:: python
@relay.transform.function_pass(opt_level=2)
def transform(func, ctx):
return ir_pass.fold_constant(func)
function_pass = transform
assert isinstance(function_pass, transform.FunctionPass)
assert function_pass.info.opt_level == 2
# Given a module m, the optimization could be invoked as the follwoing:
updated_mod = function_pass(m)
# Now constant folding should have been applied to every function in
# the provided module m. And the updated module will be returned.
"""
if opt_level is None:
raise ValueError("Please provide opt_level for the funtion pass.")
required = required if required else []
if not isinstance(required, (list, tuple)):
raise TypeError("Required is expected to be the type of " +
"list/tuple.")
def create_function_pass(pass_func):
"""Internal function that creates a function pass"""
if not isinstance(pass_func, (types.FunctionType, types.LambdaType)):
raise TypeError("pass_func must be a callable for Module pass")
return _transform.CreateFunctionPass(
pass_func, opt_level, name if name else pass_func.__name__,
required)
if pass_func:
return create_function_pass(pass_func)
return create_function_pass
def InferType(): def InferType():
"""Infer the type of an expr. """Infer the type of an expr.
...@@ -593,3 +442,150 @@ def PartialEvaluate(): ...@@ -593,3 +442,150 @@ def PartialEvaluate():
The registered pass that performs partial evaluation on an expression. The registered pass that performs partial evaluation on an expression.
""" """
return _transform.PartialEvaluate() return _transform.PartialEvaluate()
def module_pass(pass_func=None, opt_level=None, name=None, required=None):
"""Create a module pass. This function returns a callback when pass_func
is provided. Otherwise, it returns the created module level pass using the
given optimization function.
Parameters
----------
pass_func : Optional[Callable[(Module/Function, PassContext) ->
Module/Function]]
The implemented optimization pass.
opt_level : int
The optimization level of this module pass.
name : Optional[str]
The name of the module pass. The name could be empty. In this case, the
name of the optimization function will be used as the pass name.
required : Optional[List[str]]
The list of passes that the module pass is dependent on.
Returns
-------
create_module_pass : Union[Callable, ModulePass]
The callable that will create a module pass is returned when
pass_func is not passed in. Otherwise, a ModulePass object will be
directly created.
Examples
--------
The following code creates a module level pass and adds an abs function to
the module.
.. code-block:: python
@relay.transform.module_pass(opt_level=2)
def transform(mod, ctx):
tp = relay.TensorType((10,), "float32")
x = relay.var("x", tp)
gv = relay.GlobalVar("var")
func = relay.Function([x], relay.abs(x))
new_mod = relay.Module({gv: func})
new_mod.update(mod)
return new_mod
module_pass = transform
assert isinstance(module_pass, transform.ModulePass)
assert module_pass.info.opt_level == 2
# Given a module m, the optimization could be invoked as the follwoing:
updated_mod = module_pass(m)
# Now a function abs should be added to the module m.
"""
if opt_level is None:
raise ValueError("Please provide opt_level for the module pass.")
required = required if required else []
if not isinstance(required, (list, tuple)):
raise TypeError("Required is expected to be the type of " +
"list/tuple.")
def create_module_pass(pass_func):
"""Internal function that creates a module pass"""
if not isinstance(pass_func, (types.FunctionType, types.LambdaType)):
raise TypeError("pass_func must be a callable for Module pass")
fname = name if name else pass_func.__name__
info = PassInfo(opt_level, fname, required)
return _transform.MakeModulePass(pass_func, info)
if pass_func:
return create_module_pass(pass_func)
return create_module_pass
def function_pass(pass_func=None, opt_level=None, name=None, required=None):
"""Create a function pass. This function returns a callback when pass_func
is provided. Otherwise, it returns the created function pass using the
given optimization function.
Parameters
----------
pass_func : Optional[Callable[(Module/Function, PassContext) ->
Module/Function]]
The implemented optimization pass.
opt_level : int
The optimization level of this module pass.
name : Optional[str]
The name of the function pass. The name could be empty. In this case, the
name of the optimization function will be used as the pass name.
required : Optional[List[str]]
The list of passes that the module pass is dependent on.
Returns
-------
create_function_pass : Union[Callable, FunctionPass]
The callable that will create a function pass is returned when
pass_func is not passed in. Otherwise, a FunctionPass object will be
created.
Examples
--------
The following code creates a function level pass that performs constant
folding.
.. code-block:: python
@relay.transform.function_pass(opt_level=2)
def transform(func, ctx):
return ir_pass.fold_constant(func)
function_pass = transform
assert isinstance(function_pass, transform.FunctionPass)
assert function_pass.info.opt_level == 2
# Given a module m, the optimization could be invoked as the follwoing:
updated_mod = function_pass(m)
# Now constant folding should have been applied to every function in
# the provided module m. And the updated module will be returned.
"""
if opt_level is None:
raise ValueError("Please provide opt_level for the funtion pass.")
required = required if required else []
if not isinstance(required, (list, tuple)):
raise TypeError("Required is expected to be the type of " +
"list/tuple.")
def create_function_pass(pass_func):
"""Internal function that creates a function pass"""
if not isinstance(pass_func, (types.FunctionType, types.LambdaType)):
raise TypeError("pass_func must be a callable for Module pass")
fname = name if name else pass_func.__name__
info = PassInfo(opt_level, fname, required)
return _transform.MakeFunctionPass(pass_func, info)
if pass_func:
return create_function_pass(pass_func)
return create_function_pass
...@@ -465,8 +465,8 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -465,8 +465,8 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
TVM_REGISTER_NODE_TYPE(ModulePassNode); TVM_REGISTER_NODE_TYPE(ModulePassNode);
TVM_REGISTER_API("relay._transform.CreateModulePass") TVM_REGISTER_API("relay._transform.MakeModulePass")
.set_body_typed(CreateModulePass); .set_body_typed(ModulePassNode::make);
TVM_REGISTER_API("relay._transform.RunPass") TVM_REGISTER_API("relay._transform.RunPass")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
...@@ -485,8 +485,8 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -485,8 +485,8 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
TVM_REGISTER_NODE_TYPE(FunctionPassNode); TVM_REGISTER_NODE_TYPE(FunctionPassNode);
TVM_REGISTER_API("relay._transform.CreateFunctionPass") TVM_REGISTER_API("relay._transform.MakeFunctionPass")
.set_body_typed(CreateFunctionPass); .set_body_typed(FunctionPassNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<FunctionPassNode>([](const FunctionPassNode* node, .set_dispatch<FunctionPassNode>([](const FunctionPassNode* node,
......
...@@ -259,6 +259,12 @@ def test_function_pass(): ...@@ -259,6 +259,12 @@ def test_function_pass():
test_pass_run() test_pass_run()
def test_pass_info():
info = relay.transform.PassInfo(opt_level=1, name="xyz")
assert info.opt_level == 1
assert info.name == "xyz"
def test_sequential_pass(): def test_sequential_pass():
shape = (10, ) shape = (10, )
dtype = 'float32' dtype = 'float32'
...@@ -449,3 +455,4 @@ if __name__ == "__main__": ...@@ -449,3 +455,4 @@ if __name__ == "__main__":
test_function_pass() test_function_pass()
test_sequential_pass() test_sequential_pass()
test_sequential_with_scoping() test_sequential_with_scoping()
test_pass_info()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment