Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
bb48a45b
Commit
bb48a45b
authored
Jun 03, 2019
by
Zhi
Committed by
Tianqi Chen
Jun 03, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RELAY][TRANSFORM] Migrate buildmodule to transform (#3251)
parent
0faf7310
Hide whitespace changes
Inline
Side-by-side
Showing
24 changed files
with
879 additions
and
455 deletions
+879
-455
include/tvm/relay/module.h
+13
-13
include/tvm/relay/pass.h
+20
-0
include/tvm/relay/transform.h
+85
-5
python/tvm/relay/build_module.py
+3
-91
python/tvm/relay/transform.py
+199
-0
src/relay/backend/build_module.cc
+119
-251
src/relay/pass/alter_op_layout.cc
+23
-4
src/relay/pass/canonicalize_ops.cc
+17
-0
src/relay/pass/combine_parallel_conv2d.cc
+17
-0
src/relay/pass/dead_code.cc
+4
-1
src/relay/pass/device_annotation.cc
+6
-2
src/relay/pass/eliminate_common_subexpr.cc
+17
-0
src/relay/pass/fold_constant.cc
+6
-2
src/relay/pass/fold_scale_axis.cc
+40
-2
src/relay/pass/forward_rewrite.cc
+2
-2
src/relay/pass/fuse_ops.cc
+6
-1
src/relay/pass/partial_eval.cc
+5
-4
src/relay/pass/pass_manager.cc
+94
-72
src/relay/pass/simplify_inference.cc
+17
-0
src/relay/pass/to_a_normal_form.cc
+4
-1
src/relay/pass/to_graph_normal_form.cc
+4
-1
src/relay/pass/type_infer.cc
+19
-0
tests/cpp/relay_transform_sequential.cc
+111
-0
tests/python/relay/test_pass_manager.py
+48
-3
No files found.
include/tvm/relay/module.h
View file @
bb48a45b
...
...
@@ -87,14 +87,14 @@ class ModuleNode : public RelayNode {
* \param update Controls whether you can replace a definition in the
* environment.
*/
void
Add
(
const
GlobalVar
&
var
,
const
Function
&
func
,
bool
update
=
false
);
TVM_DLL
void
Add
(
const
GlobalVar
&
var
,
const
Function
&
func
,
bool
update
=
false
);
/*!
* \brief Add a type-level definition to the global environment.
* \param var The var of the global type definition.
* \param type The type definition.
*/
void
AddDef
(
const
GlobalTypeVar
&
var
,
const
TypeData
&
type
);
TVM_DLL
void
AddDef
(
const
GlobalTypeVar
&
var
,
const
TypeData
&
type
);
/*!
* \brief Add a function to the global environment.
...
...
@@ -103,69 +103,69 @@ class ModuleNode : public RelayNode {
*
* It does not do type inference as Add does.
*/
void
AddUnchecked
(
const
GlobalVar
&
var
,
const
Function
&
func
);
TVM_DLL
void
AddUnchecked
(
const
GlobalVar
&
var
,
const
Function
&
func
);
/*!
* \brief Update a function in the global environment.
* \param var The name of the global function to update.
* \param func The new function.
*/
void
Update
(
const
GlobalVar
&
var
,
const
Function
&
func
);
TVM_DLL
void
Update
(
const
GlobalVar
&
var
,
const
Function
&
func
);
/*!
* \brief Remove a function from the global environment.
* \param var The name of the global function to update.
*/
void
Remove
(
const
GlobalVar
&
var
);
TVM_DLL
void
Remove
(
const
GlobalVar
&
var
);
/*!
* \brief Lookup a global function by its variable.
* \param str The unique string specifying the global variable.
* \returns The global variable.
*/
GlobalVar
GetGlobalVar
(
const
std
::
string
&
str
);
TVM_DLL
GlobalVar
GetGlobalVar
(
const
std
::
string
&
str
);
/*!
* \brief Look up a global function by its name.
* \param str The unique string specifying the global variable.
* \returns The global variable.
*/
GlobalTypeVar
GetGlobalTypeVar
(
const
std
::
string
&
str
);
TVM_DLL
GlobalTypeVar
GetGlobalTypeVar
(
const
std
::
string
&
str
);
/*!
* \brief Lookup a global function by its variable.
* \param var The global var to lookup.
* \returns The function named by the variable argument.
*/
Function
Lookup
(
const
GlobalVar
&
var
);
TVM_DLL
Function
Lookup
(
const
GlobalVar
&
var
);
/*!
* \brief Lookup a global function by its string name
* \param name The name of the function.
* \returns The function named by the argument.
*/
Function
Lookup
(
const
std
::
string
&
name
);
TVM_DLL
Function
Lookup
(
const
std
::
string
&
name
);
/*!
* \brief Lookup a global type definition by its variable.
* \param var The var of the global type definition.
* \return The type definition.
*/
TypeData
LookupDef
(
const
GlobalTypeVar
&
var
);
T
VM_DLL
T
ypeData
LookupDef
(
const
GlobalTypeVar
&
var
);
/*!
* \brief Lookup a global type definition by its name.
* \param var The name of the global type definition.
* \return The type definition.
*/
TypeData
LookupDef
(
const
std
::
string
&
var
);
T
VM_DLL
T
ypeData
LookupDef
(
const
std
::
string
&
var
);
/*!
* \brief Update the functions inside this environment by
* functions in another environment.
* \param other The other environment.
*/
void
Update
(
const
Module
&
other
);
TVM_DLL
void
Update
(
const
Module
&
other
);
/*! \brief Construct a module from a standalone expression.
*
...
...
@@ -177,7 +177,7 @@ class ModuleNode : public RelayNode {
*
* \returns A module with expr set as the entry point.
*/
static
Module
FromExpr
(
TVM_DLL
static
Module
FromExpr
(
const
Expr
&
expr
,
const
tvm
::
Map
<
GlobalVar
,
Function
>&
global_funcs
=
{});
...
...
include/tvm/relay/pass.h
View file @
bb48a45b
...
...
@@ -359,6 +359,15 @@ TVM_DLL Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device);
TVM_DLL
Map
<
Expr
,
Integer
>
CollectDeviceInfo
(
const
Expr
&
expr
);
/*!
* \brief Collect the device anntation operators.
*
* \param expr The expression.
*
* \return The annotated expression to device type mapping for annotation ops.
*/
TVM_DLL
Map
<
Expr
,
Integer
>
CollectDeviceAnnotationOps
(
const
Expr
&
expr
);
/*!
* \brief turn a dataflow graph into Administrative Normal Form, or A-Normal Form (ANF).
*
* It will turn an expression that is in a graph form (with sharing implicit),
...
...
@@ -403,6 +412,17 @@ TVM_DLL Expr ToGraphNormalForm(const Expr& e);
*/
TVM_DLL
Expr
PartialEval
(
const
Expr
&
e
);
/*!
* \brief Bind the free variables to a Relay expression.
*
* \param expr The expression.
* \param bind_map The variable to expression map that will be used to help the
* binding.
*
* \return The updated expression.
*/
TVM_DLL
Expr
Bind
(
const
Expr
&
expr
,
const
tvm
::
Map
<
Var
,
Expr
>&
bind_map
);
/*! \brief A hashing structure in the style of std::hash. */
struct
StructuralHash
{
/*! \brief Hash a Relay type.
...
...
include/tvm/relay/transform.h
View file @
bb48a45b
...
...
@@ -58,9 +58,11 @@
#include <tvm/base.h>
#include <tvm/packed_func_ext.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/error.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/module.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <string>
#include <unordered_map>
...
...
@@ -292,9 +294,9 @@ class Sequential : public Pass {
* \param passes The passes to apply.
* \param pass_info The pass metadata.
*/
TVM_DLL
Sequential
(
tvm
::
Array
<
Pass
>
passes
,
PassInfo
pass_info
);
/*!
TVM_DLL
Sequential
(
tvm
::
Array
<
Pass
>
passes
,
PassInfo
pass_info
);
/*!
* \brief The constructor of `Sequential`.
*
* \param passes The passes to apply.
...
...
@@ -311,7 +313,6 @@ class Sequential : public Pass {
using
ContainerType
=
Sequential
;
};
/*
* \brief Create a module pass.
*
...
...
@@ -339,7 +340,7 @@ Pass CreateModulePass(
* \return The created function pass.
*/
TVM_DLL
Pass
CreateFunctionPass
(
const
runtime
::
TypedPackedFunc
<
Function
(
Function
,
Module
,
PassContext
)
>&
pass_func
,
Function
(
Function
,
Module
,
PassContext
)
>&
pass_func
,
int
opt_level
,
const
std
::
string
&
name
,
const
tvm
::
Array
<
tvm
::
Expr
>&
required
);
...
...
@@ -451,6 +452,85 @@ TVM_DLL Pass ToGraphNormalForm();
*/
TVM_DLL
Pass
PartialEval
();
/*!
* \brief Simplify certain operators during inference. For example, batch norm
* will be unpacked into a number of simplified operators.
*
* \return The Pass.
*/
TVM_DLL
Pass
SimplifyInference
();
/*!
* \brief Infer the type of an expression.
*
* The result of type checking is a new expression with unambigous
* type information filled in, as well as it's checked type field
* populated with the result type.
*
* \return The pass.
*/
TVM_DLL
Pass
InferType
();
/*!
* \brief Search and eliminate common subexpression. For example, if there are
* two expressions evaluated to an identical value, a single variable is created
* and these two expressions are replaced by this variable.
*
* \param fskip The callback argument that allows to skip certain expressions.
*
* \return The pass.
*/
TVM_DLL
Pass
EliminateCommonSubexpr
(
PackedFunc
fskip
=
nullptr
);
/*!
* \brief Combine parallel 2d convolutions into a single convolution if the
* number of branches of this conv2d operator is not less than
* `min_num_branch`.
*
* \param min_num_branches The minimun number of branches.
*
* \return The pass.
*/
TVM_DLL
Pass
CombineParallelConv2D
(
uint64_t
min_num_branches
=
3
);
/*!
* \brief Backward fold axis scaling into weights of conv/dense operators.
*
* \return The pass.
*/
TVM_DLL
Pass
BackwardFoldScaleAxis
();
/*!
* \brief Forward fold axis scaling into weights of conv/dense operators.
*
* \return The pass.
*/
TVM_DLL
Pass
ForwardFoldScaleAxis
();
/*!
* \brief A sequential pass that executes ForwardFoldScaleAxis and
* BackwardFoldScaleAxis passes.
*
* \return The pass.
*/
TVM_DLL
Pass
FoldScaleAxis
();
/*!
* \brief Canonicalize some operators to the simplified operators. For example,
* bias_add can be canonicalized to expand_dims and broadcast_add.
*
* \return The pass.
*/
TVM_DLL
Pass
CanonicalizeOps
();
/*!
* \brief Alternate the layouts of operators or replace primitive operators
* with other expressions.
*
* \return The pass.
*/
TVM_DLL
Pass
AlterOpLayout
();
}
// namespace transform
}
// namespace relay
}
// namespace tvm
...
...
python/tvm/relay/build_module.py
View file @
bb48a45b
...
...
@@ -20,7 +20,6 @@ from a Relay expression.
"""
import
numpy
as
np
from
tvm._ffi.runtime_ctypes
import
TVMContext
from
tvm
import
expr
as
tvm_expr
from
..
import
nd
as
_nd
,
target
as
_target
,
autotvm
from
..contrib
import
graph_runtime
as
_graph_rt
...
...
@@ -28,7 +27,6 @@ from . import _build_module
from
.
import
ir_pass
from
.
import
ty
as
_ty
from
.
import
expr
as
_expr
from
.
import
transform
as
_transform
from
.backend
import
interpreter
as
_interpreter
from
.backend.vm
import
VMExecutor
...
...
@@ -61,10 +59,6 @@ class BuildModule(object):
self
.
_get_graph_json
=
self
.
mod
[
"get_graph_json"
]
self
.
_get_module
=
self
.
mod
[
"get_module"
]
self
.
_build
=
self
.
mod
[
"build"
]
self
.
_add_pass
=
self
.
mod
[
"add_pass"
]
self
.
_disable_pass
=
self
.
mod
[
"disable_pass"
]
self
.
_set_opt_level
=
self
.
mod
[
"set_opt_level"
]
self
.
_set_fallback_device
=
self
.
mod
[
"set_fallback_device"
]
self
.
_set_params_func
=
self
.
mod
[
"set_params"
]
self
.
_get_params_func
=
self
.
mod
[
"get_params"
]
...
...
@@ -106,8 +100,9 @@ class BuildModule(object):
"""
target
=
_update_target
(
target
)
# Setup the build configurations passed in through `with build_config`.
self
.
_setup_build_config
(
params
)
# Setup the params.
if
params
:
self
.
_set_params
(
params
)
# Build the function
self
.
_build
(
func
,
target
,
target_host
)
# Get artifacts
...
...
@@ -117,41 +112,6 @@ class BuildModule(object):
return
graph_json
,
mod
,
params
def
_setup_build_config
(
self
,
params
):
cfg
=
_transform
.
PassContext
.
current
()
# Set opt_level.
self
.
set_opt_level
(
cfg
.
opt_level
)
# Set fallback device if it is available.
if
cfg
.
fallback_device
:
self
.
set_fallback_device
(
cfg
.
fallback_device
)
# Add required passes.
if
cfg
.
required_pass
:
passes
=
set
()
if
isinstance
(
cfg
.
required_pass
,
(
list
,
tuple
,
set
)):
passes
=
set
(
cfg
.
required_pass
)
else
:
raise
TypeError
(
"add_pass must be list, tuple, or set, but "
+
"got {}"
.
format
(
type
(
cfg
.
required_pass
)))
for
pass_name
in
passes
:
self
.
add_pass
(
pass_name
)
# Add disabled passes.
if
cfg
.
disabled_pass
:
passes
=
set
()
if
isinstance
(
cfg
.
disabled_pass
,
(
list
,
tuple
,
set
)):
passes
=
set
(
cfg
.
disabled_pass
)
else
:
raise
TypeError
(
"disable_pass must be list, tuple, or set, "
+
"but got {}"
.
format
(
type
(
cfg
.
disabled_pass
)))
for
pass_name
in
passes
:
self
.
disable_pass
(
pass_name
)
if
params
:
self
.
_set_params
(
params
)
def
_set_params
(
self
,
params
):
inputs
=
{}
for
name
,
param
in
params
.
items
():
...
...
@@ -160,28 +120,6 @@ class BuildModule(object):
inputs
[
name
]
=
_expr
.
const
(
param
)
self
.
_set_params_func
(
inputs
)
def
add_pass
(
self
,
pass_name
):
"""Add a pass to the pass list.
Parameters
----------
pass_name : str
The name of the pass that will be added to the list of passes used
for optimizations.
"""
self
.
_add_pass
(
pass_name
)
def
disable_pass
(
self
,
pass_name
):
"""Add a pass to the disabled pass list.
Parameters
----------
pass_name : str
The name of a pass. This pass will be added to the list of passes
that are disabled during optimization.
"""
self
.
_disable_pass
(
pass_name
)
def
get_json
(
self
):
"""Return the json file of the built program."""
return
self
.
_get_graph_json
()
...
...
@@ -198,32 +136,6 @@ class BuildModule(object):
ret
[
key
]
=
value
.
data
return
ret
def
set_opt_level
(
self
,
level
):
"""Set the optimization level.
Parameters
----------
level : int
The optimization level for build.
"""
self
.
_set_opt_level
(
level
)
def
set_fallback_device
(
self
,
fallback_device
):
"""Set the fallback device for heterogeneous execution.
Parameters
----------
fallback_device : str or tvm.TVMContext
The fallback device used for heterogeneous execution.
"""
if
isinstance
(
fallback_device
,
(
int
,
str
)):
fallback_device
=
_nd
.
context
(
fallback_device
)
if
not
isinstance
(
fallback_device
,
TVMContext
):
raise
TypeError
(
"fallback_device is expected to be str, int, or "
+
"TVMContext but received: {}"
.
format
(
type
(
fallback_device
)))
self
.
_set_fallback_device
(
fallback_device
.
device_type
)
def
build
(
func
,
target
=
None
,
target_host
=
None
,
params
=
None
):
"""Helper function that builds a Relay function to run on TVM graph
...
...
python/tvm/relay/transform.py
View file @
bb48a45b
...
...
@@ -16,6 +16,7 @@
# under the License.
# pylint: disable=no-else-return
# pylint: disable=unidiomatic-typecheck
# pylint: disable=invalid-name
"""
This file contains the pass manager for Relay which exposes different
granularity of interfaces for users to implement and use passes more
...
...
@@ -394,3 +395,201 @@ def function_pass(pass_func=None, opt_level=None, name=None, required=None):
if
pass_func
:
return
create_function_pass
(
pass_func
)
return
create_function_pass
def
InferType
():
"""Infer the type of an expr.
Returns
-------
ret : tvm.relay.Pass
The registered type inference pass.
"""
return
_transform
.
InferType
()
def
FoldScaleAxis
():
"""Fold the scaling of axis into weights of conv2d/dense. This pass will
invoke both forward and backward scale folding.
Returns
-------
ret : tvm.relay.Pass
The registered pass to fold expressions.
Note
----
Internally, we will call backward_fold_scale_axis before using
forward_fold_scale_axis. As backward folding targets common conv-bn
pattern.
"""
return
_transform
.
FoldScaleAxis
()
def
SimplifyInference
():
"""Simplify the data-flow graph for inference phase. An simplified expression
which is semantically equal to the input expression will be returned.
Returns
-------
ret: tvm.relay.Pass
The registered to perform operator simplification.
"""
return
_transform
.
SimplifyInference
()
def
CanonicalizeOps
():
""" Canonicalize special operators to basic operators.
This can simplify followed analysis. (e.g. expanding bias_add to
expand_dims and broadcast_add.)
Returns
-------
ret: tvm.relay.Pass
The registered pass performing the canonicalization.
"""
return
_transform
.
CanonicalizeOps
()
def
DeadCodeElimination
():
""" Remove expressions which does not effect the program result (dead code).
Returns
-------
ret: tvm.relay.Pass
The registered pass that eliminates the dead code in a Relay program.
"""
return
_transform
.
DeadCodeElimination
()
def
FoldConstant
():
"""Fold the constant expression in expr.
Returns
-------
ret : tvm.relay.Pass
The registered pass for constant folding.
"""
return
_transform
.
FoldConstant
()
def
FuseOps
(
fuse_opt_level
=-
1
):
"""Fuse operators in an expr to a larger operator according to some rules.
Parameters
----------
fuse_opt_level : int
The level of fuse optimization. -1 indicates that the level will be
inferred from pass context.
Returns
-------
ret : tvm.relay.Pass
The registered pass for operator fusion.
"""
return
_transform
.
FuseOps
(
fuse_opt_level
)
def
CombineParallelConv2D
(
min_num_branches
=
3
):
"""Combine multiple conv2d operators into one.
Parameters
----------
min_num_branches : int
The minimum number of required parallel branches for performing this
optimization.
Returns
-------
ret: tvm.relay.Pass
The registered pass that combines parallel conv2d operators.
"""
return
_transform
.
CombineParallelConv2D
(
min_num_branches
)
def
AlterOpLayout
():
"""Alternate the layouts of operators or replace primitive operators with
other expressions.
This pass can be used for computing convolution in custom layouts or
other general weight pre-transformation.
Returns
-------
ret : tvm.relay.Pass
The registered pass that alters the layout of operators.
"""
return
_transform
.
AlterOpLayout
()
def
RewriteAnnotatedOps
(
fallback_device
):
"""Rewrite the annotated program where annotation operators, e.g.
`on_deivce`, mark which device an expression should be scheduled to.
This pass helps heterogeneous execution where different operators may need
to be allocated on various devices.
Parameters
----------
fallback_device : int
The fallback device type. It is also used as the default device for
operators with no annotated device.
Returns
-------
ret: tvm.relay.Pass
The registered pass that rewrites an expression with annotated
`on_device` operators.
"""
return
_transform
.
RewriteDeviceAnnotation
(
fallback_device
)
def
ToANormalForm
():
"""Turn Graph Normal Form expression into A Normal Form Expression.
The scope of the root expression is the global scope.
The scope of any non root expression is the least common ancestor of all it's scope.
Values are ordered by post-DFS order in each scope.
Returns
-------
ret: tvm.relay.Pass
The registered pass that transforms an expression into A Normal Form.
"""
return
_transform
.
ToANormalForm
()
def
ToGraphNormalForm
():
"""Turn A Normal Form expression into Graph Normal Form expression
Returns
-------
ret : tvm.relay.Pass
The registered pass that transforms an expression into Graph Normal Form.
"""
return
_transform
.
ToGraphNormalForm
()
def
EliminateCommonSubexpr
(
fskip
=
None
):
"""Eliminate common subexpressions.
Parameters
----------
fskip: Callable
The callback function that decides whether an expression should be
skipped.
Returns
-------
ret : tvm.relay.Pass
The registered pass that eliminates common subexpressions.
"""
return
_transform
.
EliminateCommonSubexpr
(
fskip
)
def
PartialEvaluate
():
"""Evaluate the static fragment of the code.
Returns
-------
ret : tvm.relay.Pass
The registered pass that performs partial evaluation on an expression.
"""
return
_transform
.
PartialEvaluate
()
src/relay/backend/build_module.cc
View file @
bb48a45b
...
...
@@ -23,12 +23,8 @@
*/
#include <tvm/build_module.h>
#include <tvm/runtime/device_api.h>
#include <tvm/relay/op.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/transform.h>
#include <vector>
#include <string>
#include <tvm/relay/transform.h>
#include <memory>
#include "utils.h"
...
...
@@ -38,39 +34,7 @@ namespace relay {
namespace
backend
{
using
TargetsMap
=
Map
<
tvm
::
Integer
,
tvm
::
Target
>
;
/*!
* \brief A data structure to map the names of specific optimizations to
* numeric optimization levels
*
*/
struct
OptPassLevel
{
static
const
std
::
unordered_map
<
std
::
string
,
int
>
_data
;
/*!
* \brief Get level for an optimization pass
*
* \param key pass name
* \return int level
*/
int
operator
[](
const
std
::
string
&
key
)
const
{
auto
it
=
_data
.
find
(
key
);
if
(
it
==
_data
.
end
())
{
return
-
1
;
}
return
it
->
second
;
}
};
const
std
::
unordered_map
<
std
::
string
,
int
>
OptPassLevel
::
_data
=
{
{
"SimplifyInference"
,
0
},
{
"OpFusion"
,
1
},
{
"FoldConstant"
,
2
},
{
"CombineParallelConv2D"
,
4
},
{
"FoldScaleAxis"
,
3
},
{
"AlterOpLayout"
,
3
},
{
"CanonicalizeOps"
,
3
},
{
"EliminateCommonSubexpr"
,
3
}
};
using
namespace
tvm
::
relay
::
transform
;
/*!
* \brief Output of building module
...
...
@@ -83,27 +47,6 @@ struct BuildOutput {
};
/*!
* \brief Relay building config
*
*/
struct
RelayBuildConfig
{
int
opt_level
{
2
};
int
fallback_device
{
static_cast
<
int
>
(
kDLCPU
)};
std
::
unordered_set
<
std
::
string
>
enabled_pass
;
std
::
unordered_set
<
std
::
string
>
disabled_pass
;
OptPassLevel
OPT_PASS_LEVEL
;
inline
bool
pass_enabled
(
const
std
::
string
&
pass_name
)
const
{
if
(
disabled_pass
.
count
(
pass_name
))
{
return
false
;
}
if
(
enabled_pass
.
count
(
pass_name
))
{
return
true
;
}
return
opt_level
>=
OPT_PASS_LEVEL
[
pass_name
];
}
};
/*!
* \brief GraphCodegen module wrapper
*
*/
...
...
@@ -156,18 +99,6 @@ struct GraphCodegen {
}
};
template
<
typename
R
,
typename
...
Args
>
R
CallPackedFunc
(
const
std
::
string
&
name
,
Args
...
args
)
{
auto
pf
=
GetPackedFunc
(
name
);
return
(
*
pf
)(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
Function
CallPackedFunc
(
const
std
::
string
&
name
,
Args
...
args
)
{
auto
pf
=
GetPackedFunc
(
name
);
return
(
*
pf
)(
std
::
forward
<
Args
>
(
args
)...);
}
/*!
* \brief Relay build module
*
...
...
@@ -203,28 +134,6 @@ class RelayBuildModule : public runtime::ModuleNode {
return
PackedFunc
([
sptr_to_self
,
this
](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
*
rv
=
this
->
GetParams
();
});
}
else
if
(
name
==
"set_opt_level"
)
{
return
PackedFunc
([
sptr_to_self
,
this
](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
CHECK_EQ
(
args
.
num_args
,
1
);
int
level
=
args
[
0
];
this
->
SetOptLevel
(
level
);
});
}
else
if
(
name
==
"set_fallback_device"
)
{
return
PackedFunc
([
sptr_to_self
,
this
](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
CHECK_EQ
(
args
.
num_args
,
1
);
int
dev
=
args
[
0
];
this
->
SetFallBackDev
(
dev
);
});
}
else
if
(
name
==
"add_pass"
)
{
return
PackedFunc
([
sptr_to_self
,
this
](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
std
::
string
pass_name
=
args
[
0
];
this
->
AddPass
(
pass_name
);
});
}
else
if
(
name
==
"disable_pass"
)
{
return
PackedFunc
([
sptr_to_self
,
this
](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
std
::
string
pass_name
=
args
[
0
];
this
->
DisablePass
(
pass_name
);
});
}
else
if
(
name
==
"set_params"
)
{
return
PackedFunc
([
sptr_to_self
,
this
](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
Map
<
std
::
string
,
Constant
>
params
=
args
[
0
];
...
...
@@ -246,30 +155,7 @@ class RelayBuildModule : public runtime::ModuleNode {
const
std
::
string
&
GetGraphJSON
()
{
return
ret_
.
graph_json
;
}
/*!
* \brief Add extra pass into build cfg
*
* \param pass_name name of pass
*/
void
AddPass
(
const
std
::
string
&
pass_name
)
{
cfg_
.
enabled_pass
.
insert
(
pass_name
);
}
/*!
* \brief Disable a specific pass in cfg
*
* \param pass_name name of pass
*/
void
DisablePass
(
const
std
::
string
&
pass_name
)
{
cfg_
.
disabled_pass
.
insert
(
pass_name
);
}
/*!
* \brief Set the Fallback device
*
* \param device name
*/
void
SetFallBackDev
(
int
dev
)
{
cfg_
.
fallback_device
=
dev
;
}
/*!
* \brief Get the Module object
*
...
...
@@ -316,15 +202,6 @@ class RelayBuildModule : public runtime::ModuleNode {
}
/*!
* \brief Set the optimization level
*
* \param level
*/
void
SetOptLevel
(
char
level
)
{
cfg_
.
opt_level
=
level
;
}
/*!
* \brief type key
*
* \return const char*
...
...
@@ -345,7 +222,7 @@ class RelayBuildModule : public runtime::ModuleNode {
const
tvm
::
Target
&
target_host
)
{
targets_
=
targets
;
target_host_
=
target_host
;
BuildRelay
(
func
,
cfg_
,
params_
);
BuildRelay
(
func
,
params_
);
}
protected
:
...
...
@@ -378,85 +255,81 @@ class RelayBuildModule : public runtime::ModuleNode {
if
(
repeat_var
.
count
(
arg
))
{
LOG
(
FATAL
)
<<
"Multiple args in the function have name "
<<
kv
.
first
;
}
auto
e
=
CallPackedFunc
<
Expr
>
(
"relay._make.Constant"
,
kv
.
second
);
bind_dict
[
arg
]
=
e
;
bind_dict
[
arg
]
=
ConstantNode
::
make
(
kv
.
second
);
}
return
CallPackedFunc
(
"relay._expr.Bind"
,
func
,
tvm
::
Map
<
relay
::
Var
,
Expr
>
(
bind_dict
));
Expr
bound_expr
=
relay
::
Bind
(
func
,
bind_dict
);
Function
ret
=
Downcast
<
Function
>
(
bound_expr
);
CHECK
(
ret
.
defined
())
<<
"The returning type is expected to be a Relay Function."
<<
"
\n
"
;
return
ret
;
}
/*!
* \brief Optimize
Relay function
* \brief Optimize
a Relay module.
*
* \param func Input function
* \param target target device
* \param cfg Relay build config
* \param params params dict
* \return relay::Function
* \param relay_module The input Relay module where optmization will be
* applied on.
* \param targets The device type to `Target` mapping.
* \param params The param name to value mapping.
*
* \return relay::Module The updated Relay module after optimization.
*/
relay
::
Function
Optimize
(
relay
::
Function
func
,
const
TargetsMap
&
targets
,
const
RelayBuildConfig
&
cfg
,
const
std
::
unordered_map
<
std
::
string
,
runtime
::
NDArray
>&
params
)
{
if
(
params
.
size
())
{
func
=
BindParamsByName
(
func
,
params
);
}
if
(
cfg
.
pass_enabled
(
"SimplifyInference"
))
{
func
=
CallPackedFunc
(
"relay._ir_pass.infer_type"
,
func
,
nullptr
);
func
=
CallPackedFunc
(
"relay._ir_pass.simplify_inference"
,
func
);
}
if
(
cfg
.
pass_enabled
(
"EliminateCommonSubexpr"
))
{
auto
fskip
=
PackedFunc
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
Expr
expr
=
args
[
0
];
if
(
expr
.
as
<
CallNode
>
())
{
auto
call_node
=
expr
.
as
<
CallNode
>
();
auto
op_node
=
call_node
->
op
.
as
<
OpNode
>
();
if
(
op_node
->
name
==
"cast"
)
{
auto
attrs
=
call_node
->
attrs
.
as
<
CastAttrs
>
();
if
(
attrs
->
dtype
==
HalideIR
::
Int
(
32
))
{
*
rv
=
true
;
}
relay
::
Module
Optimize
(
relay
::
Module
relay_module
,
const
TargetsMap
&
targets
,
const
std
::
unordered_map
<
std
::
string
,
runtime
::
NDArray
>&
params
)
{
Array
<
Pass
>
pass_seqs
;
pass_seqs
.
push_back
(
transform
::
SimplifyInference
());
PackedFunc
fskip
=
PackedFunc
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
Expr
expr
=
args
[
0
];
if
(
expr
.
as
<
CallNode
>
())
{
auto
call_node
=
expr
.
as
<
CallNode
>
();
auto
op_node
=
call_node
->
op
.
as
<
OpNode
>
();
if
(
op_node
->
name
==
"cast"
)
{
auto
attrs
=
call_node
->
attrs
.
as
<
CastAttrs
>
();
if
(
attrs
->
dtype
==
HalideIR
::
Int
(
32
))
{
*
rv
=
true
;
}
}
*
rv
=
false
;
});
func
=
CallPackedFunc
(
"relay._ir_pass.infer_type"
,
func
,
nullptr
);
func
=
CallPackedFunc
(
"relay._ir_pass.eliminate_common_subexpr"
,
func
,
fskip
);
}
if
(
cfg
.
pass_enabled
(
"CombineParallelConv2D"
))
{
const
int
min_num_branches
=
3
;
func
=
CallPackedFunc
(
"relay._ir_pass.infer_type"
,
func
,
nullptr
);
func
=
CallPackedFunc
(
"relay._ir_pass.CombineParallelConv2D"
,
func
,
min_num_branches
);
}
if
(
cfg
.
pass_enabled
(
"FoldConstant"
))
{
func
=
CallPackedFunc
(
"relay._ir_pass.FoldConstant"
,
func
);
}
if
(
cfg
.
pass_enabled
(
"FoldScaleAxis"
))
{
func
=
CallPackedFunc
(
"relay._ir_pass.infer_type"
,
func
,
nullptr
);
func
=
CallPackedFunc
(
"relay._ir_pass.backward_fold_scale_axis"
,
func
);
func
=
CallPackedFunc
(
"relay._ir_pass.infer_type"
,
func
,
nullptr
);
func
=
CallPackedFunc
(
"relay._ir_pass.forward_fold_scale_axis"
,
func
);
func
=
CallPackedFunc
(
"relay._ir_pass.FoldConstant"
,
func
);
}
if
(
cfg
.
pass_enabled
(
"CanonicalizeOps"
))
{
func
=
CallPackedFunc
(
"relay._ir_pass.infer_type"
,
func
,
nullptr
);
func
=
CallPackedFunc
(
"relay._ir_pass.canonicalize_ops"
,
func
);
}
*
rv
=
false
;
});
pass_seqs
.
push_back
(
transform
::
EliminateCommonSubexpr
(
fskip
));
pass_seqs
.
push_back
(
transform
::
CombineParallelConv2D
(
3
));
pass_seqs
.
push_back
(
transform
::
FoldConstant
());
pass_seqs
.
push_back
(
transform
::
FoldScaleAxis
());
pass_seqs
.
push_back
(
transform
::
CanonicalizeOps
());
// Alter layout transformation is only applied to homogeneous execution yet.
if
(
targets
.
size
()
==
1
)
{
pass_seqs
.
push_back
(
transform
::
AlterOpLayout
());
}
if
(
cfg
.
pass_enabled
(
"AlterOpLayout"
))
{
if
(
targets
.
size
()
==
1
)
{
func
=
CallPackedFunc
(
"relay._ir_pass.infer_type"
,
func
,
nullptr
);
for
(
const
auto
&
kv
:
targets
)
{
With
<
Target
>
tctx
(
kv
.
second
);
func
=
CallPackedFunc
(
"relay._ir_pass.AlterOpLayout"
,
func
);
}
}
else
{
LOG
(
WARNING
)
<<
"AlterOpLayout pass is not enabled for heterogeneous"
<<
" execution yet."
;
pass_seqs
.
push_back
(
transform
::
FoldConstant
());
// Create a sequential pass and perform optimizations.
transform
::
Pass
seq
=
transform
::
Sequential
(
pass_seqs
);
if
(
targets
.
size
()
==
1
)
{
for
(
const
auto
&
kv
:
targets
)
{
With
<
Target
>
tctx
(
kv
.
second
);
relay_module
=
seq
(
relay_module
);
}
}
else
{
relay_module
=
seq
(
relay_module
);
}
if
(
cfg
.
pass_enabled
(
"FoldConstant"
))
{
func
=
CallPackedFunc
(
"relay._ir_pass.FoldConstant"
,
func
);
// Handle heterogeneous compilation.
transform
::
PassContext
pass_ctx
=
PassContext
::
Current
();
if
(
targets_
.
size
()
>
1
)
{
relay_module
=
RunDeviceAnnotationPass
(
relay_module
,
pass_ctx
->
fallback_device
);
}
return
func
;
// Fuse the operations if it is needed.
relay_module
=
transform
::
FuseOps
()(
relay_module
);
relay_module
=
transform
::
InferType
()(
relay_module
);
return
relay_module
;
}
/*!
...
...
@@ -470,54 +343,58 @@ class RelayBuildModule : public runtime::ModuleNode {
if
(
name
==
"gpu"
)
return
Target
::
Create
(
"cuda"
);
return
Target
::
Create
(
name
);
}
/*!
* \brief Update the target and fallback device required for heterogeneous
* compilation. CPU is used as the fallback device if it wasn't provided.
* Meanwhile, a CPU device type and "llvm" pair will be added to the target
* dictionary in this case.
*
* \param targets dictionary
* \param cfg
* \return Map<tvm::Integer, tvm::Target>
* \param fallback_device The fallback device for heterogeneous execution.
*/
TargetsMap
UpdateHeterogeneousInputs
(
const
TargetsMap
&
targets
,
const
RelayBuildConfig
&
cfg
)
{
TargetsMap
device_target
=
targets
;
void
UpdateHeterogeneousInputs
(
int
fallback_device
)
{
std
::
unordered_map
<
int64_t
,
tvm
::
Target
>
tmp_map
;
for
(
const
auto
&
kv
:
targets
)
{
for
(
const
auto
&
kv
:
targets
_
)
{
tmp_map
[
kv
.
first
->
value
]
=
kv
.
second
;
}
if
(
tmp_map
.
count
(
cfg
.
fallback_device
)
==
0
)
{
device_target
.
Set
(
cfg
.
fallback_device
,
CreateDefaultTarget
(
cfg
.
fallback_device
));
if
(
tmp_map
.
count
(
fallback_device
)
==
0
)
{
targets_
.
Set
(
fallback_device
,
CreateDefaultTarget
(
fallback_device
));
}
return
device_target
;
}
/*!
* \brief Execute the device annotation passes to update the input program and
* target information.
*
* \param
func
* \param
cfg
*
\param targets_map_ptr
* \return
Function
* \param
relay_module The input Relay module.
* \param
fallback_device The fallback device for heterogeneous execution.
*
* \return
updated_module The updated module after device annotation.
*/
Function
RunDeviceAnnotationPass
(
Function
func
,
const
RelayBuildConfig
&
cfg
,
TargetsMap
*
targets_map_ptr
)
{
func
=
CallPackedFunc
(
"relay._ir_pass.infer_type"
,
func
,
nullptr
);
func
=
CallPackedFunc
(
"relay._ir_pass.RewriteDeviceAnnotation"
,
func
,
cfg
.
fallback_device
);
auto
device_map
=
CallPackedFunc
<
Map
<
Expr
,
Integer
>
>
(
"relay._ir_pass.CollectDeviceInfo"
,
func
,
nullptr
);
if
(
device_map
.
size
()
==
0
)
{
auto
annotation_map
=
CallPackedFunc
<
Map
<
Expr
,
Integer
>
>
(
"relay._ir_pass.CollectDeviceAnnotationOps"
,
func
,
nullptr
);
if
(
annotation_map
.
size
()
==
0
)
{
targets_map_ptr
->
Set
(
0
,
CreateDefaultTarget
(
cfg
.
fallback_device
));
relay
::
Module
RunDeviceAnnotationPass
(
const
relay
::
Module
&
relay_module
,
int
fallback_device
)
{
UpdateHeterogeneousInputs
(
fallback_device
);
auto
rewrite
=
transform
::
RewriteAnnotatedOps
(
fallback_device
);
auto
updated_module
=
rewrite
(
relay_module
);
CHECK
(
updated_module
.
defined
());
tvm
::
Map
<
Expr
,
Integer
>
device_map
;
for
(
const
auto
&
it
:
updated_module
->
functions
)
{
device_map
=
relay
::
CollectDeviceInfo
(
it
.
second
);
if
(
!
device_map
.
empty
())
break
;
}
if
(
device_map
.
empty
())
{
tvm
::
Map
<
Expr
,
Integer
>
annotation_map
;
for
(
const
auto
&
it
:
relay_module
->
functions
)
{
annotation_map
=
relay
::
CollectDeviceAnnotationOps
(
it
.
second
);
if
(
!
annotation_map
.
empty
())
break
;
}
// None op is annotated but they are fallen back to the default device.
if
(
annotation_map
.
empty
())
{
targets_
.
Set
(
0
,
CreateDefaultTarget
(
fallback_device
));
}
else
{
// All ops are annotated to the same device type.
int64_t
dev_type
=
-
1
;
for
(
auto
kv
:
annotation_map
)
{
dev_type
=
kv
.
second
->
value
;
...
...
@@ -531,47 +408,42 @@ class RelayBuildModule : public runtime::ModuleNode {
<<
"found. Please check the "
<<
"RewriteAnnotation pass."
;
}
targets_
map_ptr
->
Set
(
0
,
CreateDefaultTarget
(
dev_type
));
targets_
.
Set
(
0
,
CreateDefaultTarget
(
dev_type
));
}
}
return
func
;
return
updated_module
;
}
/*!
* \brief Build relay function to runtime module
*
* \param func Relay Function
* \param cfg Relay build config
* \param params parameters
*/
void
BuildRelay
(
Function
func
,
const
RelayBuildConfig
&
cfg
,
const
std
::
unordered_map
<
std
::
string
,
tvm
::
runtime
::
NDArray
>
&
params
)
{
// convert
tvm_cfg_
=
BuildConfig
::
Create
();
TargetsMap
device_target
;
if
(
targets_
.
size
()
>
1
)
{
device_target
=
UpdateHeterogeneousInputs
(
targets_
,
cfg
);
}
else
{
device_target
=
targets_
;
}
func
=
Optimize
(
func
,
targets_
,
cfg
,
params
);
if
(
device_target
.
size
()
>
1
)
{
func
=
RunDeviceAnnotationPass
(
func
,
cfg
,
&
device_target
);
void
BuildRelay
(
Function
func
,
const
std
::
unordered_map
<
std
::
string
,
tvm
::
runtime
::
NDArray
>&
params
)
{
if
(
params
.
size
())
{
func
=
BindParamsByName
(
func
,
params
);
}
// TODO(@jroesch): use the passes directly.
func
=
CallPackedFunc
(
"relay._ir_pass.infer_type"
,
func
,
nullptr
);
func
=
CallPackedFunc
(
"relay._ir_pass.FuseOps"
,
func
,
cfg
.
opt_level
,
nullptr
);
func
=
CallPackedFunc
(
"relay._ir_pass.infer_type"
,
func
,
nullptr
);
// Perform Module->Module optimizations.
relay
::
Module
relay_module
=
relay
::
ModuleNode
::
FromExpr
(
func
);
relay_module
=
Optimize
(
relay_module
,
targets_
,
params
);
CHECK
(
relay_module
.
defined
());
// Get the updated function.
func
=
relay_module
->
Lookup
(
relay_module
->
entry_func
->
name_hint
);
// Generate code for the updated function.
graph_codegen_
=
std
::
unique_ptr
<
GraphCodegen
>
(
new
GraphCodegen
());
graph_codegen_
->
Init
(
nullptr
,
device_target
);
graph_codegen_
->
Init
(
nullptr
,
targets_
);
graph_codegen_
->
Codegen
(
func
);
ret_
.
graph_json
=
graph_codegen_
->
GetJSON
();
ret_
.
params
=
graph_codegen_
->
GetParams
();
ret_
.
mod
=
tvm
::
build
(
graph_codegen_
->
GetLoweredFunc
(),
target_host_
,
tvm_cfg_
);
ret_
.
mod
=
tvm
::
build
(
graph_codegen_
->
GetLoweredFunc
(),
target_host_
,
BuildConfig
::
Current
());
}
protected
:
...
...
@@ -580,14 +452,10 @@ class RelayBuildModule : public runtime::ModuleNode {
TargetsMap
targets_
;
/*! \brief target host device */
tvm
::
Target
target_host_
;
/*! \brief frontend optimization configure */
RelayBuildConfig
cfg_
;
/*! \brief parameters */
std
::
unordered_map
<
std
::
string
,
runtime
::
NDArray
>
params_
;
/*! \brief building output */
BuildOutput
ret_
;
/*! \brief tvm building cfg */
BuildConfig
tvm_cfg_
;
};
runtime
::
Module
RelayBuildCreate
()
{
...
...
src/relay/pass/alter_op_layout.cc
View file @
bb48a45b
...
...
@@ -27,6 +27,7 @@
#include <tvm/relay/pass.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/transform.h>
#include <tvm/tvm.h>
#include <tuple>
#include <vector>
...
...
@@ -338,17 +339,35 @@ Expr AlterOpLayoutRewrite(const Call &ref_call,
// Limiations:
// 1. the altered op should have the same number of arguments as the previous one
// 2. do not support nested tuple arguments
TVM_REGISTER_API
(
"relay._ir_pass.AlterOpLayout"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
Expr
AlterOpLayout
(
const
Expr
&
expr
)
{
TransformMemorizer
transformMemorizer
(
make_node
<
TransformMemorizerNode
>
());
auto
fcontext
=
[
&
](
const
Call
&
call
)
->
NodeRef
{
return
transformMemorizer
;
};
*
ret
=
ForwardRewrite
(
args
[
0
],
AlterOpLayoutRewrite
,
fcontext
);
});
return
ForwardRewrite
(
expr
,
AlterOpLayoutRewrite
,
fcontext
);
}
TVM_REGISTER_API
(
"relay._ir_pass.AlterOpLayout"
)
.
set_body_typed
(
AlterOpLayout
);
}
// namespace alter_op_layout
namespace
transform
{
Pass
AlterOpLayout
()
{
runtime
::
TypedPackedFunc
<
Function
(
Function
,
Module
,
PassContext
)
>
pass_func
=
[
=
](
Function
f
,
Module
m
,
PassContext
pc
)
{
return
Downcast
<
Function
>
(
relay
::
alter_op_layout
::
AlterOpLayout
(
f
));
};
return
CreateFunctionPass
(
pass_func
,
3
,
"AlterOpLayout"
,
{
ir
::
StringImm
::
make
(
"InferType"
)});
}
TVM_REGISTER_API
(
"relay._transform.AlterOpLayout"
)
.
set_body_typed
(
AlterOpLayout
);
}
// namespace transform
}
// namespace relay
}
// namespace tvm
src/relay/pass/canonicalize_ops.cc
View file @
bb48a45b
...
...
@@ -26,6 +26,7 @@
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/transform.h>
#include "pattern_util.h"
namespace
tvm
{
...
...
@@ -63,5 +64,21 @@ Expr CanonicalizeOps(const Expr& e) {
TVM_REGISTER_API
(
"relay._ir_pass.canonicalize_ops"
)
.
set_body_typed
(
CanonicalizeOps
);
namespace
transform
{
Pass
CanonicalizeOps
()
{
runtime
::
TypedPackedFunc
<
Function
(
Function
,
Module
,
PassContext
)
>
pass_func
=
[
=
](
Function
f
,
Module
m
,
PassContext
pc
)
{
return
Downcast
<
Function
>
(
CanonicalizeOps
(
f
));
};
return
CreateFunctionPass
(
pass_func
,
3
,
"CanonicalizeOps"
,
{
ir
::
StringImm
::
make
(
"InferType"
)});
}
TVM_REGISTER_API
(
"relay._transform.CanonicalizeOps"
)
.
set_body_typed
(
CanonicalizeOps
);
}
// namespace transform
}
// namespace relay
}
// namespace tvm
src/relay/pass/combine_parallel_conv2d.cc
View file @
bb48a45b
...
...
@@ -38,6 +38,7 @@
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h>
#include <unordered_map>
#include <unordered_set>
#include "./expr_subst.h"
...
...
@@ -357,5 +358,21 @@ Expr CombineParallelConv2D(const Expr& expr, uint64_t min_num_branches) {
TVM_REGISTER_API
(
"relay._ir_pass.CombineParallelConv2D"
)
.
set_body_typed
(
CombineParallelConv2D
);
namespace
transform
{
Pass
CombineParallelConv2D
(
uint64_t
min_num_branches
)
{
runtime
::
TypedPackedFunc
<
Function
(
Function
,
Module
,
PassContext
)
>
pass_func
=
[
=
](
Function
f
,
Module
m
,
PassContext
pc
)
{
return
Downcast
<
Function
>
(
CombineParallelConv2D
(
f
,
min_num_branches
));
};
return
CreateFunctionPass
(
pass_func
,
4
,
"CombineParallelConv2d"
,
{
ir
::
StringImm
::
make
(
"InferType"
)});
}
TVM_REGISTER_API
(
"relay._transform.CombineParallelConv2D"
)
.
set_body_typed
(
CombineParallelConv2D
);
}
// namespace transform
}
// namespace relay
}
// namespace tvm
src/relay/pass/dead_code.cc
View file @
bb48a45b
...
...
@@ -158,9 +158,12 @@ Pass DeadCodeElimination() {
[
=
](
Function
f
,
Module
m
,
PassContext
pc
)
{
return
Downcast
<
Function
>
(
DeadCodeElimination
(
f
));
};
return
CreateFunctionPass
(
pass_func
,
1
,
"
dead_code_e
limination"
,
{});
return
CreateFunctionPass
(
pass_func
,
1
,
"
DeadCodeE
limination"
,
{});
}
TVM_REGISTER_API
(
"relay._transform.DeadCodeElimination"
)
.
set_body_typed
(
DeadCodeElimination
);
}
// namespace transform
}
// namespace relay
...
...
src/relay/pass/device_annotation.cc
View file @
bb48a45b
...
...
@@ -35,6 +35,7 @@
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pass.h>
#include <tvm/relay/transform.h>
#include <memory>
#include <unordered_map>
...
...
@@ -564,11 +565,14 @@ Pass RewriteAnnotatedOps(int fallback_device) {
[
=
](
Function
f
,
Module
m
,
PassContext
pc
)
{
return
Downcast
<
Function
>
(
RewriteAnnotatedOps
(
f
,
fallback_device
));
};
return
CreateFunctionPass
(
pass_func
,
1
,
"rewrite_annotated_ops"
,
{});
return
CreateFunctionPass
(
pass_func
,
1
,
"RewriteAnnotatedOps"
,
{
ir
::
StringImm
::
make
(
"InferType"
)});
}
TVM_REGISTER_API
(
"relay._transform.RewriteDeviceAnnotation"
)
.
set_body_typed
(
RewriteAnnotatedOps
);
}
// namespace transform
}
// namespace relay
}
// namespace tvm
src/relay/pass/eliminate_common_subexpr.cc
View file @
bb48a45b
...
...
@@ -29,6 +29,7 @@
*/
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include <unordered_map>
#include "./pattern_util.h"
...
...
@@ -87,5 +88,21 @@ Expr EliminateCommonSubexpr(const Expr& expr, PackedFunc callback) {
TVM_REGISTER_API
(
"relay._ir_pass.eliminate_common_subexpr"
)
.
set_body_typed
<
Expr
(
Expr
,
PackedFunc
)
>
(
EliminateCommonSubexpr
);
namespace
transform
{
Pass
EliminateCommonSubexpr
(
PackedFunc
fskip
)
{
runtime
::
TypedPackedFunc
<
Function
(
Function
,
Module
,
PassContext
)
>
pass_func
=
[
=
](
Function
f
,
Module
m
,
PassContext
pc
)
{
return
Downcast
<
Function
>
(
EliminateCommonSubexpr
(
f
,
fskip
));
};
return
CreateFunctionPass
(
pass_func
,
3
,
"EliminateCommonSubexpr"
,
{
ir
::
StringImm
::
make
(
"InferType"
)});
}
TVM_REGISTER_API
(
"relay._transform.EliminateCommonSubexpr"
)
.
set_body_typed
(
EliminateCommonSubexpr
);
}
// namespace transform
}
// namespace relay
}
// namespace tvm
src/relay/pass/fold_constant.cc
View file @
bb48a45b
...
...
@@ -26,6 +26,7 @@
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/interpreter.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/transform.h>
namespace
tvm
{
namespace
relay
{
...
...
@@ -220,11 +221,14 @@ namespace transform {
Pass
FoldConstant
()
{
runtime
::
TypedPackedFunc
<
Function
(
Function
,
Module
,
PassContext
)
>
pass_func
=
[
=
](
Function
f
,
Module
m
,
PassContext
pc
)
{
return
Downcast
<
Function
>
(
FoldConstant
(
f
));
return
Downcast
<
Function
>
(
FoldConstant
(
f
));
};
return
CreateFunctionPass
(
pass_func
,
1
,
"fold_c
onstant"
,
{});
return
CreateFunctionPass
(
pass_func
,
2
,
"FoldC
onstant"
,
{});
}
TVM_REGISTER_API
(
"relay._transform.FoldConstant"
)
.
set_body_typed
(
FoldConstant
);
}
// namespace transform
}
// namespace relay
...
...
src/relay/pass/fold_scale_axis.cc
View file @
bb48a45b
...
...
@@ -29,6 +29,7 @@
#include <tvm/relay/pass.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include "pattern_util.h"
#include "pass_util.h"
...
...
@@ -530,7 +531,7 @@ RELAY_REGISTER_OP("nn.conv2d")
.
set_attr
<
FForwardRewrite
>
(
"FScaleAxisForwardRewrite"
,
Conv2DForwardRewrite
);
Expr
ForwardFoldScaleAxis
(
Expr
data
)
{
Expr
ForwardFoldScaleAxis
(
const
Expr
&
data
)
{
auto
message
=
ForwardPrep
().
Prepare
(
data
);
auto
fcontext
=
[
&
](
const
Call
&
call
)
->
NodeRef
{
auto
it
=
message
.
find
(
call
.
get
());
...
...
@@ -942,7 +943,7 @@ RELAY_REGISTER_OP("nn.conv2d")
RELAY_REGISTER_OP
(
"nn.conv2d"
)
.
set_attr
<
FBackwardTransform
>
(
"FScaleAxisBackwardTransform"
,
Conv2DBackwardTransform
);
Expr
BackwardFoldScaleAxis
(
Expr
data
)
{
Expr
BackwardFoldScaleAxis
(
const
Expr
&
data
)
{
return
make_node
<
BackwardTransformerNode
>
()
->
Fold
(
data
);
}
...
...
@@ -950,5 +951,42 @@ TVM_REGISTER_API("relay._ir_pass.backward_fold_scale_axis")
.
set_body_typed
<
Expr
(
Expr
)
>
(
BackwardFoldScaleAxis
);
}
// namespace fold_scale_axis
namespace
transform
{
Pass
ForwardFoldScaleAxis
()
{
runtime
::
TypedPackedFunc
<
Function
(
Function
,
Module
,
PassContext
)
>
pass_func
=
[
=
](
Function
f
,
Module
m
,
PassContext
pc
)
{
return
Downcast
<
Function
>
(
relay
::
fold_scale_axis
::
ForwardFoldScaleAxis
(
f
));
};
return
CreateFunctionPass
(
pass_func
,
3
,
"ForwardFoldScaleAxis"
,
{
ir
::
StringImm
::
make
(
"InferType"
)});
}
Pass
BackwardFoldScaleAxis
()
{
runtime
::
TypedPackedFunc
<
Function
(
Function
,
Module
,
PassContext
)
>
pass_func
=
[
=
](
Function
f
,
Module
m
,
PassContext
pc
)
{
return
Downcast
<
Function
>
(
relay
::
fold_scale_axis
::
BackwardFoldScaleAxis
(
f
));
};
return
CreateFunctionPass
(
pass_func
,
3
,
"BackwardFoldScaleAxis"
,
{
ir
::
StringImm
::
make
(
"InferType"
)});
}
Pass
FoldScaleAxis
()
{
// FoldScaleAxis pass contains the following three passes. Therefore, we can
// register it as a sequential pass.
Pass
pass
=
Sequential
(
{
BackwardFoldScaleAxis
(),
ForwardFoldScaleAxis
(),
FoldConstant
()},
"FoldScaleAxis"
);
return
pass
;
}
TVM_REGISTER_API
(
"relay._transform.FoldScaleAxis"
)
.
set_body_typed
(
FoldScaleAxis
);
}
// namespace transform
}
// namespace relay
}
// namespace tvm
src/relay/pass/forward_rewrite.cc
View file @
bb48a45b
...
...
@@ -220,7 +220,7 @@ Pass ForwardRewrite(const std::string& rewrite_map_attr_name,
fcontext
,
fmulti_ref_trigger
));
};
return
CreateFunctionPass
(
pass_func
,
1
,
"
forward_r
ewrite"
,
{});
return
CreateFunctionPass
(
pass_func
,
1
,
"
ForwardR
ewrite"
,
{});
}
Pass
ForwardRewrite
(
const
FForwardRewrite
&
rewrite_func
,
...
...
@@ -233,7 +233,7 @@ Pass ForwardRewrite(const FForwardRewrite& rewrite_func,
fcontext
,
fmulti_ref_trigger
));
};
return
CreateFunctionPass
(
pass_func
,
1
,
"
forward_rewrite
"
,
{});
return
CreateFunctionPass
(
pass_func
,
1
,
"
ForwardRewriteFunc
"
,
{});
}
}
// namespace transform
...
...
src/relay/pass/fuse_ops.cc
View file @
bb48a45b
...
...
@@ -29,6 +29,7 @@
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h>
#include "./pattern_util.h"
#include "../../common/arena.h"
...
...
@@ -973,9 +974,13 @@ Pass FuseOps(int fuse_opt_level) {
int
opt_level
=
fuse_opt_level
==
-
1
?
pc
->
opt_level
:
fuse_opt_level
;
return
Downcast
<
Function
>
(
FuseOps
(
f
,
opt_level
,
m
));
};
return
CreateFunctionPass
(
pass_func
,
1
,
"fuse_ops"
,
{});
return
CreateFunctionPass
(
pass_func
,
1
,
"FuseOps"
,
{
ir
::
StringImm
::
make
(
"InferType"
)});
}
TVM_REGISTER_API
(
"relay._transform.FuseOps"
)
.
set_body_typed
(
FuseOps
);
}
// namespace transform
}
// namespace relay
...
...
src/relay/pass/partial_eval.cc
View file @
bb48a45b
...
...
@@ -797,9 +797,7 @@ Expr PartialEval(const Expr& e) {
}
TVM_REGISTER_API
(
"relay._ir_pass.partial_evaluate"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
PartialEval
(
args
[
0
]);
});
.
set_body_typed
(
PartialEval
);
namespace
transform
{
...
...
@@ -808,9 +806,12 @@ Pass PartialEval() {
[
=
](
Function
f
,
Module
m
,
PassContext
pc
)
{
return
Downcast
<
Function
>
(
PartialEval
(
f
));
};
return
CreateFunctionPass
(
pass_func
,
1
,
"
partial_eval
"
,
{});
return
CreateFunctionPass
(
pass_func
,
1
,
"
PartialEvaluate
"
,
{});
}
TVM_REGISTER_API
(
"relay._transform.PartialEvaluate"
)
.
set_body_typed
(
PartialEval
);
}
// namespace transform
}
// namespace relay
...
...
src/relay/pass/pass_manager.cc
View file @
bb48a45b
...
...
@@ -37,42 +37,46 @@ namespace transform {
using
tvm
::
IRPrinter
;
/*!
* \brief A data structure to map the names of specific optimizations to
* numeric optimization levels
*/
class
OptPassLevel
{
public
:
/*!
* \brief Get level for an optimization pass
*
* \param key pass name
* \return int level
*/
int
operator
[](
const
std
::
string
&
key
)
const
{
const
auto
data
=
CreateMap
();
auto
it
=
data
.
find
(
key
);
if
(
it
==
data
.
end
())
{
return
-
1
;
}
return
it
->
second
;
namespace
{
// TODO(zhiics) Maybe we can use PackedFunc here so that parameters can be
// handled because we need to register the pass for Python invocation anyway.
Pass
GetPass
(
const
std
::
string
&
pass_name
)
{
if
(
pass_name
==
"InferType"
)
{
return
InferType
();
}
else
if
(
pass_name
==
"AlterOpLayout"
)
{
return
AlterOpLayout
();
}
else
if
(
pass_name
==
"CanonicalizeOps"
)
{
return
CanonicalizeOps
();
}
else
if
(
pass_name
==
"CombineParallelConv2d"
)
{
return
CombineParallelConv2D
();
}
else
if
(
pass_name
==
"DeadCodeElimination"
)
{
return
DeadCodeElimination
();
}
else
if
(
pass_name
==
"EliminateCommonSubexpr"
)
{
return
DeadCodeElimination
();
}
else
if
(
pass_name
==
"FoldConstant"
)
{
return
FoldConstant
();
}
else
if
(
pass_name
==
"BackwardFoldScaleAxis"
)
{
return
FoldScaleAxis
();
}
else
if
(
pass_name
==
"ForwardFoldScaleAxis"
)
{
return
FoldScaleAxis
();
}
else
if
(
pass_name
==
"FoldScaleAxis"
)
{
return
FoldScaleAxis
();
}
else
if
(
pass_name
==
"PartialEvaluate"
)
{
return
SimplifyInference
();
}
else
if
(
pass_name
==
"SimplifyInference"
)
{
return
SimplifyInference
();
}
else
if
(
pass_name
==
"ToANormalForm"
)
{
return
ToANormalForm
();
}
else
if
(
pass_name
==
"ToGraphNormalForm"
)
{
return
ToGraphNormalForm
();
}
else
{
LOG
(
FATAL
)
<<
pass_name
<<
" has not been registered yet."
<<
"
\n
"
;
return
Pass
(
nullptr
);
}
}
private
:
static
const
std
::
unordered_map
<
std
::
string
,
int
>
CreateMap
()
{
const
std
::
unordered_map
<
std
::
string
,
int
>
m
=
{
{
"SimplifyInference"
,
0
},
{
"OpFusion"
,
1
},
{
"FoldConstant"
,
2
},
{
"CombineParallelConv2D"
,
3
},
{
"FoldScaleAxis"
,
3
},
{
"AlterOpLayout"
,
3
},
{
"CanonicalizeOps"
,
3
},
{
"EliminateCommonSubexpr"
,
3
}
};
return
m
;
}
};
}
// namespace
struct
RelayPassContextThreadLocalEntry
{
/*! \brief The default pass context. */
...
...
@@ -246,12 +250,6 @@ class SequentialNode : public PassNode {
/* \brief The pass meta data.*/
PassInfo
pass_info
;
/*!
* \brief A helper struct to get the optimization pass name to opt level
* mapping.
*/
OptPassLevel
opt_pass_level
;
/*! \brief A list of passes that used to compose a sequential pass. */
tvm
::
Array
<
Pass
>
passes
;
void
VisitAttrs
(
tvm
::
AttrVisitor
*
v
)
final
{
...
...
@@ -300,7 +298,7 @@ class SequentialNode : public PassNode {
const
Array
<
tvm
::
Expr
>&
disabled
)
const
;
std
::
unordered_set
<
std
::
string
>
RequiredPasses
(
const
Array
<
tvm
::
Expr
>&
disabl
ed
)
const
;
const
Array
<
tvm
::
Expr
>&
requir
ed
)
const
;
/*!
* \brief Perform optimizations on a series of passes. The aforementioned
...
...
@@ -338,14 +336,25 @@ ModulePass ModulePassNode::make(
}
// Module -> Module optimizations.
// TODO(zhiics) Check and handle the required passes.
Module
ModulePassNode
::
operator
()(
const
Module
&
mod
,
const
PassContext
&
pass_ctx
)
const
{
PassInfo
pass_info
=
Info
();
DLOG
(
INFO
)
<<
"Executing module pass : "
<<
pass_info
->
name
<<
" with opt level: "
<<
pass_info
->
opt_level
<<
"
\n
"
;
CHECK
(
mod
.
defined
());
auto
updated_mod
=
pass_func
(
mod
,
pass_ctx
);
Module
updated_mod
=
mod
;
// Execute the required passes in a DFS way.
// TODO(zhiics) We may need to pass validation to detect the cyclic
// dependency.
for
(
const
auto
&
it
:
pass_info
->
required
)
{
const
auto
*
name
=
it
.
as
<
tvm
::
ir
::
StringImm
>
();
CHECK
(
name
);
auto
pass
=
GetPass
(
name
->
value
);
updated_mod
=
pass
(
updated_mod
,
pass_ctx
);
}
updated_mod
=
pass_func
(
updated_mod
,
pass_ctx
);
CHECK
(
updated_mod
.
defined
());
return
updated_mod
;
}
...
...
@@ -365,12 +374,26 @@ Module FunctionPassNode::operator()(const Module& mod,
const
PassContext
&
pass_ctx
)
const
{
PassInfo
pass_info
=
Info
();
CHECK
(
mod
.
defined
());
Module
new_mod
=
ModuleNode
::
make
({},
mod
->
type_definitions
);
DLOG
(
INFO
)
<<
"Executing module pass : "
<<
pass_info
->
name
<<
" with opt level: "
<<
pass_info
->
opt_level
<<
"
\n
"
;
Module
updated_mod
=
mod
;
// Execute the required passes in a DFS way.
// TODO(zhiics) We may need to pass validation to detect the cyclic
// dependency.
for
(
const
auto
&
it
:
pass_info
->
required
)
{
const
auto
*
name
=
it
.
as
<
tvm
::
ir
::
StringImm
>
();
CHECK
(
name
);
auto
pass
=
GetPass
(
name
->
value
);
updated_mod
=
pass
(
updated_mod
,
pass_ctx
);
}
Module
new_mod
=
ModuleNode
::
make
({},
mod
->
type_definitions
);
// Execute the pass function and return a new module.
for
(
const
auto
&
it
:
mod
->
functions
)
{
auto
updated_func
=
SkipFunction
(
it
.
second
)
?
it
.
second
:
pass_func
(
it
.
second
,
mod
,
pass_ctx
);
auto
updated_func
=
SkipFunction
(
it
.
second
)
?
it
.
second
:
pass_func
(
it
.
second
,
updated_mod
,
pass_ctx
);
new_mod
->
Add
(
it
.
first
,
updated_func
);
}
...
...
@@ -418,7 +441,7 @@ std::unordered_set<std::string> SequentialNode::DisabledPasses(
std
::
unordered_set
<
std
::
string
>
ret
;
for
(
const
auto
&
it
:
disabled
)
{
const
auto
*
str
=
it
.
as
<
tvm
::
ir
::
StringImm
>
();
CHECK
(
str
)
<<
"
disabled passes
must be string."
;
CHECK
(
str
)
<<
"
Disabled pass name
must be string."
;
ret
.
emplace
(
str
->
value
);
}
return
ret
;
...
...
@@ -429,7 +452,7 @@ std::unordered_set<std::string> SequentialNode::RequiredPasses(
std
::
unordered_set
<
std
::
string
>
ret
;
for
(
const
auto
&
it
:
required
)
{
const
auto
*
str
=
it
.
as
<
tvm
::
ir
::
StringImm
>
();
CHECK
(
str
)
<<
"
disabled passes
must be string."
;
CHECK
(
str
)
<<
"
Required pass name
must be string."
;
ret
.
emplace
(
str
->
value
);
}
return
ret
;
...
...
@@ -439,7 +462,7 @@ bool SequentialNode::PassEnabled(const std::string& pass_name) const {
PassContext
ctx
=
PassContext
::
Current
();
auto
required
=
RequiredPasses
(
ctx
->
required_pass
);
auto
disabled
=
DisabledPasses
(
ctx
->
requir
ed_pass
);
auto
disabled
=
DisabledPasses
(
ctx
->
disabl
ed_pass
);
if
(
disabled
.
count
(
pass_name
))
{
return
false
;
...
...
@@ -448,29 +471,27 @@ bool SequentialNode::PassEnabled(const std::string& pass_name) const {
if
(
required
.
count
(
pass_name
))
{
return
true
;
}
return
ctx
->
opt_level
>=
opt_pass_level
[
pass_name
];
const
Pass
pass
=
GetPass
(
pass_name
);
PassInfo
info
=
pass
->
Info
();
return
ctx
->
opt_level
>=
info
->
opt_level
;
}
// TODO(zhiics): we currenlty only sequentially execute each pass in
// a Sequential without the consideration of their orders. The phase
// ordering problem need
ed
to be handled in the future.
// ordering problem need
s
to be handled in the future.
Module
SequentialNode
::
operator
()(
const
Module
&
module
,
const
PassContext
&
pass_ctx
)
const
{
int
opt_level
=
pass_ctx
->
opt_level
;
auto
disabled
=
DisabledPasses
(
pass_ctx
->
disabled_pass
);
Module
mod
=
module
;
for
(
const
Pass
&
pass
:
passes
)
{
CHECK
(
pass
.
defined
())
<<
"Found undefined pass for optimization."
;
PassInfo
info
=
pass
->
Info
();
const
auto
&
pass_name
=
info
->
name
;
const
auto
&
pass_opt_level
=
info
->
opt_level
;
// Skip the pass if its optimization level is higher that the one of in the
// pass context or if this pass is disabled.
if
(
pass_opt_level
>
opt_level
||
disabled
.
count
(
pass_name
))
{
continue
;
// Execute the pass if it is enabled.
if
(
PassEnabled
(
pass_name
))
{
mod
=
pass
(
mod
,
pass_ctx
);
}
const
auto
*
pn
=
pass
.
operator
->
();
mod
=
(
*
pn
)(
mod
,
pass_ctx
);
}
return
mod
;
}
...
...
@@ -525,15 +546,17 @@ TVM_REGISTER_API("relay._transform.CreateModulePass")
TVM_REGISTER_API
(
"relay._transform.RunPass"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
args
[
0
].
operator
Pass
()(
args
[
1
]);
Pass
pass
=
args
[
0
];
Module
mod
=
args
[
1
];
*
ret
=
pass
(
mod
);
});
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
.
set_dispatch
<
ModulePassNode
>
([](
const
ModulePassNode
*
node
,
tvm
::
IRPrinter
*
p
)
{
const
PassInfo
Node
*
pn
=
node
->
Info
().
operator
->
();
p
->
stream
<<
"Run Module pass: "
<<
pn
->
name
<<
" at the optimization level "
<<
pn
->
opt_level
;
const
PassInfo
info
=
node
->
Info
();
p
->
stream
<<
"Run Module pass: "
<<
info
->
name
<<
" at the optimization level "
<<
info
->
opt_level
;
});
TVM_REGISTER_NODE_TYPE
(
FunctionPassNode
);
...
...
@@ -544,9 +567,9 @@ TVM_REGISTER_API("relay._transform.CreateFunctionPass")
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
.
set_dispatch
<
FunctionPassNode
>
([](
const
FunctionPassNode
*
node
,
tvm
::
IRPrinter
*
p
)
{
const
PassInfo
Node
*
pn
=
node
->
Info
().
operator
->
();
p
->
stream
<<
"Run Function pass: "
<<
pn
->
name
<<
" at the optimization level "
<<
pn
->
opt_level
;
const
PassInfo
info
=
node
->
Info
();
p
->
stream
<<
"Run Function pass: "
<<
info
->
name
<<
" at the optimization level "
<<
info
->
opt_level
;
});
TVM_REGISTER_NODE_TYPE
(
SequentialNode
);
...
...
@@ -564,14 +587,13 @@ TVM_REGISTER_API("relay._transform.Sequential")
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
.
set_dispatch
<
SequentialNode
>
([](
const
SequentialNode
*
node
,
tvm
::
IRPrinter
*
p
)
{
const
PassInfo
Node
*
seq_pn
=
node
->
Info
().
operator
->
();
p
->
stream
<<
"Run Sequential pass: "
<<
seq_pn
->
name
<<
" at the optimization level "
<<
seq_pn
->
opt_level
<<
". "
;
const
PassInfo
info
=
node
->
Info
();
p
->
stream
<<
"Run Sequential pass: "
<<
info
->
name
<<
" at the optimization level "
<<
info
->
opt_level
<<
". "
;
p
->
stream
<<
"The passes will be executed are: ["
;
for
(
const
auto
&
it
:
node
->
passes
)
{
const
PassNode
*
pn
=
it
.
operator
->
();
const
PassInfoNode
*
pass_info_node
=
pn
->
Info
().
operator
->
();
p
->
stream
<<
pass_info_node
->
name
<<
" "
;
const
PassInfo
pass_info
=
it
->
Info
();
p
->
stream
<<
pass_info
->
name
<<
" "
;
}
p
->
stream
<<
"]"
;
});
...
...
src/relay/pass/simplify_inference.cc
View file @
bb48a45b
...
...
@@ -24,6 +24,7 @@
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/transform.h>
#include "./pattern_util.h"
namespace
tvm
{
...
...
@@ -105,5 +106,21 @@ Expr SimplifyInference(const Expr& e) {
TVM_REGISTER_API
(
"relay._ir_pass.simplify_inference"
)
.
set_body_typed
(
SimplifyInference
);
namespace
transform
{
Pass
SimplifyInference
()
{
runtime
::
TypedPackedFunc
<
Function
(
Function
,
Module
,
PassContext
)
>
pass_func
=
[
=
](
Function
f
,
Module
m
,
PassContext
pc
)
{
return
Downcast
<
Function
>
(
SimplifyInference
(
f
));
};
return
CreateFunctionPass
(
pass_func
,
0
,
"SimplifyInference"
,
{
ir
::
StringImm
::
make
(
"InferType"
)});
}
TVM_REGISTER_API
(
"relay._transform.SimplifyInference"
)
.
set_body_typed
(
SimplifyInference
);
}
// namespace transform
}
// namespace relay
}
// namespace tvm
src/relay/pass/to_a_normal_form.cc
View file @
bb48a45b
...
...
@@ -340,9 +340,12 @@ Pass ToANormalForm() {
[
=
](
Function
f
,
Module
m
,
PassContext
pc
)
{
return
Downcast
<
Function
>
(
ToANormalForm
(
f
,
m
));
};
return
CreateFunctionPass
(
pass_func
,
1
,
"
to_a_normal_f
orm"
,
{});
return
CreateFunctionPass
(
pass_func
,
1
,
"
ToANormalF
orm"
,
{});
}
TVM_REGISTER_API
(
"relay._transform.ToANormalForm"
)
.
set_body_typed
(
ToANormalForm
);
}
// namespace transform
}
// namespace relay
...
...
src/relay/pass/to_graph_normal_form.cc
View file @
bb48a45b
...
...
@@ -86,9 +86,12 @@ Pass ToGraphNormalForm() {
[
=
](
Function
f
,
Module
m
,
PassContext
pc
)
{
return
Downcast
<
Function
>
(
ToGraphNormalForm
(
f
));
};
return
CreateFunctionPass
(
pass_func
,
1
,
"
to_graph_normal_f
orm"
,
{});
return
CreateFunctionPass
(
pass_func
,
1
,
"
ToGraphNormalF
orm"
,
{});
}
TVM_REGISTER_API
(
"relay._transform.ToGraphNormalForm"
)
.
set_body_typed
(
ToGraphNormalForm
);
}
// namespace transform
}
// namespace relay
...
...
src/relay/pass/type_infer.cc
View file @
bb48a45b
...
...
@@ -43,6 +43,7 @@
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include <tvm/relay/pass.h>
#include <tvm/relay/transform.h>
#include "./pass_util.h"
#include "type_solver.h"
#include "../ir/type_functor.h"
...
...
@@ -807,5 +808,23 @@ TVM_REGISTER_API("relay._ir_pass.infer_type")
.
set_body_typed
<
Expr
(
const
Expr
&
,
const
Module
&
)
>
([](
const
Expr
&
expr
,
const
Module
&
mod_ref
)
{
return
InferType
(
expr
,
mod_ref
);
});
namespace
transform
{
Pass
InferType
()
{
runtime
::
TypedPackedFunc
<
Function
(
Function
,
Module
,
PassContext
)
>
pass_func
=
[
=
](
Function
f
,
Module
m
,
PassContext
pc
)
{
return
Downcast
<
Function
>
(
InferType
(
f
,
m
));
};
return
CreateFunctionPass
(
pass_func
,
0
,
"InferType"
,
{});
}
TVM_REGISTER_API
(
"relay._transform.InferType"
)
.
set_body_typed
<
Pass
()
>
([]()
{
return
InferType
();
});
}
// namespace transform
}
// namespace relay
}
// namespace tvm
tests/cpp/relay_transform_sequential.cc
0 → 100644
View file @
bb48a45b
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <gtest/gtest.h>
#include <topi/generic/injective.h>
#include <tvm/build_module.h>
#include <tvm/packed_func_ext.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/module.h>
#include <tvm/relay/pass.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/type.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <tvm/tvm.h>
TVM_REGISTER_GLOBAL
(
"schedule"
)
.
set_body
([](
tvm
::
TVMArgs
args
,
tvm
::
TVMRetValue
*
rv
)
{
*
rv
=
topi
::
generic
::
schedule_injective
(
args
[
0
],
args
[
1
]);
});
TEST
(
Relay
,
Sequential
)
{
using
namespace
tvm
;
auto
tensor_type
=
relay
::
TensorTypeNode
::
make
({
1
,
2
,
3
},
::
tvm
::
Float
(
32
));
auto
c_data
=
tvm
::
runtime
::
NDArray
::
Empty
({
1
,
2
,
3
},
{
kDLFloat
,
32
,
1
},
{
kDLCPU
,
0
});
// Create a function for optimization.
auto
c
=
relay
::
ConstantNode
::
make
(
c_data
);
auto
a
=
relay
::
VarNode
::
make
(
"a"
,
tensor_type
);
auto
x
=
relay
::
VarNode
::
make
(
"x"
,
tensor_type
);
auto
add_op
=
relay
::
Op
::
Get
(
"add"
);
auto
y
=
relay
::
CallNode
::
make
(
add_op
,
{
c
,
c
});
y
=
relay
::
CallNode
::
make
(
add_op
,
{
x
,
y
});
auto
z
=
relay
::
CallNode
::
make
(
add_op
,
{
y
,
c
});
auto
z1
=
relay
::
CallNode
::
make
(
add_op
,
{
y
,
c
});
auto
z2
=
relay
::
CallNode
::
make
(
add_op
,
{
z
,
z1
});
// Let expression and varaible a should be dead-code eliminated.
auto
z3
=
relay
::
LetNode
::
make
(
a
,
c
,
z2
);
relay
::
Function
func
=
relay
::
FunctionNode
::
make
(
relay
::
FreeVars
(
z3
),
z3
,
relay
::
Type
(),
{});
// Get schedule
auto
reg
=
tvm
::
runtime
::
Registry
::
Get
(
"relay.op._Register"
);
auto
sch
=
tvm
::
runtime
::
Registry
::
Get
(
"schedule"
);
if
(
!
reg
||
!
sch
)
{
LOG
(
FATAL
)
<<
"Register/schedule is not defined."
;
}
(
*
reg
)(
"add"
,
"FTVMSchedule"
,
*
sch
,
10
);
// Run sequential passes.
tvm
::
Array
<
relay
::
transform
::
Pass
>
pass_seqs
{
relay
::
transform
::
InferType
(),
relay
::
transform
::
DeadCodeElimination
(),
relay
::
transform
::
EliminateCommonSubexpr
(),
relay
::
transform
::
AlterOpLayout
()
};
relay
::
transform
::
Pass
seq
=
relay
::
transform
::
Sequential
(
pass_seqs
);
auto
mod
=
relay
::
ModuleNode
::
FromExpr
(
func
);
auto
pass_ctx
=
relay
::
transform
::
PassContext
::
Create
();
pass_ctx
->
opt_level
=
3
;
pass_ctx
->
fallback_device
=
1
;
{
tvm
::
With
<
relay
::
transform
::
PassContext
>
ctx_scope
(
pass_ctx
);
tvm
::
With
<
tvm
::
Target
>
tctx
(
tvm
::
Target
::
Create
(
"llvm"
));
mod
=
seq
(
mod
);
}
CHECK
(
mod
.
defined
());
auto
entry_func
=
mod
->
entry_func
;
CHECK
(
entry_func
.
defined
());
relay
::
Function
f
=
mod
->
Lookup
(
entry_func
->
name_hint
);
CHECK
(
f
.
defined
());
// Expected function
auto
c1
=
relay
::
ConstantNode
::
make
(
c_data
);
auto
x1
=
relay
::
VarNode
::
make
(
"x"
,
tensor_type
);
auto
y1
=
relay
::
CallNode
::
make
(
add_op
,
{
c1
,
c1
});
y1
=
relay
::
CallNode
::
make
(
add_op
,
{
x1
,
y1
});
auto
zz
=
relay
::
CallNode
::
make
(
add_op
,
{
y1
,
c1
});
zz
=
relay
::
CallNode
::
make
(
add_op
,
{
zz
,
zz
});
relay
::
Function
expected_func
=
relay
::
FunctionNode
::
make
(
relay
::
FreeVars
(
zz
),
zz
,
relay
::
Type
(),
{});
// Infer type for the expected function.
auto
expected
=
relay
::
InferType
(
expected_func
,
relay
::
Module
(
nullptr
));
CHECK
(
relay
::
AlphaEqual
(
f
,
expected
));
}
int
main
(
int
argc
,
char
**
argv
)
{
testing
::
InitGoogleTest
(
&
argc
,
argv
);
testing
::
FLAGS_gtest_death_test_style
=
"threadsafe"
;
return
RUN_ALL_TESTS
();
}
tests/python/relay/test_pass_manager.py
View file @
bb48a45b
...
...
@@ -327,7 +327,8 @@ def test_sequential_pass():
def
test_only_module_pass
():
passes
=
[
module_pass
]
sequential
=
_transform
.
Sequential
(
opt_level
=
1
,
passes
=
passes
)
ret_mod
=
sequential
(
mod
)
with
relay
.
build_config
(
required_pass
=
[
"mod_transform"
]):
ret_mod
=
sequential
(
mod
)
# Check the subtract function.
sub_var
,
new_sub
=
extract_var_func
(
ret_mod
,
v_sub
.
name_hint
)
check_func
(
new_sub
,
sub
)
...
...
@@ -341,7 +342,8 @@ def test_sequential_pass():
# Check the subtract function.
passes
=
[
function_pass
]
sequential
=
_transform
.
Sequential
(
opt_level
=
1
,
passes
=
passes
)
ret_mod
=
sequential
(
mod
)
with
relay
.
build_config
(
required_pass
=
[
"func_transform"
]):
ret_mod
=
sequential
(
mod
)
_
,
new_sub
=
extract_var_func
(
ret_mod
,
v_sub
.
name_hint
)
check_func
(
new_sub
,
get_ref_sub
())
...
...
@@ -355,7 +357,9 @@ def test_sequential_pass():
mod
=
relay
.
Module
({
v_sub
:
sub
,
v_log
:
log
})
passes
=
[
module_pass
,
function_pass
]
sequential
=
_transform
.
Sequential
(
opt_level
=
1
,
passes
=
passes
)
ret_mod
=
sequential
(
mod
)
required
=
[
"mod_transform"
,
"func_transform"
]
with
relay
.
build_config
(
required_pass
=
required
):
ret_mod
=
sequential
(
mod
)
# Check the abs function is added.
abs_var
,
abs_func
=
get_var_func
()
...
...
@@ -400,7 +404,48 @@ def test_sequential_pass():
test_multiple_passes
()
def
test_sequential_with_scoping
():
shape
=
(
1
,
2
,
3
)
c_data
=
np
.
array
(
shape
)
.
astype
(
"float32"
)
tp
=
relay
.
TensorType
(
shape
,
"float32"
)
def
before
():
c
=
relay
.
const
(
c_data
)
x
=
relay
.
var
(
"x"
,
tp
)
y
=
relay
.
add
(
c
,
c
)
y
=
relay
.
multiply
(
y
,
relay
.
const
(
2
,
"float32"
))
y
=
relay
.
add
(
x
,
y
)
z
=
relay
.
add
(
y
,
c
)
z1
=
relay
.
add
(
y
,
c
)
z2
=
relay
.
add
(
z
,
z1
)
return
relay
.
Function
([
x
],
z2
)
def
expected
():
x
=
relay
.
var
(
"x"
,
tp
)
c_folded
=
(
c_data
+
c_data
)
*
2
y
=
relay
.
add
(
x
,
relay
.
const
(
c_folded
))
z
=
relay
.
add
(
y
,
relay
.
const
(
c_data
))
z1
=
relay
.
add
(
z
,
z
)
return
relay
.
Function
([
x
],
z1
)
seq
=
_transform
.
Sequential
([
relay
.
transform
.
InferType
(),
relay
.
transform
.
FoldConstant
(),
relay
.
transform
.
EliminateCommonSubexpr
(),
relay
.
transform
.
AlterOpLayout
()
])
mod
=
relay
.
Module
({
"main"
:
before
()})
with
relay
.
build_config
(
opt_level
=
3
):
with
tvm
.
target
.
create
(
"llvm"
):
mod
=
seq
(
mod
)
zz
=
mod
[
"main"
]
zexpected
=
ir_pass
.
infer_type
(
expected
())
assert
relay
.
ir_pass
.
alpha_equal
(
zz
,
zexpected
)
if
__name__
==
"__main__"
:
test_module_pass
()
test_function_pass
()
test_sequential_pass
()
test_sequential_with_scoping
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment