Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
89a88c57
Commit
89a88c57
authored
May 24, 2019
by
雾雨魔理沙
Committed by
Tianqi Chen
May 24, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay] Start porting pass to the pass manager (#3191)
parent
7e648417
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
328 additions
and
69 deletions
+328
-69
include/tvm/relay/pass.h
+87
-52
include/tvm/relay/transform.h
+114
-6
src/relay/pass/dead_code.cc
+12
-0
src/relay/pass/device_annotation.cc
+12
-0
src/relay/pass/fold_constant.cc
+12
-0
src/relay/pass/forward_rewrite.cc
+31
-0
src/relay/pass/fuse_ops.cc
+14
-0
src/relay/pass/partial_eval.cc
+12
-0
src/relay/pass/pass_manager.cc
+8
-9
src/relay/pass/to_a_normal_form.cc
+12
-0
src/relay/pass/to_graph_normal_form.cc
+12
-0
tests/python/relay/test_pass_manager.py
+2
-2
No files found.
include/tvm/relay/pass.h
View file @
89a88c57
...
...
@@ -31,6 +31,7 @@
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/type.h>
#include <tvm/relay/adt.h>
#include <tvm/relay/transform.h>
#include <tvm/runtime/vm.h>
#include <string>
#include <vector>
...
...
@@ -84,7 +85,8 @@ TVM_DLL Function InferType(const Function& f, const Module& mod,
*/
TVM_DLL
Kind
KindCheck
(
const
Type
&
t
,
const
Module
&
mod
);
/*! \brief Compare two expressions for structural equivalence.
/*!
* \brief Compare two expressions for structural equivalence.
*
* This comparison operator respects scoping and compares
* expressions without regard to variable choice.
...
...
@@ -101,7 +103,8 @@ TVM_DLL Kind KindCheck(const Type& t, const Module& mod);
*/
TVM_DLL
bool
AlphaEqual
(
const
Expr
&
e1
,
const
Expr
&
e2
);
/*! \brief Compare two types for structural equivalence.
/*!
* \brief Compare two types for structural equivalence.
*
* This comparison operator respects scoping and compares
* expressions without regard to variable choice.
...
...
@@ -119,7 +122,8 @@ TVM_DLL bool AlphaEqual(const Expr& e1, const Expr& e2);
*/
TVM_DLL
bool
AlphaEqual
(
const
Type
&
t1
,
const
Type
&
t2
);
/*! \brief Add abstraction over a function
/*!
* \brief Add abstraction over a function
*
* For example: `square` is transformed to
* `fun x -> square x`.
...
...
@@ -135,7 +139,8 @@ TVM_DLL bool AlphaEqual(const Type& t1, const Type& t2);
*/
TVM_DLL
Expr
EtaExpand
(
const
Expr
&
e
,
const
Module
&
mod
);
/*! \brief Check that each Var is only bound once.
/*!
* \brief Check that each Var is only bound once.
*
* For example, the expression `let x = 1 in let x = 2 in 3` bound x twice.
*
...
...
@@ -148,7 +153,8 @@ TVM_DLL Expr EtaExpand(const Expr& e, const Module& mod);
*/
TVM_DLL
bool
WellFormed
(
const
Expr
&
expr
);
/*! \brief Get all bound variables from expression expr.
/*!
* \brief Get all bound variables from expression expr.
*
* Bound variables are all variables that are declared in the expr.
* They only have meaning inside that expr, and can only be used in it.
...
...
@@ -159,7 +165,8 @@ TVM_DLL bool WellFormed(const Expr& expr);
*/
TVM_DLL
tvm
::
Array
<
Var
>
BoundVars
(
const
Expr
&
expr
);
/*! \brief Get all bound variables from pattern pat.
/*!
* \brief Get all bound variables from pattern pat.
*
* Bound variables are all variables that got bound by the pat.
* They only have meaning inside that expr, and can only be used in it.
...
...
@@ -170,7 +177,8 @@ TVM_DLL tvm::Array<Var> BoundVars(const Expr& expr);
*/
TVM_DLL
tvm
::
Array
<
Var
>
BoundVars
(
const
Pattern
&
pat
);
/*! \brief Get free type parameters from expression expr.
/*!
* \brief Get free type parameters from expression expr.
*
* Free variables are variables that are not bound by a
* let or a function parameter in the context.
...
...
@@ -181,7 +189,8 @@ TVM_DLL tvm::Array<Var> BoundVars(const Pattern& pat);
*/
TVM_DLL
tvm
::
Array
<
Var
>
FreeVars
(
const
Expr
&
expr
);
/*! \brief Get all variables from expression expr.
/*!
* \brief Get all variables from expression expr.
*
* \param expr the expression.
*
...
...
@@ -189,7 +198,8 @@ TVM_DLL tvm::Array<Var> FreeVars(const Expr& expr);
*/
TVM_DLL
tvm
::
Array
<
Var
>
AllVars
(
const
Expr
&
expr
);
/*! \brief Get free TypeVars from expression expr.
/*!
* \brief Get free TypeVars from expression expr.
*
* Free type parameters are type parameters that are not bound by a function
* type in the context.
...
...
@@ -201,7 +211,8 @@ TVM_DLL tvm::Array<Var> AllVars(const Expr& expr);
*/
TVM_DLL
tvm
::
Array
<
TypeVar
>
FreeTypeVars
(
const
Expr
&
expr
,
const
Module
&
mod
);
/*! \brief Get free TypeVars from type t.
/*!
* \brief Get free TypeVars from type t.
*
* Free type parameters are type parameters that are not bound by a function
* type in the context.
...
...
@@ -213,7 +224,8 @@ TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Expr& expr, const Module& mod);
*/
TVM_DLL
tvm
::
Array
<
TypeVar
>
FreeTypeVars
(
const
Type
&
t
,
const
Module
&
mod
);
/*! \brief Get all bound type variables from expression expr.
/*!
* \brief Get all bound type variables from expression expr.
*
* Bound variables are all type variables that are declared in the expr.
* They only have meaning inside that expr, and can only be used in it.
...
...
@@ -225,7 +237,8 @@ TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Type& t, const Module& mod);
*/
TVM_DLL
tvm
::
Array
<
TypeVar
>
BoundTypeVars
(
const
Expr
&
expr
,
const
Module
&
mod
);
/*! \brief Get all bound type variables from type t.
/*!
* \brief Get all bound type variables from type t.
*
* Bound variables are all type variables that are declared in the type.
* They only have meaning inside that type, and can only be used in it.
...
...
@@ -237,7 +250,8 @@ TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Expr& expr, const Module& mod);
*/
TVM_DLL
tvm
::
Array
<
TypeVar
>
BoundTypeVars
(
const
Type
&
t
,
const
Module
&
mod
);
/*! \brief Get all type variables in expression expr.
/*!
* \brief Get all type variables in expression expr.
*
* \param expr the expression.
* \param mod the module.
...
...
@@ -246,7 +260,8 @@ TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Type& t, const Module& mod);
*/
TVM_DLL
tvm
::
Array
<
TypeVar
>
AllTypeVars
(
const
Expr
&
expr
,
const
Module
&
mod
);
/*! \brief Get all type variables in type t.
/*!
* \brief Get all type variables in type t.
*
* \param t the type.
* \param mod the module.
...
...
@@ -273,22 +288,27 @@ TVM_DLL Expr DeadCodeElimination(const Expr& e);
/*!
* \brief Fold constant expressions.
*
* \param expr the expression to be optimized.
*
* \return The optimized expression.
*/
TVM_DLL
Expr
FoldConstant
(
const
Expr
&
expr
);
/*!
* \brief Fuse operations into expr into seperate functions.
*
* \param expr The expression.
* \param fuse_opt_level Optimization level.
* \param mod the module.
*
* \return The optimized expression.
*/
TVM_DLL
Expr
FuseOps
(
const
Expr
&
expr
,
int
fuse_opt_level
,
const
Module
&
mod
);
/*!
* \brief Apply rewrite rules to rewrite the expr in post DFS order.
*
* \param expr The expression.
* \param rewrite_map_attr_name The Op's attr name which corresponds to the rewrite
* rule function.
...
...
@@ -298,84 +318,68 @@ TVM_DLL Expr FuseOps(const Expr& expr, int fuse_opt_level, const Module& mod);
* \return The rewritten expression.
*/
TVM_DLL
Expr
ForwardRewrite
(
const
Expr
&
expr
,
const
std
::
string
&
rewrite_map_attr_name
,
std
::
function
<
NodeRef
(
const
Call
&
)
>
fcontext
=
nullptr
,
std
::
function
<
Expr
(
const
Expr
&
)
>
fmulti_ref_trigger
=
nullptr
);
const
std
::
string
&
rewrite_map_attr_name
,
std
::
function
<
NodeRef
(
const
Call
&
)
>
fcontext
=
nullptr
,
std
::
function
<
Expr
(
const
Expr
&
)
>
fmulti_ref_trigger
=
nullptr
);
/*!
* \brief Apply rewrite rules to rewrite the expr in post DFS order.
*
* \param expr The expression.
* \param rewrite_func The rewrite func that will apply to all operators.
* \param fcontext Additional callback to provide context argument for each call node.
* \param fmulti_ref_trigger Transformation function to be called when
* an Expr consumed by multiple callers.
*
* \return The rewritten expression.
*/
TVM_DLL
Expr
ForwardRewrite
(
const
Expr
&
expr
,
const
FForwardRewrite
&
rewrite_func
,
std
::
function
<
NodeRef
(
const
Call
&
)
>
fcontext
=
nullptr
,
std
::
function
<
Expr
(
const
Expr
&
)
>
fmulti_ref_trigger
=
nullptr
);
const
FForwardRewrite
&
rewrite_func
,
std
::
function
<
NodeRef
(
const
Call
&
)
>
fcontext
=
nullptr
,
std
::
function
<
Expr
(
const
Expr
&
)
>
fmulti_ref_trigger
=
nullptr
);
/*!
* \brief Rewrite the annotated program.
*
* \param expr The expression.
* \param fallback_device The fallback device which is the default device for
* operators without annotation.
*
* \return The updated program.
*/
TVM_DLL
Expr
RewriteAnnotatedOps
(
const
Expr
&
expr
,
int
fallback_device
);
/*!
* \brief Collect the device mapping information of each expression.
*
* \param expr The expression.
*
* \return The device mapping.
*/
TVM_DLL
Map
<
Expr
,
Integer
>
CollectDeviceInfo
(
const
Expr
&
expr
);
/*! \brief A hashing structure in the style of std::hash. */
struct
StructuralHash
{
/*! \brief Hash a Relay type.
*
* Implements structural hashing of a Relay type.
*
* \param type the type to hash.
*
* \return the hash value.
*/
size_t
operator
()(
const
Type
&
type
)
const
;
/*! \brief Hash a Relay expression.
*
* Implements structural hashing of a Relay expression.
*
* \param expr the expression to hash.
*
* \return the hash value.
*/
size_t
operator
()(
const
Expr
&
expr
)
const
;
};
/*! \brief turn a dataflow graph into Administrative Normal Form, or A-Normal Form (ANF).
/*!
* \brief turn a dataflow graph into Administrative Normal Form, or A-Normal Form (ANF).
*
* It will turn an expression that is in a graph form (with sharing implicit),
* to an expression with explicit sharing (A-Normal Form).
*
* The scope of the root expression is the global scope.
*
* The scope of any non root expression is the least common ancestor of all it's scope.
*
* Values are ordered by post-DFS order in each scope.
*
* \param e the expression to observably share
*
* \param e the expression to observably share.
* \param mod The module used for referencing global functions, can be
* None.
*
* \return expression in A-Normal Form
* \return expression in A-Normal Form
.
*/
TVM_DLL
Expr
ToANormalForm
(
const
Expr
&
e
,
const
Module
&
mod
);
/*! \brief Remove let binding and directly share via pointer instead.
/*!
* \brief Remove let binding and directly share via pointer instead.
*
* It will remove all let binding,
* and turn all of the variable bound by let into direct pointer reference.
...
...
@@ -386,18 +390,49 @@ TVM_DLL Expr ToANormalForm(const Expr& e, const Module& mod);
*/
TVM_DLL
Expr
ToGraphNormalForm
(
const
Expr
&
e
);
/*! \brief Aggressive constant propagation/constant folding/inlining.
/*!
* \brief Aggressive constant propagation/constant folding/inlining.
*
* It will do as much computation in compile time as possible.
* It has two benefit: remove runtime overhead, and allow more optimization (typically fusion).
* As a side effect, code size will explode.
*
* \param e the expression,
*
* \return the optimized expression.
*/
Expr
PartialEval
(
const
Expr
&
e
);
TVM_DLL
Expr
PartialEval
(
const
Expr
&
e
);
/*! \brief A hashing structure in the style of std::hash. */
struct
StructuralHash
{
/*! \brief Hash a Relay type.
*
* Implements structural hashing of a Relay type.
*
* \param type the type to hash.
*
* \return the hash value.
*/
size_t
operator
()(
const
Type
&
type
)
const
;
/*! \brief Hash a Relay expression.
*
* Implements structural hashing of a Relay expression.
*
* \param expr the expression to hash.
*
* \return the hash value.
*/
size_t
operator
()(
const
Expr
&
expr
)
const
;
};
namespace
vm
{
/*! \brief Compile a module, and construct the virtual machine.
/*!
* \brief Compile a module, and construct the virtual machine.
*
* \param mod The module to compile.
*
* \return The constructed virtual machine.
*/
runtime
::
vm
::
VirtualMachine
CompileModule
(
const
Module
&
mod
);
...
...
include/tvm/relay/transform.h
View file @
89a88c57
...
...
@@ -61,6 +61,7 @@
#include <tvm/relay/error.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/module.h>
#include <tvm/relay/op_attr_types.h>
#include <string>
#include <unordered_map>
#include <vector>
...
...
@@ -198,7 +199,7 @@ class Pass;
*/
class
PassNode
:
public
RelayNode
{
public
:
/*
/*
!
* \brief Get the pass information/meta data. */
virtual
PassInfo
Info
()
const
=
0
;
...
...
@@ -300,11 +301,118 @@ Pass CreateModulePass(
*
* \return The created function pass.
*/
Pass
CreateFunctionPass
(
const
runtime
::
TypedPackedFunc
<
Function
(
Function
,
PassContext
)
>&
pass_func
,
int
opt_level
,
const
std
::
string
&
name
,
const
tvm
::
Array
<
tvm
::
Expr
>&
required
);
TVM_DLL
Pass
CreateFunctionPass
(
const
runtime
::
TypedPackedFunc
<
Function
(
Function
,
Module
,
PassContext
)
>&
pass_func
,
int
opt_level
,
const
std
::
string
&
name
,
const
tvm
::
Array
<
tvm
::
Expr
>&
required
);
/*! \brief Remove expressions which does not effect the program result.
*
* It will remove let bindings which are not referenced,
* and inline let bindings that are only used once.
*
* For example, this pass should turn `let a = 1 in 2` into `2`,
* as the value of the expression does not depend on a.
*
* As another example, `let a = 1 in a` will be optimized into 1.
*
* \return the pass.
*/
TVM_DLL
Pass
DeadCodeElimination
();
/*!
* \brief Fold constant expressions.
*
* \return The pass.
*/
TVM_DLL
Pass
FoldConstant
();
/*!
* \brief Fuse operations into expr into seperate functions.
*
* \param fuse_opt_level Optimization level. If it is -1 it will be inferred from pass context.
*
* \return The pass.
*/
TVM_DLL
Pass
FuseOps
(
int
fuse_opt_level
=
-
1
);
/*!
* \brief Apply rewrite rules to rewrite the expr in post DFS order.
*
* \param rewrite_map_attr_name The Op's attr name which corresponds to the rewrite
* rule function.
* \param fcontext Additional callback to provide context argument for each call node.
* \param fmulti_ref_trigger Transformation function to be called when
* an Expr consumed by multiple callers.
*
* \return The pass.
*/
TVM_DLL
Pass
ForwardRewrite
(
const
std
::
string
&
rewrite_map_attr_name
,
std
::
function
<
NodeRef
(
const
Call
&
)
>
fcontext
=
nullptr
,
std
::
function
<
Expr
(
const
Expr
&
)
>
fmulti_ref_trigger
=
nullptr
);
/*!
* \brief Apply rewrite rules to rewrite the expr in post DFS order.
*
* \param rewrite_func The rewrite func that will apply to all operators.
* \param fcontext Additional callback to provide context argument for each call node.
* \param fmulti_ref_trigger Transformation function to be called when
* an Expr consumed by multiple callers.
*
* \return The pass.
*/
TVM_DLL
Pass
ForwardRewrite
(
const
FForwardRewrite
&
rewrite_func
,
std
::
function
<
NodeRef
(
const
Call
&
)
>
fcontext
=
nullptr
,
std
::
function
<
Expr
(
const
Expr
&
)
>
fmulti_ref_trigger
=
nullptr
);
/*!
* \brief Rewrite the annotated program.
*
* \param fallback_device The fallback device which is the default device for
* operators without annotation.
*
* \return The pass.
*/
TVM_DLL
Pass
RewriteAnnotatedOps
(
int
fallback_device
);
/*!
* \brief turn a dataflow graph into Administrative Normal Form, or A-Normal Form (ANF).
*
* It will turn an expression that is in a graph form (with sharing implicit),
* to an expression with explicit sharing (A-Normal Form).
*
* The scope of the root expression is the global scope.
*
* The scope of any non root expression is the least common ancestor of all it's scope.
*
* Values are ordered by post-DFS order in each scope.
*
* \return The pass.
*/
TVM_DLL
Pass
ToANormalForm
();
/*!
* \brief Remove let binding and directly share via pointer instead.
*
* It will remove all let binding,
* and turn all of the variable bound by let into direct pointer reference.
*
* \return the expression in graph normal form.
*/
TVM_DLL
Pass
ToGraphNormalForm
();
/*!
* \brief Aggressive constant propagation/constant folding/inlining.
*
* It will do as much computation in compile time as possible.
* It has two benefit: remove runtime overhead, and allow more optimization (typically fusion).
* As a side effect, code size will explode.
*
* \return the optimized expression.
*/
TVM_DLL
Pass
PartialEval
();
}
// namespace transform
}
// namespace relay
...
...
src/relay/pass/dead_code.cc
View file @
89a88c57
...
...
@@ -151,5 +151,17 @@ Expr DeadCodeElimination(const Expr& e) {
TVM_REGISTER_API
(
"relay._ir_pass.dead_code_elimination"
)
.
set_body_typed
(
DeadCodeElimination
);
namespace
transform
{
Pass
DeadCodeElimination
()
{
runtime
::
TypedPackedFunc
<
Function
(
Function
,
Module
,
PassContext
)
>
pass_func
=
[
=
](
Function
f
,
Module
m
,
PassContext
pc
)
{
return
Downcast
<
Function
>
(
DeadCodeElimination
(
f
));
};
return
CreateFunctionPass
(
pass_func
,
1
,
"dead_code_elimination"
,
{});
}
}
// namespace transform
}
// namespace relay
}
// namespace tvm
src/relay/pass/device_annotation.cc
View file @
89a88c57
...
...
@@ -550,6 +550,18 @@ TVM_REGISTER_API("relay._ir_pass.RewriteDeviceAnnotation")
TVM_REGISTER_API
(
"relay._ir_pass.CollectDeviceAnnotationOps"
)
.
set_body_typed
(
CollectDeviceAnnotationOps
);
namespace
transform
{
Pass
RewriteAnnotatedOps
(
int
fallback_device
)
{
runtime
::
TypedPackedFunc
<
Function
(
Function
,
Module
,
PassContext
)
>
pass_func
=
[
=
](
Function
f
,
Module
m
,
PassContext
pc
)
{
return
Downcast
<
Function
>
(
RewriteAnnotatedOps
(
f
,
fallback_device
));
};
return
CreateFunctionPass
(
pass_func
,
1
,
"rewrite_annotated_ops"
,
{});
}
}
// namespace transform
}
// namespace relay
}
// namespace tvm
src/relay/pass/fold_constant.cc
View file @
89a88c57
...
...
@@ -215,5 +215,17 @@ Expr FoldConstant(const Expr& expr) {
TVM_REGISTER_API
(
"relay._ir_pass.FoldConstant"
)
.
set_body_typed
(
FoldConstant
);
namespace
transform
{
Pass
FoldConstant
()
{
runtime
::
TypedPackedFunc
<
Function
(
Function
,
Module
,
PassContext
)
>
pass_func
=
[
=
](
Function
f
,
Module
m
,
PassContext
pc
)
{
return
Downcast
<
Function
>
(
FoldConstant
(
f
));
};
return
CreateFunctionPass
(
pass_func
,
1
,
"fold_constant"
,
{});
}
}
// namespace transform
}
// namespace relay
}
// namespace tvm
src/relay/pass/forward_rewrite.cc
View file @
89a88c57
...
...
@@ -206,6 +206,37 @@ Expr ForwardRewrite(const Expr& expr,
return
ForwardRewriter
(
&
rewrite_func
,
fcontext
,
fmulti_ref_trigger
).
Rewrite
(
expr
);
}
namespace
transform
{
using
std
::
function
;
Pass
ForwardRewrite
(
const
std
::
string
&
rewrite_map_attr_name
,
function
<
NodeRef
(
const
Call
&
)
>
fcontext
,
function
<
Expr
(
const
Expr
&
)
>
fmulti_ref_trigger
)
{
runtime
::
TypedPackedFunc
<
Function
(
Function
,
Module
,
PassContext
)
>
pass_func
=
[
=
](
Function
f
,
Module
m
,
PassContext
pc
)
{
return
Downcast
<
Function
>
(
ForwardRewrite
(
f
,
rewrite_map_attr_name
,
fcontext
,
fmulti_ref_trigger
));
};
return
CreateFunctionPass
(
pass_func
,
1
,
"forward_rewrite"
,
{});
}
Pass
ForwardRewrite
(
const
FForwardRewrite
&
rewrite_func
,
function
<
NodeRef
(
const
Call
&
)
>
fcontext
,
function
<
Expr
(
const
Expr
&
)
>
fmulti_ref_trigger
)
{
runtime
::
TypedPackedFunc
<
Function
(
Function
,
Module
,
PassContext
)
>
pass_func
=
[
=
](
Function
f
,
Module
m
,
PassContext
pc
)
{
return
Downcast
<
Function
>
(
ForwardRewrite
(
f
,
rewrite_func
,
fcontext
,
fmulti_ref_trigger
));
};
return
CreateFunctionPass
(
pass_func
,
1
,
"forward_rewrite"
,
{});
}
}
// namespace transform
}
// namespace relay
}
// namespace tvm
src/relay/pass/fuse_ops.cc
View file @
89a88c57
...
...
@@ -964,5 +964,19 @@ Expr FuseOps(const Expr& expr, int fuse_opt_level, const Module& module) {
TVM_REGISTER_API
(
"relay._ir_pass.FuseOps"
)
.
set_body_typed
(
FuseOps
);
namespace
transform
{
Pass
FuseOps
(
int
fuse_opt_level
)
{
runtime
::
TypedPackedFunc
<
Function
(
Function
,
Module
,
PassContext
)
>
pass_func
=
[
=
](
Function
f
,
Module
m
,
PassContext
pc
)
{
int
opt_level
=
fuse_opt_level
==
-
1
?
pc
->
opt_level
:
fuse_opt_level
;
return
Downcast
<
Function
>
(
FuseOps
(
f
,
opt_level
,
m
));
};
return
CreateFunctionPass
(
pass_func
,
1
,
"fuse_ops"
,
{});
}
}
// namespace transform
}
// namespace relay
}
// namespace tvm
src/relay/pass/partial_eval.cc
View file @
89a88c57
...
...
@@ -801,5 +801,17 @@ TVM_REGISTER_API("relay._ir_pass.partial_evaluate")
*
ret
=
PartialEval
(
args
[
0
]);
});
namespace
transform
{
Pass
PartialEval
()
{
runtime
::
TypedPackedFunc
<
Function
(
Function
,
Module
,
PassContext
)
>
pass_func
=
[
=
](
Function
f
,
Module
m
,
PassContext
pc
)
{
return
Downcast
<
Function
>
(
PartialEval
(
f
));
};
return
CreateFunctionPass
(
pass_func
,
1
,
"partial_eval"
,
{});
}
}
// namespace transform
}
// namespace relay
}
// namespace tvm
src/relay/pass/pass_manager.cc
View file @
89a88c57
...
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
...
@@ -201,7 +201,7 @@ class FunctionPassNode : public PassNode {
* `pass_func` and let it run on a given module. The same `pass_func` will
* then be applied on each function in the module.
*/
runtime
::
TypedPackedFunc
<
Function
(
Function
,
PassContext
)
>
pass_func
;
runtime
::
TypedPackedFunc
<
Function
(
Function
,
Module
,
PassContext
)
>
pass_func
;
FunctionPassNode
()
=
default
;
...
...
@@ -225,7 +225,7 @@ class FunctionPassNode : public PassNode {
PassInfo
Info
()
const
{
return
pass_info
;
}
TVM_DLL
static
FunctionPass
make
(
runtime
::
TypedPackedFunc
<
Function
(
Function
,
PassContext
)
>
pass_func
,
runtime
::
TypedPackedFunc
<
Function
(
Function
,
Module
,
PassContext
)
>
pass_func
,
PassInfo
pass_info
);
static
constexpr
const
char
*
_type_key
=
"relay.FunctionPass"
;
...
...
@@ -363,7 +363,7 @@ Module ModulePassNode::operator()(const Module& mod,
}
FunctionPass
FunctionPassNode
::
make
(
runtime
::
TypedPackedFunc
<
Function
(
Function
,
PassContext
)
>
pass_func
,
runtime
::
TypedPackedFunc
<
Function
(
Function
,
Module
,
PassContext
)
>
pass_func
,
PassInfo
pass_info
)
{
auto
n
=
make_node
<
FunctionPassNode
>
();
n
->
pass_func
=
std
::
move
(
pass_func
);
...
...
@@ -383,8 +383,7 @@ Module FunctionPassNode::operator()(const Module& mod,
// Execute the pass function and return a new module.
for
(
const
auto
&
it
:
mod
->
functions
)
{
auto
updated_func
=
SkipFunction
(
it
.
second
)
?
it
.
second
:
pass_func
(
it
.
second
,
pass_ctx
);
auto
updated_func
=
SkipFunction
(
it
.
second
)
?
it
.
second
:
pass_func
(
it
.
second
,
mod
,
pass_ctx
);
new_mod
->
Add
(
it
.
first
,
updated_func
);
}
...
...
@@ -501,7 +500,7 @@ Pass CreateModulePass(
}
Pass
CreateFunctionPass
(
const
runtime
::
TypedPackedFunc
<
Function
(
Function
,
PassContext
)
>&
pass_func
,
const
runtime
::
TypedPackedFunc
<
Function
(
Function
,
Module
,
PassContext
)
>&
pass_func
,
int
opt_level
,
const
std
::
string
&
name
,
const
tvm
::
Array
<
tvm
::
Expr
>&
required
)
{
...
...
@@ -589,7 +588,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
tvm
::
IRPrinter
*
p
)
{
const
PassInfoNode
*
seq_pn
=
node
->
Info
().
operator
->
();
p
->
stream
<<
"Run Sequential pass: "
<<
seq_pn
->
name
<<
" at the optimization level
. "
<<
seq_pn
->
opt_level
;
<<
" at the optimization level
"
<<
seq_pn
->
opt_level
<<
". "
;
p
->
stream
<<
"The passes will be executed are: ["
;
for
(
const
auto
&
it
:
node
->
passes
)
{
const
PassNode
*
pn
=
it
.
operator
->
();
...
...
src/relay/pass/to_a_normal_form.cc
View file @
89a88c57
...
...
@@ -333,5 +333,17 @@ Expr ToANormalForm(const Expr& e, const Module& m) {
TVM_REGISTER_API
(
"relay._ir_pass.to_a_normal_form"
)
.
set_body_typed
(
static_cast
<
Expr
(
*
)(
const
Expr
&
,
const
Module
&
)
>
(
ToANormalForm
));
namespace
transform
{
Pass
ToANormalForm
()
{
runtime
::
TypedPackedFunc
<
Function
(
Function
,
Module
,
PassContext
)
>
pass_func
=
[
=
](
Function
f
,
Module
m
,
PassContext
pc
)
{
return
Downcast
<
Function
>
(
ToANormalForm
(
f
,
m
));
};
return
CreateFunctionPass
(
pass_func
,
1
,
"to_a_normal_form"
,
{});
}
}
// namespace transform
}
// namespace relay
}
// namespace tvm
src/relay/pass/to_graph_normal_form.cc
View file @
89a88c57
...
...
@@ -79,5 +79,17 @@ Expr ToGraphNormalForm(const Expr& e) {
TVM_REGISTER_API
(
"relay._ir_pass.to_graph_normal_form"
)
.
set_body_typed
(
ToGraphNormalForm
);
namespace
transform
{
Pass
ToGraphNormalForm
()
{
runtime
::
TypedPackedFunc
<
Function
(
Function
,
Module
,
PassContext
)
>
pass_func
=
[
=
](
Function
f
,
Module
m
,
PassContext
pc
)
{
return
Downcast
<
Function
>
(
ToGraphNormalForm
(
f
));
};
return
CreateFunctionPass
(
pass_func
,
1
,
"to_graph_normal_form"
,
{});
}
}
// namespace transform
}
// namespace relay
}
// namespace tvm
tests/python/relay/test_pass_manager.py
View file @
89a88c57
...
...
@@ -204,7 +204,7 @@ def test_function_pass():
pass_ctx
=
None
@_transform.function_pass
(
opt_level
=
opt_level
,
name
=
pass_name
)
def
transform
(
expr
,
ctx
):
def
transform
(
expr
,
mod
,
ctx
):
return
opt_tester
.
transform
(
expr
,
ctx
)
def
get_ref_log
():
...
...
@@ -303,7 +303,7 @@ def test_sequential_pass():
# Register a function pass.
@_transform.function_pass
(
opt_level
=
1
)
def
func_transform
(
expr
,
ctx
):
def
func_transform
(
expr
,
mod
,
ctx
):
return
opt_tester
.
transform
(
expr
,
ctx
)
function_pass
=
func_transform
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment