Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
59c70a0e
Unverified
Commit
59c70a0e
authored
Nov 15, 2018
by
Tianqi Chen
Committed by
GitHub
Nov 15, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RELAY][[PASS] Consolidate ForwardRewrite pass. (#2124)
parent
de3b63a4
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
319 additions
and
178 deletions
+319
-178
include/tvm/relay/expr.h
+26
-0
include/tvm/relay/op.h
+40
-0
include/tvm/relay/op_attr_types.h
+19
-0
include/tvm/relay/pass.h
+11
-0
include/tvm/runtime/packed_func.h
+2
-0
src/relay/pass/fold_scale_axis.cc
+89
-178
src/relay/pass/forward_rewrite.cc
+132
-0
No files found.
include/tvm/relay/expr.h
View file @
59c70a0e
...
...
@@ -415,6 +415,32 @@ class TupleGetItemNode : public ExprNode {
RELAY_DEFINE_NODE_REF
(
TupleGetItem
,
TupleGetItemNode
,
Expr
);
/*!
* \brief Base class of the temporary expression.
*
* TempExprs are pass specific expression that can be
* useful to define intermediate result in the
* rewriting pass such as layout or type transformation.
*
* Subclass TempExprNode allows us to pattern match on
* specific kind TempExpr and use them for expression rewriting.
*
* TempExpr should only be used within a pass,
*/
class
TempExprNode
:
public
ExprNode
{
public
:
/*!
* \brief Convert the expression to a normal(non-temp) Expr.
* \return The corresponding normal(non-temp) expression.
*/
virtual
Expr
Realize
()
const
=
0
;
static
constexpr
const
char
*
_type_key
=
"relay.TempExpr"
;
TVM_DECLARE_BASE_NODE_INFO
(
TempExprNode
,
ExprNode
);
};
RELAY_DEFINE_NODE_REF
(
TempExpr
,
TempExprNode
,
Expr
);
// implementataions
template
<
typename
TTypeNode
>
inline
const
TTypeNode
*
ExprNode
::
type_as
()
const
{
...
...
include/tvm/relay/op.h
View file @
59c70a0e
...
...
@@ -276,6 +276,16 @@ class GenericOpMap {
*/
template
<
typename
ValueType
>
inline
ValueType
get
(
const
Op
&
op
,
ValueType
def_value
)
const
;
/*!
* \brief get the corresponding value element at op with default value.
* \param expr The key to the map
* \param def_value The default value when the key does not exist
* or if expr is not an Op.
* \return the const reference to the content value.
* \tparam ValueType The content value type.
*/
template
<
typename
ValueType
>
inline
ValueType
get
(
const
Expr
&
expr
,
ValueType
def_value
)
const
;
private
:
friend
class
OpRegistry
;
...
...
@@ -313,6 +323,14 @@ class OpMap {
* \return the const reference to the content value.
*/
inline
ValueType
get
(
const
Op
&
op
,
ValueType
def_value
)
const
;
/*!
* \brief get the corresponding value element at op with default value.
* \param expr The key to the map
* \param def_value The default value when the key does not exist
* or if expr is not an Op.
* \return the const reference to the content value.
*/
inline
ValueType
get
(
const
Expr
&
expr
,
ValueType
def_value
)
const
;
private
:
friend
class
Op
;
...
...
@@ -497,6 +515,21 @@ inline ValueType GenericOpMap::get(const Op& op, ValueType value) const {
}
template
<
typename
ValueType
>
inline
ValueType
GenericOpMap
::
get
(
const
Expr
&
expr
,
ValueType
value
)
const
{
CHECK
(
expr
.
defined
());
if
(
const
OpNode
*
op
=
expr
.
as
<
OpNode
>
())
{
const
uint32_t
idx
=
op
->
index_
;
if
(
idx
<
data_
.
size
()
&&
data_
[
idx
].
second
!=
0
)
{
return
data_
[
idx
].
first
;
}
else
{
return
value
;
}
}
else
{
return
value
;
}
}
template
<
typename
ValueType
>
inline
int
OpMap
<
ValueType
>::
count
(
const
Op
&
op
)
const
{
return
map_
.
count
(
op
);
}
...
...
@@ -505,12 +538,19 @@ template <typename ValueType>
inline
ValueType
OpMap
<
ValueType
>::
operator
[](
const
Op
&
op
)
const
{
return
map_
[
op
];
}
template
<
typename
ValueType
>
inline
ValueType
OpMap
<
ValueType
>::
get
(
const
Op
&
op
,
ValueType
def_value
)
const
{
return
map_
.
get
<
ValueType
>
(
op
,
def_value
);
}
template
<
typename
ValueType
>
inline
ValueType
OpMap
<
ValueType
>::
get
(
const
Expr
&
expr
,
ValueType
def_value
)
const
{
return
map_
.
get
<
ValueType
>
(
expr
,
def_value
);
}
/*!
* \brief Check that an expression is a "primtive operator".
*
...
...
include/tvm/relay/op_attr_types.h
View file @
59c70a0e
...
...
@@ -85,6 +85,25 @@ using FTVMSchedule = runtime::TypedPackedFunc<
Schedule
(
const
Attrs
&
attrs
,
const
Array
<
Tensor
>&
outs
,
const
Target
&
target
)
>
;
/*!
* \brief Forward rewriting rule for a specific op.
*
* \param ref_call The reference old call type to be rewritten.
* We can make use of the op and type information.
* \param new_args The new arguments (some of them could be TempExpr).
* \param ctx Optional context information about ref_call.
* \return The rewriten result call, can also return nullptr,
* which indicate the rewriter should use the default fallback
* rule that realizes all its input and compose the call.
*
* \note When we register the function, we can register
* a different signature with ctx to be a specific node type.
*/
using
FForwardRewrite
=
runtime
::
TypedPackedFunc
<
Expr
(
const
Call
&
ref_call
,
const
Array
<
Expr
>&
new_args
,
const
NodeRef
&
ctx
)
>
;
}
// namespace relay
}
// namespace tvm
#endif // TVM_RELAY_OP_ATTR_TYPES_H_
include/tvm/relay/pass.h
View file @
59c70a0e
...
...
@@ -158,6 +158,17 @@ Expr FoldConstant(const Expr& expr);
*/
Expr
FuseOps
(
const
Expr
&
expr
,
int
fuse_opt_level
);
/*!
* \brief Apply rewrite rules to rewrite the expr in post DFS order.
* \param expr The expression.
* \param rewrite_map_attr_name The Op's attr name which corresponds to the rewrite
* rule function.
* \param fcontext Additional callback to provide context argument for each call node.
* \return The rewritten expression.
*/
Expr
ForwardRewrite
(
const
Expr
&
expr
,
const
std
::
string
&
rewrite_map_attr_name
,
std
::
function
<
NodeRef
(
const
Call
&
)
>
fcontext
=
nullptr
);
/*! \brief A hashing structure in the style of std::hash. */
struct
StructuralHash
{
...
...
include/tvm/runtime/packed_func.h
View file @
59c70a0e
...
...
@@ -73,6 +73,8 @@ class PackedFunc {
using
FType
=
std
::
function
<
void
(
TVMArgs
args
,
TVMRetValue
*
rv
)
>
;
/*! \brief default constructor */
PackedFunc
()
{}
/*! \brief constructor from null */
PackedFunc
(
std
::
nullptr_t
null
)
{}
// NOLINT(*)
/*!
* \brief constructing a packed function from a std::function.
* \param body the internal container of packed function.
...
...
src/relay/pass/fold_scale_axis.cc
View file @
59c70a0e
...
...
@@ -88,23 +88,6 @@ AxesSet Intersect(const AxesSet& lhs, const AxesSet& rhs) {
}
/*!
* \param Get function from op_map.
* \param op_map The OpMap.
* \param op The operator being called.
* \tparam ValueType the content value type.
* \return The result value map.
*/
template
<
typename
ValueType
>
ValueType
GetFunc
(
const
OpMap
<
ValueType
>&
op_map
,
const
Expr
&
op
)
{
if
(
const
OpNode
*
opnode
=
op
.
as
<
OpNode
>
())
{
return
op_map
.
get
(
GetRef
<
Op
>
(
opnode
),
ValueType
());
}
else
{
return
ValueType
();
}
}
/*!
* \brief Preparation function for pass scale forward.
* \param call The call node.
* \param out_scale_axes Possible scaling on axes of the output.
...
...
@@ -114,7 +97,7 @@ using FForwardPrep = runtime::TypedPackedFunc<
Array
<
AxesSet
>
(
const
Call
&
call
,
const
AxesSet
&
out_scale_axes
)
>
;
/*! \brief Axis scale tuple. */
class
S
TupleNode
:
public
Node
{
class
S
caledExprNode
:
public
TempExpr
Node
{
public
:
/*! \brief The value */
Expr
value
;
...
...
@@ -123,29 +106,26 @@ class STupleNode : public Node {
/*! \brief The scaling factor */
Expr
scale
=
NullValue
<
Expr
>
();
Expr
Realize
()
const
final
{
CHECK
(
!
axes
.
defined
())
<<
"outstanding scale"
;
return
value
;
}
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"value"
,
&
value
);
v
->
Visit
(
"axes"
,
&
axes
);
v
->
Visit
(
"scale"
,
&
scale
);
}
static
constexpr
const
char
*
_type_key
=
"relay.fold_scale_axis.S
TupleNode
"
;
TVM_DECLARE_NODE_TYPE_INFO
(
S
TupleNode
,
Node
);
static
constexpr
const
char
*
_type_key
=
"relay.fold_scale_axis.S
caledExpr
"
;
TVM_DECLARE_NODE_TYPE_INFO
(
S
caledExprNode
,
TempExpr
Node
);
};
RELAY_DEFINE_NODE_REF
(
STuple
,
STupleNode
,
NodeRef
);
/*!
* \brief The transform function, transform an old call to
* a new one given the new args.
* \param ref_call Reference call node that represent the op and the types.
* \param expected_out_axes The scale axes allowed in the output.
* \param sargs The input arguments.
*/
using
FForwardTransform
=
TypedPackedFunc
<
STuple
(
const
Call
&
ref_call
,
const
AxesSet
&
expected_out_axes
,
const
Array
<
STuple
>&
sargs
)
>
;
using
FForwardRewrite
=
TypedPackedFunc
<
Expr
(
const
Call
&
ref_call
,
const
Array
<
Expr
>&
new_args
,
const
AxesSet
&
expeced_out_axes
)
>
;
//----------------------------------------------
// Generic Visitors for FScaleAxisForward
...
...
@@ -219,7 +199,7 @@ class ForwardPrep : private ExprVisitor {
out_axes
=
NullValue
<
AxesSet
>
();
}
// pass the message back to all the children it references.
auto
f
=
GetFunc
(
fprep
,
call
->
op
);
auto
f
=
fprep
.
get
(
call
->
op
,
nullptr
);
if
(
f
!=
nullptr
)
{
Array
<
AxesSet
>
in_axes
=
f
(
GetRef
<
Call
>
(
call
),
out_axes
);
CHECK_EQ
(
in_axes
.
size
(),
call
->
args
.
size
());
...
...
@@ -261,87 +241,6 @@ class ForwardPrep : private ExprVisitor {
}
};
class
ForwardTransformer
:
private
ExprMutator
{
public
:
// Transform expression.
Expr
Fold
(
Expr
expr
)
{
expected_scale_axes_
=
ForwardPrep
().
Prepare
(
expr
);
return
this
->
Mutate
(
expr
);
}
private
:
// Valid axes on each node.
std
::
unordered_map
<
const
Node
*
,
AxesSet
>
expected_scale_axes_
;
std
::
unordered_map
<
const
Node
*
,
STuple
>
scale_memo_
;
// If user simply call mutate,
// then only Expr is returned and we cannot
// accept outstanding scales.
Expr
VisitExpr
(
const
Expr
&
expr
)
final
{
Expr
res
=
ExprMutator
::
VisitExpr
(
expr
);
CHECK
(
!
scale_memo_
.
count
(
expr
.
get
()))
<<
"Outstanding scale"
;
return
res
;
}
STuple
GetSTuple
(
const
Expr
&
expr
)
{
Expr
res
=
ExprMutator
::
VisitExpr
(
expr
);
auto
it
=
scale_memo_
.
find
(
expr
.
get
());
if
(
it
!=
scale_memo_
.
end
())
{
CHECK
(
it
->
second
->
value
.
same_as
(
res
));
return
it
->
second
;
}
else
{
auto
node
=
make_node
<
STupleNode
>
();
node
->
value
=
res
;
return
STuple
(
node
);
}
}
Expr
VisitExpr_
(
const
CallNode
*
call_node
)
final
{
static
const
auto
&
ftransform
=
Op
::
GetAttr
<
FForwardTransform
>
(
"FScaleAxisForwardTransform"
);
auto
new_op
=
this
->
Mutate
(
call_node
->
op
);
bool
has_scale
=
false
;
bool
unchanged
=
call_node
->
op
.
same_as
(
new_op
);
Array
<
STuple
>
call_sargs
;
Array
<
Expr
>
call_args
;
for
(
auto
arg
:
call_node
->
args
)
{
STuple
new_sarg
=
this
->
GetSTuple
(
arg
);
unchanged
&=
new_sarg
->
value
.
same_as
(
arg
);
if
(
new_sarg
->
axes
.
defined
())
has_scale
=
true
;
call_sargs
.
push_back
(
new_sarg
);
call_args
.
push_back
(
new_sarg
->
value
);
}
// get expected scale axes.
AxesSet
expected_out_axes
;
auto
axis_it
=
expected_scale_axes_
.
find
(
call_node
);
if
(
axis_it
!=
expected_scale_axes_
.
end
())
{
expected_out_axes
=
axis_it
->
second
;
}
// propagation function
auto
f
=
GetFunc
(
ftransform
,
call_node
->
op
);
if
(
f
!=
nullptr
)
{
STuple
sret
=
f
(
GetRef
<
Call
>
(
call_node
),
expected_out_axes
,
call_sargs
);
if
(
sret
.
defined
())
{
if
(
sret
->
axes
.
defined
())
{
scale_memo_
[
call_node
]
=
sret
;
}
return
sret
->
value
;
}
}
// normal path
CHECK
(
!
has_scale
)
<<
"Outstanding scale, on op="
<<
call_node
->
op
;
if
(
unchanged
)
{
return
GetRef
<
Expr
>
(
call_node
);
}
else
{
return
CallNode
::
make
(
new_op
,
call_args
,
call_node
->
attrs
,
call_node
->
type_args
);
}
}
};
//----------------------------------------------
// Per operator defs for FScaleAxisForward
//----------------------------------------------
...
...
@@ -351,30 +250,31 @@ Array<AxesSet> ReluForwardPrep(const Call& call, AxesSet out) {
return
{
out
};
}
STuple
ReluForwardTransform
(
const
Call
&
ref_call
,
const
AxesSet
&
expected_axes
,
const
Array
<
STuple
>&
sargs
)
{
if
(
!
sargs
[
0
]
->
axes
.
defined
())
return
STuple
();
Expr
ReluForwardRewrite
(
const
Call
&
ref_call
,
const
Array
<
Expr
>&
new_args
,
const
AxesSet
&
expected_axes
)
{
const
auto
*
input
=
new_args
[
0
].
as
<
ScaledExprNode
>
();
if
(
input
==
nullptr
)
return
Expr
(
nullptr
);
// return transformed conv2d
auto
rnode
=
make_node
<
S
Tuple
Node
>
();
auto
rnode
=
make_node
<
S
caledExpr
Node
>
();
rnode
->
value
=
CallNode
::
make
(
ref_call
->
op
,
{
sargs
[
0
]
->
value
},
ref_call
->
attrs
,
ref_call
->
type_args
);
rnode
->
scale
=
sargs
[
0
]
->
scale
;
rnode
->
axes
=
sargs
[
0
]
->
axes
;
return
STuple
(
rnode
);
ref_call
->
op
,
{
input
->
value
},
ref_call
->
attrs
,
ref_call
->
type_args
);
rnode
->
scale
=
input
->
scale
;
rnode
->
axes
=
input
->
axes
;
return
Expr
(
rnode
);
}
RELAY_REGISTER_OP
(
"nn.relu"
)
.
set_attr
<
FForwardPrep
>
(
"FScaleAxisForwardPrep"
,
ReluForwardPrep
);
RELAY_REGISTER_OP
(
"nn.relu"
)
.
set_attr
<
FForward
Transform
>
(
"FScaleAxisForwardTransform"
,
ReluForwardTransform
);
.
set_attr
<
FForward
Rewrite
>
(
"FScaleAxisForwardRewrite"
,
ReluForwardRewrite
);
RELAY_REGISTER_OP
(
"nn.leaky_relu"
)
.
set_attr
<
FForwardPrep
>
(
"FScaleAxisForwardPrep"
,
ReluForwardPrep
);
RELAY_REGISTER_OP
(
"nn.leaky_relu"
)
.
set_attr
<
FForward
Transform
>
(
"FScaleAxisForwardTransform"
,
ReluForwardTransform
);
.
set_attr
<
FForward
Rewrite
>
(
"FScaleAxisForwardRewrite"
,
ReluForwardRewrite
);
// AddSub
Array
<
AxesSet
>
AddSubForwardPrep
(
const
Call
&
call
,
AxesSet
out_axes
)
{
...
...
@@ -391,69 +291,69 @@ Array<AxesSet> AddSubForwardPrep(const Call& call, AxesSet out_axes) {
}
}
STuple
AddSubForwardTransform
(
const
Call
&
ref_call
,
const
AxesSet
&
expected_out_axe
s
,
const
Array
<
STuple
>&
sarg
s
)
{
if
(
!
sargs
[
0
]
->
axes
.
defined
()
&&
!
sargs
[
1
]
->
axes
.
defined
())
{
return
STuple
();
}
Expr
AddSubForwardRewrite
(
const
Call
&
ref_call
,
const
Array
<
Expr
>&
new_arg
s
,
const
AxesSet
&
expected_out_axe
s
)
{
const
auto
*
slhs
=
new_args
[
0
].
as
<
ScaledExprNode
>
();
const
auto
*
srhs
=
new_args
[
1
].
as
<
ScaledExprNode
>
();
if
(
!
slhs
&&
!
srhs
)
return
Expr
();
const
auto
*
tlhs
=
ref_call
->
args
[
0
]
->
type_as
<
TensorTypeNode
>
();
const
auto
*
trhs
=
ref_call
->
args
[
1
]
->
type_as
<
TensorTypeNode
>
();
auto
rnode
=
make_node
<
ScaledExprNode
>
();
auto
rnode
=
make_node
<
STupleNode
>
();
if
(
sargs
[
0
]
->
axes
.
defined
())
{
CHECK
(
!
sargs
[
1
]
->
axes
.
defined
());
CHECK
(
MatchBroadcastToLeftAxes
(
tlhs
,
trhs
,
sargs
[
0
]
->
axes
));
if
(
slhs
!=
nullptr
)
{
CHECK
(
srhs
==
nullptr
);
CHECK
(
MatchBroadcastToLeftAxes
(
tlhs
,
trhs
,
slhs
->
axes
));
Expr
scale
=
ExpandBiasToMatchAxis
(
s
args
[
0
]
->
scale
,
tlhs
->
shape
.
size
(),
sargs
[
0
]
->
axes
);
Expr
rhs
=
Divide
(
sargs
[
1
]
->
value
,
scale
);
rnode
->
value
=
CallNode
::
make
(
ref_call
->
op
,
{
s
args
[
0
]
->
value
,
rhs
},
s
lhs
->
scale
,
tlhs
->
shape
.
size
(),
slhs
->
axes
);
Expr
rhs
=
Divide
(
new_args
[
1
]
,
scale
);
rnode
->
value
=
CallNode
::
make
(
ref_call
->
op
,
{
s
lhs
->
value
,
rhs
},
ref_call
->
attrs
,
ref_call
->
type_args
);
rnode
->
scale
=
s
args
[
0
]
->
scale
;
rnode
->
axes
=
s
args
[
0
]
->
axes
;
rnode
->
scale
=
s
lhs
->
scale
;
rnode
->
axes
=
s
lhs
->
axes
;
}
else
{
CHECK
(
sargs
[
1
]
->
axes
.
defined
());
CHECK
(
sargs
[
0
]
->
axes
.
defined
());
CHECK
(
MatchBroadcastToLeftAxes
(
trhs
,
tlhs
,
sargs
[
1
]
->
axes
));
CHECK
(
slhs
!=
nullptr
);
CHECK
(
MatchBroadcastToLeftAxes
(
trhs
,
tlhs
,
srhs
->
axes
));
Expr
scale
=
ExpandBiasToMatchAxis
(
s
args
[
1
]
->
scale
,
trhs
->
shape
.
size
(),
sargs
[
1
]
->
axes
);
Expr
lhs
=
Divide
(
sargs
[
0
]
->
value
,
scale
);
rnode
->
value
=
CallNode
::
make
(
ref_call
->
op
,
{
lhs
,
s
args
[
1
]
->
value
},
s
rhs
->
scale
,
trhs
->
shape
.
size
(),
srhs
->
axes
);
Expr
lhs
=
Divide
(
new_args
[
0
]
,
scale
);
rnode
->
value
=
CallNode
::
make
(
ref_call
->
op
,
{
lhs
,
s
rhs
->
value
},
ref_call
->
attrs
,
ref_call
->
type_args
);
rnode
->
scale
=
s
args
[
1
]
->
scale
;
rnode
->
axes
=
s
args
[
1
]
->
axes
;
rnode
->
scale
=
s
rhs
->
scale
;
rnode
->
axes
=
s
rhs
->
axes
;
}
return
STuple
(
rnode
);
return
Expr
(
rnode
);
}
RELAY_REGISTER_OP
(
"add"
)
.
set_attr
<
FForwardPrep
>
(
"FScaleAxisForwardPrep"
,
AddSubForwardPrep
);
RELAY_REGISTER_OP
(
"add"
)
.
set_attr
<
FForward
Transform
>
(
"FScaleAxisForwardTransform"
,
AddSubForwardTransform
);
.
set_attr
<
FForward
Rewrite
>
(
"FScaleAxisForwardRewrite"
,
AddSubForwardRewrite
);
RELAY_REGISTER_OP
(
"subtract"
)
.
set_attr
<
FForwardPrep
>
(
"FScaleAxisForwardPrep"
,
AddSubForwardPrep
);
RELAY_REGISTER_OP
(
"subtract"
)
.
set_attr
<
FForward
Transform
>
(
"FScaleAxisForwardTransform"
,
AddSubForwardTransform
);
.
set_attr
<
FForward
Rewrite
>
(
"FScaleAxisForwardRewrite"
,
AddSubForwardRewrite
);
// Producer operators
// Multiply produces the scale-axis pair.
STuple
MultiplyForwardTransform
(
const
Call
&
ref_call
,
const
AxesSet
&
expected_out_axe
s
,
const
Array
<
STuple
>&
sarg
s
)
{
if
(
!
expected_out_axes
.
defined
())
return
STuple
();
Expr
MultiplyForwardRewrite
(
const
Call
&
ref_call
,
const
Array
<
Expr
>&
new_arg
s
,
const
AxesSet
&
expected_out_axe
s
)
{
if
(
!
expected_out_axes
.
defined
())
return
Expr
();
// TODO(tvm-team) allow same axes accumulation
// not as important because it is less common in nn.
CHECK
(
!
sargs
[
0
]
->
axes
.
defined
());
CHECK
(
!
sargs
[
1
]
->
axes
.
defined
());
const
auto
*
slhs
=
new_args
[
0
].
as
<
ScaledExprNode
>
();
const
auto
*
srhs
=
new_args
[
1
].
as
<
ScaledExprNode
>
();
CHECK
(
!
slhs
&&
!
srhs
);
const
auto
*
tlhs
=
ref_call
->
args
[
0
]
->
type_as
<
TensorTypeNode
>
();
const
auto
*
trhs
=
ref_call
->
args
[
1
]
->
type_as
<
TensorTypeNode
>
();
Expr
lhs
=
sargs
[
0
]
->
value
;
Expr
rhs
=
sargs
[
1
]
->
value
;
auto
rnode
=
make_node
<
STupleNode
>
();
Expr
lhs
=
new_args
[
0
];
Expr
rhs
=
new_args
[
1
];
auto
rnode
=
make_node
<
ScaledExprNode
>
();
if
(
MatchBroadcastToLeftAxes
(
tlhs
,
trhs
,
expected_out_axes
,
&
rhs
))
{
rnode
->
value
=
lhs
;
rnode
->
scale
=
rhs
;
...
...
@@ -463,11 +363,11 @@ STuple MultiplyForwardTransform(const Call& ref_call,
rnode
->
scale
=
lhs
;
rnode
->
axes
=
expected_out_axes
;
}
return
STuple
(
rnode
);
return
Expr
(
rnode
);
}
RELAY_REGISTER_OP
(
"multiply"
)
.
set_attr
<
FForward
Transform
>
(
"FScaleAxisForwardTransform"
,
MultiplyForwardTransform
);
.
set_attr
<
FForward
Rewrite
>
(
"FScaleAxisForwardRewrite"
,
MultiplyForwardRewrite
);
// Consumer operators
// Conv2D send out requirement of axis folding.
...
...
@@ -500,13 +400,14 @@ Array<AxesSet> Conv2DForwardPrep(const Call& call, AxesSet out) {
}
// Conv2D consumes the scale axis during transformation.
STuple
Conv2DForwardTransform
(
const
Call
&
ref_call
,
const
AxesSet
&
expected_axe
s
,
const
Array
<
STuple
>&
sarg
s
)
{
Expr
Conv2DForwardRewrite
(
const
Call
&
ref_call
,
const
Array
<
Expr
>&
new_arg
s
,
const
AxesSet
&
expected_axe
s
)
{
// if data do not have scale, normal transform path.
STuple
sdata
=
sargs
[
0
];
if
(
!
sdata
->
scale
.
defined
())
return
STuple
();
CHECK
(
sdata
->
axes
.
defined
());
const
auto
*
sdata
=
new_args
[
0
].
as
<
ScaledExprNode
>
();
const
auto
*
sweight
=
new_args
[
1
].
as
<
ScaledExprNode
>
();
if
(
sdata
==
nullptr
)
return
Expr
();
if
(
sweight
!=
nullptr
)
return
Expr
();
const
auto
*
param
=
ref_call
->
attrs
.
as
<
Conv2DAttrs
>
();
CHECK
(
param
!=
nullptr
);
Layout
data_layout
(
param
->
data_layout
);
...
...
@@ -524,7 +425,8 @@ STuple Conv2DForwardTransform(const Call& ref_call,
// Check it must be depthwise or full conv2d.
bool
is_depthwise_conv2d
=
IsDepthwiseConv2D
(
ref_call
,
param
,
weight_layout
);
CHECK
(
param
->
groups
==
1
||
is_depthwise_conv2d
);
Expr
weight
=
sargs
[
1
]
->
value
;
Expr
weight
=
new_args
[
1
];
// match the ic_axis
if
(
is_depthwise_conv2d
)
{
...
...
@@ -537,21 +439,30 @@ STuple Conv2DForwardTransform(const Call& ref_call,
weight
=
Multiply
(
weight
,
scale
);
}
// return transformed conv2d
auto
rnode
=
make_node
<
STupleNode
>
();
rnode
->
value
=
CallNode
::
make
(
return
CallNode
::
make
(
ref_call
->
op
,
{
sdata
->
value
,
weight
},
ref_call
->
attrs
,
ref_call
->
type_args
);
return
STuple
(
rnode
);
}
RELAY_REGISTER_OP
(
"nn.conv2d"
)
.
set_attr
<
FForwardPrep
>
(
"FScaleAxisForwardPrep"
,
Conv2DForwardPrep
);
RELAY_REGISTER_OP
(
"nn.conv2d"
)
.
set_attr
<
FForward
Transform
>
(
"FScaleAxisForwardTransform"
,
Conv2DForwardTransform
);
.
set_attr
<
FForward
Rewrite
>
(
"FScaleAxisForwardRewrite"
,
Conv2DForwardRewrite
);
Expr
ForwardFoldScaleAxis
(
Expr
data
)
{
return
ForwardTransformer
().
Fold
(
data
);
auto
expected_scale_axes
=
ForwardPrep
().
Prepare
(
data
);
auto
fcontext
=
[
&
](
const
Call
&
call
)
->
NodeRef
{
auto
it
=
expected_scale_axes
.
find
(
call
.
get
());
if
(
it
!=
expected_scale_axes
.
end
())
{
return
it
->
second
;
}
else
{
return
NodeRef
(
nullptr
);
}
};
return
ForwardRewrite
(
data
,
"FScaleAxisForwardRewrite"
,
fcontext
);
}
// Expose the FoldScaleAxisFoward
...
...
@@ -602,7 +513,7 @@ class BackwardPrep : private ExprVisitor {
ExprVisitor
::
VisitExpr_
(
call
);
static
const
auto
&
fprep
=
Op
::
GetAttr
<
FBackwardPrep
>
(
"FScaleAxisBackwardPrep"
);
auto
f
=
GetFunc
(
fprep
,
call
->
op
);
auto
f
=
fprep
.
get
(
call
->
op
,
nullptr
);
if
(
f
==
nullptr
)
return
;
auto
rit
=
ref_counter_
.
find
(
call
);
CHECK
(
rit
!=
ref_counter_
.
end
());
...
...
@@ -705,7 +616,7 @@ Expr BackwardTransformerNode::Transform(
const
CallNode
*
call_node
,
AxesSet
axes
,
Expr
scale
)
{
static
const
auto
&
ftransform
=
Op
::
GetAttr
<
FBackwardTransform
>
(
"FScaleAxisBackwardTransform"
);
auto
f
=
GetFunc
(
ftransform
,
call_node
->
op
);
auto
f
=
ftransform
.
get
(
call_node
->
op
,
nullptr
);
if
(
f
!=
nullptr
)
{
return
f
(
GetRef
<
Call
>
(
call_node
),
axes
,
...
...
src/relay/pass/forward_rewrite.cc
0 → 100644
View file @
59c70a0e
/*!
* Copyright (c) 2018 by Contributors
*
* \file forward_rewrite.cc
* \brief Apply rewriting rules in a forward fashion.
*/
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h>
namespace
tvm
{
namespace
relay
{
// Realizer class that realizes the expression
// Note that we can take benefit of its internal memo
// so that calling realize repeatively won't hurt perf.
class
TempRealizer
:
private
ExprMutator
{
public
:
Expr
Realize
(
Expr
expr
)
{
return
VisitExpr
(
expr
);
}
private
:
Expr
VisitExpr
(
const
Expr
&
expr
)
final
{
auto
it
=
memo_
.
find
(
expr
);
if
(
it
!=
memo_
.
end
())
{
return
it
->
second
;
}
else
{
Expr
res
;
if
(
const
auto
*
temp
=
expr
.
as_derived
<
TempExprNode
>
())
{
res
=
temp
->
Realize
();
}
else
{
res
=
ExprFunctor
::
VisitExpr
(
expr
);
}
memo_
[
res
]
=
res
;
return
res
;
}
}
};
class
ForwardRewriter
:
private
ExprMutator
{
public
:
ForwardRewriter
(
const
OpMap
<
FForwardRewrite
>&
rewrite_map
,
std
::
function
<
NodeRef
(
const
Call
&
)
>
fcontext
)
:
rewrite_map_
(
rewrite_map
),
fcontext_
(
fcontext
)
{
}
// Transform expression.
Expr
Rewrite
(
Expr
expr
)
{
return
this
->
VisitExpr
(
expr
);
}
private
:
// The rewrite rule.
const
OpMap
<
FForwardRewrite
>&
rewrite_map_
;
// The context.
std
::
function
<
NodeRef
(
const
Call
&
)
>
fcontext_
{
nullptr
};
// internal realizer
TempRealizer
realizer_
;
Expr
VisitExpr
(
const
Expr
&
expr
)
final
{
// by default always realize.
return
realizer_
.
Realize
(
ExprMutator
::
VisitExpr
(
expr
));
}
// Visit and allow non-realized version.
Expr
GetTempExpr
(
const
Expr
&
expr
)
{
return
ExprMutator
::
VisitExpr
(
expr
);
}
// Automatic fold TupleGetItem.
Expr
VisitExpr_
(
const
TupleGetItemNode
*
op
)
final
{
Expr
tuple
=
this
->
GetTempExpr
(
op
->
tuple
);
if
(
const
auto
*
ptuple
=
tuple
.
as
<
TupleNode
>
())
{
return
ptuple
->
fields
[
op
->
index
];
}
else
{
if
(
tuple
.
same_as
(
op
->
tuple
))
{
return
GetRef
<
Expr
>
(
op
);
}
else
{
return
TupleGetItemNode
::
make
(
tuple
,
op
->
index
);
}
}
}
Expr
VisitExpr_
(
const
CallNode
*
call_node
)
final
{
const
Call
&
ref_call
=
GetRef
<
Call
>
(
call_node
);
PackedFunc
frewrite
=
rewrite_map_
.
get
(
call_node
->
op
,
nullptr
);
auto
new_op
=
this
->
Mutate
(
call_node
->
op
);
bool
unchanged
=
call_node
->
op
.
same_as
(
new_op
);
Array
<
Expr
>
call_args
;
for
(
auto
arg
:
call_node
->
args
)
{
Expr
new_arg
=
this
->
GetTempExpr
(
arg
);
if
(
frewrite
==
nullptr
)
{
new_arg
=
realizer_
.
Realize
(
new_arg
);
}
unchanged
&=
new_arg
.
same_as
(
arg
);
call_args
.
push_back
(
new_arg
);
}
// try to rewrite.
if
(
frewrite
!=
nullptr
)
{
Expr
res
=
frewrite
(
ref_call
,
call_args
,
fcontext_
!=
nullptr
?
fcontext_
(
ref_call
)
:
NodeRef
(
nullptr
));
if
(
res
.
defined
())
return
res
;
// abort, use old rule
for
(
size_t
i
=
0
;
i
<
call_args
.
size
();
++
i
)
{
Expr
arg
=
call_args
[
i
];
Expr
new_arg
=
realizer_
.
Realize
(
arg
);
if
(
!
arg
.
same_as
(
new_arg
))
{
call_args
.
Set
(
i
,
new_arg
);
unchanged
=
false
;
}
}
}
if
(
unchanged
)
return
ref_call
;
return
CallNode
::
make
(
new_op
,
call_args
,
call_node
->
attrs
,
call_node
->
type_args
);
}
};
Expr
ForwardRewrite
(
const
Expr
&
expr
,
const
std
::
string
&
rewrite_map_name
,
std
::
function
<
NodeRef
(
const
Call
&
)
>
fcontext
)
{
auto
rewrite_map
=
Op
::
GetAttr
<
FForwardRewrite
>
(
rewrite_map_name
);
return
ForwardRewriter
(
rewrite_map
,
fcontext
).
Rewrite
(
expr
);
}
}
// namespace relay
}
// namespace tvm
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment