Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
a3530f8f
Unverified
Commit
a3530f8f
authored
Nov 26, 2018
by
Tianqi Chen
Committed by
GitHub
Nov 26, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RELAY] Add multiref trigger to ForwardRewrite (#2168)
parent
0a1f3d41
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
57 additions
and
6 deletions
+57
-6
include/tvm/relay/pass.h
+4
-1
python/tvm/relay/expr.py
+17
-0
src/relay/ir/expr.cc
+8
-0
src/relay/pass/forward_rewrite.cc
+28
-5
No files found.
include/tvm/relay/pass.h
View file @
a3530f8f
...
...
@@ -164,11 +164,14 @@ Expr FuseOps(const Expr& expr, int fuse_opt_level);
* \param rewrite_map_attr_name The Op's attr name which corresponds to the rewrite
* rule function.
* \param fcontext Additional callback to provide context argument for each call node.
* \param fmulti_ref_trigger Transformation function to be called when
* an Expr consumed by multiple callers.
* \return The rewritten expression.
*/
Expr
ForwardRewrite
(
const
Expr
&
expr
,
const
std
::
string
&
rewrite_map_attr_name
,
std
::
function
<
NodeRef
(
const
Call
&
)
>
fcontext
=
nullptr
);
std
::
function
<
NodeRef
(
const
Call
&
)
>
fcontext
=
nullptr
,
std
::
function
<
Expr
(
const
Expr
&
)
>
fmulti_ref_trigger
=
nullptr
);
/*! \brief A hashing structure in the style of std::hash. */
struct
StructuralHash
{
...
...
python/tvm/relay/expr.py
View file @
a3530f8f
...
...
@@ -320,6 +320,23 @@ class TupleGetItem(Expr):
_make
.
TupleGetItem
,
tuple_value
,
index
)
class
TempExpr
(
Expr
):
"""Baseclass of all TempExpr.
TempExprs are pass specific expression that can be
useful to define intermediate result in the
rewriting pass such as layout or type transformation.
"""
def
realize
(
self
):
"""Convert the expression to a normal(non-temp) Expr.
Returns
-------
The corresponding normal expression.
"""
return
_expr
.
TempExprRealize
(
self
)
class
ExprFunctor
(
object
):
"""
An abstract visitor defined over Expr.
...
...
src/relay/ir/expr.cc
View file @
a3530f8f
...
...
@@ -258,5 +258,13 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
p
->
stream
<<
"TupleGetItemNode("
<<
node
->
tuple
<<
", "
<<
node
->
index
<<
")"
;
});
TVM_REGISTER_API
(
"relay._expr.TempExprRealize"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
TempExpr
temp
=
args
[
0
];
*
ret
=
temp
->
Realize
();
});
}
// namespace relay
}
// namespace tvm
src/relay/pass/forward_rewrite.cc
View file @
a3530f8f
...
...
@@ -7,6 +7,7 @@
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h>
#include "pass_util.h"
namespace
tvm
{
namespace
relay
{
...
...
@@ -42,13 +43,18 @@ class TempRealizer : private ExprMutator {
class
ForwardRewriter
:
private
ExprMutator
{
public
:
ForwardRewriter
(
const
OpMap
<
FForwardRewrite
>&
rewrite_map
,
std
::
function
<
NodeRef
(
const
Call
&
)
>
fcontext
)
std
::
function
<
NodeRef
(
const
Call
&
)
>
fcontext
,
std
::
function
<
Expr
(
const
Expr
&
)
>
fmulti_ref_trigger
)
:
rewrite_map_
(
rewrite_map
),
fcontext_
(
fcontext
)
{
fcontext_
(
fcontext
),
fmulti_ref_trigger_
(
fmulti_ref_trigger
)
{
}
// Transform expression.
Expr
Rewrite
(
Expr
expr
)
{
if
(
fmulti_ref_trigger_
!=
nullptr
)
{
ref_counter_
=
GetExprRefCount
(
expr
);
}
return
this
->
VisitExpr
(
expr
);
}
...
...
@@ -57,6 +63,10 @@ class ForwardRewriter : private ExprMutator {
const
OpMap
<
FForwardRewrite
>&
rewrite_map_
;
// The context.
std
::
function
<
NodeRef
(
const
Call
&
)
>
fcontext_
{
nullptr
};
// The multiple reference trigger
std
::
function
<
Expr
(
const
Expr
&
)
>
fmulti_ref_trigger_
{
nullptr
};
// Internal ref counter
std
::
unordered_map
<
const
Node
*
,
size_t
>
ref_counter_
;
// internal realizer
TempRealizer
realizer_
;
...
...
@@ -67,7 +77,17 @@ class ForwardRewriter : private ExprMutator {
// Visit and allow non-realized version.
Expr
GetTempExpr
(
const
Expr
&
expr
)
{
return
ExprMutator
::
VisitExpr
(
expr
);
if
(
fmulti_ref_trigger_
!=
nullptr
)
{
Expr
ret
=
ExprMutator
::
VisitExpr
(
expr
);
auto
it
=
ref_counter_
.
find
(
expr
.
get
());
CHECK
(
it
!=
ref_counter_
.
end
());
if
(
it
->
second
>
1
)
{
ret
=
fmulti_ref_trigger_
(
ret
);
}
return
ret
;
}
else
{
return
ExprMutator
::
VisitExpr
(
expr
);
}
}
// Automatic fold TupleGetItem.
...
...
@@ -124,9 +144,12 @@ class ForwardRewriter : private ExprMutator {
Expr
ForwardRewrite
(
const
Expr
&
expr
,
const
std
::
string
&
rewrite_map_name
,
std
::
function
<
NodeRef
(
const
Call
&
)
>
fcontext
)
{
std
::
function
<
NodeRef
(
const
Call
&
)
>
fcontext
,
std
::
function
<
Expr
(
const
Expr
&
)
>
fmulti_ref_trigger
)
{
auto
rewrite_map
=
Op
::
GetAttr
<
FForwardRewrite
>
(
rewrite_map_name
);
return
ForwardRewriter
(
rewrite_map
,
fcontext
).
Rewrite
(
expr
);
return
ForwardRewriter
(
rewrite_map
,
fcontext
,
fmulti_ref_trigger
).
Rewrite
(
expr
);
}
}
// namespace relay
}
// namespace tvm
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment