Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
4369b7f6
Unverified
Commit
4369b7f6
authored
Nov 13, 2018
by
Tianqi Chen
Committed by
GitHub
Nov 13, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RELAY][PASS] General OpFusion. (#2090)
parent
e470f8ea
Show whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
1022 additions
and
109 deletions
+1022
-109
include/tvm/relay/expr.h
+10
-0
include/tvm/runtime/packed_func.h
+2
-0
python/tvm/relay/base.py
+7
-2
python/tvm/relay/build_module.py
+3
-1
python/tvm/relay/expr.py
+1
-2
python/tvm/relay/ir_pass.py
+5
-2
src/common/arena.h
+57
-1
src/relay/backend/compile_engine.cc
+23
-0
src/relay/ir/text_printer.cc
+19
-7
src/relay/pass/fold_scale_axis.cc
+1
-17
src/relay/pass/fuse_ops.cc
+721
-22
src/relay/pass/pass_util.h
+27
-0
src/relay/pass/type_infer.cc
+22
-1
src/relay/pass/type_solver.cc
+3
-3
src/relay/pass/type_solver.h
+3
-48
src/relay/pass/util.cc
+18
-0
tests/python/relay/test_ir_text_printer.py
+1
-0
tests/python/relay/test_pass_fold_scale_axis.py
+8
-0
tests/python/relay/test_pass_fuse_ops.py
+91
-3
No files found.
include/tvm/relay/expr.h
View file @
4369b7f6
...
@@ -429,6 +429,16 @@ inline const TTypeNode* ExprNode::type_as() const {
...
@@ -429,6 +429,16 @@ inline const TTypeNode* ExprNode::type_as() const {
return
node
;
return
node
;
}
}
/*!
* \brief Print node as text format.
* \param node The node to be printed.
* \param annotate An optional callback function for attaching
* additional comment block to an expr.
* \return The text representation.
*/
std
::
string
RelayPrint
(
const
NodeRef
&
node
,
runtime
::
TypedPackedFunc
<
std
::
string
(
Expr
)
>
annotate
=
nullptr
);
}
// namespace relay
}
// namespace relay
}
// namespace tvm
}
// namespace tvm
#endif // TVM_RELAY_EXPR_H_
#endif // TVM_RELAY_EXPR_H_
include/tvm/runtime/packed_func.h
View file @
4369b7f6
...
@@ -161,6 +161,8 @@ class TypedPackedFunc<R(Args...)> {
...
@@ -161,6 +161,8 @@ class TypedPackedFunc<R(Args...)> {
using
TSelf
=
TypedPackedFunc
<
R
(
Args
...)
>
;
using
TSelf
=
TypedPackedFunc
<
R
(
Args
...)
>
;
/*! \brief default constructor */
/*! \brief default constructor */
TypedPackedFunc
()
{}
TypedPackedFunc
()
{}
/*! \brief constructor from null */
TypedPackedFunc
(
std
::
nullptr_t
null
)
{}
// NOLINT(*)
/*!
/*!
* \brief construct by wrap a PackedFunc
* \brief construct by wrap a PackedFunc
*
*
...
...
python/tvm/relay/base.py
View file @
4369b7f6
...
@@ -22,15 +22,20 @@ def register_relay_node(type_key=None):
...
@@ -22,15 +22,20 @@ def register_relay_node(type_key=None):
class
RelayNode
(
NodeBase
):
class
RelayNode
(
NodeBase
):
def
astext
(
self
):
"""Base class of all relay node."""
def
astext
(
self
,
annotate
=
None
):
"""Get the text format of the expression.
"""Get the text format of the expression.
Returns
Returns
-------
-------
text : str
text : str
The text format of the expression.
The text format of the expression.
annotate: Optional[relay.Expr->str]
Optional annotate function to provide additional
information in the comment block.
"""
"""
return
_expr
.
_text_print
(
self
)
return
_expr
.
RelayPrint
(
self
,
annotate
)
@register_relay_node
@register_relay_node
...
...
python/tvm/relay/build_module.py
View file @
4369b7f6
...
@@ -173,11 +173,13 @@ def build(func,
...
@@ -173,11 +173,13 @@ def build(func,
else
:
else
:
tophub_context
=
autotvm
.
util
.
EmptyContext
()
tophub_context
=
autotvm
.
util
.
EmptyContext
()
cfg
=
BuildConfig
.
current
with
tophub_context
:
with
tophub_context
:
func
=
optimize
(
func
)
func
=
optimize
(
func
)
# Fuse ops before running code gen
# Fuse ops before running code gen
func
=
ir_pass
.
infer_type
(
func
)
func
=
ir_pass
.
infer_type
(
func
)
func
=
ir_pass
.
fuse_ops
(
func
)
func
=
ir_pass
.
fuse_ops
(
func
,
cfg
.
opt_level
)
# Graph code generation
# Graph code generation
func
=
ir_pass
.
infer_type
(
func
)
func
=
ir_pass
.
infer_type
(
func
)
graph_gen
=
_graph_gen
.
GraphRuntimeCodegen
(
mod
=
None
,
target
=
target
)
graph_gen
=
_graph_gen
.
GraphRuntimeCodegen
(
mod
=
None
,
target
=
target
)
...
...
python/tvm/relay/expr.py
View file @
4369b7f6
...
@@ -6,7 +6,6 @@ from numbers import Number as _Number
...
@@ -6,7 +6,6 @@ from numbers import Number as _Number
import
numpy
as
_np
import
numpy
as
_np
from
.base
import
RelayNode
,
register_relay_node
from
.base
import
RelayNode
,
register_relay_node
from
.
import
_make
from
.
import
_make
from
.
import
_expr
from
.
import
ty
as
_ty
from
.
import
ty
as
_ty
from
.._ffi
import
base
as
_base
from
.._ffi
import
base
as
_base
from
..
import
nd
as
_nd
from
..
import
nd
as
_nd
...
@@ -477,7 +476,7 @@ class TupleWrapper(object):
...
@@ -477,7 +476,7 @@ class TupleWrapper(object):
text : str
text : str
The text format of the tuple expression.
The text format of the tuple expression.
"""
"""
return
_expr
.
_text_print
(
self
.
tuple_value
)
return
self
.
tuple_value
.
astext
(
)
def
__getitem__
(
self
,
index
):
def
__getitem__
(
self
,
index
):
if
index
>=
len
(
self
):
if
index
>=
len
(
self
):
...
...
python/tvm/relay/ir_pass.py
View file @
4369b7f6
...
@@ -259,7 +259,7 @@ def structural_hash(value):
...
@@ -259,7 +259,7 @@ def structural_hash(value):
raise
TypeError
(
msg
)
raise
TypeError
(
msg
)
def
fuse_ops
(
expr
):
def
fuse_ops
(
expr
,
opt_level
=
1
):
"""Fuse operators in expr together.
"""Fuse operators in expr together.
Parameters
Parameters
...
@@ -267,9 +267,12 @@ def fuse_ops(expr):
...
@@ -267,9 +267,12 @@ def fuse_ops(expr):
expr : tvm.relay.Expr
expr : tvm.relay.Expr
The input expression.
The input expression.
opt_level : int
The level of fuse optimization.
Returns
Returns
-------
-------
transformed_expr : tvm.relay.Expr
transformed_expr : tvm.relay.Expr
Transformed expression, containing fused result.
Transformed expression, containing fused result.
"""
"""
return
_ir_pass
.
FuseOps
(
expr
)
return
_ir_pass
.
FuseOps
(
expr
,
opt_level
)
src/common/arena.h
View file @
4369b7f6
...
@@ -38,11 +38,29 @@ class Arena {
...
@@ -38,11 +38,29 @@ class Arena {
/*!
/*!
* \brief Allocate a space from Arena for type T
* \brief Allocate a space from Arena for type T
* \param T the data type to be allocated
* \param T the data type to be allocated
* \note The space of T is not initialized.
*/
*/
template
<
typename
T
>
template
<
typename
T
>
T
*
Alloc
()
{
T
*
allocate_
()
{
return
static_cast
<
T
*>
(
Alloc
(
sizeof
(
T
),
alignof
(
T
)));
return
static_cast
<
T
*>
(
Alloc
(
sizeof
(
T
),
alignof
(
T
)));
}
}
/*!
* \brief Create a new instance of type T.
* \param args The constructor argument.
* \tparam T the type to be created.
* \tparam Args Arguments to the constructor.
*
* \return The allocated object.
* \note The type T must be simple type, or only contain
* memory allocated from the same arena.
* Otherwise the destructor needs to be called explicitly.
*/
template
<
typename
T
,
typename
...
Args
>
T
*
make
(
Args
&&
...
args
)
{
T
*
ptr
=
allocate_
<
T
>
();
new
(
ptr
)
T
(
std
::
forward
<
Args
>
(
args
)...);
return
ptr
;
}
private
:
private
:
// page size 16 KB
// page size 16 KB
...
@@ -87,6 +105,44 @@ class Arena {
...
@@ -87,6 +105,44 @@ class Arena {
}
}
};
};
/*!
* \brief Link list node
* \tparam T the content data type
*/
template
<
typename
T
>
struct
LinkNode
{
/*! \brief The content value */
T
value
;
/*! \brief pointer to the next location */
LinkNode
<
T
>*
next
{
nullptr
};
};
/*!
* \brief LinkedList structure
* \tparam T the content data type
* \note This is a simple data structure that can be used together with the arena.
* \sa LinkNode
*/
template
<
typename
T
>
struct
LinkedList
{
/*! \brief Head pointer */
LinkNode
<
T
>*
head
{
nullptr
};
/*! \brief Tail pointer */
LinkNode
<
T
>*
tail
{
nullptr
};
/*!
* \brief Push a new node to the end of the linked list.
* \param node The node to be pushed.
*/
void
Push
(
LinkNode
<
T
>*
node
)
{
node
->
next
=
nullptr
;
if
(
this
->
tail
!=
nullptr
)
{
this
->
tail
->
next
=
node
;
this
->
tail
=
node
;
}
else
{
head
=
tail
=
node
;
}
}
};
}
// namespace common
}
// namespace common
}
// namespace tvm
}
// namespace tvm
#endif // TVM_COMMON_ARENA_H_
#endif // TVM_COMMON_ARENA_H_
src/relay/backend/compile_engine.cc
View file @
4369b7f6
...
@@ -109,6 +109,29 @@ class ScheduleGetter :
...
@@ -109,6 +109,29 @@ class ScheduleGetter :
return
{};
return
{};
}
}
Array
<
Tensor
>
VisitExpr_
(
const
ConstantNode
*
op
)
final
{
CHECK
(
op
->
is_scalar
());
void
*
data
=
op
->
data
->
data
;
DataType
dtype
=
TVMType2Type
(
op
->
data
->
dtype
);
Tensor
value
=
tvm
::
compute
({},
[
&
](
const
Array
<
tvm
::
Var
>&
)
{
if
(
dtype
==
Int
(
32
))
{
return
make_const
(
dtype
,
static_cast
<
const
int32_t
*>
(
data
)[
0
]);
}
else
if
(
dtype
==
Int
(
64
))
{
return
make_const
(
dtype
,
static_cast
<
const
int64_t
*>
(
data
)[
0
]);
}
else
if
(
dtype
==
Float
(
32
))
{
return
make_const
(
dtype
,
static_cast
<
const
float
*>
(
data
)[
0
]);
}
else
if
(
dtype
==
Float
(
64
))
{
return
make_const
(
dtype
,
static_cast
<
const
double
*>
(
data
)[
0
]);
}
else
if
(
dtype
==
Bool
())
{
return
make_const
(
dtype
,
static_cast
<
const
uint8_t
*>
(
data
)[
0
]);
}
else
{
LOG
(
FATAL
)
<<
"not handled"
;
return
tvm
::
Expr
();
}
});
return
{
value
};
}
Array
<
Tensor
>
VisitExpr_
(
const
CallNode
*
call_node
)
final
{
Array
<
Tensor
>
VisitExpr_
(
const
CallNode
*
call_node
)
final
{
static
auto
fcompute
=
static
auto
fcompute
=
Op
::
GetAttr
<
FTVMCompute
>
(
"FTVMCompute"
);
Op
::
GetAttr
<
FTVMCompute
>
(
"FTVMCompute"
);
...
...
src/relay/ir/text_printer.cc
View file @
4369b7f6
...
@@ -125,6 +125,8 @@ class TextPrinter :
...
@@ -125,6 +125,8 @@ class TextPrinter :
public
TypeFunctor
<
void
(
const
Type
&
,
std
::
ostream
&
os
)
>
,
// NOLINT(*)
public
TypeFunctor
<
void
(
const
Type
&
,
std
::
ostream
&
os
)
>
,
// NOLINT(*)
public
AttrFunctor
<
void
(
const
NodeRef
&
,
std
::
ostream
&
os
)
>
{
// NOLINT(*)
public
AttrFunctor
<
void
(
const
NodeRef
&
,
std
::
ostream
&
os
)
>
{
// NOLINT(*)
public:
public:
explicit
TextPrinter
(
runtime
::
TypedPackedFunc
<
std
::
string
(
Expr
)
>
annotate
)
:
annotate_
(
annotate
)
{}
/*!
/*!
* \brief Print a node to string.
* \brief Print a node to string.
* \param node.
* \param node.
...
@@ -279,11 +281,11 @@ class TextPrinter :
...
@@ -279,11 +281,11 @@ class TextPrinter :
TextValue
VisitExpr_
(
const
CallNode
*
op
)
final
{
TextValue
VisitExpr_
(
const
CallNode
*
op
)
final
{
// possibly through meta-data
// possibly through meta-data
TextValue
call_op
=
GetValue
(
op
->
op
);
std
::
vector
<
TextValue
>
args
;
std
::
vector
<
TextValue
>
args
;
for
(
Expr
arg
:
op
->
args
)
{
for
(
Expr
arg
:
op
->
args
)
{
args
.
emplace_back
(
GetValue
(
arg
));
args
.
emplace_back
(
GetValue
(
arg
));
}
}
TextValue
call_op
=
GetValue
(
op
->
op
);
TextValue
id
=
this
->
AllocTempVar
();
TextValue
id
=
this
->
AllocTempVar
();
this
->
PrintIndent
();
this
->
PrintIndent
();
...
@@ -532,7 +534,9 @@ class TextPrinter :
...
@@ -532,7 +534,9 @@ class TextPrinter :
*/
*/
void
PrintOptionalInfo
(
const
Expr
&
expr
)
{
void
PrintOptionalInfo
(
const
Expr
&
expr
)
{
// additional information in comment.
// additional information in comment.
if
(
expr
->
checked_type_
.
defined
())
{
if
(
annotate_
!=
nullptr
)
{
stream_
<<
" # "
<<
annotate_
(
expr
);
}
else
if
(
expr
->
checked_type_
.
defined
())
{
stream_
<<
" # ty="
;
stream_
<<
" # ty="
;
this
->
PrintType
(
expr
->
checked_type
(),
stream_
);
this
->
PrintType
(
expr
->
checked_type
(),
stream_
);
}
}
...
@@ -678,7 +682,10 @@ class TextPrinter :
...
@@ -678,7 +682,10 @@ class TextPrinter :
name
=
"%"
+
name
;
name
=
"%"
+
name
;
}
}
TextValue
val
(
GetUniqueName
(
name
));
TextValue
val
(
GetUniqueName
(
name
));
CHECK
(
!
memo_
.
count
(
var
))
<<
"Duplicated variable "
<<
var
;
// still print if ir is malformed, but show the error.
if
(
memo_
.
count
(
var
))
{
memo_
[
var
]
=
TextValue
(
val
.
name
+
"-malformed-ir"
);
}
memo_
[
var
]
=
val
;
memo_
[
var
]
=
val
;
return
val
;
return
val
;
}
}
...
@@ -686,6 +693,8 @@ class TextPrinter :
...
@@ -686,6 +693,8 @@ class TextPrinter :
private
:
private
:
class
AttrPrinter
;
class
AttrPrinter
;
friend
class
AttrPrinter
;
friend
class
AttrPrinter
;
/*! \brief additional comment function */
runtime
::
TypedPackedFunc
<
std
::
string
(
Expr
)
>
annotate_
;
/*! \brief meta data context */
/*! \brief meta data context */
TextMetaDataContext
meta_
;
TextMetaDataContext
meta_
;
/*! \brief Check whether scope is still valid */
/*! \brief Check whether scope is still valid */
...
@@ -776,12 +785,15 @@ void TextPrinter::PrintCallAttrs(const Expr& op,
...
@@ -776,12 +785,15 @@ void TextPrinter::PrintCallAttrs(const Expr& op,
os
<<
", "
<<
meta_
.
GetMetaNode
(
attrs
);
os
<<
", "
<<
meta_
.
GetMetaNode
(
attrs
);
}
}
std
::
string
RelayPrint
(
const
NodeRef
&
node
)
{
std
::
string
RelayPrint
(
const
NodeRef
&
node
,
return
TextPrinter
().
Print
(
node
);
runtime
::
TypedPackedFunc
<
std
::
string
(
Expr
)
>
annotate
)
{
return
TextPrinter
(
annotate
).
Print
(
node
);
}
}
TVM_REGISTER_API
(
"relay._expr._text_print"
)
TVM_REGISTER_API
(
"relay._expr.RelayPrint"
)
.
set_body_typed
<
std
::
string
(
const
NodeRef
&
)
>
(
RelayPrint
);
.
set_body_typed
<
std
::
string
(
const
NodeRef
&
,
runtime
::
TypedPackedFunc
<
std
::
string
(
Expr
)
>
)
>
(
RelayPrint
);
}
// namespace relay
}
// namespace relay
}
// namespace tvm
}
// namespace tvm
src/relay/pass/fold_scale_axis.cc
View file @
4369b7f6
...
@@ -10,6 +10,7 @@
...
@@ -10,6 +10,7 @@
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/expr_functor.h>
#include "pattern_util.h"
#include "pattern_util.h"
#include "pass_util.h"
#include "../op/nn/layout.h"
#include "../op/nn/layout.h"
namespace
tvm
{
namespace
tvm
{
...
@@ -580,23 +581,6 @@ using FBackwardTransform = TypedPackedFunc<
...
@@ -580,23 +581,6 @@ using FBackwardTransform = TypedPackedFunc<
//----------------------------------------------
//----------------------------------------------
// Generic Visitors for FScaleAxisBackward
// Generic Visitors for FScaleAxisBackward
//----------------------------------------------
//----------------------------------------------
/*!
* \brief Get reference counter of each internal ExprNode in body.
* \param body The body expression.
* \return The reference count mapping.
*/
std
::
unordered_map
<
const
Node
*
,
size_t
>
GetExprRefCount
(
const
Expr
&
body
)
{
class
ExprRefCounter
:
private
ExprVisitor
{
public
:
std
::
unordered_map
<
const
Node
*
,
size_t
>
Get
(
const
Expr
&
body
)
{
this
->
VisitExpr
(
body
);
return
std
::
move
(
this
->
visit_counter_
);
}
};
return
ExprRefCounter
().
Get
(
body
);
}
class
BackwardPrep
:
private
ExprVisitor
{
class
BackwardPrep
:
private
ExprVisitor
{
public
:
public
:
...
...
src/relay/pass/fuse_ops.cc
View file @
4369b7f6
...
@@ -9,13 +9,686 @@
...
@@ -9,13 +9,686 @@
#include <tvm/ir_operator.h>
#include <tvm/ir_operator.h>
#include <tvm/relay/pass.h>
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h>
#include "../../common/arena.h"
namespace
tvm
{
namespace
tvm
{
namespace
relay
{
namespace
relay
{
// Simple fuser that only makes each operator function as primitive.
/*
class
SimpleFuser
:
public
ExprMutator
{
Note on Fusing algorithm:
The main challenge of genenral fusor is to handle possible diamond shape branches,
in the following graph, conv2d can be fused to elemwise add.
conv2d
/ | \
/ | \
op op op
\ | /
\ | /
elemwise add
|
However, at the point of conv2d we do not necessarily know that all its future path
will merge at the elemwise add. The new fusor algorithm applies post-dominator analysis.
The immediate post-dominator of a node defined by the closest node where all the future path goes into.
In the above case, the elemwise add is the post-dominator of conv2d. The general algorithm is as follows:
- Construct a DAG of dataflow graph for dominator analysis
- Construct a post-dominator tree which gives immediate post dominator of each node.
- Run fusion algorithm with the given post-dominator information.
Note that, because we run analysis on a DAG, we use a single pass post-dominator
tree construction algorithm via LCA, which is simpler than the full version that handles cycles.
The fusion algorithm traverses from each node and checks if it can be fused to its
immediate post dominator. It has to check the following things:
- CheckPath: check all the path between a node and its immediate post-dominator
satiesfies the fuse condition.
- Note that these intermediate node can already be fused with another nodes, the algorithm
will still run correctly.
- CommitFuse: mark all the nodes between source and post-dominator as the same group.
- We use an Union-Find data structure to manage the groups.
*/
using
common
::
LinkNode
;
using
common
::
LinkedList
;
/*!
* \brief Indexed data flow graph in forward direction.
* This is a temporary data structure used for operator fusion analysis.
*
* This data structure only captures the dataflow fragement and
* could ignore blocks like let by simply ordering each dataflow block
* and mark the output node as extern_ref;
*/
class
IndexedForwardGraph
{
public
:
struct
Node
;
/*!
* The forward edge in the dataflow graph.
*/
struct
Edge
{
/*! \brief The corresponding node */
Node
*
node
{
nullptr
};
/*! \brief The respective pattern of this op */
OpPatternKind
pattern
{
kOpaque
};
};
/*! \brief A node in the graph. */
struct
Node
{
/*! \brief weak reference to the corresponding edge. */
const
tvm
::
Node
*
ref
{
nullptr
};
/*! \brief The index of the node in topological order. */
size_t
index
{
0
};
/*! \brief Whether this node is referenced by external source */
bool
extern_ref
{
false
};
/*! \brief The general pattern in the node */
OpPatternKind
pattern
{
kOpaque
};
/*! \brief The outputs of the node. */
LinkedList
<
Edge
>
outputs
;
};
/*! \brief The node map that maps node to graph */
std
::
unordered_map
<
const
tvm
::
Node
*
,
Node
*>
node_map
;
/*! \brief All the nodes in post DFS order */
std
::
vector
<
Node
*>
post_dfs_order
;
/*! \brief Dump the graph into string. */
void
DebugDump
()
{
std
::
ostringstream
os
;
for
(
size_t
i
=
0
;
i
<
post_dfs_order
.
size
();
++
i
)
{
Node
*
node
=
post_dfs_order
[
i
];
os
<<
"node["
<<
i
<<
"], "
<<
GetRef
<
NodeRef
>
(
node
->
ref
)
<<
" outputs=["
;
for
(
auto
*
link
=
node
->
outputs
.
head
;
link
!=
nullptr
;
link
=
link
->
next
)
{
os
<<
link
->
value
.
node
->
index
<<
", "
;
}
os
<<
"]
\n
"
;
}
LOG
(
INFO
)
<<
os
.
str
();
}
/*!
* \brief create a indexed forward graph.
* \param arena The arena used for data allocation.
* \param body The body of the expression to create a graph.
*/
static
IndexedForwardGraph
Create
(
common
::
Arena
*
arena
,
const
Expr
&
body
);
private
:
class
Creator
;
};
// Creator of post dominator tree of the dataflow
class
IndexedForwardGraph
::
Creator
:
private
ExprVisitor
{
public
:
explicit
Creator
(
common
::
Arena
*
arena
)
:
arena_
(
arena
)
{}
IndexedForwardGraph
Prepare
(
const
Expr
&
body
)
{
this
->
Update
(
body
,
nullptr
,
kOpaque
);
this
->
VisitExpr
(
body
);
return
std
::
move
(
graph_
);
}
private
:
/*! \brief allocator of all the internal node object */
common
::
Arena
*
arena_
;
// The output.
IndexedForwardGraph
graph_
;
// attribute equal comparator
AttrsEqual
attr_equal_
;
// Update the message stored at the node.
void
Update
(
const
Expr
&
node
,
IndexedForwardGraph
::
Node
*
parent
,
OpPatternKind
pattern
)
{
const
tvm
::
Node
*
key
=
node
.
get
();
IndexedForwardGraph
::
Node
*
current
;
auto
it
=
graph_
.
node_map
.
find
(
key
);
if
(
it
!=
graph_
.
node_map
.
end
())
{
current
=
it
->
second
;
}
else
{
current
=
arena_
->
make
<
IndexedForwardGraph
::
Node
>
();
graph_
.
node_map
[
key
]
=
current
;
}
if
(
parent
!=
nullptr
)
{
auto
*
link
=
arena_
->
make
<
LinkNode
<
IndexedForwardGraph
::
Edge
>
>
();
link
->
value
.
node
=
parent
;
link
->
value
.
pattern
=
pattern
;
current
->
outputs
.
Push
(
link
);
}
else
{
current
->
extern_ref
=
true
;
}
}
void
AddNode
(
const
tvm
::
Node
*
key
)
{
auto
it
=
graph_
.
node_map
.
find
(
key
);
CHECK
(
it
!=
graph_
.
node_map
.
end
())
<<
"Cannot find node "
<<
GetRef
<
NodeRef
>
(
key
);
IndexedForwardGraph
::
Node
*
node
=
it
->
second
;
CHECK
(
node
->
ref
==
nullptr
);
node
->
ref
=
key
;
node
->
index
=
graph_
.
post_dfs_order
.
size
();
graph_
.
post_dfs_order
.
push_back
(
node
);
}
// Post order tree
void
VisitExpr_
(
const
FunctionNode
*
op
)
{
for
(
auto
param
:
op
->
params
)
{
this
->
Update
(
param
,
nullptr
,
kOpaque
);
}
this
->
Update
(
op
->
body
,
nullptr
,
kOpaque
);
ExprVisitor
::
VisitExpr_
(
op
);
}
void
VisitExpr_
(
const
ConstantNode
*
op
)
{
this
->
AddNode
(
op
);
Node
*
node
=
graph_
.
node_map
.
at
(
op
);
DataType
dtype
=
TVMType2Type
(
op
->
data
->
dtype
);
// This rule must be consistent with code generator.
bool
is_simple_const
=
(
dtype
==
Int
(
32
)
||
dtype
==
Int
(
64
)
||
dtype
==
Float
(
32
)
||
dtype
==
Float
(
64
)
||
dtype
==
Bool
());
if
(
op
->
is_scalar
()
&&
is_simple_const
)
{
node
->
pattern
=
kElemWise
;
}
else
{
// for now, mark non-scalar constant
// as opaque, we will not choose to fuse it.
node
->
pattern
=
kOpaque
;
}
}
void
VisitExpr_
(
const
CallNode
*
call
)
{
CHECK
(
graph_
.
node_map
.
count
(
call
));
Node
*
node
=
graph_
.
node_map
.
at
(
call
);
static
auto
fpattern
=
Op
::
GetAttr
<
TOpPattern
>
(
"TOpPattern"
);
// setup pattern.
OpPatternKind
op_pattern
=
kOpaque
;
if
(
const
OpNode
*
opnode
=
call
->
op
.
as
<
OpNode
>
())
{
op_pattern
=
static_cast
<
OpPatternKind
>
(
fpattern
[
GetRef
<
Op
>
(
opnode
)]);
}
node
->
pattern
=
op_pattern
;
const
auto
*
rtype
=
call
->
checked_type
().
as
<
TensorTypeNode
>
();
// pass the message back to all the children it references.
for
(
size_t
i
=
0
;
i
<
call
->
args
.
size
();
++
i
)
{
const
auto
*
arg_type
=
call
->
args
[
i
]
->
checked_type
().
as
<
TensorTypeNode
>
();
// specifically check if result type
OpPatternKind
edge_pattern
=
op_pattern
;
if
(
edge_pattern
==
kBroadcast
&&
arg_type
!=
nullptr
&&
rtype
!=
nullptr
&&
attr_equal_
(
rtype
->
shape
,
arg_type
->
shape
))
{
edge_pattern
=
kElemWise
;
}
this
->
Update
(
call
->
args
[
i
],
node
,
edge_pattern
);
}
ExprVisitor
::
VisitExpr_
(
call
);
this
->
AddNode
(
call
);
}
void
VisitExpr_
(
const
TupleNode
*
op
)
{
for
(
const
Expr
&
field
:
op
->
fields
)
{
this
->
Update
(
field
,
nullptr
,
kOpaque
);
}
ExprVisitor
::
VisitExpr_
(
op
);
this
->
AddNode
(
op
);
}
void
VisitExpr_
(
const
TupleGetItemNode
*
op
)
{
CHECK
(
graph_
.
node_map
.
count
(
op
));
Node
*
node
=
graph_
.
node_map
.
at
(
op
);
this
->
Update
(
op
->
tuple
,
node
,
kOpaque
);
ExprVisitor
::
VisitExpr_
(
op
);
this
->
AddNode
(
op
);
}
void
VisitExpr_
(
const
VarNode
*
op
)
{
this
->
AddNode
(
op
);
}
void
VisitExpr_
(
const
LetNode
*
op
)
{
// do not fuse through let.
this
->
Update
(
op
->
var
,
nullptr
,
kOpaque
);
this
->
Update
(
op
->
value
,
nullptr
,
kOpaque
);
this
->
Update
(
op
->
body
,
nullptr
,
kOpaque
);
ExprVisitor
::
VisitExpr_
(
op
);
this
->
AddNode
(
op
);
}
void
VisitExpr_
(
const
IfNode
*
op
)
{
// do not fuse through if.
this
->
Update
(
op
->
cond
,
nullptr
,
kOpaque
);
this
->
Update
(
op
->
true_branch
,
nullptr
,
kOpaque
);
this
->
Update
(
op
->
false_branch
,
nullptr
,
kOpaque
);
ExprVisitor
::
VisitExpr_
(
op
);
this
->
AddNode
(
op
);
}
};
IndexedForwardGraph
IndexedForwardGraph
::
Create
(
common
::
Arena
*
arena
,
const
Expr
&
body
)
{
return
Creator
(
arena
).
Prepare
(
body
);
}
/*!
* \brief Dominator tree that represent domination or
* post domination relation of the node.
*/
class
DominatorTree
{
public
:
public
:
/*!
* \brief A node in the dominator tree.
*/
struct
Node
{
/*! \brief The node in the tree */
IndexedForwardGraph
::
Node
*
gnode
{
nullptr
};
/*! \brief parent of the tree */
Node
*
parent
{
nullptr
};
/*! \brief current depth*/
int
depth
{
0
};
/*! \brief aggregated pattern to parent */
OpPatternKind
pattern
{
kOpaque
};
};
// index -> node.
std
::
vector
<
Node
*>
nodes
;
/*!
* \brief compute a post dominator relation for a given dataflow graph.
* \param arena The arena used for node allocation.
* \param graph The graph to be analyze.
* \return The dominator tree of the graph.
* \note This algorithm makes use of the fact that graph is DAG,
* and runs a single pass algorithm via LCA.
*/
static
DominatorTree
PostDom
(
common
::
Arena
*
arena
,
const
IndexedForwardGraph
&
graph
);
private
:
// Combine pattern together.
static
OpPatternKind
CombinePattern
(
OpPatternKind
lhs
,
OpPatternKind
rhs
)
{
if
(
lhs
>
rhs
)
return
lhs
;
return
rhs
;
}
/*!
* \brief Find the least common acenstor of the two nodes.
* \param lhs The left node.
* \param rhs The right node.
* \param edge_pattern
* The combined edge pattern across all the parents.
* \return The least common acenstor of thw two.
*/
static
Node
*
LeastCommonAcenstor
(
Node
*
lhs
,
Node
*
rhs
,
OpPatternKind
*
edge_pattern
)
{
while
(
lhs
!=
rhs
)
{
if
(
lhs
==
nullptr
)
return
nullptr
;
if
(
rhs
==
nullptr
)
return
nullptr
;
if
(
lhs
->
depth
<
rhs
->
depth
)
{
edge_pattern
[
0
]
=
CombinePattern
(
edge_pattern
[
0
],
rhs
->
pattern
);
rhs
=
rhs
->
parent
;
}
else
if
(
rhs
->
depth
<
lhs
->
depth
)
{
edge_pattern
[
0
]
=
CombinePattern
(
edge_pattern
[
0
],
lhs
->
pattern
);
lhs
=
lhs
->
parent
;
}
else
{
lhs
=
lhs
->
parent
;
rhs
=
rhs
->
parent
;
edge_pattern
[
0
]
=
CombinePattern
(
edge_pattern
[
0
],
lhs
->
pattern
);
edge_pattern
[
0
]
=
CombinePattern
(
edge_pattern
[
0
],
rhs
->
pattern
);
}
}
return
lhs
;
}
};
DominatorTree
DominatorTree
::
PostDom
(
common
::
Arena
*
arena
,
const
IndexedForwardGraph
&
graph
)
{
DominatorTree
tree
;
tree
.
nodes
.
resize
(
graph
.
post_dfs_order
.
size
(),
nullptr
);
// reverse topo order
for
(
size_t
i
=
graph
.
post_dfs_order
.
size
();
i
!=
0
;
--
i
)
{
size_t
index
=
i
-
1
;
Node
*
tnode
=
arena
->
make
<
Node
>
();
auto
*
gnode
=
graph
.
post_dfs_order
[
index
];
tnode
->
gnode
=
gnode
;
if
(
gnode
->
extern_ref
)
{
tnode
->
depth
=
1
;
tnode
->
parent
=
nullptr
;
tnode
->
pattern
=
kOpaque
;
}
else
{
// find the LCAs of all outputs.
OpPatternKind
pattern
=
kElemWise
;
Node
*
parent
=
nullptr
;
for
(
auto
link
=
gnode
->
outputs
.
head
;
link
!=
nullptr
;
link
=
link
->
next
)
{
size_t
oindex
=
link
->
value
.
node
->
index
;
CHECK_LT
(
oindex
,
tree
.
nodes
.
size
());
Node
*
onode
=
tree
.
nodes
[
oindex
];
CHECK
(
onode
!=
nullptr
);
if
(
parent
!=
nullptr
)
{
parent
=
LeastCommonAcenstor
(
parent
,
onode
,
&
pattern
);
}
else
{
parent
=
onode
;
}
pattern
=
CombinePattern
(
pattern
,
link
->
value
.
pattern
);
}
CHECK
(
parent
!=
nullptr
);
tnode
->
depth
=
parent
->
depth
+
1
;
tnode
->
parent
=
parent
;
tnode
->
pattern
=
pattern
;
}
tree
.
nodes
[
index
]
=
tnode
;
}
return
tree
;
}
/*!
* \brief A partition of the graph marked by union find data structure.
*/
class
GraphPartitioner
{
public
:
explicit
GraphPartitioner
(
common
::
Arena
*
arena
,
int
opt_level
)
:
arena_
(
arena
),
opt_level_
(
opt_level
)
{}
/*!
* \brief Group as a union find data structure.
*/
struct
Group
{
/*! \brief The parent in the union find data structure. */
Group
*
parent
{
nullptr
};
/*! \brief The pattern of the group */
OpPatternKind
pattern
;
/*! \brief reference to the root node. */
const
tvm
::
Node
*
root_ref
{
nullptr
};
/*!
* \brief Reference to the master node,
* this field is not nullptr only if pattern is kOutEWiseFusable.
*/
const
tvm
::
Node
*
master_ref
{
nullptr
};
/*!
* \brief Find the group root, perform path compression
* \return The root type node.
*/
Group
*
FindRoot
()
{
// fast path
if
(
this
->
parent
==
nullptr
)
return
this
;
// slow path with path compression.
Group
*
root
=
this
;
while
(
root
->
parent
!=
nullptr
)
{
root
=
root
->
parent
;
}
for
(
Group
*
p
=
this
;
p
!=
root
;)
{
Group
*
parent
=
p
->
parent
;
p
->
parent
=
root
;
p
=
parent
;
}
return
root
;
}
};
/*!
* \brief Partition a graph.
* \return group assignments of each node.
*/
std
::
vector
<
Group
*>
Partition
(
const
IndexedForwardGraph
&
graph
);
private
:
/*! \brief The internal arena for temporary space. */
common
::
Arena
*
arena_
;
/*! \brief optimization level for fuse operation. */
int
opt_level_
;
/*! \brief The internal groups. */
std
::
vector
<
Group
*>
groups_
;
/*! \brief internal field used for deduplication */
std
::
unordered_set
<
IndexedForwardGraph
::
Node
*>
visited_
;
// Internal implelementation of CheckPath
template
<
typename
F
>
bool
CheckPath_
(
IndexedForwardGraph
::
Node
*
src
,
IndexedForwardGraph
::
Node
*
sink
,
F
fcond
)
{
if
(
visited_
.
count
(
src
))
return
true
;
visited_
.
insert
(
src
);
Group
*
gnode
=
groups_
[
src
->
index
];
CHECK
(
gnode
!=
nullptr
);
gnode
=
gnode
->
FindRoot
();
if
(
!
fcond
(
gnode
->
pattern
,
src
==
sink
))
return
false
;
if
(
src
==
sink
)
return
true
;
for
(
auto
link
=
src
->
outputs
.
head
;
link
!=
nullptr
;
link
=
link
->
next
)
{
if
(
!
CheckPath_
(
link
->
value
.
node
,
sink
,
fcond
))
return
false
;
}
return
true
;
}
/*!
* \brief Check all the node between src and sink satisfies fcond.
*
* src and sink are not checked.
*
* \param src The source node.
* \param sink The termination node.
* \param fcond The condition to be checked.
* \tparam F the condition function.
* \note sink must be a post-dominator of src.
*/
template
<
typename
F
>
bool
CheckPath
(
IndexedForwardGraph
::
Node
*
src
,
IndexedForwardGraph
::
Node
*
sink
,
F
fcond
)
{
CHECK
(
!
src
->
extern_ref
);
visited_
.
clear
();
CHECK
(
src
!=
sink
);
for
(
auto
link
=
src
->
outputs
.
head
;
link
!=
nullptr
;
link
=
link
->
next
)
{
if
(
!
CheckPath_
(
link
->
value
.
node
,
sink
,
fcond
))
return
false
;
}
return
true
;
}
// Combine two patterns together.
static
OpPatternKind
CombinePattern
(
OpPatternKind
lhs
,
OpPatternKind
rhs
)
{
if
(
lhs
>
kBroadcast
&&
rhs
>
kBroadcast
)
{
LOG
(
FATAL
)
<<
"Cannot merge two complex group together"
;
}
if
(
lhs
>
rhs
)
return
lhs
;
return
rhs
;
}
/*!
* \brief Merge the child group to the parent.
* \param child The child group.
* \param parent The parent group.
*/
void
MergeFromTo
(
Group
*
child
,
Group
*
parent
)
{
child
=
child
->
FindRoot
();
parent
=
parent
->
FindRoot
();
if
(
child
==
parent
)
return
;
child
->
parent
=
parent
;
// update master ref and pattern
if
(
child
->
master_ref
!=
nullptr
)
{
CHECK
(
parent
->
master_ref
==
nullptr
);
parent
->
master_ref
=
child
->
master_ref
;
parent
->
pattern
=
CombinePattern
(
child
->
pattern
,
parent
->
pattern
);
}
}
// Internal implelementation of CommitFuse
void
CommitFuse_
(
IndexedForwardGraph
::
Node
*
src
,
IndexedForwardGraph
::
Node
*
sink
,
Group
*
target
)
{
if
(
src
==
sink
)
return
;
if
(
visited_
.
count
(
src
))
return
;
visited_
.
insert
(
src
);
Group
*
gnode
=
groups_
[
src
->
index
];
CHECK
(
gnode
!=
nullptr
);
// merge the current group to the parent if possible.
MergeFromTo
(
gnode
,
target
);
for
(
auto
link
=
src
->
outputs
.
head
;
link
!=
nullptr
;
link
=
link
->
next
)
{
CommitFuse_
(
link
->
value
.
node
,
sink
,
target
);;
}
}
/*!
* \brief Commit fusion operation.
* \param src The source node.
* \param sink The termination node.
* \tparam group the group to be committed.
* \note sink must be a post-dominator of src.
*/
void
CommitFuse
(
IndexedForwardGraph
::
Node
*
src
,
IndexedForwardGraph
::
Node
*
sink
)
{
Group
*
target
=
groups_
[
sink
->
index
];
visited_
.
clear
();
CHECK
(
src
!=
sink
);
CommitFuse_
(
src
,
sink
,
target
);
}
// Initialize the groups.
void
InitGroups
(
const
IndexedForwardGraph
&
graph
)
{
groups_
.
resize
(
graph
.
post_dfs_order
.
size
());
for
(
size_t
nid
=
0
;
nid
<
groups_
.
size
();
++
nid
)
{
const
auto
*
graph_node
=
graph
.
post_dfs_order
[
nid
];
auto
*
group_node
=
arena_
->
make
<
Group
>
();
group_node
->
pattern
=
graph_node
->
pattern
;
group_node
->
root_ref
=
graph_node
->
ref
;
// set master ref if necessary.
if
(
group_node
->
pattern
==
kOutEWiseFusable
)
{
group_node
->
master_ref
=
graph_node
->
ref
;
}
groups_
[
nid
]
=
group_node
;
}
}
// execute the fusion algorithm.
void
RunFuse
(
const
IndexedForwardGraph
&
graph
,
const
DominatorTree
&
post_dom_tree
,
int
phase
)
{
for
(
size_t
nid
=
0
;
nid
<
groups_
.
size
();
++
nid
)
{
// the group of current node has been specified already.
auto
*
graph_node
=
graph
.
post_dfs_order
[
nid
];
auto
*
dom_node
=
post_dom_tree
.
nodes
[
nid
];
Group
*
group_node
=
groups_
[
nid
];
CHECK
(
group_node
!=
nullptr
);
// no actions for opaque nodes
if
(
group_node
->
pattern
==
kOpaque
)
continue
;
// no actions needed if the current node have no dominator
if
(
dom_node
->
parent
==
nullptr
)
continue
;
CHECK
(
!
graph_node
->
extern_ref
);
// Skip if current node is already fused to the parent.
size_t
dom_parent_gindex
=
dom_node
->
parent
->
gnode
->
index
;
if
(
groups_
[
dom_parent_gindex
]
!=
nullptr
&&
group_node
->
FindRoot
()
==
groups_
[
dom_parent_gindex
]
->
FindRoot
())
{
continue
;
}
// Try to fuse current node to its post-dominator.
if
(
group_node
->
pattern
==
kOutEWiseFusable
)
{
if
(
phase
!=
0
)
continue
;
// Path for OutEWiseFusable: conv2d
// Check if the dominator relation is elemwise.
if
(
dom_node
->
parent
!=
nullptr
&&
dom_node
->
pattern
==
kElemWise
)
{
CHECK
(
dom_node
->
parent
->
gnode
!=
nullptr
);
// The fuse can be executed if all the intermediate ops are still broadcast.
auto
fcond
=
[](
OpPatternKind
kind
,
bool
is_sink
)
{
return
kind
<=
kBroadcast
;
};
if
(
CheckPath
(
graph_node
,
dom_node
->
parent
->
gnode
,
fcond
))
{
CommitFuse
(
graph_node
,
dom_node
->
parent
->
gnode
);
}
}
}
else
if
(
group_node
->
pattern
<=
kBroadcast
)
{
// The fuse can be executed if all the intermediate ops are still broadcast.
auto
fcond
=
[](
OpPatternKind
kind
,
bool
is_sink
)
{
if
(
!
is_sink
)
{
return
kind
<=
kBroadcast
;
}
else
{
return
(
kind
<=
kBroadcast
||
kind
==
kCommReduce
||
kind
==
kOutEWiseFusable
);
}
};
if
(
CheckPath
(
graph_node
,
dom_node
->
parent
->
gnode
,
fcond
))
{
CommitFuse
(
graph_node
,
dom_node
->
parent
->
gnode
);
}
}
else
if
(
group_node
->
pattern
==
kInjective
)
{
// defer injective fusion to second phase.
// so conv2d always finishes fusing.
if
(
phase
!=
1
)
continue
;
// Check if all path are injective.
auto
fcond
=
[](
OpPatternKind
kind
,
bool
is_sink
)
{
return
kind
<=
kInjective
;
};
if
(
CheckPath
(
graph_node
,
dom_node
->
parent
->
gnode
,
fcond
))
{
CommitFuse
(
graph_node
,
dom_node
->
parent
->
gnode
);
}
}
else
{
// do nothing.
CHECK
(
group_node
->
pattern
==
kCommReduce
);
}
}
}
};
std
::
vector
<
GraphPartitioner
::
Group
*>
GraphPartitioner
::
Partition
(
const
IndexedForwardGraph
&
graph
)
{
this
->
InitGroups
(
graph
);
if
(
opt_level_
==
0
)
return
std
::
move
(
groups_
);
// get post dominator tree
auto
post_dom_tree
=
DominatorTree
::
PostDom
(
arena_
,
graph
);
// run fusion algorithm.
for
(
int
phase
=
0
;
phase
<
2
;
++
phase
)
{
this
->
RunFuse
(
graph
,
post_dom_tree
,
phase
);
}
return
std
::
move
(
groups_
);
}
class
FuseMutator
:
private
ExprMutator
{
public
:
// Run the transform
Expr
Transform
(
const
Expr
&
body
,
int
fuse_opt_level
)
{
// setup the group map.
auto
graph
=
IndexedForwardGraph
::
Create
(
&
arena_
,
body
);
auto
groups
=
GraphPartitioner
(
&
arena_
,
fuse_opt_level
).
Partition
(
graph
);
for
(
size_t
nid
=
0
;
nid
<
graph
.
post_dfs_order
.
size
();
++
nid
)
{
CHECK
(
graph
.
post_dfs_order
[
nid
]
->
ref
!=
nullptr
);
gmap_
[
graph
.
post_dfs_order
[
nid
]
->
ref
]
=
groups
[
nid
];
}
// The following line can be used for debug.
// this->DebugDumpGroup(body);
return
this
->
Mutate
(
body
);
}
private
:
/*! \brief Temporary information from each group. */
struct
GroupInfo
{
public
:
// The parameters of the function.
Array
<
Var
>
params
;
// The arguments to call the functions.
Array
<
Expr
>
arguments
;
// Get a new parameter or allocate an old one
Var
GetOrAllocParam
(
const
Expr
&
expr
,
const
Type
&
type
)
{
// run linear scan as most fused groups contain only a few inputs.
for
(
size_t
i
=
0
;
i
<
arguments
.
size
();
++
i
)
{
if
(
expr
.
same_as
(
arguments
[
i
]))
return
params
[
i
];
}
// create a new parameter.
std
::
ostringstream
os
;
os
<<
"p"
<<
params
.
size
();
auto
var
=
VarNode
::
make
(
os
.
str
(),
type
);
params
.
push_back
(
var
);
arguments
.
push_back
(
expr
);
return
var
;
}
};
/*! \brief Internal arena. */
common
::
Arena
arena_
;
/*! \brief The group assignment map. */
std
::
unordered_map
<
const
Node
*
,
GraphPartitioner
::
Group
*>
gmap_
;
/* \brief Internal group information map. */
std
::
unordered_map
<
GraphPartitioner
::
Group
*
,
GroupInfo
>
ginfo_
;
// Skip primitive function.
// Skip primitive function.
Expr
VisitExpr_
(
const
FunctionNode
*
fn_node
)
{
Expr
VisitExpr_
(
const
FunctionNode
*
fn_node
)
{
NodeRef
res
=
FunctionGetAttr
(
GetRef
<
Function
>
(
fn_node
),
"Primitive"
);
NodeRef
res
=
FunctionGetAttr
(
GetRef
<
Function
>
(
fn_node
),
"Primitive"
);
...
@@ -26,48 +699,74 @@ class SimpleFuser : public ExprMutator {
...
@@ -26,48 +699,74 @@ class SimpleFuser : public ExprMutator {
return
ExprMutator
::
VisitExpr_
(
fn_node
);
return
ExprMutator
::
VisitExpr_
(
fn_node
);
}
}
}
}
// Transform calls.
Expr
VisitExpr_
(
const
CallNode
*
call
)
{
Expr
VisitExpr_
(
const
CallNode
*
call
)
{
if
(
call
->
op
.
as
<
OpNode
>
())
{
if
(
call
->
op
.
as
<
OpNode
>
())
{
// Placeholder fusion algorithm which abstracts
// If it is a primitive op call
// single definitions into functions only.
// then we must have a group assignment for it already.
Array
<
Var
>
params
;
CHECK
(
gmap_
.
count
(
call
));
Array
<
Expr
>
inner_args
;
auto
*
ret_group
=
gmap_
.
at
(
call
)
->
FindRoot
();
Array
<
Expr
>
args
;
Array
<
Expr
>
new_args
;
int
param_number
=
0
;
for
(
auto
arg
:
call
->
args
)
{
for
(
auto
arg
:
call
->
args
)
{
std
::
ostringstream
os
;
os
<<
"p"
<<
param_number
++
;
auto
type
=
arg
->
checked_type
();
auto
type
=
arg
->
checked_type
();
auto
var
=
VarNode
::
make
(
os
.
str
(),
type
);
CHECK
(
gmap_
.
count
(
arg
.
get
()))
params
.
push_back
(
var
);
<<
"cannot find group of "
<<
arg
;
inner_args
.
push_back
(
var
);
auto
*
arg_group
=
gmap_
.
at
(
arg
.
get
())
->
FindRoot
();
args
.
push_back
(
this
->
Mutate
(
arg
));
Expr
new_arg
=
this
->
Mutate
(
arg
);
if
(
ret_group
!=
arg_group
)
{
Var
param
=
ginfo_
[
ret_group
].
GetOrAllocParam
(
new_arg
,
type
);
new_args
.
push_back
(
param
);
}
else
{
new_args
.
push_back
(
new_arg
);
}
}
auto
body
=
CallNode
::
make
(
call
->
op
,
inner_args
,
call
->
attrs
);
}
auto
new_call
=
CallNode
::
make
(
call
->
op
,
new_args
,
call
->
attrs
,
call
->
type_args
);
if
(
ret_group
->
root_ref
==
call
)
{
// This is the root of the group
// create the new call node.
const
GroupInfo
&
ginfo
=
ginfo_
[
ret_group
];
auto
func
=
FunctionNode
::
make
(
auto
func
=
FunctionNode
::
make
(
params
,
body
,
call
->
checked_type
(),
{});
ginfo
.
params
,
new_call
,
call
->
checked_type
(),
{});
func
=
FunctionSetAttr
(
func
,
"Primitive"
,
tvm
::
Integer
(
1
));
func
=
FunctionSetAttr
(
func
,
"Primitive"
,
tvm
::
Integer
(
1
));
return
CallNode
::
make
(
func
,
args
,
Attrs
());
return
CallNode
::
make
(
func
,
ginfo
.
arguments
,
Attrs
());
}
else
{
// This is an intermediate node of a fused function
// simply return the new call.
return
new_call
;
}
}
else
{
}
else
{
return
ExprMutator
::
VisitExpr_
(
call
);
return
ExprMutator
::
VisitExpr_
(
call
);
}
}
}
}
// Debug function, dump the group assignment in text.
void
DebugDumpGroup
(
const
Expr
&
body
)
{
std
::
string
text
=
RelayPrint
(
body
,
[
this
](
const
Expr
&
expr
)
->
std
::
string
{
auto
it
=
gmap_
.
find
(
expr
.
get
());
if
(
it
==
gmap_
.
end
())
return
""
;
std
::
ostringstream
os
;
auto
*
group
=
it
->
second
->
FindRoot
();
os
<<
"group="
<<
group
;
return
os
.
str
();
});
LOG
(
INFO
)
<<
"Dump of group info:
\n
"
<<
text
;
}
};
};
Expr
FuseOps
(
const
Expr
&
expr
)
{
Expr
FuseOps
(
const
Expr
&
expr
,
int
fuse_opt_level
)
{
// First we convert all chains of fusable ops into
// First we convert all chains of fusable ops into
// abstracted functions which we mark as primtive
// abstracted functions which we mark as primtive
// then we convert these primtive functions into
// then we convert these primtive functions into
// new operators.
// new operators.
return
SimpleFuser
().
Mutate
(
expr
);
return
FuseMutator
().
Transform
(
expr
,
fuse_opt_level
);
}
}
TVM_REGISTER_API
(
"relay._ir_pass.FuseOps"
)
TVM_REGISTER_API
(
"relay._ir_pass.FuseOps"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
FuseOps
(
args
[
0
]);
*
ret
=
FuseOps
(
args
[
0
]
,
args
[
1
]
);
});
});
}
// namespace relay
}
// namespace relay
}
// namespace tvm
}
// namespace tvm
src/relay/pass/pass_util.h
0 → 100644
View file @
4369b7f6
/*!
* Copyright (c) 2018 by Contributors.
*
* \file tvm/relay/pass/pass_util.h
* \brief Utilities for writing
*/
#ifndef TVM_RELAY_PASS_PASS_UTIL_H_
#define TVM_RELAY_PASS_PASS_UTIL_H_
#include <tvm/relay/op.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/attrs/transform.h>
namespace
tvm
{
namespace
relay
{
/*!
* \brief Get reference counter of each internal ExprNode in body.
* \param body The body expression.
* \return The reference count mapping.
*/
std
::
unordered_map
<
const
Node
*
,
size_t
>
GetExprRefCount
(
const
Expr
&
body
);
}
// namespace relay
}
// namespace tvm
#endif // TVM_RELAY_PASS_PASS_UTIL_H_
src/relay/pass/type_infer.cc
View file @
4369b7f6
...
@@ -442,6 +442,9 @@ class TypeInferencer::Resolver : public ExprMutator {
...
@@ -442,6 +442,9 @@ class TypeInferencer::Resolver : public ExprMutator {
VarNode
*
new_var
=
(
VarNode
*
new_var
=
(
std
::
is_base_of
<
VarNode
,
T
>::
value
?
std
::
is_base_of
<
VarNode
,
T
>::
value
?
static_cast
<
VarNode
*>
(
new_e
.
node_
.
get
())
:
nullptr
);
static_cast
<
VarNode
*>
(
new_e
.
node_
.
get
())
:
nullptr
);
FunctionNode
*
new_fn
=
(
std
::
is_base_of
<
FunctionNode
,
T
>::
value
?
static_cast
<
FunctionNode
*>
(
new_e
.
node_
.
get
())
:
nullptr
);
// check if we need update the new_e
// check if we need update the new_e
bool
need_update_type
=
!
checked_type
.
same_as
(
new_e
->
checked_type_
);
bool
need_update_type
=
!
checked_type
.
same_as
(
new_e
->
checked_type_
);
...
@@ -454,7 +457,17 @@ class TypeInferencer::Resolver : public ExprMutator {
...
@@ -454,7 +457,17 @@ class TypeInferencer::Resolver : public ExprMutator {
update_missing_type_annotation_
&&
update_missing_type_annotation_
&&
!
new_var
->
type_annotation
.
defined
());
!
new_var
->
type_annotation
.
defined
());
if
(
!
need_update_type
&&
!
need_update_var
&&
!
need_update_call
)
return
new_e
;
bool
need_update_fn
=
(
std
::
is_base_of
<
FunctionNode
,
T
>::
value
&&
update_missing_type_annotation_
&&
!
new_fn
->
ret_type
.
defined
());
if
(
!
need_update_type
&&
!
need_update_var
&&
!
need_update_call
&&
!
need_update_fn
)
{
return
new_e
;
}
if
(
!
new_e
.
node_
.
unique
())
{
if
(
!
new_e
.
node_
.
unique
())
{
// Copy on write optimization
// Copy on write optimization
...
@@ -467,6 +480,9 @@ class TypeInferencer::Resolver : public ExprMutator {
...
@@ -467,6 +480,9 @@ class TypeInferencer::Resolver : public ExprMutator {
new_var
=
(
new_var
=
(
std
::
is_base_of
<
VarNode
,
T
>::
value
?
std
::
is_base_of
<
VarNode
,
T
>::
value
?
static_cast
<
VarNode
*>
(
new_e
.
node_
.
get
())
:
nullptr
);
static_cast
<
VarNode
*>
(
new_e
.
node_
.
get
())
:
nullptr
);
new_fn
=
(
std
::
is_base_of
<
FunctionNode
,
T
>::
value
?
static_cast
<
FunctionNode
*>
(
new_e
.
node_
.
get
())
:
nullptr
);
}
}
// attach the information.
// attach the information.
...
@@ -483,6 +499,11 @@ class TypeInferencer::Resolver : public ExprMutator {
...
@@ -483,6 +499,11 @@ class TypeInferencer::Resolver : public ExprMutator {
if
(
need_update_var
)
{
if
(
need_update_var
)
{
new_var
->
type_annotation
=
checked_type
;
new_var
->
type_annotation
=
checked_type
;
}
}
if
(
need_update_fn
)
{
auto
*
fn_type
=
checked_type
.
as
<
FuncTypeNode
>
();
CHECK
(
fn_type
!=
nullptr
);
new_fn
->
ret_type
=
fn_type
->
ret_type
;
}
return
new_e
;
return
new_e
;
}
}
...
...
src/relay/pass/type_solver.cc
View file @
4369b7f6
...
@@ -85,18 +85,18 @@ Type TypeSolver::Unify(const Type& dst, const Type& src) {
...
@@ -85,18 +85,18 @@ Type TypeSolver::Unify(const Type& dst, const Type& src) {
void
TypeSolver
::
AddConstraint
(
const
TypeConstraint
&
constraint
)
{
void
TypeSolver
::
AddConstraint
(
const
TypeConstraint
&
constraint
)
{
if
(
auto
*
op
=
constraint
.
as
<
TypeRelationNode
>
())
{
if
(
auto
*
op
=
constraint
.
as
<
TypeRelationNode
>
())
{
// create a new relation node.
// create a new relation node.
RelationNode
*
rnode
=
make
<
RelationNode
>
();
RelationNode
*
rnode
=
arena_
.
make
<
RelationNode
>
();
rnode
->
rel
=
GetRef
<
TypeRelation
>
(
op
);
rnode
->
rel
=
GetRef
<
TypeRelation
>
(
op
);
rel_nodes_
.
push_back
(
rnode
);
rel_nodes_
.
push_back
(
rnode
);
// populate the type information.
// populate the type information.
for
(
size_t
i
=
0
;
i
<
op
->
args
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
op
->
args
.
size
();
++
i
)
{
// insert link to the type list
// insert link to the type list
LinkNode
<
TypeNode
*>*
tlink
=
make
<
LinkNode
<
TypeNode
*>
>
();
LinkNode
<
TypeNode
*>*
tlink
=
arena_
.
make
<
LinkNode
<
TypeNode
*>
>
();
TypeNode
*
tnode
=
GetTypeNode
(
op
->
args
[
i
]);
TypeNode
*
tnode
=
GetTypeNode
(
op
->
args
[
i
]);
tlink
->
value
=
tnode
;
tlink
->
value
=
tnode
;
rnode
->
type_list
.
Push
(
tlink
);
rnode
->
type_list
.
Push
(
tlink
);
// insert type->relation node
// insert type->relation node
LinkNode
<
RelationNode
*>*
rlink
=
make
<
LinkNode
<
RelationNode
*>
>
();
LinkNode
<
RelationNode
*>*
rlink
=
arena_
.
make
<
LinkNode
<
RelationNode
*>
>
();
rlink
->
value
=
rnode
;
rlink
->
value
=
rnode
;
tnode
->
rel_list
.
Push
(
rlink
);
tnode
->
rel_list
.
Push
(
rlink
);
}
}
...
...
src/relay/pass/type_solver.h
View file @
4369b7f6
...
@@ -16,6 +16,8 @@
...
@@ -16,6 +16,8 @@
namespace
tvm
{
namespace
tvm
{
namespace
relay
{
namespace
relay
{
using
common
::
LinkNode
;
using
common
::
LinkedList
;
/*!
/*!
* \brief Interface of type solver used in type inference.
* \brief Interface of type solver used in type inference.
*
*
...
@@ -70,41 +72,6 @@ class TypeSolver {
...
@@ -70,41 +72,6 @@ class TypeSolver {
// All the object in the structure is managed by a arena allocator
// All the object in the structure is managed by a arena allocator
// which releases the memory upon distruction of the type solver.
// which releases the memory upon distruction of the type solver.
/*!
/*!
* \brief Link list node
* \tparam T the content data type
*/
template
<
typename
T
>
struct
LinkNode
{
/*! \brief The content value */
T
value
;
/*! \brief pointer to the next location */
LinkNode
<
T
>*
next
{
nullptr
};
};
/*!
* \brief LinkedList structure
* \tparam T the content data type
*/
template
<
typename
T
>
struct
LinkedList
{
/*! \brief Head pointer */
LinkNode
<
T
>*
head
{
nullptr
};
/*! \brief Tail pointer */
LinkNode
<
T
>*
tail
{
nullptr
};
/*!
* \brief Push a new node to the end of the linked list.
* \param node The node to be pushed.
*/
void
Push
(
LinkNode
<
T
>*
node
)
{
node
->
next
=
nullptr
;
if
(
this
->
tail
!=
nullptr
)
{
this
->
tail
->
next
=
node
;
this
->
tail
=
node
;
}
else
{
head
=
tail
=
node
;
}
}
};
/*!
* \brief type node struct
* \brief type node struct
* TypeNode implements a union-find data structure(via parent)
* TypeNode implements a union-find data structure(via parent)
* that can unifies the same types to the name resolved_type.
* that can unifies the same types to the name resolved_type.
...
@@ -165,18 +132,6 @@ class TypeSolver {
...
@@ -165,18 +132,6 @@ class TypeSolver {
/*! \brief Reporter that reports back to self */
/*! \brief Reporter that reports back to self */
TypeReporter
reporter_
;
TypeReporter
reporter_
;
/*!
/*!
* \brief Create function to create a new node ptr via arena
* \tparam The type parameter
* \return The node pointer.
*/
template
<
typename
T
>
T
*
make
()
{
T
*
ptr
=
arena_
.
Alloc
<
T
>
();
// call constructor
new
(
ptr
)
T
();
return
ptr
;
}
/*!
* \brief GetTypeNode that is corresponds to t.
* \brief GetTypeNode that is corresponds to t.
* if it do not exist, create a new one.
* if it do not exist, create a new one.
* \return The type node.
* \return The type node.
...
@@ -186,7 +141,7 @@ class TypeSolver {
...
@@ -186,7 +141,7 @@ class TypeSolver {
if
(
it
!=
tmap_
.
end
())
{
if
(
it
!=
tmap_
.
end
())
{
return
it
->
second
->
FindRoot
();
return
it
->
second
->
FindRoot
();
}
else
{
}
else
{
TypeNode
*
n
=
make
<
TypeNode
>
();
TypeNode
*
n
=
arena_
.
make
<
TypeNode
>
();
type_nodes_
.
push_back
(
n
);
type_nodes_
.
push_back
(
n
);
n
->
resolved_type
=
t
;
n
->
resolved_type
=
t
;
tmap_
[
t
]
=
n
;
tmap_
[
t
]
=
n
;
...
...
src/relay/pass/util.cc
View file @
4369b7f6
...
@@ -129,5 +129,23 @@ TVM_REGISTER_API("relay._ir_pass.free_type_vars")
...
@@ -129,5 +129,23 @@ TVM_REGISTER_API("relay._ir_pass.free_type_vars")
}
}
});
});
/*!
* \brief Get reference counter of each internal ExprNode in body.
* \param body The body expression.
* \return The reference count mapping.
*/
std
::
unordered_map
<
const
Node
*
,
size_t
>
GetExprRefCount
(
const
Expr
&
body
)
{
class
ExprRefCounter
:
private
ExprVisitor
{
public
:
std
::
unordered_map
<
const
Node
*
,
size_t
>
Get
(
const
Expr
&
body
)
{
this
->
VisitExpr
(
body
);
return
std
::
move
(
this
->
visit_counter_
);
}
};
return
ExprRefCounter
().
Get
(
body
);
}
}
// namespace relay
}
// namespace relay
}
// namespace tvm
}
// namespace tvm
tests/python/relay/test_ir_text_printer.py
View file @
4369b7f6
...
@@ -33,6 +33,7 @@ def test_env():
...
@@ -33,6 +33,7 @@ def test_env():
text
=
env
.
astext
()
text
=
env
.
astext
()
assert
"def @myf"
in
text
assert
"def @myf"
in
text
assert
"
%1
= add(
%0
,
%0
) # ty=float32"
in
text
assert
"
%1
= add(
%0
,
%0
) # ty=float32"
in
text
show
(
env
.
astext
(
annotate
=
lambda
x
:
str
(
x
.
checked_type
.
dtype
)))
show
(
text
)
show
(
text
)
...
...
tests/python/relay/test_pass_fold_scale_axis.py
View file @
4369b7f6
...
@@ -46,6 +46,8 @@ def test_fold_fwd_simple():
...
@@ -46,6 +46,8 @@ def test_fold_fwd_simple():
weight
=
relay
.
var
(
"weight"
,
type_dict
[
"weight"
])
weight
=
relay
.
var
(
"weight"
,
type_dict
[
"weight"
])
y1_folded
=
relay
.
ir_pass
.
forward_fold_scale_axis
(
y1
)
y1_folded
=
relay
.
ir_pass
.
forward_fold_scale_axis
(
y1
)
y1_expected
=
expected
(
x
,
weight
,
in_bias
,
in_scale
,
channels
)
y1_expected
=
expected
(
x
,
weight
,
in_bias
,
in_scale
,
channels
)
y1_folded
=
relay
.
ir_pass
.
infer_type
(
y1_folded
)
y1_expected
=
relay
.
ir_pass
.
infer_type
(
y1_expected
)
assert
relay
.
ir_pass
.
alpha_equal
(
y1_folded
,
y1_expected
)
assert
relay
.
ir_pass
.
alpha_equal
(
y1_folded
,
y1_expected
)
check
((
2
,
4
,
10
,
10
),
2
)
check
((
2
,
4
,
10
,
10
),
2
)
...
@@ -113,6 +115,8 @@ def test_fold_fwd_dual_path():
...
@@ -113,6 +115,8 @@ def test_fold_fwd_dual_path():
type_dict
=
{
x
.
name_hint
:
x
.
checked_type
for
x
in
y1
.
params
}
type_dict
=
{
x
.
name_hint
:
x
.
checked_type
for
x
in
y1
.
params
}
weight
=
relay
.
var
(
"weight"
,
type_dict
[
"weight"
])
weight
=
relay
.
var
(
"weight"
,
type_dict
[
"weight"
])
y1_expected
=
expected
(
x
,
weight
,
in_bias
,
in_scale
,
channels
)
y1_expected
=
expected
(
x
,
weight
,
in_bias
,
in_scale
,
channels
)
y1_folded
=
relay
.
ir_pass
.
infer_type
(
y1_folded
)
y1_expected
=
relay
.
ir_pass
.
infer_type
(
y1_expected
)
assert
relay
.
ir_pass
.
alpha_equal
(
y1_folded
,
y1_expected
)
assert
relay
.
ir_pass
.
alpha_equal
(
y1_folded
,
y1_expected
)
check
((
2
,
4
,
10
,
3
),
3
)
check
((
2
,
4
,
10
,
3
),
3
)
...
@@ -194,6 +198,8 @@ def test_fold_bwd_simple():
...
@@ -194,6 +198,8 @@ def test_fold_bwd_simple():
weight
=
relay
.
var
(
"weight"
,
type_dict
[
"weight"
])
weight
=
relay
.
var
(
"weight"
,
type_dict
[
"weight"
])
y1_folded
=
relay
.
ir_pass
.
backward_fold_scale_axis
(
y1
)
y1_folded
=
relay
.
ir_pass
.
backward_fold_scale_axis
(
y1
)
y1_expected
=
expected
(
x
,
weight
,
out_bias
,
out_scale
,
channels
)
y1_expected
=
expected
(
x
,
weight
,
out_bias
,
out_scale
,
channels
)
y1_folded
=
relay
.
ir_pass
.
infer_type
(
y1_folded
)
y1_expected
=
relay
.
ir_pass
.
infer_type
(
y1_expected
)
assert
relay
.
ir_pass
.
alpha_equal
(
y1_folded
,
y1_expected
)
assert
relay
.
ir_pass
.
alpha_equal
(
y1_folded
,
y1_expected
)
check
((
2
,
4
,
10
,
10
),
8
)
check
((
2
,
4
,
10
,
10
),
8
)
...
@@ -255,6 +261,8 @@ def test_fold_bwd_dual_path():
...
@@ -255,6 +261,8 @@ def test_fold_bwd_dual_path():
weight
=
relay
.
var
(
"weight"
,
type_dict
[
"weight"
])
weight
=
relay
.
var
(
"weight"
,
type_dict
[
"weight"
])
y1_folded
=
relay
.
ir_pass
.
backward_fold_scale_axis
(
y1
)
y1_folded
=
relay
.
ir_pass
.
backward_fold_scale_axis
(
y1
)
y1_expected
=
expected
(
x
,
weight
,
out_bias
,
out_scale
,
channels
)
y1_expected
=
expected
(
x
,
weight
,
out_bias
,
out_scale
,
channels
)
y1_folded
=
relay
.
ir_pass
.
infer_type
(
y1_folded
)
y1_expected
=
relay
.
ir_pass
.
infer_type
(
y1_expected
)
assert
relay
.
ir_pass
.
alpha_equal
(
y1_folded
,
y1_expected
)
assert
relay
.
ir_pass
.
alpha_equal
(
y1_folded
,
y1_expected
)
check
((
2
,
4
,
10
,
10
),
8
)
check
((
2
,
4
,
10
,
10
),
8
)
...
...
tests/python/relay/test_pass_fuse_ops.py
View file @
4369b7f6
...
@@ -3,15 +3,103 @@ from tvm import relay
...
@@ -3,15 +3,103 @@ from tvm import relay
def
test_fuse_simple
():
def
test_fuse_simple
():
"""Simple testcase."""
"""Simple testcase."""
def
before
():
x
=
relay
.
var
(
"x"
,
shape
=
(
10
,
20
))
x
=
relay
.
var
(
"x"
,
shape
=
(
10
,
20
))
y
=
relay
.
add
(
x
,
x
)
y
=
relay
.
add
(
x
,
relay
.
const
(
1
,
"float32"
)
)
z
=
relay
.
exp
(
y
)
z
=
relay
.
exp
(
y
)
return
relay
.
Function
([
x
],
z
)
def
expected
():
x
=
relay
.
var
(
"p"
,
shape
=
(
10
,
20
))
y
=
relay
.
add
(
x
,
relay
.
const
(
1
,
"float32"
))
z
=
relay
.
exp
(
y
)
f1
=
relay
.
Function
([
x
],
z
)
x
=
relay
.
var
(
"x"
,
shape
=
(
10
,
20
))
y
=
relay
.
Call
(
f1
,
[
x
])
return
relay
.
Function
([
x
],
y
)
z
=
before
()
z
=
relay
.
ir_pass
.
infer_type
(
z
)
z
=
relay
.
ir_pass
.
infer_type
(
z
)
zz
=
relay
.
ir_pass
.
fuse_ops
(
z
)
zz
=
relay
.
ir_pass
.
fuse_ops
(
z
,
opt_level
=
2
)
zz
=
relay
.
ir_pass
.
infer_type
(
zz
)
zz
=
relay
.
ir_pass
.
fuse_ops
(
zz
)
zz
=
relay
.
ir_pass
.
fuse_ops
(
zz
)
zz
=
relay
.
ir_pass
.
infer_type
(
zz
)
zz
=
relay
.
ir_pass
.
infer_type
(
zz
)
zz
.
astext
()
after
=
relay
.
ir_pass
.
infer_type
(
expected
())
assert
relay
.
ir_pass
.
alpha_equal
(
zz
,
after
)
def
test_conv2d_fuse
():
"""Test fusion case of conv2d"""
def
before
(
dshape
):
x
=
relay
.
var
(
"x"
,
shape
=
dshape
)
y
=
relay
.
nn
.
conv2d
(
x
,
relay
.
var
(
"w1"
),
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
),
channels
=
16
)
# this is the next dominator.
y1
=
relay
.
add
(
relay
.
const
(
1
,
"float32"
),
y
)
y
=
relay
.
add
(
y
,
y1
)
# second path
z2
=
relay
.
nn
.
conv2d
(
y
,
relay
.
var
(
"w2"
),
kernel_size
=
(
1
,
1
),
padding
=
(
0
,
0
),
channels
=
16
)
z3
=
relay
.
nn
.
conv2d
(
y
,
relay
.
var
(
"w3"
),
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
),
channels
=
16
)
# add can only be fused to z1
z
=
relay
.
add
(
z2
,
z3
)
return
relay
.
Function
(
relay
.
ir_pass
.
free_vars
(
z
),
z
)
def
expected
(
dshape
):
# segment 1
x
=
relay
.
var
(
"p0"
,
shape
=
dshape
)
w
=
relay
.
var
(
"p1"
)
y
=
relay
.
nn
.
conv2d
(
x
,
w
,
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
),
channels
=
16
)
y1
=
relay
.
add
(
relay
.
const
(
1
,
"float32"
),
y
)
y
=
relay
.
add
(
y
,
y1
)
f1
=
relay
.
Function
([
x
,
w
],
y
)
# segment 2
x
=
relay
.
var
(
"p0"
,
shape
=
dshape
)
w
=
relay
.
var
(
"p1"
)
z2
=
relay
.
nn
.
conv2d
(
x
,
w
,
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
),
channels
=
16
)
f2
=
relay
.
Function
([
x
,
w
],
z2
)
# segment 3
x
=
relay
.
var
(
"p0"
,
shape
=
dshape
)
w
=
relay
.
var
(
"p1"
)
offset
=
relay
.
var
(
"p2"
,
shape
=
dshape
)
z3
=
relay
.
nn
.
conv2d
(
x
,
w
,
kernel_size
=
(
1
,
1
),
padding
=
(
0
,
0
),
channels
=
16
)
z3
=
relay
.
add
(
z3
,
offset
)
f3
=
relay
.
Function
([
x
,
w
,
offset
],
z3
)
# compose
x
=
relay
.
var
(
"x"
,
shape
=
dshape
)
y
=
relay
.
Call
(
f1
,
[
x
,
relay
.
var
(
"w1"
)])
z2
=
relay
.
Call
(
f2
,
[
y
,
relay
.
var
(
"w3"
)])
z3
=
relay
.
Call
(
f3
,
[
y
,
relay
.
var
(
"w2"
),
z2
])
z
=
z3
return
relay
.
Function
(
relay
.
ir_pass
.
free_vars
(
z
),
z
)
dshape
=
(
1
,
16
,
64
,
64
)
z
=
before
(
dshape
)
z
=
relay
.
ir_pass
.
infer_type
(
z
)
zz
=
relay
.
ir_pass
.
fuse_ops
(
z
,
opt_level
=
2
)
zz
=
relay
.
ir_pass
.
infer_type
(
zz
)
after
=
relay
.
ir_pass
.
infer_type
(
expected
(
dshape
))
assert
relay
.
ir_pass
.
alpha_equal
(
zz
,
after
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_fuse_simple
()
test_fuse_simple
()
test_conv2d_fuse
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment