Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
475158f6
Commit
475158f6
authored
Dec 30, 2019
by
Zhi
Committed by
masahi
Dec 31, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[relay][refactor] Cache Op::Get in passes to reduce lookup overhead (#4594)
* Refactor to use IsOp utility * retrigger CI
parent
35af4c8b
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
133 additions
and
95 deletions
+133
-95
include/tvm/relay/op.h
+2
-3
src/relay/backend/compile_engine.cc
+9
-6
src/relay/backend/interpreter.cc
+14
-9
src/relay/pass/canonicalize_cast.cc
+8
-4
src/relay/pass/canonicalize_ops.cc
+8
-2
src/relay/pass/combine_parallel_op.cc
+12
-12
src/relay/pass/combine_parallel_op.h
+6
-6
src/relay/pass/fold_constant.cc
+23
-10
src/relay/pass/fuse_ops.cc
+3
-2
src/relay/pass/partial_eval.cc
+14
-16
src/relay/pass/quantize/calibrate.cc
+5
-3
src/relay/pass/simplify_inference.cc
+18
-13
src/relay/pass/util.cc
+11
-9
No files found.
include/tvm/relay/op.h
View file @
475158f6
...
...
@@ -594,12 +594,11 @@ inline ValueType OpMap<ValueType>::get(const Expr& expr,
return
map_
.
get
<
ValueType
>
(
expr
,
def_value
);
}
/*!
* \brief Check that an expression is a "primtive operator".
* \brief Check that an expression is a "prim
i
tive operator".
*
* Will return true if the expression is an operator which
* matches the form of primtive operators registered directly
* matches the form of prim
i
tive operators registered directly
* by the Relay codebase.
*
* That is the arguments are all type variables, and there is a single
...
...
src/relay/backend/compile_engine.cc
View file @
475158f6
...
...
@@ -21,6 +21,8 @@
* \file relay/backend/compile_engine.cc
* \brief Internal compialtion engine.
*/
#include "compile_engine.h"
#include <tvm/schedule.h>
#include <tvm/packed_func_ext.h>
#include <tvm/operation.h>
...
...
@@ -29,6 +31,7 @@
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <topi/tags.h>
#include <utility>
...
...
@@ -38,7 +41,6 @@
#include <vector>
#include <unordered_map>
#include "../ir/type_functor.h"
#include "compile_engine.h"
namespace
tvm
{
namespace
relay
{
...
...
@@ -102,7 +104,7 @@ class ScheduleGetter :
public
ExprFunctor
<
Array
<
Tensor
>
(
const
Expr
&
)
>
{
public
:
explicit
ScheduleGetter
(
Target
target
)
:
target_
(
target
)
{}
:
target_
(
target
)
,
device_copy_op_
(
Op
::
Get
(
"device_copy"
))
{}
std
::
pair
<
Schedule
,
CachedFunc
>
Create
(
const
Function
&
prim_func
)
{
static
auto
fschedule
=
...
...
@@ -250,11 +252,9 @@ class ScheduleGetter :
CHECK
(
call_node
->
op
.
as
<
OpNode
>
())
<<
"Primitive function only allows call into primitive ops"
;
Op
op
=
Downcast
<
Op
>
(
call_node
->
op
);
// Check if the op is a device copy op.
bool
is_copy_op
=
op
.
same_as
(
Op
::
Get
(
"device_copy"
));
Array
<
Tensor
>
outputs
;
// Skip fcompute for device copy operators as it is not registered.
if
(
is_copy_op
)
{
if
(
op
==
device_copy_op_
)
{
const
auto
*
copy_input
=
inputs
[
0
].
operator
->
();
outputs
.
push_back
(
TensorNode
::
make
(
copy_input
->
shape
,
copy_input
->
dtype
,
Operation
(),
0
));
...
...
@@ -282,7 +282,7 @@ class ScheduleGetter :
}
// Set the name to `__copy`. It will be detected in graph runtime to perform
// data copy across devices.
if
(
is_copy_op
)
{
if
(
op
==
device_copy_op_
)
{
readable_name_stream_
.
str
(
std
::
string
());
readable_name_stream_
<<
"__copy"
;
}
else
{
...
...
@@ -332,6 +332,9 @@ class ScheduleGetter :
std
::
ostringstream
readable_name_stream_
;
std
::
unordered_map
<
Expr
,
Array
<
Tensor
>
,
NodeHash
,
NodeEqual
>
memo_
;
Array
<
Operation
>
scalars_
;
// Cache device copy op for equivalence checking to reduce registry lookup
// overhead for each invocation of call node when retrieving schedules.
const
Op
&
device_copy_op_
;
};
// Creates shape function from functor.
...
...
src/relay/backend/interpreter.cc
View file @
475158f6
...
...
@@ -246,10 +246,12 @@ class Interpreter :
public
ExprFunctor
<
Value
(
const
Expr
&
n
)
>
,
PatternFunctor
<
bool
(
const
Pattern
&
p
,
const
Value
&
v
)
>
{
public
:
Interpreter
(
Module
mod
,
DLContext
context
,
Target
target
)
:
mod_
(
mod
),
context_
(
context
),
target_
(
target
)
{
Interpreter
(
Module
mod
,
DLContext
context
,
Target
target
)
:
mod_
(
mod
),
context_
(
context
),
target_
(
target
),
debug_op_
(
Op
::
Get
(
"debug"
)),
shape_of_op_
(
Op
::
Get
(
"shape_of"
))
{
engine_
=
CompileEngine
::
Global
();
}
...
...
@@ -263,7 +265,7 @@ class Interpreter :
stack_
.
current_frame
().
locals
.
Set
(
id
,
v
);
}
inline
Value
Lookup
(
const
Var
&
local
)
{
Value
Lookup
(
const
Var
&
local
)
{
return
stack_
.
Lookup
(
local
);
}
...
...
@@ -307,7 +309,7 @@ class Interpreter :
return
TupleValueNode
::
make
(
values
);
}
inline
Value
MakeClosure
(
const
Function
&
func
,
Var
letrec_name
=
Var
())
{
Value
MakeClosure
(
const
Function
&
func
,
Var
letrec_name
=
Var
())
{
tvm
::
Map
<
Var
,
Value
>
captured_mod
;
Array
<
Var
>
free_vars
=
FreeVars
(
func
);
...
...
@@ -454,9 +456,9 @@ class Interpreter :
Value
InvokePrimitiveOp
(
const
Function
&
func
,
const
Array
<
Value
>&
args
)
{
auto
call_node
=
func
->
body
.
as
<
CallNode
>
();
const
auto
*
call_node
=
func
->
body
.
as
<
CallNode
>
();
if
(
call_node
&&
call_node
->
op
==
Op
::
Get
(
"debug"
)
)
{
if
(
call_node
&&
call_node
->
op
==
debug_op_
)
{
auto
dattrs
=
call_node
->
attrs
.
as
<
DebugAttrs
>
();
auto
interp_state
=
this
->
get_state
(
call_node
->
args
[
0
]);
...
...
@@ -540,7 +542,7 @@ class Interpreter :
Array
<
Shape
>
out_shapes
;
auto
ret_type
=
func
->
body
->
checked_type
();
bool
is_dyn
=
IsDynamic
(
func
->
checked_type
());
if
(
call_node
->
op
==
Op
::
Get
(
"shape_of"
)
)
{
if
(
call_node
->
op
==
shape_of_op_
)
{
// The output shape of shape_of must be static since Relay doesn't support
// dynamic rank tensors.
is_dyn
=
false
;
...
...
@@ -782,6 +784,9 @@ class Interpreter :
Stack
stack_
;
// Backend compile engine.
CompileEngine
engine_
;
// Cache ops that need to be frequently used later to reduce lookup overhead.
const
Op
&
debug_op_
;
const
Op
&
shape_of_op_
;
};
...
...
src/relay/pass/canonicalize_cast.cc
View file @
475158f6
...
...
@@ -62,6 +62,8 @@ namespace relay {
// \endcode
class
CastCanonicalizer
:
public
ExprMutator
{
public
:
CastCanonicalizer
()
:
cast_op_
(
Op
::
Get
(
"cast"
))
{}
Expr
VisitExpr_
(
const
CallNode
*
call
)
{
static
auto
fpattern
=
Op
::
GetAttr
<
TOpPattern
>
(
"TOpPattern"
);
...
...
@@ -91,15 +93,17 @@ class CastCanonicalizer : public ExprMutator {
private
:
std
::
unordered_map
<
const
Node
*
,
size_t
>
ref_counter_
;
// cast op is frequently checked for equivalence. Therefore, we cache it to
// reduce lookup overhead.
const
Op
&
cast_op_
;
Expr
GetNewCallArg
(
const
Expr
&
e
)
{
// if e is a upcast and ref count > 1, create an copy; otherwise call the default visitor
static
auto
&
cast
=
Op
::
Get
(
"cast"
);
Expr
new_expr
=
this
->
VisitExpr
(
e
);
if
(
const
CallNode
*
call
=
e
.
as
<
CallNode
>
())
{
if
(
call
->
op
.
same_as
(
cast
)
)
{
if
(
call
->
op
==
cast_op_
)
{
auto
attrs
=
call
->
attrs
.
as
<
CastAttrs
>
();
const
auto
*
from_type
=
call
->
args
[
0
]
->
type_as
<
TensorTypeNode
>
();
CHECK
(
from_type
);
...
...
@@ -108,7 +112,7 @@ class CastCanonicalizer : public ExprMutator {
if
(
++
ref_counter_
[
call
]
>
1
)
{
const
CallNode
*
new_call
=
new_expr
.
as
<
CallNode
>
();
CHECK
(
new_call
);
CHECK
(
new_call
->
op
.
same_as
(
cast
)
);
CHECK
(
new_call
->
op
==
cast_op_
);
return
CallNode
::
make
(
new_call
->
op
,
new_call
->
args
,
new_call
->
attrs
,
new_call
->
type_args
);
}
...
...
src/relay/pass/canonicalize_ops.cc
View file @
475158f6
...
...
@@ -24,6 +24,7 @@
*/
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/transform.h>
#include "pattern_util.h"
...
...
@@ -33,10 +34,11 @@ namespace relay {
class
BiasAddSimplifier
:
public
ExprMutator
{
public
:
BiasAddSimplifier
()
:
bias_add_op_
(
Op
::
Get
(
"nn.bias_add"
))
{}
Expr
VisitExpr_
(
const
CallNode
*
n
)
{
static
const
Op
&
bias_add
=
Op
::
Get
(
"nn.bias_add"
);
auto
new_n
=
ExprMutator
::
VisitExpr_
(
n
);
if
(
n
->
op
.
same_as
(
bias_add
)
)
{
if
(
n
->
op
==
bias_add_op_
)
{
Call
call
=
Downcast
<
Call
>
(
new_n
);
CHECK_EQ
(
call
->
args
.
size
(),
2
);
const
BiasAddAttrs
*
param
=
call
->
attrs
.
as
<
BiasAddAttrs
>
();
...
...
@@ -54,6 +56,10 @@ class BiasAddSimplifier : public ExprMutator {
}
return
new_n
;
}
private
:
// Cache the bias_add for equivalence checking.
const
Op
&
bias_add_op_
;
};
Expr
CanonicalizeOps
(
const
Expr
&
e
)
{
...
...
src/relay/pass/combine_parallel_op.cc
View file @
475158f6
...
...
@@ -27,29 +27,30 @@
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h>
#include <algorithm>
#include <utility>
#include <unordered_map>
#include <unordered_set>
#include "
./
expr_subst.h"
#include "
./
pattern_util.h"
#include "
./
combine_parallel_op.h"
#include "expr_subst.h"
#include "pattern_util.h"
#include "combine_parallel_op.h"
namespace
tvm
{
namespace
relay
{
BranchGroupFinder
::
BranchGroupFinder
(
const
std
::
string
&
op_name
,
BranchGroupFinder
::
BranchGroupFinder
(
const
Op
&
op
,
FIsSupportedOp
fis_supported_op
,
FAreCompatibleOps
fare_compatible_ops
)
:
op_name_
(
op_name
),
:
cached_op_
(
op
),
fis_supported_op_
(
fis_supported_op
),
fare_compatible_ops_
(
fare_compatible_ops
)
{
}
std
::
vector
<
Group
>
BranchGroupFinder
::
Find
(
const
Expr
&
expr
)
{
const
Op
&
op
=
Op
::
Get
(
op_name_
);
this
->
VisitExpr
(
expr
);
std
::
vector
<
Group
>
groups
;
...
...
@@ -57,7 +58,7 @@ std::vector<Group> BranchGroupFinder::Find(const Expr& expr) {
const
auto
&
children
=
children_map_
.
at
(
root
);
size_t
ngroups
=
groups
.
size
();
for
(
const
CallNode
*
child
:
children
)
{
if
(
!
child
->
op
.
same_as
(
op
)
)
continue
;
if
(
child
->
op
!=
cached_op_
)
continue
;
auto
&&
branch
=
CreateBranch
(
child
);
// add the branch to a group, or create a new group
...
...
@@ -97,9 +98,8 @@ Branch BranchGroupFinder::CreateBranch(const CallNode* op) {
}
void
BranchGroupFinder
::
VisitExpr_
(
const
CallNode
*
n
)
{
const
Op
&
op
=
Op
::
Get
(
op_name_
);
ExprVisitor
::
VisitExpr_
(
n
);
if
(
n
->
op
.
same_as
(
op
)
&&
fis_supported_op_
(
n
))
{
if
(
n
->
op
==
cached_op_
&&
fis_supported_op_
(
n
))
{
op_roots_
.
insert
(
n
->
args
[
0
]);
children_map_
[
n
->
args
[
0
]].
push_back
(
n
);
}
else
{
...
...
@@ -110,12 +110,12 @@ void BranchGroupFinder::VisitExpr_(const CallNode* n) {
}
ParallelOpCombiner
::
ParallelOpCombiner
(
const
std
::
string
&
op_name
,
uint64_t
min_num_branches
)
:
op_name_
(
op_name
),
:
cached_op_
(
Op
::
Get
(
op_name
)
),
min_num_branches_
(
min_num_branches
)
{
}
Expr
ParallelOpCombiner
::
Combine
(
const
Expr
&
expr
)
{
auto
groups
=
BranchGroupFinder
(
op_name
_
,
auto
groups
=
BranchGroupFinder
(
cached_op
_
,
[
&
](
const
CallNode
*
n
)
{
return
IsSupportedOp
(
n
);
},
...
...
src/relay/pass/combine_parallel_op.h
View file @
475158f6
...
...
@@ -68,13 +68,13 @@ class BranchGroupFinder : private ExprVisitor {
public
:
/*
* \brief Constructor
* \param op
_name name of op to start
each group
* \param op
The op that indicates the start of
each group
* \param fis_supported_op function that returns true if op
* is supported for combining
* \param fare_compatible_ops function that returns true if
* two ops are compatible for combining
*/
BranchGroupFinder
(
const
std
::
string
&
op_name
,
BranchGroupFinder
(
const
Op
&
op
,
FIsSupportedOp
fis_supported_op
,
FAreCompatibleOps
fare_compatible_ops
);
...
...
@@ -87,8 +87,8 @@ class BranchGroupFinder : private ExprVisitor {
std
::
vector
<
Group
>
Find
(
const
Expr
&
expr
);
private
:
/* \brief
name of op to find parallel branches for
*/
std
::
string
op_name
_
;
/* \brief
Cache the op for finding parallel branches
*/
const
Op
&
cached_op
_
;
/* \brief function to return true if op is eligible to be combined,
* false otherwise
...
...
@@ -205,8 +205,8 @@ class ParallelOpCombiner {
ExprSubstMap
*
subst_map
)
=
0
;
private
:
/* \brief
name of
op to be combined */
std
::
string
op_name
_
;
/* \brief
Cache the
op to be combined */
const
Op
&
cached_op
_
;
/* \brief minimum number of parallel branches to combine */
uint64_t
min_num_branches_
;
...
...
src/relay/pass/fold_constant.cc
View file @
475158f6
...
...
@@ -22,6 +22,7 @@
*/
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/interpreter.h>
#include <tvm/relay/attrs/transform.h>
...
...
@@ -33,7 +34,6 @@ namespace relay {
using
FInterpreter
=
runtime
::
TypedPackedFunc
<
Value
(
Expr
)
>
;
class
ConstantChecker
:
private
ExprVisitor
{
public
:
// Check whether an expression is constant. The results are memoized.
...
...
@@ -78,8 +78,14 @@ TVM_REGISTER_API("relay._analysis.check_constant")
class
ConstantFolder
:
public
ExprMutator
{
public
:
explicit
ConstantFolder
(
FInterpreter
executor
,
Module
module
)
:
executor_
(
executor
),
module_
(
module
)
{
}
:
executor_
(
executor
),
module_
(
module
),
shape_of_op_
(
Op
::
Get
(
"shape_of"
)),
invoke_tvm_op_
(
Op
::
Get
(
"memory.invoke_tvm_op"
)),
shape_func_op_
(
Op
::
Get
(
"memory.shape_func"
)),
alloc_tensor_op_
(
Op
::
Get
(
"memory.alloc_tensor"
)),
alloc_storage_op_
(
Op
::
Get
(
"memory.alloc_storage"
)),
cast_op_
(
Op
::
Get
(
"cast"
))
{}
Expr
VisitExpr_
(
const
LetNode
*
op
)
final
{
Expr
value
=
this
->
Mutate
(
op
->
value
);
...
...
@@ -119,15 +125,15 @@ class ConstantFolder : public ExprMutator {
// skip stateful ops.
if
(
op_stateful
.
get
(
GetRef
<
Op
>
(
op
),
false
))
return
res
;
// Try to evaluate shape_of op
if
(
call
->
op
.
same_as
(
Op
::
Get
(
"shape_of"
))
)
{
if
(
call
->
op
==
shape_of_op_
)
{
return
EvaluateShapeOf
(
res
,
origin_args
,
call
->
attrs
);
}
// We should think about potentially constant evaluation over these ops too.
if
(
call
->
op
.
same_as
(
Op
::
Get
(
"memory.invoke_tvm_op"
))
||
call
->
op
.
same_as
(
Op
::
Get
(
"memory.shape_func"
))
||
call
->
op
.
same_as
(
Op
::
Get
(
"memory.alloc_tensor"
))
||
call
->
op
.
same_as
(
Op
::
Get
(
"memory.alloc_storage"
))
)
{
if
(
call
->
op
==
invoke_tvm_op_
||
call
->
op
==
shape_func_op_
||
call
->
op
==
alloc_tensor_op_
||
call
->
op
==
alloc_storage_op_
)
{
return
GetRef
<
Call
>
(
call
);
}
...
...
@@ -162,6 +168,14 @@ class ConstantFolder : public ExprMutator {
// Module
Module
module_
;
// Cache the following ops for equivalence checking in this pass.
const
Op
&
shape_of_op_
;
const
Op
&
invoke_tvm_op_
;
const
Op
&
shape_func_op_
;
const
Op
&
alloc_tensor_op_
;
const
Op
&
alloc_storage_op_
;
const
Op
&
cast_op_
;
// Convert value to expression.
Expr
ValueToExpr
(
Value
value
)
{
if
(
const
auto
*
val
=
value
.
as
<
TensorValueNode
>
())
{
...
...
@@ -254,8 +268,7 @@ class ConstantFolder : public ExprMutator {
// Cast the constant into correct dtype
auto
cast_attrs
=
make_node
<
CastAttrs
>
();
cast_attrs
->
dtype
=
param
->
dtype
;
static
const
Op
&
cast_op
=
Op
::
Get
(
"cast"
);
Expr
ret
=
CallNode
::
make
(
cast_op
,
{
shape
},
Attrs
(
cast_attrs
),
{});
Expr
ret
=
CallNode
::
make
(
cast_op_
,
{
shape
},
Attrs
(
cast_attrs
),
{});
return
ConstEvaluate
(
ret
);
}
};
...
...
src/relay/pass/fuse_ops.cc
View file @
475158f6
...
...
@@ -78,6 +78,8 @@ using common::LinkedList;
constexpr
uint32_t
kMaxFusedOps
=
256
;
static
const
Op
&
stop_fusion_op
=
Op
::
Get
(
"annotation.stop_fusion"
);
/*!
* \brief Indexed data flow graph in forward direction.
* This is a temporary data structure used for operator fusion analysis.
...
...
@@ -860,7 +862,6 @@ class FuseMutator : private ExprMutator {
// Transform calls.
Expr
VisitExpr_
(
const
CallNode
*
call
)
{
static
const
Op
&
stop_fusion
=
Op
::
Get
(
"annotation.stop_fusion"
);
if
(
call
->
op
.
as
<
OpNode
>
())
{
static
auto
fnoncomputational
=
Op
::
GetAttr
<
TNonComputational
>
(
"TNonComputational"
);
...
...
@@ -872,7 +873,7 @@ class FuseMutator : private ExprMutator {
// If it is a primitive op call
// then we must have a group assignment for it already.
CHECK
(
gmap_
.
count
(
call
));
if
(
call
->
op
.
same_as
(
stop_fusion
)
)
{
if
(
call
->
op
==
stop_fusion_op
)
{
return
ExprMutator
::
VisitExpr
(
call
->
args
[
0
]);
}
auto
*
ret_group
=
gmap_
.
at
(
call
)
->
FindRoot
();
...
...
src/relay/pass/partial_eval.cc
View file @
475158f6
...
...
@@ -559,30 +559,28 @@ struct WithFuncIdAttrs : public tvm::AttrsNode<WithFuncIdAttrs> {
TVM_REGISTER_NODE_TYPE
(
WithFuncIdAttrs
);
Op
WithFuncIdOp
()
{
static
const
Op
&
op
=
Op
::
Get
(
"annotation.with_funcid"
);
return
op
;
}
Expr
MkWithFuncId
(
const
Expr
&
expr
,
FuncId
fid
)
{
auto
attrs
=
make_node
<
WithFuncIdAttrs
>
();
attrs
->
fid
=
fid
;
return
CallNode
::
make
(
WithFuncIdOp
(),
{
expr
},
Attrs
(
attrs
),
{});
}
RELAY_REGISTER_OP
(
"annotation.with_funcid"
)
.
describe
(
R"code(Annotate a function with a funcid.)code"
TVM_ADD_FILELINE
)
.
set_num_inputs
(
1
)
.
add_argument
(
"func"
,
"Function"
,
"The input data."
);
// Cache with_funcid op to reduce lookup overhead during traversal.
static
const
Op
&
with_funcid_op
=
Op
::
Get
(
"annotation.with_funcid"
);
Expr
MkWithFuncId
(
const
Expr
&
expr
,
FuncId
fid
)
{
auto
attrs
=
make_node
<
WithFuncIdAttrs
>
();
attrs
->
fid
=
fid
;
return
CallNode
::
make
(
with_funcid_op
,
{
expr
},
Attrs
(
attrs
),
{});
}
Expr
StripWithFuncId
(
const
Expr
&
e
);
Function
AsFunc
(
const
Expr
&
e
)
{
if
(
e
.
as
<
FunctionNode
>
())
{
return
Downcast
<
Function
>
(
e
);
}
else
if
(
const
CallNode
*
c
=
e
.
as
<
CallNode
>
())
{
CHECK
(
c
->
op
.
same_as
(
WithFuncIdOp
())
);
CHECK
(
c
->
op
==
with_funcid_op
);
CHECK_EQ
(
c
->
args
.
size
(),
1
);
return
AsFunc
(
c
->
args
[
0
]);
}
else
{
...
...
@@ -604,7 +602,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
PStatic
VisitExpr
(
const
Expr
&
e
,
LetList
*
ll
,
const
Var
&
name
)
{
if
(
const
CallNode
*
c
=
e
.
as
<
CallNode
>
())
{
if
(
c
->
op
.
same_as
(
WithFuncIdOp
())
)
{
if
(
c
->
op
==
with_funcid_op
)
{
CHECK_EQ
(
c
->
args
.
size
(),
1
);
return
VisitExpr
(
c
->
args
[
0
],
ll
,
name
);
}
...
...
@@ -722,7 +720,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
}
PStatic
VisitExpr_
(
const
CallNode
*
op
,
LetList
*
ll
)
final
{
if
(
op
->
op
.
same_as
(
WithFuncIdOp
())
)
{
if
(
op
->
op
==
with_funcid_op
)
{
CHECK_EQ
(
op
->
args
.
size
(),
1
);
return
VisitExpr
(
op
->
args
[
0
],
ll
);
}
...
...
@@ -1096,7 +1094,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
explicit
RegisterFuncIdVisitor
(
PartialEvaluator
*
pe
)
:
pe
(
pe
)
{
}
void
VisitExpr_
(
const
CallNode
*
op
)
final
{
if
(
op
->
op
.
same_as
(
WithFuncIdOp
())
)
{
if
(
op
->
op
==
with_funcid_op
)
{
CHECK_EQ
(
op
->
args
.
size
(),
1
);
CHECK
(
op
->
attrs
.
defined
());
CHECK
(
op
->
attrs
.
as
<
WithFuncIdAttrs
>
());
...
...
@@ -1194,7 +1192,7 @@ Expr Remap(const Expr& e) {
Expr
StripWithFuncId
(
const
Expr
&
e
)
{
struct
StripWithFuncIdMutator
:
ExprMutator
,
PatternMutator
{
Expr
VisitExpr_
(
const
CallNode
*
op
)
final
{
if
(
op
->
op
.
same_as
(
WithFuncIdOp
())
)
{
if
(
op
->
op
==
with_funcid_op
)
{
CHECK_EQ
(
op
->
args
.
size
(),
1
);
return
VisitExpr
(
op
->
args
[
0
]);
}
else
{
...
...
src/relay/pass/quantize/calibrate.cc
View file @
475158f6
...
...
@@ -25,15 +25,17 @@
*/
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
#include "./quantize.h"
namespace
tvm
{
namespace
relay
{
namespace
quantize
{
class
StatsCollector
:
private
ExprMutator
{
public
:
StatsCollector
()
:
simulated_quantize_op_
(
Op
::
Get
(
"relay.op.annotation.simulated_quantize"
))
{}
Expr
Collect
(
const
Expr
&
expr
)
{
auto
new_e
=
this
->
Mutate
(
expr
);
const
FunctionNode
*
func
=
new_e
.
as
<
FunctionNode
>
();
...
...
@@ -45,13 +47,13 @@ class StatsCollector : private ExprMutator {
private
:
Array
<
Expr
>
profile_data_
;
const
Op
&
simulated_quantize_op_
;
Expr
VisitExpr_
(
const
CallNode
*
call
)
{
static
const
Op
&
simulated_quantize
=
Op
::
Get
(
"relay.op.annotation.simulated_quantize"
);
Expr
new_e
=
ExprMutator
::
VisitExpr_
(
call
);
const
CallNode
*
new_call
=
new_e
.
as
<
CallNode
>
();
CHECK
(
new_call
);
if
(
new_call
->
op
.
same_as
(
simulated_quantize
)
)
{
if
(
new_call
->
op
==
simulated_quantize_op_
)
{
auto
attrs
=
new_call
->
attrs
.
as
<
SimulatedQuantizeAttrs
>
();
// rewrite the annotation
auto
new_attrs
=
make_node
<
SimulatedQuantizeAttrs
>
();
...
...
src/relay/pass/simplify_inference.cc
View file @
475158f6
...
...
@@ -91,7 +91,6 @@ Expr LayerNormToInferUnpack(const Attrs attrs,
return
out
;
}
Expr
InstanceNormToInferUnpack
(
const
Attrs
attrs
,
Expr
data
,
Expr
gamma
,
...
...
@@ -125,23 +124,25 @@ Expr InstanceNormToInferUnpack(const Attrs attrs,
return
out
;
}
class
InferenceSimplifier
:
public
ExprMutator
{
public
:
Expr
VisitExpr_
(
const
TupleGetItemNode
*
n
)
final
{
static
const
Op
&
batch_norm
=
Op
::
Get
(
"nn.batch_norm"
);
static
const
Op
&
dropout
=
Op
::
Get
(
"nn.dropout"
);
InferenceSimplifier
()
:
batch_norm_op_
(
Op
::
Get
(
"nn.batch_norm"
)),
dropout_op_
(
Op
::
Get
(
"nn.dropout"
)),
instance_norm_op_
(
Op
::
Get
(
"nn.instance_norm"
)),
layer_norm_op_
(
Op
::
Get
(
"nn.layer_norm"
))
{}
Expr
VisitExpr_
(
const
TupleGetItemNode
*
n
)
final
{
Expr
new_e
=
ExprMutator
::
VisitExpr_
(
n
);
const
auto
*
new_n
=
new_e
.
as
<
TupleGetItemNode
>
();
if
(
new_n
->
index
!=
0
)
{
return
new_e
;
}
if
(
const
auto
*
call
=
new_n
->
tuple
.
as
<
CallNode
>
())
{
if
(
call
->
op
.
same_as
(
batch_norm
)
)
{
if
(
call
->
op
==
batch_norm_op_
)
{
return
BatchNormToInferUnpack
(
call
->
attrs
,
call
->
args
[
0
],
call
->
args
[
1
],
call
->
args
[
2
],
call
->
args
[
3
],
call
->
args
[
4
],
ty_map_
.
at
(
call
->
args
[
0
]));
}
else
if
(
call
->
op
.
same_as
(
dropout
)
)
{
}
else
if
(
call
->
op
==
dropout_op_
)
{
return
call
->
args
[
0
];
}
}
...
...
@@ -149,17 +150,14 @@ class InferenceSimplifier : public ExprMutator {
}
Expr
VisitExpr_
(
const
CallNode
*
n
)
{
static
const
Op
&
batch_norm
=
Op
::
Get
(
"nn.batch_norm"
);
static
const
Op
&
instance_norm
=
Op
::
Get
(
"nn.instance_norm"
);
static
const
Op
&
layer_norm
=
Op
::
Get
(
"nn.layer_norm"
);
auto
new_n
=
ExprMutator
::
VisitExpr_
(
n
);
if
(
n
->
op
.
same_as
(
batch_norm
)
)
{
if
(
n
->
op
==
batch_norm_op_
)
{
ty_map_
[
new_n
.
as
<
CallNode
>
()
->
args
[
0
]]
=
n
->
args
[
0
]
->
checked_type
();
}
else
if
(
n
->
op
.
same_as
(
layer_norm
)
)
{
}
else
if
(
n
->
op
==
layer_norm_op_
)
{
const
auto
*
call
=
new_n
.
as
<
CallNode
>
();
return
LayerNormToInferUnpack
(
call
->
attrs
,
call
->
args
[
0
],
call
->
args
[
1
],
call
->
args
[
2
],
n
->
args
[
0
]
->
checked_type
());
}
else
if
(
n
->
op
.
same_as
(
instance_norm
)
)
{
}
else
if
(
n
->
op
==
instance_norm_op_
)
{
const
auto
*
call
=
new_n
.
as
<
CallNode
>
();
return
InstanceNormToInferUnpack
(
call
->
attrs
,
call
->
args
[
0
],
call
->
args
[
1
],
call
->
args
[
2
],
n
->
args
[
0
]
->
checked_type
());
...
...
@@ -168,6 +166,13 @@ class InferenceSimplifier : public ExprMutator {
}
private
:
// Cache the following ops. They will be used in the passes repeatedly for
// operator equivalence checking so that the registry lookup overhead can be
// reduced.
const
Op
&
batch_norm_op_
;
const
Op
&
dropout_op_
;
const
Op
&
instance_norm_op_
;
const
Op
&
layer_norm_op_
;
std
::
unordered_map
<
Expr
,
Type
,
NodeHash
,
NodeEqual
>
ty_map_
;
};
...
...
src/relay/pass/util.cc
View file @
475158f6
...
...
@@ -25,6 +25,7 @@
*/
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
#include <tvm/relay/pattern_functor.h>
#include "pass_util.h"
#include "../ir/type_functor.h"
...
...
@@ -360,13 +361,14 @@ bool IsNDArrayAllGreaterEqual(const runtime::NDArray& tensor, T value) {
return
true
;
}
// Cache the operators that are checked recursively to reduce lookup overhead.
static
const
auto
&
expand_dims_op
=
Op
::
Get
(
"expand_dims"
);
static
const
auto
&
reshape_op
=
Op
::
Get
(
"reshape"
);
static
const
auto
&
transpose_op
=
Op
::
Get
(
"transpose"
);
static
const
auto
&
squeeze_op
=
Op
::
Get
(
"squeeze"
);
bool
IsAllPositiveConstant
(
const
Expr
&
expr
)
{
// peel through a few common transform ops.
static
const
auto
&
expand_dims
=
Op
::
Get
(
"expand_dims"
);
static
const
auto
&
reshape
=
Op
::
Get
(
"reshape"
);
static
const
auto
&
transpose
=
Op
::
Get
(
"transpose"
);
static
const
auto
&
squeeze
=
Op
::
Get
(
"squeeze"
);
if
(
const
auto
*
constant
=
expr
.
as
<
ConstantNode
>
())
{
const
auto
&
tensor
=
constant
->
data
;
const
auto
&
dtype
=
tensor
->
dtype
;
...
...
@@ -389,10 +391,10 @@ bool IsAllPositiveConstant(const Expr& expr) {
}
}
else
if
(
const
auto
*
op
=
expr
.
as
<
CallNode
>
())
{
// tail recursion.
if
(
op
->
op
.
same_as
(
expand_dims
)
||
op
->
op
.
same_as
(
reshape
)
||
op
->
op
.
same_as
(
transpose
)
||
op
->
op
.
same_as
(
squeeze
)
)
{
if
(
op
->
op
==
expand_dims_op
||
op
->
op
==
reshape_op
||
op
->
op
==
transpose_op
||
op
->
op
==
squeeze_op
)
{
return
IsAllPositiveConstant
(
op
->
args
[
0
]);
}
else
{
return
false
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment