Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
475158f6
Commit
475158f6
authored
Dec 30, 2019
by
Zhi
Committed by
masahi
Dec 31, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[relay][refactor] Cache Op::Get in passes to reduce lookup overhead (#4594)
* Refactor to use IsOp utility * retrigger CI
parent
35af4c8b
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
133 additions
and
95 deletions
+133
-95
include/tvm/relay/op.h
+2
-3
src/relay/backend/compile_engine.cc
+9
-6
src/relay/backend/interpreter.cc
+14
-9
src/relay/pass/canonicalize_cast.cc
+8
-4
src/relay/pass/canonicalize_ops.cc
+8
-2
src/relay/pass/combine_parallel_op.cc
+12
-12
src/relay/pass/combine_parallel_op.h
+6
-6
src/relay/pass/fold_constant.cc
+23
-10
src/relay/pass/fuse_ops.cc
+3
-2
src/relay/pass/partial_eval.cc
+14
-16
src/relay/pass/quantize/calibrate.cc
+5
-3
src/relay/pass/simplify_inference.cc
+18
-13
src/relay/pass/util.cc
+11
-9
No files found.
include/tvm/relay/op.h
View file @
475158f6
...
@@ -594,12 +594,11 @@ inline ValueType OpMap<ValueType>::get(const Expr& expr,
...
@@ -594,12 +594,11 @@ inline ValueType OpMap<ValueType>::get(const Expr& expr,
return
map_
.
get
<
ValueType
>
(
expr
,
def_value
);
return
map_
.
get
<
ValueType
>
(
expr
,
def_value
);
}
}
/*!
/*!
* \brief Check that an expression is a "primtive operator".
* \brief Check that an expression is a "prim
i
tive operator".
*
*
* Will return true if the expression is an operator which
* Will return true if the expression is an operator which
* matches the form of primtive operators registered directly
* matches the form of prim
i
tive operators registered directly
* by the Relay codebase.
* by the Relay codebase.
*
*
* That is the arguments are all type variables, and there is a single
* That is the arguments are all type variables, and there is a single
...
...
src/relay/backend/compile_engine.cc
View file @
475158f6
...
@@ -21,6 +21,8 @@
...
@@ -21,6 +21,8 @@
* \file relay/backend/compile_engine.cc
* \file relay/backend/compile_engine.cc
* \brief Internal compialtion engine.
* \brief Internal compialtion engine.
*/
*/
#include "compile_engine.h"
#include <tvm/schedule.h>
#include <tvm/schedule.h>
#include <tvm/packed_func_ext.h>
#include <tvm/packed_func_ext.h>
#include <tvm/operation.h>
#include <tvm/operation.h>
...
@@ -29,6 +31,7 @@
...
@@ -29,6 +31,7 @@
#include <tvm/relay/analysis.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/op_attr_types.h>
#include <topi/tags.h>
#include <topi/tags.h>
#include <utility>
#include <utility>
...
@@ -38,7 +41,6 @@
...
@@ -38,7 +41,6 @@
#include <vector>
#include <vector>
#include <unordered_map>
#include <unordered_map>
#include "../ir/type_functor.h"
#include "../ir/type_functor.h"
#include "compile_engine.h"
namespace
tvm
{
namespace
tvm
{
namespace
relay
{
namespace
relay
{
...
@@ -102,7 +104,7 @@ class ScheduleGetter :
...
@@ -102,7 +104,7 @@ class ScheduleGetter :
public
ExprFunctor
<
Array
<
Tensor
>
(
const
Expr
&
)
>
{
public
ExprFunctor
<
Array
<
Tensor
>
(
const
Expr
&
)
>
{
public
:
public
:
explicit
ScheduleGetter
(
Target
target
)
explicit
ScheduleGetter
(
Target
target
)
:
target_
(
target
)
{}
:
target_
(
target
)
,
device_copy_op_
(
Op
::
Get
(
"device_copy"
))
{}
std
::
pair
<
Schedule
,
CachedFunc
>
Create
(
const
Function
&
prim_func
)
{
std
::
pair
<
Schedule
,
CachedFunc
>
Create
(
const
Function
&
prim_func
)
{
static
auto
fschedule
=
static
auto
fschedule
=
...
@@ -250,11 +252,9 @@ class ScheduleGetter :
...
@@ -250,11 +252,9 @@ class ScheduleGetter :
CHECK
(
call_node
->
op
.
as
<
OpNode
>
())
CHECK
(
call_node
->
op
.
as
<
OpNode
>
())
<<
"Primitive function only allows call into primitive ops"
;
<<
"Primitive function only allows call into primitive ops"
;
Op
op
=
Downcast
<
Op
>
(
call_node
->
op
);
Op
op
=
Downcast
<
Op
>
(
call_node
->
op
);
// Check if the op is a device copy op.
bool
is_copy_op
=
op
.
same_as
(
Op
::
Get
(
"device_copy"
));
Array
<
Tensor
>
outputs
;
Array
<
Tensor
>
outputs
;
// Skip fcompute for device copy operators as it is not registered.
// Skip fcompute for device copy operators as it is not registered.
if
(
is_copy_op
)
{
if
(
op
==
device_copy_op_
)
{
const
auto
*
copy_input
=
inputs
[
0
].
operator
->
();
const
auto
*
copy_input
=
inputs
[
0
].
operator
->
();
outputs
.
push_back
(
TensorNode
::
make
(
copy_input
->
shape
,
copy_input
->
dtype
,
outputs
.
push_back
(
TensorNode
::
make
(
copy_input
->
shape
,
copy_input
->
dtype
,
Operation
(),
0
));
Operation
(),
0
));
...
@@ -282,7 +282,7 @@ class ScheduleGetter :
...
@@ -282,7 +282,7 @@ class ScheduleGetter :
}
}
// Set the name to `__copy`. It will be detected in graph runtime to perform
// Set the name to `__copy`. It will be detected in graph runtime to perform
// data copy across devices.
// data copy across devices.
if
(
is_copy_op
)
{
if
(
op
==
device_copy_op_
)
{
readable_name_stream_
.
str
(
std
::
string
());
readable_name_stream_
.
str
(
std
::
string
());
readable_name_stream_
<<
"__copy"
;
readable_name_stream_
<<
"__copy"
;
}
else
{
}
else
{
...
@@ -332,6 +332,9 @@ class ScheduleGetter :
...
@@ -332,6 +332,9 @@ class ScheduleGetter :
std
::
ostringstream
readable_name_stream_
;
std
::
ostringstream
readable_name_stream_
;
std
::
unordered_map
<
Expr
,
Array
<
Tensor
>
,
NodeHash
,
NodeEqual
>
memo_
;
std
::
unordered_map
<
Expr
,
Array
<
Tensor
>
,
NodeHash
,
NodeEqual
>
memo_
;
Array
<
Operation
>
scalars_
;
Array
<
Operation
>
scalars_
;
// Cache device copy op for equivalence checking to reduce registry lookup
// overhead for each invocation of call node when retrieving schedules.
const
Op
&
device_copy_op_
;
};
};
// Creates shape function from functor.
// Creates shape function from functor.
...
...
src/relay/backend/interpreter.cc
View file @
475158f6
...
@@ -246,10 +246,12 @@ class Interpreter :
...
@@ -246,10 +246,12 @@ class Interpreter :
public
ExprFunctor
<
Value
(
const
Expr
&
n
)
>
,
public
ExprFunctor
<
Value
(
const
Expr
&
n
)
>
,
PatternFunctor
<
bool
(
const
Pattern
&
p
,
const
Value
&
v
)
>
{
PatternFunctor
<
bool
(
const
Pattern
&
p
,
const
Value
&
v
)
>
{
public
:
public
:
Interpreter
(
Module
mod
,
Interpreter
(
Module
mod
,
DLContext
context
,
Target
target
)
DLContext
context
,
:
mod_
(
mod
),
Target
target
)
context_
(
context
),
:
mod_
(
mod
),
context_
(
context
),
target_
(
target
)
{
target_
(
target
),
debug_op_
(
Op
::
Get
(
"debug"
)),
shape_of_op_
(
Op
::
Get
(
"shape_of"
))
{
engine_
=
CompileEngine
::
Global
();
engine_
=
CompileEngine
::
Global
();
}
}
...
@@ -263,7 +265,7 @@ class Interpreter :
...
@@ -263,7 +265,7 @@ class Interpreter :
stack_
.
current_frame
().
locals
.
Set
(
id
,
v
);
stack_
.
current_frame
().
locals
.
Set
(
id
,
v
);
}
}
inline
Value
Lookup
(
const
Var
&
local
)
{
Value
Lookup
(
const
Var
&
local
)
{
return
stack_
.
Lookup
(
local
);
return
stack_
.
Lookup
(
local
);
}
}
...
@@ -307,7 +309,7 @@ class Interpreter :
...
@@ -307,7 +309,7 @@ class Interpreter :
return
TupleValueNode
::
make
(
values
);
return
TupleValueNode
::
make
(
values
);
}
}
inline
Value
MakeClosure
(
const
Function
&
func
,
Var
letrec_name
=
Var
())
{
Value
MakeClosure
(
const
Function
&
func
,
Var
letrec_name
=
Var
())
{
tvm
::
Map
<
Var
,
Value
>
captured_mod
;
tvm
::
Map
<
Var
,
Value
>
captured_mod
;
Array
<
Var
>
free_vars
=
FreeVars
(
func
);
Array
<
Var
>
free_vars
=
FreeVars
(
func
);
...
@@ -454,9 +456,9 @@ class Interpreter :
...
@@ -454,9 +456,9 @@ class Interpreter :
Value
InvokePrimitiveOp
(
const
Function
&
func
,
Value
InvokePrimitiveOp
(
const
Function
&
func
,
const
Array
<
Value
>&
args
)
{
const
Array
<
Value
>&
args
)
{
auto
call_node
=
func
->
body
.
as
<
CallNode
>
();
const
auto
*
call_node
=
func
->
body
.
as
<
CallNode
>
();
if
(
call_node
&&
call_node
->
op
==
Op
::
Get
(
"debug"
)
)
{
if
(
call_node
&&
call_node
->
op
==
debug_op_
)
{
auto
dattrs
=
call_node
->
attrs
.
as
<
DebugAttrs
>
();
auto
dattrs
=
call_node
->
attrs
.
as
<
DebugAttrs
>
();
auto
interp_state
=
this
->
get_state
(
call_node
->
args
[
0
]);
auto
interp_state
=
this
->
get_state
(
call_node
->
args
[
0
]);
...
@@ -540,7 +542,7 @@ class Interpreter :
...
@@ -540,7 +542,7 @@ class Interpreter :
Array
<
Shape
>
out_shapes
;
Array
<
Shape
>
out_shapes
;
auto
ret_type
=
func
->
body
->
checked_type
();
auto
ret_type
=
func
->
body
->
checked_type
();
bool
is_dyn
=
IsDynamic
(
func
->
checked_type
());
bool
is_dyn
=
IsDynamic
(
func
->
checked_type
());
if
(
call_node
->
op
==
Op
::
Get
(
"shape_of"
)
)
{
if
(
call_node
->
op
==
shape_of_op_
)
{
// The output shape of shape_of must be static since Relay doesn't support
// The output shape of shape_of must be static since Relay doesn't support
// dynamic rank tensors.
// dynamic rank tensors.
is_dyn
=
false
;
is_dyn
=
false
;
...
@@ -782,6 +784,9 @@ class Interpreter :
...
@@ -782,6 +784,9 @@ class Interpreter :
Stack
stack_
;
Stack
stack_
;
// Backend compile engine.
// Backend compile engine.
CompileEngine
engine_
;
CompileEngine
engine_
;
// Cache ops that need to be frequently used later to reduce lookup overhead.
const
Op
&
debug_op_
;
const
Op
&
shape_of_op_
;
};
};
...
...
src/relay/pass/canonicalize_cast.cc
View file @
475158f6
...
@@ -62,6 +62,8 @@ namespace relay {
...
@@ -62,6 +62,8 @@ namespace relay {
// \endcode
// \endcode
class
CastCanonicalizer
:
public
ExprMutator
{
class
CastCanonicalizer
:
public
ExprMutator
{
public
:
public
:
CastCanonicalizer
()
:
cast_op_
(
Op
::
Get
(
"cast"
))
{}
Expr
VisitExpr_
(
const
CallNode
*
call
)
{
Expr
VisitExpr_
(
const
CallNode
*
call
)
{
static
auto
fpattern
=
Op
::
GetAttr
<
TOpPattern
>
(
"TOpPattern"
);
static
auto
fpattern
=
Op
::
GetAttr
<
TOpPattern
>
(
"TOpPattern"
);
...
@@ -91,15 +93,17 @@ class CastCanonicalizer : public ExprMutator {
...
@@ -91,15 +93,17 @@ class CastCanonicalizer : public ExprMutator {
private
:
private
:
std
::
unordered_map
<
const
Node
*
,
size_t
>
ref_counter_
;
std
::
unordered_map
<
const
Node
*
,
size_t
>
ref_counter_
;
// cast op is frequently checked for equivalence. Therefore, we cache it to
// reduce lookup overhead.
const
Op
&
cast_op_
;
Expr
GetNewCallArg
(
const
Expr
&
e
)
{
Expr
GetNewCallArg
(
const
Expr
&
e
)
{
// if e is a upcast and ref count > 1, create an copy; otherwise call the default visitor
// if e is a upcast and ref count > 1, create an copy; otherwise call the default visitor
static
auto
&
cast
=
Op
::
Get
(
"cast"
);
Expr
new_expr
=
this
->
VisitExpr
(
e
);
Expr
new_expr
=
this
->
VisitExpr
(
e
);
if
(
const
CallNode
*
call
=
e
.
as
<
CallNode
>
())
{
if
(
const
CallNode
*
call
=
e
.
as
<
CallNode
>
())
{
if
(
call
->
op
.
same_as
(
cast
)
)
{
if
(
call
->
op
==
cast_op_
)
{
auto
attrs
=
call
->
attrs
.
as
<
CastAttrs
>
();
auto
attrs
=
call
->
attrs
.
as
<
CastAttrs
>
();
const
auto
*
from_type
=
call
->
args
[
0
]
->
type_as
<
TensorTypeNode
>
();
const
auto
*
from_type
=
call
->
args
[
0
]
->
type_as
<
TensorTypeNode
>
();
CHECK
(
from_type
);
CHECK
(
from_type
);
...
@@ -108,7 +112,7 @@ class CastCanonicalizer : public ExprMutator {
...
@@ -108,7 +112,7 @@ class CastCanonicalizer : public ExprMutator {
if
(
++
ref_counter_
[
call
]
>
1
)
{
if
(
++
ref_counter_
[
call
]
>
1
)
{
const
CallNode
*
new_call
=
new_expr
.
as
<
CallNode
>
();
const
CallNode
*
new_call
=
new_expr
.
as
<
CallNode
>
();
CHECK
(
new_call
);
CHECK
(
new_call
);
CHECK
(
new_call
->
op
.
same_as
(
cast
)
);
CHECK
(
new_call
->
op
==
cast_op_
);
return
CallNode
::
make
(
new_call
->
op
,
new_call
->
args
,
new_call
->
attrs
,
return
CallNode
::
make
(
new_call
->
op
,
new_call
->
args
,
new_call
->
attrs
,
new_call
->
type_args
);
new_call
->
type_args
);
}
}
...
...
src/relay/pass/canonicalize_ops.cc
View file @
475158f6
...
@@ -24,6 +24,7 @@
...
@@ -24,6 +24,7 @@
*/
*/
#include <tvm/relay/analysis.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/transform.h>
#include "pattern_util.h"
#include "pattern_util.h"
...
@@ -33,10 +34,11 @@ namespace relay {
...
@@ -33,10 +34,11 @@ namespace relay {
class
BiasAddSimplifier
:
public
ExprMutator
{
class
BiasAddSimplifier
:
public
ExprMutator
{
public
:
public
:
BiasAddSimplifier
()
:
bias_add_op_
(
Op
::
Get
(
"nn.bias_add"
))
{}
Expr
VisitExpr_
(
const
CallNode
*
n
)
{
Expr
VisitExpr_
(
const
CallNode
*
n
)
{
static
const
Op
&
bias_add
=
Op
::
Get
(
"nn.bias_add"
);
auto
new_n
=
ExprMutator
::
VisitExpr_
(
n
);
auto
new_n
=
ExprMutator
::
VisitExpr_
(
n
);
if
(
n
->
op
.
same_as
(
bias_add
)
)
{
if
(
n
->
op
==
bias_add_op_
)
{
Call
call
=
Downcast
<
Call
>
(
new_n
);
Call
call
=
Downcast
<
Call
>
(
new_n
);
CHECK_EQ
(
call
->
args
.
size
(),
2
);
CHECK_EQ
(
call
->
args
.
size
(),
2
);
const
BiasAddAttrs
*
param
=
call
->
attrs
.
as
<
BiasAddAttrs
>
();
const
BiasAddAttrs
*
param
=
call
->
attrs
.
as
<
BiasAddAttrs
>
();
...
@@ -54,6 +56,10 @@ class BiasAddSimplifier : public ExprMutator {
...
@@ -54,6 +56,10 @@ class BiasAddSimplifier : public ExprMutator {
}
}
return
new_n
;
return
new_n
;
}
}
private
:
// Cache the bias_add for equivalence checking.
const
Op
&
bias_add_op_
;
};
};
Expr
CanonicalizeOps
(
const
Expr
&
e
)
{
Expr
CanonicalizeOps
(
const
Expr
&
e
)
{
...
...
src/relay/pass/combine_parallel_op.cc
View file @
475158f6
...
@@ -27,29 +27,30 @@
...
@@ -27,29 +27,30 @@
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/transform.h>
#include <algorithm>
#include <utility>
#include <unordered_map>
#include <unordered_map>
#include <unordered_set>
#include <unordered_set>
#include "
./
expr_subst.h"
#include "expr_subst.h"
#include "
./
pattern_util.h"
#include "pattern_util.h"
#include "
./
combine_parallel_op.h"
#include "combine_parallel_op.h"
namespace
tvm
{
namespace
tvm
{
namespace
relay
{
namespace
relay
{
BranchGroupFinder
::
BranchGroupFinder
(
const
std
::
string
&
op_name
,
BranchGroupFinder
::
BranchGroupFinder
(
const
Op
&
op
,
FIsSupportedOp
fis_supported_op
,
FIsSupportedOp
fis_supported_op
,
FAreCompatibleOps
fare_compatible_ops
)
FAreCompatibleOps
fare_compatible_ops
)
:
op_name_
(
op_name
),
:
cached_op_
(
op
),
fis_supported_op_
(
fis_supported_op
),
fis_supported_op_
(
fis_supported_op
),
fare_compatible_ops_
(
fare_compatible_ops
)
{
fare_compatible_ops_
(
fare_compatible_ops
)
{
}
}
std
::
vector
<
Group
>
BranchGroupFinder
::
Find
(
const
Expr
&
expr
)
{
std
::
vector
<
Group
>
BranchGroupFinder
::
Find
(
const
Expr
&
expr
)
{
const
Op
&
op
=
Op
::
Get
(
op_name_
);
this
->
VisitExpr
(
expr
);
this
->
VisitExpr
(
expr
);
std
::
vector
<
Group
>
groups
;
std
::
vector
<
Group
>
groups
;
...
@@ -57,7 +58,7 @@ std::vector<Group> BranchGroupFinder::Find(const Expr& expr) {
...
@@ -57,7 +58,7 @@ std::vector<Group> BranchGroupFinder::Find(const Expr& expr) {
const
auto
&
children
=
children_map_
.
at
(
root
);
const
auto
&
children
=
children_map_
.
at
(
root
);
size_t
ngroups
=
groups
.
size
();
size_t
ngroups
=
groups
.
size
();
for
(
const
CallNode
*
child
:
children
)
{
for
(
const
CallNode
*
child
:
children
)
{
if
(
!
child
->
op
.
same_as
(
op
)
)
continue
;
if
(
child
->
op
!=
cached_op_
)
continue
;
auto
&&
branch
=
CreateBranch
(
child
);
auto
&&
branch
=
CreateBranch
(
child
);
// add the branch to a group, or create a new group
// add the branch to a group, or create a new group
...
@@ -97,9 +98,8 @@ Branch BranchGroupFinder::CreateBranch(const CallNode* op) {
...
@@ -97,9 +98,8 @@ Branch BranchGroupFinder::CreateBranch(const CallNode* op) {
}
}
void
BranchGroupFinder
::
VisitExpr_
(
const
CallNode
*
n
)
{
void
BranchGroupFinder
::
VisitExpr_
(
const
CallNode
*
n
)
{
const
Op
&
op
=
Op
::
Get
(
op_name_
);
ExprVisitor
::
VisitExpr_
(
n
);
ExprVisitor
::
VisitExpr_
(
n
);
if
(
n
->
op
.
same_as
(
op
)
&&
fis_supported_op_
(
n
))
{
if
(
n
->
op
==
cached_op_
&&
fis_supported_op_
(
n
))
{
op_roots_
.
insert
(
n
->
args
[
0
]);
op_roots_
.
insert
(
n
->
args
[
0
]);
children_map_
[
n
->
args
[
0
]].
push_back
(
n
);
children_map_
[
n
->
args
[
0
]].
push_back
(
n
);
}
else
{
}
else
{
...
@@ -110,12 +110,12 @@ void BranchGroupFinder::VisitExpr_(const CallNode* n) {
...
@@ -110,12 +110,12 @@ void BranchGroupFinder::VisitExpr_(const CallNode* n) {
}
}
ParallelOpCombiner
::
ParallelOpCombiner
(
const
std
::
string
&
op_name
,
uint64_t
min_num_branches
)
ParallelOpCombiner
::
ParallelOpCombiner
(
const
std
::
string
&
op_name
,
uint64_t
min_num_branches
)
:
op_name_
(
op_name
),
:
cached_op_
(
Op
::
Get
(
op_name
)
),
min_num_branches_
(
min_num_branches
)
{
min_num_branches_
(
min_num_branches
)
{
}
}
Expr
ParallelOpCombiner
::
Combine
(
const
Expr
&
expr
)
{
Expr
ParallelOpCombiner
::
Combine
(
const
Expr
&
expr
)
{
auto
groups
=
BranchGroupFinder
(
op_name
_
,
auto
groups
=
BranchGroupFinder
(
cached_op
_
,
[
&
](
const
CallNode
*
n
)
{
[
&
](
const
CallNode
*
n
)
{
return
IsSupportedOp
(
n
);
return
IsSupportedOp
(
n
);
},
},
...
...
src/relay/pass/combine_parallel_op.h
View file @
475158f6
...
@@ -68,13 +68,13 @@ class BranchGroupFinder : private ExprVisitor {
...
@@ -68,13 +68,13 @@ class BranchGroupFinder : private ExprVisitor {
public
:
public
:
/*
/*
* \brief Constructor
* \brief Constructor
* \param op
_name name of op to start
each group
* \param op
The op that indicates the start of
each group
* \param fis_supported_op function that returns true if op
* \param fis_supported_op function that returns true if op
* is supported for combining
* is supported for combining
* \param fare_compatible_ops function that returns true if
* \param fare_compatible_ops function that returns true if
* two ops are compatible for combining
* two ops are compatible for combining
*/
*/
BranchGroupFinder
(
const
std
::
string
&
op_name
,
BranchGroupFinder
(
const
Op
&
op
,
FIsSupportedOp
fis_supported_op
,
FIsSupportedOp
fis_supported_op
,
FAreCompatibleOps
fare_compatible_ops
);
FAreCompatibleOps
fare_compatible_ops
);
...
@@ -87,8 +87,8 @@ class BranchGroupFinder : private ExprVisitor {
...
@@ -87,8 +87,8 @@ class BranchGroupFinder : private ExprVisitor {
std
::
vector
<
Group
>
Find
(
const
Expr
&
expr
);
std
::
vector
<
Group
>
Find
(
const
Expr
&
expr
);
private
:
private
:
/* \brief
name of op to find parallel branches for
*/
/* \brief
Cache the op for finding parallel branches
*/
std
::
string
op_name
_
;
const
Op
&
cached_op
_
;
/* \brief function to return true if op is eligible to be combined,
/* \brief function to return true if op is eligible to be combined,
* false otherwise
* false otherwise
...
@@ -205,8 +205,8 @@ class ParallelOpCombiner {
...
@@ -205,8 +205,8 @@ class ParallelOpCombiner {
ExprSubstMap
*
subst_map
)
=
0
;
ExprSubstMap
*
subst_map
)
=
0
;
private
:
private
:
/* \brief
name of
op to be combined */
/* \brief
Cache the
op to be combined */
std
::
string
op_name
_
;
const
Op
&
cached_op
_
;
/* \brief minimum number of parallel branches to combine */
/* \brief minimum number of parallel branches to combine */
uint64_t
min_num_branches_
;
uint64_t
min_num_branches_
;
...
...
src/relay/pass/fold_constant.cc
View file @
475158f6
...
@@ -22,6 +22,7 @@
...
@@ -22,6 +22,7 @@
*/
*/
#include <tvm/relay/analysis.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/interpreter.h>
#include <tvm/relay/interpreter.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/attrs/transform.h>
...
@@ -33,7 +34,6 @@ namespace relay {
...
@@ -33,7 +34,6 @@ namespace relay {
using
FInterpreter
=
runtime
::
TypedPackedFunc
<
Value
(
Expr
)
>
;
using
FInterpreter
=
runtime
::
TypedPackedFunc
<
Value
(
Expr
)
>
;
class
ConstantChecker
:
private
ExprVisitor
{
class
ConstantChecker
:
private
ExprVisitor
{
public
:
public
:
// Check whether an expression is constant. The results are memoized.
// Check whether an expression is constant. The results are memoized.
...
@@ -78,8 +78,14 @@ TVM_REGISTER_API("relay._analysis.check_constant")
...
@@ -78,8 +78,14 @@ TVM_REGISTER_API("relay._analysis.check_constant")
class
ConstantFolder
:
public
ExprMutator
{
class
ConstantFolder
:
public
ExprMutator
{
public
:
public
:
explicit
ConstantFolder
(
FInterpreter
executor
,
Module
module
)
explicit
ConstantFolder
(
FInterpreter
executor
,
Module
module
)
:
executor_
(
executor
),
module_
(
module
)
{
:
executor_
(
executor
),
}
module_
(
module
),
shape_of_op_
(
Op
::
Get
(
"shape_of"
)),
invoke_tvm_op_
(
Op
::
Get
(
"memory.invoke_tvm_op"
)),
shape_func_op_
(
Op
::
Get
(
"memory.shape_func"
)),
alloc_tensor_op_
(
Op
::
Get
(
"memory.alloc_tensor"
)),
alloc_storage_op_
(
Op
::
Get
(
"memory.alloc_storage"
)),
cast_op_
(
Op
::
Get
(
"cast"
))
{}
Expr
VisitExpr_
(
const
LetNode
*
op
)
final
{
Expr
VisitExpr_
(
const
LetNode
*
op
)
final
{
Expr
value
=
this
->
Mutate
(
op
->
value
);
Expr
value
=
this
->
Mutate
(
op
->
value
);
...
@@ -119,15 +125,15 @@ class ConstantFolder : public ExprMutator {
...
@@ -119,15 +125,15 @@ class ConstantFolder : public ExprMutator {
// skip stateful ops.
// skip stateful ops.
if
(
op_stateful
.
get
(
GetRef
<
Op
>
(
op
),
false
))
return
res
;
if
(
op_stateful
.
get
(
GetRef
<
Op
>
(
op
),
false
))
return
res
;
// Try to evaluate shape_of op
// Try to evaluate shape_of op
if
(
call
->
op
.
same_as
(
Op
::
Get
(
"shape_of"
))
)
{
if
(
call
->
op
==
shape_of_op_
)
{
return
EvaluateShapeOf
(
res
,
origin_args
,
call
->
attrs
);
return
EvaluateShapeOf
(
res
,
origin_args
,
call
->
attrs
);
}
}
// We should think about potentially constant evaluation over these ops too.
// We should think about potentially constant evaluation over these ops too.
if
(
call
->
op
.
same_as
(
Op
::
Get
(
"memory.invoke_tvm_op"
))
||
if
(
call
->
op
==
invoke_tvm_op_
||
call
->
op
.
same_as
(
Op
::
Get
(
"memory.shape_func"
))
||
call
->
op
==
shape_func_op_
||
call
->
op
.
same_as
(
Op
::
Get
(
"memory.alloc_tensor"
))
||
call
->
op
==
alloc_tensor_op_
||
call
->
op
.
same_as
(
Op
::
Get
(
"memory.alloc_storage"
))
)
{
call
->
op
==
alloc_storage_op_
)
{
return
GetRef
<
Call
>
(
call
);
return
GetRef
<
Call
>
(
call
);
}
}
...
@@ -162,6 +168,14 @@ class ConstantFolder : public ExprMutator {
...
@@ -162,6 +168,14 @@ class ConstantFolder : public ExprMutator {
// Module
// Module
Module
module_
;
Module
module_
;
// Cache the following ops for equivalence checking in this pass.
const
Op
&
shape_of_op_
;
const
Op
&
invoke_tvm_op_
;
const
Op
&
shape_func_op_
;
const
Op
&
alloc_tensor_op_
;
const
Op
&
alloc_storage_op_
;
const
Op
&
cast_op_
;
// Convert value to expression.
// Convert value to expression.
Expr
ValueToExpr
(
Value
value
)
{
Expr
ValueToExpr
(
Value
value
)
{
if
(
const
auto
*
val
=
value
.
as
<
TensorValueNode
>
())
{
if
(
const
auto
*
val
=
value
.
as
<
TensorValueNode
>
())
{
...
@@ -254,8 +268,7 @@ class ConstantFolder : public ExprMutator {
...
@@ -254,8 +268,7 @@ class ConstantFolder : public ExprMutator {
// Cast the constant into correct dtype
// Cast the constant into correct dtype
auto
cast_attrs
=
make_node
<
CastAttrs
>
();
auto
cast_attrs
=
make_node
<
CastAttrs
>
();
cast_attrs
->
dtype
=
param
->
dtype
;
cast_attrs
->
dtype
=
param
->
dtype
;
static
const
Op
&
cast_op
=
Op
::
Get
(
"cast"
);
Expr
ret
=
CallNode
::
make
(
cast_op_
,
{
shape
},
Attrs
(
cast_attrs
),
{});
Expr
ret
=
CallNode
::
make
(
cast_op
,
{
shape
},
Attrs
(
cast_attrs
),
{});
return
ConstEvaluate
(
ret
);
return
ConstEvaluate
(
ret
);
}
}
};
};
...
...
src/relay/pass/fuse_ops.cc
View file @
475158f6
...
@@ -78,6 +78,8 @@ using common::LinkedList;
...
@@ -78,6 +78,8 @@ using common::LinkedList;
constexpr
uint32_t
kMaxFusedOps
=
256
;
constexpr
uint32_t
kMaxFusedOps
=
256
;
static
const
Op
&
stop_fusion_op
=
Op
::
Get
(
"annotation.stop_fusion"
);
/*!
/*!
* \brief Indexed data flow graph in forward direction.
* \brief Indexed data flow graph in forward direction.
* This is a temporary data structure used for operator fusion analysis.
* This is a temporary data structure used for operator fusion analysis.
...
@@ -860,7 +862,6 @@ class FuseMutator : private ExprMutator {
...
@@ -860,7 +862,6 @@ class FuseMutator : private ExprMutator {
// Transform calls.
// Transform calls.
Expr
VisitExpr_
(
const
CallNode
*
call
)
{
Expr
VisitExpr_
(
const
CallNode
*
call
)
{
static
const
Op
&
stop_fusion
=
Op
::
Get
(
"annotation.stop_fusion"
);
if
(
call
->
op
.
as
<
OpNode
>
())
{
if
(
call
->
op
.
as
<
OpNode
>
())
{
static
auto
fnoncomputational
=
static
auto
fnoncomputational
=
Op
::
GetAttr
<
TNonComputational
>
(
"TNonComputational"
);
Op
::
GetAttr
<
TNonComputational
>
(
"TNonComputational"
);
...
@@ -872,7 +873,7 @@ class FuseMutator : private ExprMutator {
...
@@ -872,7 +873,7 @@ class FuseMutator : private ExprMutator {
// If it is a primitive op call
// If it is a primitive op call
// then we must have a group assignment for it already.
// then we must have a group assignment for it already.
CHECK
(
gmap_
.
count
(
call
));
CHECK
(
gmap_
.
count
(
call
));
if
(
call
->
op
.
same_as
(
stop_fusion
)
)
{
if
(
call
->
op
==
stop_fusion_op
)
{
return
ExprMutator
::
VisitExpr
(
call
->
args
[
0
]);
return
ExprMutator
::
VisitExpr
(
call
->
args
[
0
]);
}
}
auto
*
ret_group
=
gmap_
.
at
(
call
)
->
FindRoot
();
auto
*
ret_group
=
gmap_
.
at
(
call
)
->
FindRoot
();
...
...
src/relay/pass/partial_eval.cc
View file @
475158f6
...
@@ -559,30 +559,28 @@ struct WithFuncIdAttrs : public tvm::AttrsNode<WithFuncIdAttrs> {
...
@@ -559,30 +559,28 @@ struct WithFuncIdAttrs : public tvm::AttrsNode<WithFuncIdAttrs> {
TVM_REGISTER_NODE_TYPE
(
WithFuncIdAttrs
);
TVM_REGISTER_NODE_TYPE
(
WithFuncIdAttrs
);
Op
WithFuncIdOp
()
{
static
const
Op
&
op
=
Op
::
Get
(
"annotation.with_funcid"
);
return
op
;
}
Expr
MkWithFuncId
(
const
Expr
&
expr
,
FuncId
fid
)
{
auto
attrs
=
make_node
<
WithFuncIdAttrs
>
();
attrs
->
fid
=
fid
;
return
CallNode
::
make
(
WithFuncIdOp
(),
{
expr
},
Attrs
(
attrs
),
{});
}
RELAY_REGISTER_OP
(
"annotation.with_funcid"
)
RELAY_REGISTER_OP
(
"annotation.with_funcid"
)
.
describe
(
R"code(Annotate a function with a funcid.)code"
.
describe
(
R"code(Annotate a function with a funcid.)code"
TVM_ADD_FILELINE
)
TVM_ADD_FILELINE
)
.
set_num_inputs
(
1
)
.
set_num_inputs
(
1
)
.
add_argument
(
"func"
,
"Function"
,
"The input data."
);
.
add_argument
(
"func"
,
"Function"
,
"The input data."
);
// Cache with_funcid op to reduce lookup overhead during traversal.
static
const
Op
&
with_funcid_op
=
Op
::
Get
(
"annotation.with_funcid"
);
Expr
MkWithFuncId
(
const
Expr
&
expr
,
FuncId
fid
)
{
auto
attrs
=
make_node
<
WithFuncIdAttrs
>
();
attrs
->
fid
=
fid
;
return
CallNode
::
make
(
with_funcid_op
,
{
expr
},
Attrs
(
attrs
),
{});
}
Expr
StripWithFuncId
(
const
Expr
&
e
);
Expr
StripWithFuncId
(
const
Expr
&
e
);
Function
AsFunc
(
const
Expr
&
e
)
{
Function
AsFunc
(
const
Expr
&
e
)
{
if
(
e
.
as
<
FunctionNode
>
())
{
if
(
e
.
as
<
FunctionNode
>
())
{
return
Downcast
<
Function
>
(
e
);
return
Downcast
<
Function
>
(
e
);
}
else
if
(
const
CallNode
*
c
=
e
.
as
<
CallNode
>
())
{
}
else
if
(
const
CallNode
*
c
=
e
.
as
<
CallNode
>
())
{
CHECK
(
c
->
op
.
same_as
(
WithFuncIdOp
())
);
CHECK
(
c
->
op
==
with_funcid_op
);
CHECK_EQ
(
c
->
args
.
size
(),
1
);
CHECK_EQ
(
c
->
args
.
size
(),
1
);
return
AsFunc
(
c
->
args
[
0
]);
return
AsFunc
(
c
->
args
[
0
]);
}
else
{
}
else
{
...
@@ -604,7 +602,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
...
@@ -604,7 +602,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
PStatic
VisitExpr
(
const
Expr
&
e
,
LetList
*
ll
,
const
Var
&
name
)
{
PStatic
VisitExpr
(
const
Expr
&
e
,
LetList
*
ll
,
const
Var
&
name
)
{
if
(
const
CallNode
*
c
=
e
.
as
<
CallNode
>
())
{
if
(
const
CallNode
*
c
=
e
.
as
<
CallNode
>
())
{
if
(
c
->
op
.
same_as
(
WithFuncIdOp
())
)
{
if
(
c
->
op
==
with_funcid_op
)
{
CHECK_EQ
(
c
->
args
.
size
(),
1
);
CHECK_EQ
(
c
->
args
.
size
(),
1
);
return
VisitExpr
(
c
->
args
[
0
],
ll
,
name
);
return
VisitExpr
(
c
->
args
[
0
],
ll
,
name
);
}
}
...
@@ -722,7 +720,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
...
@@ -722,7 +720,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
}
}
PStatic
VisitExpr_
(
const
CallNode
*
op
,
LetList
*
ll
)
final
{
PStatic
VisitExpr_
(
const
CallNode
*
op
,
LetList
*
ll
)
final
{
if
(
op
->
op
.
same_as
(
WithFuncIdOp
())
)
{
if
(
op
->
op
==
with_funcid_op
)
{
CHECK_EQ
(
op
->
args
.
size
(),
1
);
CHECK_EQ
(
op
->
args
.
size
(),
1
);
return
VisitExpr
(
op
->
args
[
0
],
ll
);
return
VisitExpr
(
op
->
args
[
0
],
ll
);
}
}
...
@@ -1096,7 +1094,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
...
@@ -1096,7 +1094,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
explicit
RegisterFuncIdVisitor
(
PartialEvaluator
*
pe
)
:
pe
(
pe
)
{
}
explicit
RegisterFuncIdVisitor
(
PartialEvaluator
*
pe
)
:
pe
(
pe
)
{
}
void
VisitExpr_
(
const
CallNode
*
op
)
final
{
void
VisitExpr_
(
const
CallNode
*
op
)
final
{
if
(
op
->
op
.
same_as
(
WithFuncIdOp
())
)
{
if
(
op
->
op
==
with_funcid_op
)
{
CHECK_EQ
(
op
->
args
.
size
(),
1
);
CHECK_EQ
(
op
->
args
.
size
(),
1
);
CHECK
(
op
->
attrs
.
defined
());
CHECK
(
op
->
attrs
.
defined
());
CHECK
(
op
->
attrs
.
as
<
WithFuncIdAttrs
>
());
CHECK
(
op
->
attrs
.
as
<
WithFuncIdAttrs
>
());
...
@@ -1194,7 +1192,7 @@ Expr Remap(const Expr& e) {
...
@@ -1194,7 +1192,7 @@ Expr Remap(const Expr& e) {
Expr
StripWithFuncId
(
const
Expr
&
e
)
{
Expr
StripWithFuncId
(
const
Expr
&
e
)
{
struct
StripWithFuncIdMutator
:
ExprMutator
,
PatternMutator
{
struct
StripWithFuncIdMutator
:
ExprMutator
,
PatternMutator
{
Expr
VisitExpr_
(
const
CallNode
*
op
)
final
{
Expr
VisitExpr_
(
const
CallNode
*
op
)
final
{
if
(
op
->
op
.
same_as
(
WithFuncIdOp
())
)
{
if
(
op
->
op
==
with_funcid_op
)
{
CHECK_EQ
(
op
->
args
.
size
(),
1
);
CHECK_EQ
(
op
->
args
.
size
(),
1
);
return
VisitExpr
(
op
->
args
[
0
]);
return
VisitExpr
(
op
->
args
[
0
]);
}
else
{
}
else
{
...
...
src/relay/pass/quantize/calibrate.cc
View file @
475158f6
...
@@ -25,15 +25,17 @@
...
@@ -25,15 +25,17 @@
*/
*/
#include <tvm/relay/analysis.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
#include "./quantize.h"
#include "./quantize.h"
namespace
tvm
{
namespace
tvm
{
namespace
relay
{
namespace
relay
{
namespace
quantize
{
namespace
quantize
{
class
StatsCollector
:
private
ExprMutator
{
class
StatsCollector
:
private
ExprMutator
{
public
:
public
:
StatsCollector
()
:
simulated_quantize_op_
(
Op
::
Get
(
"relay.op.annotation.simulated_quantize"
))
{}
Expr
Collect
(
const
Expr
&
expr
)
{
Expr
Collect
(
const
Expr
&
expr
)
{
auto
new_e
=
this
->
Mutate
(
expr
);
auto
new_e
=
this
->
Mutate
(
expr
);
const
FunctionNode
*
func
=
new_e
.
as
<
FunctionNode
>
();
const
FunctionNode
*
func
=
new_e
.
as
<
FunctionNode
>
();
...
@@ -45,13 +47,13 @@ class StatsCollector : private ExprMutator {
...
@@ -45,13 +47,13 @@ class StatsCollector : private ExprMutator {
private
:
private
:
Array
<
Expr
>
profile_data_
;
Array
<
Expr
>
profile_data_
;
const
Op
&
simulated_quantize_op_
;
Expr
VisitExpr_
(
const
CallNode
*
call
)
{
Expr
VisitExpr_
(
const
CallNode
*
call
)
{
static
const
Op
&
simulated_quantize
=
Op
::
Get
(
"relay.op.annotation.simulated_quantize"
);
Expr
new_e
=
ExprMutator
::
VisitExpr_
(
call
);
Expr
new_e
=
ExprMutator
::
VisitExpr_
(
call
);
const
CallNode
*
new_call
=
new_e
.
as
<
CallNode
>
();
const
CallNode
*
new_call
=
new_e
.
as
<
CallNode
>
();
CHECK
(
new_call
);
CHECK
(
new_call
);
if
(
new_call
->
op
.
same_as
(
simulated_quantize
)
)
{
if
(
new_call
->
op
==
simulated_quantize_op_
)
{
auto
attrs
=
new_call
->
attrs
.
as
<
SimulatedQuantizeAttrs
>
();
auto
attrs
=
new_call
->
attrs
.
as
<
SimulatedQuantizeAttrs
>
();
// rewrite the annotation
// rewrite the annotation
auto
new_attrs
=
make_node
<
SimulatedQuantizeAttrs
>
();
auto
new_attrs
=
make_node
<
SimulatedQuantizeAttrs
>
();
...
...
src/relay/pass/simplify_inference.cc
View file @
475158f6
...
@@ -91,7 +91,6 @@ Expr LayerNormToInferUnpack(const Attrs attrs,
...
@@ -91,7 +91,6 @@ Expr LayerNormToInferUnpack(const Attrs attrs,
return
out
;
return
out
;
}
}
Expr
InstanceNormToInferUnpack
(
const
Attrs
attrs
,
Expr
InstanceNormToInferUnpack
(
const
Attrs
attrs
,
Expr
data
,
Expr
data
,
Expr
gamma
,
Expr
gamma
,
...
@@ -125,23 +124,25 @@ Expr InstanceNormToInferUnpack(const Attrs attrs,
...
@@ -125,23 +124,25 @@ Expr InstanceNormToInferUnpack(const Attrs attrs,
return
out
;
return
out
;
}
}
class
InferenceSimplifier
:
public
ExprMutator
{
class
InferenceSimplifier
:
public
ExprMutator
{
public
:
public
:
Expr
VisitExpr_
(
const
TupleGetItemNode
*
n
)
final
{
InferenceSimplifier
()
static
const
Op
&
batch_norm
=
Op
::
Get
(
"nn.batch_norm"
);
:
batch_norm_op_
(
Op
::
Get
(
"nn.batch_norm"
)),
static
const
Op
&
dropout
=
Op
::
Get
(
"nn.dropout"
);
dropout_op_
(
Op
::
Get
(
"nn.dropout"
)),
instance_norm_op_
(
Op
::
Get
(
"nn.instance_norm"
)),
layer_norm_op_
(
Op
::
Get
(
"nn.layer_norm"
))
{}
Expr
VisitExpr_
(
const
TupleGetItemNode
*
n
)
final
{
Expr
new_e
=
ExprMutator
::
VisitExpr_
(
n
);
Expr
new_e
=
ExprMutator
::
VisitExpr_
(
n
);
const
auto
*
new_n
=
new_e
.
as
<
TupleGetItemNode
>
();
const
auto
*
new_n
=
new_e
.
as
<
TupleGetItemNode
>
();
if
(
new_n
->
index
!=
0
)
{
if
(
new_n
->
index
!=
0
)
{
return
new_e
;
return
new_e
;
}
}
if
(
const
auto
*
call
=
new_n
->
tuple
.
as
<
CallNode
>
())
{
if
(
const
auto
*
call
=
new_n
->
tuple
.
as
<
CallNode
>
())
{
if
(
call
->
op
.
same_as
(
batch_norm
)
)
{
if
(
call
->
op
==
batch_norm_op_
)
{
return
BatchNormToInferUnpack
(
call
->
attrs
,
call
->
args
[
0
],
call
->
args
[
1
],
call
->
args
[
2
],
return
BatchNormToInferUnpack
(
call
->
attrs
,
call
->
args
[
0
],
call
->
args
[
1
],
call
->
args
[
2
],
call
->
args
[
3
],
call
->
args
[
4
],
ty_map_
.
at
(
call
->
args
[
0
]));
call
->
args
[
3
],
call
->
args
[
4
],
ty_map_
.
at
(
call
->
args
[
0
]));
}
else
if
(
call
->
op
.
same_as
(
dropout
)
)
{
}
else
if
(
call
->
op
==
dropout_op_
)
{
return
call
->
args
[
0
];
return
call
->
args
[
0
];
}
}
}
}
...
@@ -149,17 +150,14 @@ class InferenceSimplifier : public ExprMutator {
...
@@ -149,17 +150,14 @@ class InferenceSimplifier : public ExprMutator {
}
}
Expr
VisitExpr_
(
const
CallNode
*
n
)
{
Expr
VisitExpr_
(
const
CallNode
*
n
)
{
static
const
Op
&
batch_norm
=
Op
::
Get
(
"nn.batch_norm"
);
static
const
Op
&
instance_norm
=
Op
::
Get
(
"nn.instance_norm"
);
static
const
Op
&
layer_norm
=
Op
::
Get
(
"nn.layer_norm"
);
auto
new_n
=
ExprMutator
::
VisitExpr_
(
n
);
auto
new_n
=
ExprMutator
::
VisitExpr_
(
n
);
if
(
n
->
op
.
same_as
(
batch_norm
)
)
{
if
(
n
->
op
==
batch_norm_op_
)
{
ty_map_
[
new_n
.
as
<
CallNode
>
()
->
args
[
0
]]
=
n
->
args
[
0
]
->
checked_type
();
ty_map_
[
new_n
.
as
<
CallNode
>
()
->
args
[
0
]]
=
n
->
args
[
0
]
->
checked_type
();
}
else
if
(
n
->
op
.
same_as
(
layer_norm
)
)
{
}
else
if
(
n
->
op
==
layer_norm_op_
)
{
const
auto
*
call
=
new_n
.
as
<
CallNode
>
();
const
auto
*
call
=
new_n
.
as
<
CallNode
>
();
return
LayerNormToInferUnpack
(
call
->
attrs
,
call
->
args
[
0
],
call
->
args
[
1
],
return
LayerNormToInferUnpack
(
call
->
attrs
,
call
->
args
[
0
],
call
->
args
[
1
],
call
->
args
[
2
],
n
->
args
[
0
]
->
checked_type
());
call
->
args
[
2
],
n
->
args
[
0
]
->
checked_type
());
}
else
if
(
n
->
op
.
same_as
(
instance_norm
)
)
{
}
else
if
(
n
->
op
==
instance_norm_op_
)
{
const
auto
*
call
=
new_n
.
as
<
CallNode
>
();
const
auto
*
call
=
new_n
.
as
<
CallNode
>
();
return
InstanceNormToInferUnpack
(
call
->
attrs
,
call
->
args
[
0
],
call
->
args
[
1
],
return
InstanceNormToInferUnpack
(
call
->
attrs
,
call
->
args
[
0
],
call
->
args
[
1
],
call
->
args
[
2
],
n
->
args
[
0
]
->
checked_type
());
call
->
args
[
2
],
n
->
args
[
0
]
->
checked_type
());
...
@@ -168,6 +166,13 @@ class InferenceSimplifier : public ExprMutator {
...
@@ -168,6 +166,13 @@ class InferenceSimplifier : public ExprMutator {
}
}
private
:
private
:
// Cache the following ops. They will be used in the passes repeatedly for
// operator equivalence checking so that the registry lookup overhead can be
// reduced.
const
Op
&
batch_norm_op_
;
const
Op
&
dropout_op_
;
const
Op
&
instance_norm_op_
;
const
Op
&
layer_norm_op_
;
std
::
unordered_map
<
Expr
,
Type
,
NodeHash
,
NodeEqual
>
ty_map_
;
std
::
unordered_map
<
Expr
,
Type
,
NodeHash
,
NodeEqual
>
ty_map_
;
};
};
...
...
src/relay/pass/util.cc
View file @
475158f6
...
@@ -25,6 +25,7 @@
...
@@ -25,6 +25,7 @@
*/
*/
#include <tvm/relay/analysis.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
#include <tvm/relay/pattern_functor.h>
#include <tvm/relay/pattern_functor.h>
#include "pass_util.h"
#include "pass_util.h"
#include "../ir/type_functor.h"
#include "../ir/type_functor.h"
...
@@ -360,13 +361,14 @@ bool IsNDArrayAllGreaterEqual(const runtime::NDArray& tensor, T value) {
...
@@ -360,13 +361,14 @@ bool IsNDArrayAllGreaterEqual(const runtime::NDArray& tensor, T value) {
return
true
;
return
true
;
}
}
// Cache the operators that are checked recursively to reduce lookup overhead.
static
const
auto
&
expand_dims_op
=
Op
::
Get
(
"expand_dims"
);
static
const
auto
&
reshape_op
=
Op
::
Get
(
"reshape"
);
static
const
auto
&
transpose_op
=
Op
::
Get
(
"transpose"
);
static
const
auto
&
squeeze_op
=
Op
::
Get
(
"squeeze"
);
bool
IsAllPositiveConstant
(
const
Expr
&
expr
)
{
bool
IsAllPositiveConstant
(
const
Expr
&
expr
)
{
// peel through a few common transform ops.
// peel through a few common transform ops.
static
const
auto
&
expand_dims
=
Op
::
Get
(
"expand_dims"
);
static
const
auto
&
reshape
=
Op
::
Get
(
"reshape"
);
static
const
auto
&
transpose
=
Op
::
Get
(
"transpose"
);
static
const
auto
&
squeeze
=
Op
::
Get
(
"squeeze"
);
if
(
const
auto
*
constant
=
expr
.
as
<
ConstantNode
>
())
{
if
(
const
auto
*
constant
=
expr
.
as
<
ConstantNode
>
())
{
const
auto
&
tensor
=
constant
->
data
;
const
auto
&
tensor
=
constant
->
data
;
const
auto
&
dtype
=
tensor
->
dtype
;
const
auto
&
dtype
=
tensor
->
dtype
;
...
@@ -389,10 +391,10 @@ bool IsAllPositiveConstant(const Expr& expr) {
...
@@ -389,10 +391,10 @@ bool IsAllPositiveConstant(const Expr& expr) {
}
}
}
else
if
(
const
auto
*
op
=
expr
.
as
<
CallNode
>
())
{
}
else
if
(
const
auto
*
op
=
expr
.
as
<
CallNode
>
())
{
// tail recursion.
// tail recursion.
if
(
op
->
op
.
same_as
(
expand_dims
)
||
if
(
op
->
op
==
expand_dims_op
||
op
->
op
.
same_as
(
reshape
)
||
op
->
op
==
reshape_op
||
op
->
op
.
same_as
(
transpose
)
||
op
->
op
==
transpose_op
||
op
->
op
.
same_as
(
squeeze
)
)
{
op
->
op
==
squeeze_op
)
{
return
IsAllPositiveConstant
(
op
->
args
[
0
]);
return
IsAllPositiveConstant
(
op
->
args
[
0
]);
}
else
{
}
else
{
return
false
;
return
false
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment