Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
59c70a0e
Unverified
Commit
59c70a0e
authored
Nov 15, 2018
by
Tianqi Chen
Committed by
GitHub
Nov 15, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RELAY][[PASS] Consolidate ForwardRewrite pass. (#2124)
parent
de3b63a4
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
230 additions
and
0 deletions
+230
-0
include/tvm/relay/expr.h
+26
-0
include/tvm/relay/op.h
+40
-0
include/tvm/relay/op_attr_types.h
+19
-0
include/tvm/relay/pass.h
+11
-0
include/tvm/runtime/packed_func.h
+2
-0
src/relay/pass/fold_scale_axis.cc
+0
-0
src/relay/pass/forward_rewrite.cc
+132
-0
No files found.
include/tvm/relay/expr.h
View file @
59c70a0e
...
@@ -415,6 +415,32 @@ class TupleGetItemNode : public ExprNode {
...
@@ -415,6 +415,32 @@ class TupleGetItemNode : public ExprNode {
RELAY_DEFINE_NODE_REF
(
TupleGetItem
,
TupleGetItemNode
,
Expr
);
RELAY_DEFINE_NODE_REF
(
TupleGetItem
,
TupleGetItemNode
,
Expr
);
/*!
* \brief Base class of the temporary expression.
*
* TempExprs are pass specific expression that can be
* useful to define intermediate result in the
* rewriting pass such as layout or type transformation.
*
* Subclass TempExprNode allows us to pattern match on
* specific kind TempExpr and use them for expression rewriting.
*
* TempExpr should only be used within a pass,
*/
class
TempExprNode
:
public
ExprNode
{
public
:
/*!
* \brief Convert the expression to a normal(non-temp) Expr.
* \return The corresponding normal(non-temp) expression.
*/
virtual
Expr
Realize
()
const
=
0
;
static
constexpr
const
char
*
_type_key
=
"relay.TempExpr"
;
TVM_DECLARE_BASE_NODE_INFO
(
TempExprNode
,
ExprNode
);
};
RELAY_DEFINE_NODE_REF
(
TempExpr
,
TempExprNode
,
Expr
);
// implementataions
// implementataions
template
<
typename
TTypeNode
>
template
<
typename
TTypeNode
>
inline
const
TTypeNode
*
ExprNode
::
type_as
()
const
{
inline
const
TTypeNode
*
ExprNode
::
type_as
()
const
{
...
...
include/tvm/relay/op.h
View file @
59c70a0e
...
@@ -276,6 +276,16 @@ class GenericOpMap {
...
@@ -276,6 +276,16 @@ class GenericOpMap {
*/
*/
template
<
typename
ValueType
>
template
<
typename
ValueType
>
inline
ValueType
get
(
const
Op
&
op
,
ValueType
def_value
)
const
;
inline
ValueType
get
(
const
Op
&
op
,
ValueType
def_value
)
const
;
/*!
* \brief get the corresponding value element at op with default value.
* \param expr The key to the map
* \param def_value The default value when the key does not exist
* or if expr is not an Op.
* \return the const reference to the content value.
* \tparam ValueType The content value type.
*/
template
<
typename
ValueType
>
inline
ValueType
get
(
const
Expr
&
expr
,
ValueType
def_value
)
const
;
private
:
private
:
friend
class
OpRegistry
;
friend
class
OpRegistry
;
...
@@ -313,6 +323,14 @@ class OpMap {
...
@@ -313,6 +323,14 @@ class OpMap {
* \return the const reference to the content value.
* \return the const reference to the content value.
*/
*/
inline
ValueType
get
(
const
Op
&
op
,
ValueType
def_value
)
const
;
inline
ValueType
get
(
const
Op
&
op
,
ValueType
def_value
)
const
;
/*!
* \brief get the corresponding value element at op with default value.
* \param expr The key to the map
* \param def_value The default value when the key does not exist
* or if expr is not an Op.
* \return the const reference to the content value.
*/
inline
ValueType
get
(
const
Expr
&
expr
,
ValueType
def_value
)
const
;
private
:
private
:
friend
class
Op
;
friend
class
Op
;
...
@@ -497,6 +515,21 @@ inline ValueType GenericOpMap::get(const Op& op, ValueType value) const {
...
@@ -497,6 +515,21 @@ inline ValueType GenericOpMap::get(const Op& op, ValueType value) const {
}
}
template
<
typename
ValueType
>
template
<
typename
ValueType
>
inline
ValueType
GenericOpMap
::
get
(
const
Expr
&
expr
,
ValueType
value
)
const
{
CHECK
(
expr
.
defined
());
if
(
const
OpNode
*
op
=
expr
.
as
<
OpNode
>
())
{
const
uint32_t
idx
=
op
->
index_
;
if
(
idx
<
data_
.
size
()
&&
data_
[
idx
].
second
!=
0
)
{
return
data_
[
idx
].
first
;
}
else
{
return
value
;
}
}
else
{
return
value
;
}
}
template
<
typename
ValueType
>
inline
int
OpMap
<
ValueType
>::
count
(
const
Op
&
op
)
const
{
inline
int
OpMap
<
ValueType
>::
count
(
const
Op
&
op
)
const
{
return
map_
.
count
(
op
);
return
map_
.
count
(
op
);
}
}
...
@@ -505,12 +538,19 @@ template <typename ValueType>
...
@@ -505,12 +538,19 @@ template <typename ValueType>
inline
ValueType
OpMap
<
ValueType
>::
operator
[](
const
Op
&
op
)
const
{
inline
ValueType
OpMap
<
ValueType
>::
operator
[](
const
Op
&
op
)
const
{
return
map_
[
op
];
return
map_
[
op
];
}
}
template
<
typename
ValueType
>
template
<
typename
ValueType
>
inline
ValueType
OpMap
<
ValueType
>::
get
(
const
Op
&
op
,
inline
ValueType
OpMap
<
ValueType
>::
get
(
const
Op
&
op
,
ValueType
def_value
)
const
{
ValueType
def_value
)
const
{
return
map_
.
get
<
ValueType
>
(
op
,
def_value
);
return
map_
.
get
<
ValueType
>
(
op
,
def_value
);
}
}
template
<
typename
ValueType
>
inline
ValueType
OpMap
<
ValueType
>::
get
(
const
Expr
&
expr
,
ValueType
def_value
)
const
{
return
map_
.
get
<
ValueType
>
(
expr
,
def_value
);
}
/*!
/*!
* \brief Check that an expression is a "primtive operator".
* \brief Check that an expression is a "primtive operator".
*
*
...
...
include/tvm/relay/op_attr_types.h
View file @
59c70a0e
...
@@ -85,6 +85,25 @@ using FTVMSchedule = runtime::TypedPackedFunc<
...
@@ -85,6 +85,25 @@ using FTVMSchedule = runtime::TypedPackedFunc<
Schedule
(
const
Attrs
&
attrs
,
Schedule
(
const
Attrs
&
attrs
,
const
Array
<
Tensor
>&
outs
,
const
Array
<
Tensor
>&
outs
,
const
Target
&
target
)
>
;
const
Target
&
target
)
>
;
/*!
* \brief Forward rewriting rule for a specific op.
*
* \param ref_call The reference old call type to be rewritten.
* We can make use of the op and type information.
* \param new_args The new arguments (some of them could be TempExpr).
* \param ctx Optional context information about ref_call.
* \return The rewriten result call, can also return nullptr,
* which indicate the rewriter should use the default fallback
* rule that realizes all its input and compose the call.
*
* \note When we register the function, we can register
* a different signature with ctx to be a specific node type.
*/
using
FForwardRewrite
=
runtime
::
TypedPackedFunc
<
Expr
(
const
Call
&
ref_call
,
const
Array
<
Expr
>&
new_args
,
const
NodeRef
&
ctx
)
>
;
}
// namespace relay
}
// namespace relay
}
// namespace tvm
}
// namespace tvm
#endif // TVM_RELAY_OP_ATTR_TYPES_H_
#endif // TVM_RELAY_OP_ATTR_TYPES_H_
include/tvm/relay/pass.h
View file @
59c70a0e
...
@@ -158,6 +158,17 @@ Expr FoldConstant(const Expr& expr);
...
@@ -158,6 +158,17 @@ Expr FoldConstant(const Expr& expr);
*/
*/
Expr
FuseOps
(
const
Expr
&
expr
,
int
fuse_opt_level
);
Expr
FuseOps
(
const
Expr
&
expr
,
int
fuse_opt_level
);
/*!
* \brief Apply rewrite rules to rewrite the expr in post DFS order.
* \param expr The expression.
* \param rewrite_map_attr_name The Op's attr name which corresponds to the rewrite
* rule function.
* \param fcontext Additional callback to provide context argument for each call node.
* \return The rewritten expression.
*/
Expr
ForwardRewrite
(
const
Expr
&
expr
,
const
std
::
string
&
rewrite_map_attr_name
,
std
::
function
<
NodeRef
(
const
Call
&
)
>
fcontext
=
nullptr
);
/*! \brief A hashing structure in the style of std::hash. */
/*! \brief A hashing structure in the style of std::hash. */
struct
StructuralHash
{
struct
StructuralHash
{
...
...
include/tvm/runtime/packed_func.h
View file @
59c70a0e
...
@@ -73,6 +73,8 @@ class PackedFunc {
...
@@ -73,6 +73,8 @@ class PackedFunc {
using
FType
=
std
::
function
<
void
(
TVMArgs
args
,
TVMRetValue
*
rv
)
>
;
using
FType
=
std
::
function
<
void
(
TVMArgs
args
,
TVMRetValue
*
rv
)
>
;
/*! \brief default constructor */
/*! \brief default constructor */
PackedFunc
()
{}
PackedFunc
()
{}
/*! \brief constructor from null */
PackedFunc
(
std
::
nullptr_t
null
)
{}
// NOLINT(*)
/*!
/*!
* \brief constructing a packed function from a std::function.
* \brief constructing a packed function from a std::function.
* \param body the internal container of packed function.
* \param body the internal container of packed function.
...
...
src/relay/pass/fold_scale_axis.cc
View file @
59c70a0e
This diff is collapsed.
Click to expand it.
src/relay/pass/forward_rewrite.cc
0 → 100644
View file @
59c70a0e
/*!
* Copyright (c) 2018 by Contributors
*
* \file forward_rewrite.cc
* \brief Apply rewriting rules in a forward fashion.
*/
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h>
namespace
tvm
{
namespace
relay
{
// Realizer class that realizes the expression
// Note that we can take benefit of its internal memo
// so that calling realize repeatively won't hurt perf.
class
TempRealizer
:
private
ExprMutator
{
public
:
Expr
Realize
(
Expr
expr
)
{
return
VisitExpr
(
expr
);
}
private
:
Expr
VisitExpr
(
const
Expr
&
expr
)
final
{
auto
it
=
memo_
.
find
(
expr
);
if
(
it
!=
memo_
.
end
())
{
return
it
->
second
;
}
else
{
Expr
res
;
if
(
const
auto
*
temp
=
expr
.
as_derived
<
TempExprNode
>
())
{
res
=
temp
->
Realize
();
}
else
{
res
=
ExprFunctor
::
VisitExpr
(
expr
);
}
memo_
[
res
]
=
res
;
return
res
;
}
}
};
class
ForwardRewriter
:
private
ExprMutator
{
public
:
ForwardRewriter
(
const
OpMap
<
FForwardRewrite
>&
rewrite_map
,
std
::
function
<
NodeRef
(
const
Call
&
)
>
fcontext
)
:
rewrite_map_
(
rewrite_map
),
fcontext_
(
fcontext
)
{
}
// Transform expression.
Expr
Rewrite
(
Expr
expr
)
{
return
this
->
VisitExpr
(
expr
);
}
private
:
// The rewrite rule.
const
OpMap
<
FForwardRewrite
>&
rewrite_map_
;
// The context.
std
::
function
<
NodeRef
(
const
Call
&
)
>
fcontext_
{
nullptr
};
// internal realizer
TempRealizer
realizer_
;
Expr
VisitExpr
(
const
Expr
&
expr
)
final
{
// by default always realize.
return
realizer_
.
Realize
(
ExprMutator
::
VisitExpr
(
expr
));
}
// Visit and allow non-realized version.
Expr
GetTempExpr
(
const
Expr
&
expr
)
{
return
ExprMutator
::
VisitExpr
(
expr
);
}
// Automatic fold TupleGetItem.
Expr
VisitExpr_
(
const
TupleGetItemNode
*
op
)
final
{
Expr
tuple
=
this
->
GetTempExpr
(
op
->
tuple
);
if
(
const
auto
*
ptuple
=
tuple
.
as
<
TupleNode
>
())
{
return
ptuple
->
fields
[
op
->
index
];
}
else
{
if
(
tuple
.
same_as
(
op
->
tuple
))
{
return
GetRef
<
Expr
>
(
op
);
}
else
{
return
TupleGetItemNode
::
make
(
tuple
,
op
->
index
);
}
}
}
Expr
VisitExpr_
(
const
CallNode
*
call_node
)
final
{
const
Call
&
ref_call
=
GetRef
<
Call
>
(
call_node
);
PackedFunc
frewrite
=
rewrite_map_
.
get
(
call_node
->
op
,
nullptr
);
auto
new_op
=
this
->
Mutate
(
call_node
->
op
);
bool
unchanged
=
call_node
->
op
.
same_as
(
new_op
);
Array
<
Expr
>
call_args
;
for
(
auto
arg
:
call_node
->
args
)
{
Expr
new_arg
=
this
->
GetTempExpr
(
arg
);
if
(
frewrite
==
nullptr
)
{
new_arg
=
realizer_
.
Realize
(
new_arg
);
}
unchanged
&=
new_arg
.
same_as
(
arg
);
call_args
.
push_back
(
new_arg
);
}
// try to rewrite.
if
(
frewrite
!=
nullptr
)
{
Expr
res
=
frewrite
(
ref_call
,
call_args
,
fcontext_
!=
nullptr
?
fcontext_
(
ref_call
)
:
NodeRef
(
nullptr
));
if
(
res
.
defined
())
return
res
;
// abort, use old rule
for
(
size_t
i
=
0
;
i
<
call_args
.
size
();
++
i
)
{
Expr
arg
=
call_args
[
i
];
Expr
new_arg
=
realizer_
.
Realize
(
arg
);
if
(
!
arg
.
same_as
(
new_arg
))
{
call_args
.
Set
(
i
,
new_arg
);
unchanged
=
false
;
}
}
}
if
(
unchanged
)
return
ref_call
;
return
CallNode
::
make
(
new_op
,
call_args
,
call_node
->
attrs
,
call_node
->
type_args
);
}
};
Expr
ForwardRewrite
(
const
Expr
&
expr
,
const
std
::
string
&
rewrite_map_name
,
std
::
function
<
NodeRef
(
const
Call
&
)
>
fcontext
)
{
auto
rewrite_map
=
Op
::
GetAttr
<
FForwardRewrite
>
(
rewrite_map_name
);
return
ForwardRewriter
(
rewrite_map
,
fcontext
).
Rewrite
(
expr
);
}
}
// namespace relay
}
// namespace tvm
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment