Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
c9a2f3da
Unverified
Commit
c9a2f3da
authored
Jun 11, 2019
by
Tianqi Chen
Committed by
GitHub
Jun 11, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RELAY] Pass infra cleanup (#3336)
parent
d6c4aba8
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
169 additions
and
165 deletions
+169
-165
include/tvm/relay/transform.h
+3
-2
python/tvm/relay/transform.py
+155
-159
src/relay/pass/pass_manager.cc
+4
-4
tests/python/relay/test_pass_manager.py
+7
-0
No files found.
include/tvm/relay/transform.h
View file @
c9a2f3da
...
@@ -202,7 +202,8 @@ class PassInfoNode : public RelayNode {
...
@@ -202,7 +202,8 @@ class PassInfoNode : public RelayNode {
v
->
Visit
(
"required"
,
&
required
);
v
->
Visit
(
"required"
,
&
required
);
}
}
TVM_DLL
static
PassInfo
make
(
int
opt_level
,
std
::
string
name
,
TVM_DLL
static
PassInfo
make
(
int
opt_level
,
std
::
string
name
,
tvm
::
Array
<
tvm
::
Expr
>
required
);
tvm
::
Array
<
tvm
::
Expr
>
required
);
static
constexpr
const
char
*
_type_key
=
"relay.PassInfo"
;
static
constexpr
const
char
*
_type_key
=
"relay.PassInfo"
;
...
@@ -467,7 +468,7 @@ TVM_DLL Pass SimplifyInference();
...
@@ -467,7 +468,7 @@ TVM_DLL Pass SimplifyInference();
* type information filled in, as well as it's checked type field
* type information filled in, as well as it's checked type field
* populated with the result type.
* populated with the result type.
*
*
* \return The pass.
* \return The pass.
*/
*/
TVM_DLL
Pass
InferType
();
TVM_DLL
Pass
InferType
();
...
...
python/tvm/relay/transform.py
View file @
c9a2f3da
...
@@ -14,13 +14,9 @@
...
@@ -14,13 +14,9 @@
# KIND, either express or implied. See the License for the
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# specific language governing permissions and limitations
# under the License.
# under the License.
# pylint: disable=no-else-return
# pylint: disable=unidiomatic-typecheck
# pylint: disable=invalid-name
# pylint: disable=invalid-name
"""
"""
This file contains the pass manager for Relay which exposes different
Relay pass transformation infrastructure.
granularity of interfaces for users to implement and use passes more
conveniently.
"""
"""
import
types
import
types
...
@@ -39,19 +35,19 @@ class PassInfo(RelayNode):
...
@@ -39,19 +35,19 @@ class PassInfo(RelayNode):
Parameters
Parameters
----------
----------
name : str
The pass name.
opt_level : int
opt_level : int
The optimization level of this pass.
The optimization level of this pass.
name : str
The pass name.
required : List[str]
required : List[str]
The list of passes that are required by a certain pass.
The list of passes that are required by a certain pass.
"""
"""
def
__init__
(
self
,
name
,
opt_level
,
required
=
None
):
def
__init__
(
self
,
opt_level
,
name
,
required
=
None
):
self
.
__init_handle_by_constructor__
(
_transform
.
PassInfo
,
name
,
opt_level
,
self
.
__init_handle_by_constructor__
(
required
)
_transform
.
PassInfo
,
opt_level
,
name
,
required
)
@register_relay_node
@register_relay_node
...
@@ -194,7 +190,7 @@ class ModulePass(Pass):
...
@@ -194,7 +190,7 @@ class ModulePass(Pass):
`module_pass`, because the design of the `module_pass` API is flexible
`module_pass`, because the design of the `module_pass` API is flexible
enough to handle the creation of a module pass in different manners. In
enough to handle the creation of a module pass in different manners. In
addition, all members of a module pass can be accessed from the base class.
addition, all members of a module pass can be accessed from the base class.
The same rule applies to FunctionPass a
nd Sequential a
s well.
The same rule applies to FunctionPass as well.
"""
"""
...
@@ -250,153 +246,6 @@ class Sequential(Pass):
...
@@ -250,153 +246,6 @@ class Sequential(Pass):
passes
,
opt_level
,
name
,
required
)
passes
,
opt_level
,
name
,
required
)
def
module_pass
(
pass_func
=
None
,
opt_level
=
None
,
name
=
None
,
required
=
None
):
"""Create a module pass. This function returns a callback when pass_func
is provided. Otherwise, it returns the created module level pass using the
given optimization function.
Parameters
----------
pass_func : Optional[Callable[(Module/Function, PassContext) ->
Module/Function]]
The implemented optimization pass.
opt_level : int
The optimization level of this module pass.
name : Optional[str]
The name of the module pass. The name could be empty. In this case, the
name of the optimization function will be used as the pass name.
required : Optional[List[str]]
The list of passes that the module pass is dependent on.
Returns
-------
create_module_pass : Union[Callable, ModulePass]
The callable that will create a module pass is returned when
pass_func is not passed in. Otherwise, a ModulePass object will be
directly created.
Examples
--------
The following code creates a module level pass and adds an abs function to
the module.
.. code-block:: python
@relay.transform.module_pass(opt_level=2)
def transform(mod, ctx):
tp = relay.TensorType((10,), "float32")
x = relay.var("x", tp)
gv = relay.GlobalVar("var")
func = relay.Function([x], relay.abs(x))
new_mod = relay.Module({gv: func})
new_mod.update(mod)
return new_mod
module_pass = transform
assert isinstance(module_pass, transform.ModulePass)
assert module_pass.info.opt_level == 2
# Given a module m, the optimization could be invoked as the follwoing:
updated_mod = module_pass(m)
# Now a function abs should be added to the module m.
"""
if
opt_level
is
None
:
raise
ValueError
(
"Please provide opt_level for the module pass."
)
required
=
required
if
required
else
[]
if
not
isinstance
(
required
,
(
list
,
tuple
)):
raise
TypeError
(
"Required is expected to be the type of "
+
"list/tuple."
)
def
create_module_pass
(
pass_func
):
"""Internal function that creates a module pass"""
if
not
isinstance
(
pass_func
,
(
types
.
FunctionType
,
types
.
LambdaType
)):
raise
TypeError
(
"pass_func must be a callable for Module pass"
)
return
_transform
.
CreateModulePass
(
pass_func
,
opt_level
,
name
if
name
else
pass_func
.
__name__
,
required
)
if
pass_func
:
return
create_module_pass
(
pass_func
)
return
create_module_pass
def
function_pass
(
pass_func
=
None
,
opt_level
=
None
,
name
=
None
,
required
=
None
):
"""Create a function pass. This function returns a callback when pass_func
is provided. Otherwise, it returns the created function pass using the
given optimization function.
Parameters
----------
pass_func : Optional[Callable[(Module/Function, PassContext) ->
Module/Function]]
The implemented optimization pass.
opt_level : int
The optimization level of this module pass.
name : Optional[str]
The name of the function pass. The name could be empty. In this case, the
name of the optimization function will be used as the pass name.
required : Optional[List[str]]
The list of passes that the module pass is dependent on.
Returns
-------
create_function_pass : Union[Callable, FunctionPass]
The callable that will create a function pass is returned when
pass_func is not passed in. Otherwise, a FunctionPass object will be
created.
Examples
--------
The following code creates a function level pass that performs constant
folding.
.. code-block:: python
@relay.transform.function_pass(opt_level=2)
def transform(func, ctx):
return ir_pass.fold_constant(func)
function_pass = transform
assert isinstance(function_pass, transform.FunctionPass)
assert function_pass.info.opt_level == 2
# Given a module m, the optimization could be invoked as the follwoing:
updated_mod = function_pass(m)
# Now constant folding should have been applied to every function in
# the provided module m. And the updated module will be returned.
"""
if
opt_level
is
None
:
raise
ValueError
(
"Please provide opt_level for the funtion pass."
)
required
=
required
if
required
else
[]
if
not
isinstance
(
required
,
(
list
,
tuple
)):
raise
TypeError
(
"Required is expected to be the type of "
+
"list/tuple."
)
def
create_function_pass
(
pass_func
):
"""Internal function that creates a function pass"""
if
not
isinstance
(
pass_func
,
(
types
.
FunctionType
,
types
.
LambdaType
)):
raise
TypeError
(
"pass_func must be a callable for Module pass"
)
return
_transform
.
CreateFunctionPass
(
pass_func
,
opt_level
,
name
if
name
else
pass_func
.
__name__
,
required
)
if
pass_func
:
return
create_function_pass
(
pass_func
)
return
create_function_pass
def
InferType
():
def
InferType
():
"""Infer the type of an expr.
"""Infer the type of an expr.
...
@@ -593,3 +442,150 @@ def PartialEvaluate():
...
@@ -593,3 +442,150 @@ def PartialEvaluate():
The registered pass that performs partial evaluation on an expression.
The registered pass that performs partial evaluation on an expression.
"""
"""
return
_transform
.
PartialEvaluate
()
return
_transform
.
PartialEvaluate
()
def
module_pass
(
pass_func
=
None
,
opt_level
=
None
,
name
=
None
,
required
=
None
):
"""Create a module pass. This function returns a callback when pass_func
is provided. Otherwise, it returns the created module level pass using the
given optimization function.
Parameters
----------
pass_func : Optional[Callable[(Module/Function, PassContext) ->
Module/Function]]
The implemented optimization pass.
opt_level : int
The optimization level of this module pass.
name : Optional[str]
The name of the module pass. The name could be empty. In this case, the
name of the optimization function will be used as the pass name.
required : Optional[List[str]]
The list of passes that the module pass is dependent on.
Returns
-------
create_module_pass : Union[Callable, ModulePass]
The callable that will create a module pass is returned when
pass_func is not passed in. Otherwise, a ModulePass object will be
directly created.
Examples
--------
The following code creates a module level pass and adds an abs function to
the module.
.. code-block:: python
@relay.transform.module_pass(opt_level=2)
def transform(mod, ctx):
tp = relay.TensorType((10,), "float32")
x = relay.var("x", tp)
gv = relay.GlobalVar("var")
func = relay.Function([x], relay.abs(x))
new_mod = relay.Module({gv: func})
new_mod.update(mod)
return new_mod
module_pass = transform
assert isinstance(module_pass, transform.ModulePass)
assert module_pass.info.opt_level == 2
# Given a module m, the optimization could be invoked as the follwoing:
updated_mod = module_pass(m)
# Now a function abs should be added to the module m.
"""
if
opt_level
is
None
:
raise
ValueError
(
"Please provide opt_level for the module pass."
)
required
=
required
if
required
else
[]
if
not
isinstance
(
required
,
(
list
,
tuple
)):
raise
TypeError
(
"Required is expected to be the type of "
+
"list/tuple."
)
def
create_module_pass
(
pass_func
):
"""Internal function that creates a module pass"""
if
not
isinstance
(
pass_func
,
(
types
.
FunctionType
,
types
.
LambdaType
)):
raise
TypeError
(
"pass_func must be a callable for Module pass"
)
fname
=
name
if
name
else
pass_func
.
__name__
info
=
PassInfo
(
opt_level
,
fname
,
required
)
return
_transform
.
MakeModulePass
(
pass_func
,
info
)
if
pass_func
:
return
create_module_pass
(
pass_func
)
return
create_module_pass
def
function_pass
(
pass_func
=
None
,
opt_level
=
None
,
name
=
None
,
required
=
None
):
"""Create a function pass. This function returns a callback when pass_func
is provided. Otherwise, it returns the created function pass using the
given optimization function.
Parameters
----------
pass_func : Optional[Callable[(Module/Function, PassContext) ->
Module/Function]]
The implemented optimization pass.
opt_level : int
The optimization level of this module pass.
name : Optional[str]
The name of the function pass. The name could be empty. In this case, the
name of the optimization function will be used as the pass name.
required : Optional[List[str]]
The list of passes that the module pass is dependent on.
Returns
-------
create_function_pass : Union[Callable, FunctionPass]
The callable that will create a function pass is returned when
pass_func is not passed in. Otherwise, a FunctionPass object will be
created.
Examples
--------
The following code creates a function level pass that performs constant
folding.
.. code-block:: python
@relay.transform.function_pass(opt_level=2)
def transform(func, ctx):
return ir_pass.fold_constant(func)
function_pass = transform
assert isinstance(function_pass, transform.FunctionPass)
assert function_pass.info.opt_level == 2
# Given a module m, the optimization could be invoked as the follwoing:
updated_mod = function_pass(m)
# Now constant folding should have been applied to every function in
# the provided module m. And the updated module will be returned.
"""
if
opt_level
is
None
:
raise
ValueError
(
"Please provide opt_level for the funtion pass."
)
required
=
required
if
required
else
[]
if
not
isinstance
(
required
,
(
list
,
tuple
)):
raise
TypeError
(
"Required is expected to be the type of "
+
"list/tuple."
)
def
create_function_pass
(
pass_func
):
"""Internal function that creates a function pass"""
if
not
isinstance
(
pass_func
,
(
types
.
FunctionType
,
types
.
LambdaType
)):
raise
TypeError
(
"pass_func must be a callable for Module pass"
)
fname
=
name
if
name
else
pass_func
.
__name__
info
=
PassInfo
(
opt_level
,
fname
,
required
)
return
_transform
.
MakeFunctionPass
(
pass_func
,
info
)
if
pass_func
:
return
create_function_pass
(
pass_func
)
return
create_function_pass
src/relay/pass/pass_manager.cc
View file @
c9a2f3da
...
@@ -465,8 +465,8 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
...
@@ -465,8 +465,8 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
TVM_REGISTER_NODE_TYPE
(
ModulePassNode
);
TVM_REGISTER_NODE_TYPE
(
ModulePassNode
);
TVM_REGISTER_API
(
"relay._transform.
Creat
eModulePass"
)
TVM_REGISTER_API
(
"relay._transform.
Mak
eModulePass"
)
.
set_body_typed
(
CreateModulePass
);
.
set_body_typed
(
ModulePassNode
::
make
);
TVM_REGISTER_API
(
"relay._transform.RunPass"
)
TVM_REGISTER_API
(
"relay._transform.RunPass"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
...
@@ -485,8 +485,8 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
...
@@ -485,8 +485,8 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
TVM_REGISTER_NODE_TYPE
(
FunctionPassNode
);
TVM_REGISTER_NODE_TYPE
(
FunctionPassNode
);
TVM_REGISTER_API
(
"relay._transform.
Creat
eFunctionPass"
)
TVM_REGISTER_API
(
"relay._transform.
Mak
eFunctionPass"
)
.
set_body_typed
(
CreateFunctionPass
);
.
set_body_typed
(
FunctionPassNode
::
make
);
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
.
set_dispatch
<
FunctionPassNode
>
([](
const
FunctionPassNode
*
node
,
.
set_dispatch
<
FunctionPassNode
>
([](
const
FunctionPassNode
*
node
,
...
...
tests/python/relay/test_pass_manager.py
View file @
c9a2f3da
...
@@ -259,6 +259,12 @@ def test_function_pass():
...
@@ -259,6 +259,12 @@ def test_function_pass():
test_pass_run
()
test_pass_run
()
def
test_pass_info
():
info
=
relay
.
transform
.
PassInfo
(
opt_level
=
1
,
name
=
"xyz"
)
assert
info
.
opt_level
==
1
assert
info
.
name
==
"xyz"
def
test_sequential_pass
():
def
test_sequential_pass
():
shape
=
(
10
,
)
shape
=
(
10
,
)
dtype
=
'float32'
dtype
=
'float32'
...
@@ -449,3 +455,4 @@ if __name__ == "__main__":
...
@@ -449,3 +455,4 @@ if __name__ == "__main__":
test_function_pass
()
test_function_pass
()
test_sequential_pass
()
test_sequential_pass
()
test_sequential_with_scoping
()
test_sequential_with_scoping
()
test_pass_info
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment