Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
4369b7f6
Unverified
Commit
4369b7f6
authored
Nov 13, 2018
by
Tianqi Chen
Committed by
GitHub
Nov 13, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RELAY][PASS] General OpFusion. (#2090)
parent
e470f8ea
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
301 additions
and
87 deletions
+301
-87
include/tvm/relay/expr.h
+10
-0
include/tvm/runtime/packed_func.h
+2
-0
python/tvm/relay/base.py
+7
-2
python/tvm/relay/build_module.py
+3
-1
python/tvm/relay/expr.py
+1
-2
python/tvm/relay/ir_pass.py
+5
-2
src/common/arena.h
+57
-1
src/relay/backend/compile_engine.cc
+23
-0
src/relay/ir/text_printer.cc
+19
-7
src/relay/pass/fold_scale_axis.cc
+1
-17
src/relay/pass/fuse_ops.cc
+0
-0
src/relay/pass/pass_util.h
+27
-0
src/relay/pass/type_infer.cc
+22
-1
src/relay/pass/type_solver.cc
+3
-3
src/relay/pass/type_solver.h
+3
-48
src/relay/pass/util.cc
+18
-0
tests/python/relay/test_ir_text_printer.py
+1
-0
tests/python/relay/test_pass_fold_scale_axis.py
+8
-0
tests/python/relay/test_pass_fuse_ops.py
+91
-3
No files found.
include/tvm/relay/expr.h
View file @
4369b7f6
...
@@ -429,6 +429,16 @@ inline const TTypeNode* ExprNode::type_as() const {
...
@@ -429,6 +429,16 @@ inline const TTypeNode* ExprNode::type_as() const {
return
node
;
return
node
;
}
}
/*!
* \brief Print node as text format.
* \param node The node to be printed.
* \param annotate An optional callback function for attaching
* additional comment block to an expr.
* \return The text representation.
*/
std
::
string
RelayPrint
(
const
NodeRef
&
node
,
runtime
::
TypedPackedFunc
<
std
::
string
(
Expr
)
>
annotate
=
nullptr
);
}
// namespace relay
}
// namespace relay
}
// namespace tvm
}
// namespace tvm
#endif // TVM_RELAY_EXPR_H_
#endif // TVM_RELAY_EXPR_H_
include/tvm/runtime/packed_func.h
View file @
4369b7f6
...
@@ -161,6 +161,8 @@ class TypedPackedFunc<R(Args...)> {
...
@@ -161,6 +161,8 @@ class TypedPackedFunc<R(Args...)> {
using
TSelf
=
TypedPackedFunc
<
R
(
Args
...)
>
;
using
TSelf
=
TypedPackedFunc
<
R
(
Args
...)
>
;
/*! \brief default constructor */
/*! \brief default constructor */
TypedPackedFunc
()
{}
TypedPackedFunc
()
{}
/*! \brief constructor from null */
TypedPackedFunc
(
std
::
nullptr_t
null
)
{}
// NOLINT(*)
/*!
/*!
* \brief construct by wrap a PackedFunc
* \brief construct by wrap a PackedFunc
*
*
...
...
python/tvm/relay/base.py
View file @
4369b7f6
...
@@ -22,15 +22,20 @@ def register_relay_node(type_key=None):
...
@@ -22,15 +22,20 @@ def register_relay_node(type_key=None):
class
RelayNode
(
NodeBase
):
class
RelayNode
(
NodeBase
):
def
astext
(
self
):
"""Base class of all relay node."""
def
astext
(
self
,
annotate
=
None
):
"""Get the text format of the expression.
"""Get the text format of the expression.
Returns
Returns
-------
-------
text : str
text : str
The text format of the expression.
The text format of the expression.
annotate: Optional[relay.Expr->str]
Optional annotate function to provide additional
information in the comment block.
"""
"""
return
_expr
.
_text_print
(
self
)
return
_expr
.
RelayPrint
(
self
,
annotate
)
@register_relay_node
@register_relay_node
...
...
python/tvm/relay/build_module.py
View file @
4369b7f6
...
@@ -173,11 +173,13 @@ def build(func,
...
@@ -173,11 +173,13 @@ def build(func,
else
:
else
:
tophub_context
=
autotvm
.
util
.
EmptyContext
()
tophub_context
=
autotvm
.
util
.
EmptyContext
()
cfg
=
BuildConfig
.
current
with
tophub_context
:
with
tophub_context
:
func
=
optimize
(
func
)
func
=
optimize
(
func
)
# Fuse ops before running code gen
# Fuse ops before running code gen
func
=
ir_pass
.
infer_type
(
func
)
func
=
ir_pass
.
infer_type
(
func
)
func
=
ir_pass
.
fuse_ops
(
func
)
func
=
ir_pass
.
fuse_ops
(
func
,
cfg
.
opt_level
)
# Graph code generation
# Graph code generation
func
=
ir_pass
.
infer_type
(
func
)
func
=
ir_pass
.
infer_type
(
func
)
graph_gen
=
_graph_gen
.
GraphRuntimeCodegen
(
mod
=
None
,
target
=
target
)
graph_gen
=
_graph_gen
.
GraphRuntimeCodegen
(
mod
=
None
,
target
=
target
)
...
...
python/tvm/relay/expr.py
View file @
4369b7f6
...
@@ -6,7 +6,6 @@ from numbers import Number as _Number
...
@@ -6,7 +6,6 @@ from numbers import Number as _Number
import
numpy
as
_np
import
numpy
as
_np
from
.base
import
RelayNode
,
register_relay_node
from
.base
import
RelayNode
,
register_relay_node
from
.
import
_make
from
.
import
_make
from
.
import
_expr
from
.
import
ty
as
_ty
from
.
import
ty
as
_ty
from
.._ffi
import
base
as
_base
from
.._ffi
import
base
as
_base
from
..
import
nd
as
_nd
from
..
import
nd
as
_nd
...
@@ -477,7 +476,7 @@ class TupleWrapper(object):
...
@@ -477,7 +476,7 @@ class TupleWrapper(object):
text : str
text : str
The text format of the tuple expression.
The text format of the tuple expression.
"""
"""
return
_expr
.
_text_print
(
self
.
tuple_value
)
return
self
.
tuple_value
.
astext
(
)
def
__getitem__
(
self
,
index
):
def
__getitem__
(
self
,
index
):
if
index
>=
len
(
self
):
if
index
>=
len
(
self
):
...
...
python/tvm/relay/ir_pass.py
View file @
4369b7f6
...
@@ -259,7 +259,7 @@ def structural_hash(value):
...
@@ -259,7 +259,7 @@ def structural_hash(value):
raise
TypeError
(
msg
)
raise
TypeError
(
msg
)
def
fuse_ops
(
expr
):
def
fuse_ops
(
expr
,
opt_level
=
1
):
"""Fuse operators in expr together.
"""Fuse operators in expr together.
Parameters
Parameters
...
@@ -267,9 +267,12 @@ def fuse_ops(expr):
...
@@ -267,9 +267,12 @@ def fuse_ops(expr):
expr : tvm.relay.Expr
expr : tvm.relay.Expr
The input expression.
The input expression.
opt_level : int
The level of fuse optimization.
Returns
Returns
-------
-------
transformed_expr : tvm.relay.Expr
transformed_expr : tvm.relay.Expr
Transformed expression, containing fused result.
Transformed expression, containing fused result.
"""
"""
return
_ir_pass
.
FuseOps
(
expr
)
return
_ir_pass
.
FuseOps
(
expr
,
opt_level
)
src/common/arena.h
View file @
4369b7f6
...
@@ -38,11 +38,29 @@ class Arena {
...
@@ -38,11 +38,29 @@ class Arena {
/*!
/*!
* \brief Allocate a space from Arena for type T
* \brief Allocate a space from Arena for type T
* \param T the data type to be allocated
* \param T the data type to be allocated
* \note The space of T is not initialized.
*/
*/
template
<
typename
T
>
template
<
typename
T
>
T
*
Alloc
()
{
T
*
allocate_
()
{
return
static_cast
<
T
*>
(
Alloc
(
sizeof
(
T
),
alignof
(
T
)));
return
static_cast
<
T
*>
(
Alloc
(
sizeof
(
T
),
alignof
(
T
)));
}
}
/*!
* \brief Create a new instance of type T.
* \param args The constructor argument.
* \tparam T the type to be created.
* \tparam Args Arguments to the constructor.
*
* \return The allocated object.
* \note The type T must be simple type, or only contain
* memory allocated from the same arena.
* Otherwise the destructor needs to be called explicitly.
*/
template
<
typename
T
,
typename
...
Args
>
T
*
make
(
Args
&&
...
args
)
{
T
*
ptr
=
allocate_
<
T
>
();
new
(
ptr
)
T
(
std
::
forward
<
Args
>
(
args
)...);
return
ptr
;
}
private
:
private
:
// page size 16 KB
// page size 16 KB
...
@@ -87,6 +105,44 @@ class Arena {
...
@@ -87,6 +105,44 @@ class Arena {
}
}
};
};
/*!
* \brief Link list node
* \tparam T the content data type
*/
template
<
typename
T
>
struct
LinkNode
{
/*! \brief The content value */
T
value
;
/*! \brief pointer to the next location */
LinkNode
<
T
>*
next
{
nullptr
};
};
/*!
* \brief LinkedList structure
* \tparam T the content data type
* \note This is a simple data structure that can be used together with the arena.
* \sa LinkNode
*/
template
<
typename
T
>
struct
LinkedList
{
/*! \brief Head pointer */
LinkNode
<
T
>*
head
{
nullptr
};
/*! \brief Tail pointer */
LinkNode
<
T
>*
tail
{
nullptr
};
/*!
* \brief Push a new node to the end of the linked list.
* \param node The node to be pushed.
*/
void
Push
(
LinkNode
<
T
>*
node
)
{
node
->
next
=
nullptr
;
if
(
this
->
tail
!=
nullptr
)
{
this
->
tail
->
next
=
node
;
this
->
tail
=
node
;
}
else
{
head
=
tail
=
node
;
}
}
};
}
// namespace common
}
// namespace common
}
// namespace tvm
}
// namespace tvm
#endif // TVM_COMMON_ARENA_H_
#endif // TVM_COMMON_ARENA_H_
src/relay/backend/compile_engine.cc
View file @
4369b7f6
...
@@ -109,6 +109,29 @@ class ScheduleGetter :
...
@@ -109,6 +109,29 @@ class ScheduleGetter :
return
{};
return
{};
}
}
Array
<
Tensor
>
VisitExpr_
(
const
ConstantNode
*
op
)
final
{
CHECK
(
op
->
is_scalar
());
void
*
data
=
op
->
data
->
data
;
DataType
dtype
=
TVMType2Type
(
op
->
data
->
dtype
);
Tensor
value
=
tvm
::
compute
({},
[
&
](
const
Array
<
tvm
::
Var
>&
)
{
if
(
dtype
==
Int
(
32
))
{
return
make_const
(
dtype
,
static_cast
<
const
int32_t
*>
(
data
)[
0
]);
}
else
if
(
dtype
==
Int
(
64
))
{
return
make_const
(
dtype
,
static_cast
<
const
int64_t
*>
(
data
)[
0
]);
}
else
if
(
dtype
==
Float
(
32
))
{
return
make_const
(
dtype
,
static_cast
<
const
float
*>
(
data
)[
0
]);
}
else
if
(
dtype
==
Float
(
64
))
{
return
make_const
(
dtype
,
static_cast
<
const
double
*>
(
data
)[
0
]);
}
else
if
(
dtype
==
Bool
())
{
return
make_const
(
dtype
,
static_cast
<
const
uint8_t
*>
(
data
)[
0
]);
}
else
{
LOG
(
FATAL
)
<<
"not handled"
;
return
tvm
::
Expr
();
}
});
return
{
value
};
}
Array
<
Tensor
>
VisitExpr_
(
const
CallNode
*
call_node
)
final
{
Array
<
Tensor
>
VisitExpr_
(
const
CallNode
*
call_node
)
final
{
static
auto
fcompute
=
static
auto
fcompute
=
Op
::
GetAttr
<
FTVMCompute
>
(
"FTVMCompute"
);
Op
::
GetAttr
<
FTVMCompute
>
(
"FTVMCompute"
);
...
...
src/relay/ir/text_printer.cc
View file @
4369b7f6
...
@@ -125,6 +125,8 @@ class TextPrinter :
...
@@ -125,6 +125,8 @@ class TextPrinter :
public
TypeFunctor
<
void
(
const
Type
&
,
std
::
ostream
&
os
)
>
,
// NOLINT(*)
public
TypeFunctor
<
void
(
const
Type
&
,
std
::
ostream
&
os
)
>
,
// NOLINT(*)
public
AttrFunctor
<
void
(
const
NodeRef
&
,
std
::
ostream
&
os
)
>
{
// NOLINT(*)
public
AttrFunctor
<
void
(
const
NodeRef
&
,
std
::
ostream
&
os
)
>
{
// NOLINT(*)
public:
public:
explicit
TextPrinter
(
runtime
::
TypedPackedFunc
<
std
::
string
(
Expr
)
>
annotate
)
:
annotate_
(
annotate
)
{}
/*!
/*!
* \brief Print a node to string.
* \brief Print a node to string.
* \param node.
* \param node.
...
@@ -279,11 +281,11 @@ class TextPrinter :
...
@@ -279,11 +281,11 @@ class TextPrinter :
TextValue
VisitExpr_
(
const
CallNode
*
op
)
final
{
TextValue
VisitExpr_
(
const
CallNode
*
op
)
final
{
// possibly through meta-data
// possibly through meta-data
TextValue
call_op
=
GetValue
(
op
->
op
);
std
::
vector
<
TextValue
>
args
;
std
::
vector
<
TextValue
>
args
;
for
(
Expr
arg
:
op
->
args
)
{
for
(
Expr
arg
:
op
->
args
)
{
args
.
emplace_back
(
GetValue
(
arg
));
args
.
emplace_back
(
GetValue
(
arg
));
}
}
TextValue
call_op
=
GetValue
(
op
->
op
);
TextValue
id
=
this
->
AllocTempVar
();
TextValue
id
=
this
->
AllocTempVar
();
this
->
PrintIndent
();
this
->
PrintIndent
();
...
@@ -532,7 +534,9 @@ class TextPrinter :
...
@@ -532,7 +534,9 @@ class TextPrinter :
*/
*/
void
PrintOptionalInfo
(
const
Expr
&
expr
)
{
void
PrintOptionalInfo
(
const
Expr
&
expr
)
{
// additional information in comment.
// additional information in comment.
if
(
expr
->
checked_type_
.
defined
())
{
if
(
annotate_
!=
nullptr
)
{
stream_
<<
" # "
<<
annotate_
(
expr
);
}
else
if
(
expr
->
checked_type_
.
defined
())
{
stream_
<<
" # ty="
;
stream_
<<
" # ty="
;
this
->
PrintType
(
expr
->
checked_type
(),
stream_
);
this
->
PrintType
(
expr
->
checked_type
(),
stream_
);
}
}
...
@@ -678,7 +682,10 @@ class TextPrinter :
...
@@ -678,7 +682,10 @@ class TextPrinter :
name
=
"%"
+
name
;
name
=
"%"
+
name
;
}
}
TextValue
val
(
GetUniqueName
(
name
));
TextValue
val
(
GetUniqueName
(
name
));
CHECK
(
!
memo_
.
count
(
var
))
<<
"Duplicated variable "
<<
var
;
// still print if ir is malformed, but show the error.
if
(
memo_
.
count
(
var
))
{
memo_
[
var
]
=
TextValue
(
val
.
name
+
"-malformed-ir"
);
}
memo_
[
var
]
=
val
;
memo_
[
var
]
=
val
;
return
val
;
return
val
;
}
}
...
@@ -686,6 +693,8 @@ class TextPrinter :
...
@@ -686,6 +693,8 @@ class TextPrinter :
private
:
private
:
class
AttrPrinter
;
class
AttrPrinter
;
friend
class
AttrPrinter
;
friend
class
AttrPrinter
;
/*! \brief additional comment function */
runtime
::
TypedPackedFunc
<
std
::
string
(
Expr
)
>
annotate_
;
/*! \brief meta data context */
/*! \brief meta data context */
TextMetaDataContext
meta_
;
TextMetaDataContext
meta_
;
/*! \brief Check whether scope is still valid */
/*! \brief Check whether scope is still valid */
...
@@ -776,12 +785,15 @@ void TextPrinter::PrintCallAttrs(const Expr& op,
...
@@ -776,12 +785,15 @@ void TextPrinter::PrintCallAttrs(const Expr& op,
os
<<
", "
<<
meta_
.
GetMetaNode
(
attrs
);
os
<<
", "
<<
meta_
.
GetMetaNode
(
attrs
);
}
}
std
::
string
RelayPrint
(
const
NodeRef
&
node
)
{
std
::
string
RelayPrint
(
const
NodeRef
&
node
,
return
TextPrinter
().
Print
(
node
);
runtime
::
TypedPackedFunc
<
std
::
string
(
Expr
)
>
annotate
)
{
return
TextPrinter
(
annotate
).
Print
(
node
);
}
}
TVM_REGISTER_API
(
"relay._expr._text_print"
)
TVM_REGISTER_API
(
"relay._expr.RelayPrint"
)
.
set_body_typed
<
std
::
string
(
const
NodeRef
&
)
>
(
RelayPrint
);
.
set_body_typed
<
std
::
string
(
const
NodeRef
&
,
runtime
::
TypedPackedFunc
<
std
::
string
(
Expr
)
>
)
>
(
RelayPrint
);
}
// namespace relay
}
// namespace relay
}
// namespace tvm
}
// namespace tvm
src/relay/pass/fold_scale_axis.cc
View file @
4369b7f6
...
@@ -10,6 +10,7 @@
...
@@ -10,6 +10,7 @@
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/expr_functor.h>
#include "pattern_util.h"
#include "pattern_util.h"
#include "pass_util.h"
#include "../op/nn/layout.h"
#include "../op/nn/layout.h"
namespace
tvm
{
namespace
tvm
{
...
@@ -580,23 +581,6 @@ using FBackwardTransform = TypedPackedFunc<
...
@@ -580,23 +581,6 @@ using FBackwardTransform = TypedPackedFunc<
//----------------------------------------------
//----------------------------------------------
// Generic Visitors for FScaleAxisBackward
// Generic Visitors for FScaleAxisBackward
//----------------------------------------------
//----------------------------------------------
/*!
* \brief Get reference counter of each internal ExprNode in body.
* \param body The body expression.
* \return The reference count mapping.
*/
std
::
unordered_map
<
const
Node
*
,
size_t
>
GetExprRefCount
(
const
Expr
&
body
)
{
class
ExprRefCounter
:
private
ExprVisitor
{
public
:
std
::
unordered_map
<
const
Node
*
,
size_t
>
Get
(
const
Expr
&
body
)
{
this
->
VisitExpr
(
body
);
return
std
::
move
(
this
->
visit_counter_
);
}
};
return
ExprRefCounter
().
Get
(
body
);
}
class
BackwardPrep
:
private
ExprVisitor
{
class
BackwardPrep
:
private
ExprVisitor
{
public
:
public
:
...
...
src/relay/pass/fuse_ops.cc
View file @
4369b7f6
This diff is collapsed.
Click to expand it.
src/relay/pass/pass_util.h
0 → 100644
View file @
4369b7f6
/*!
* Copyright (c) 2018 by Contributors.
*
* \file tvm/relay/pass/pass_util.h
* \brief Utilities for writing
*/
#ifndef TVM_RELAY_PASS_PASS_UTIL_H_
#define TVM_RELAY_PASS_PASS_UTIL_H_
#include <tvm/relay/op.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/attrs/transform.h>
namespace
tvm
{
namespace
relay
{
/*!
* \brief Get reference counter of each internal ExprNode in body.
* \param body The body expression.
* \return The reference count mapping.
*/
std
::
unordered_map
<
const
Node
*
,
size_t
>
GetExprRefCount
(
const
Expr
&
body
);
}
// namespace relay
}
// namespace tvm
#endif // TVM_RELAY_PASS_PASS_UTIL_H_
src/relay/pass/type_infer.cc
View file @
4369b7f6
...
@@ -442,6 +442,9 @@ class TypeInferencer::Resolver : public ExprMutator {
...
@@ -442,6 +442,9 @@ class TypeInferencer::Resolver : public ExprMutator {
VarNode
*
new_var
=
(
VarNode
*
new_var
=
(
std
::
is_base_of
<
VarNode
,
T
>::
value
?
std
::
is_base_of
<
VarNode
,
T
>::
value
?
static_cast
<
VarNode
*>
(
new_e
.
node_
.
get
())
:
nullptr
);
static_cast
<
VarNode
*>
(
new_e
.
node_
.
get
())
:
nullptr
);
FunctionNode
*
new_fn
=
(
std
::
is_base_of
<
FunctionNode
,
T
>::
value
?
static_cast
<
FunctionNode
*>
(
new_e
.
node_
.
get
())
:
nullptr
);
// check if we need update the new_e
// check if we need update the new_e
bool
need_update_type
=
!
checked_type
.
same_as
(
new_e
->
checked_type_
);
bool
need_update_type
=
!
checked_type
.
same_as
(
new_e
->
checked_type_
);
...
@@ -454,7 +457,17 @@ class TypeInferencer::Resolver : public ExprMutator {
...
@@ -454,7 +457,17 @@ class TypeInferencer::Resolver : public ExprMutator {
update_missing_type_annotation_
&&
update_missing_type_annotation_
&&
!
new_var
->
type_annotation
.
defined
());
!
new_var
->
type_annotation
.
defined
());
if
(
!
need_update_type
&&
!
need_update_var
&&
!
need_update_call
)
return
new_e
;
bool
need_update_fn
=
(
std
::
is_base_of
<
FunctionNode
,
T
>::
value
&&
update_missing_type_annotation_
&&
!
new_fn
->
ret_type
.
defined
());
if
(
!
need_update_type
&&
!
need_update_var
&&
!
need_update_call
&&
!
need_update_fn
)
{
return
new_e
;
}
if
(
!
new_e
.
node_
.
unique
())
{
if
(
!
new_e
.
node_
.
unique
())
{
// Copy on write optimization
// Copy on write optimization
...
@@ -467,6 +480,9 @@ class TypeInferencer::Resolver : public ExprMutator {
...
@@ -467,6 +480,9 @@ class TypeInferencer::Resolver : public ExprMutator {
new_var
=
(
new_var
=
(
std
::
is_base_of
<
VarNode
,
T
>::
value
?
std
::
is_base_of
<
VarNode
,
T
>::
value
?
static_cast
<
VarNode
*>
(
new_e
.
node_
.
get
())
:
nullptr
);
static_cast
<
VarNode
*>
(
new_e
.
node_
.
get
())
:
nullptr
);
new_fn
=
(
std
::
is_base_of
<
FunctionNode
,
T
>::
value
?
static_cast
<
FunctionNode
*>
(
new_e
.
node_
.
get
())
:
nullptr
);
}
}
// attach the information.
// attach the information.
...
@@ -483,6 +499,11 @@ class TypeInferencer::Resolver : public ExprMutator {
...
@@ -483,6 +499,11 @@ class TypeInferencer::Resolver : public ExprMutator {
if
(
need_update_var
)
{
if
(
need_update_var
)
{
new_var
->
type_annotation
=
checked_type
;
new_var
->
type_annotation
=
checked_type
;
}
}
if
(
need_update_fn
)
{
auto
*
fn_type
=
checked_type
.
as
<
FuncTypeNode
>
();
CHECK
(
fn_type
!=
nullptr
);
new_fn
->
ret_type
=
fn_type
->
ret_type
;
}
return
new_e
;
return
new_e
;
}
}
...
...
src/relay/pass/type_solver.cc
View file @
4369b7f6
...
@@ -85,18 +85,18 @@ Type TypeSolver::Unify(const Type& dst, const Type& src) {
...
@@ -85,18 +85,18 @@ Type TypeSolver::Unify(const Type& dst, const Type& src) {
void
TypeSolver
::
AddConstraint
(
const
TypeConstraint
&
constraint
)
{
void
TypeSolver
::
AddConstraint
(
const
TypeConstraint
&
constraint
)
{
if
(
auto
*
op
=
constraint
.
as
<
TypeRelationNode
>
())
{
if
(
auto
*
op
=
constraint
.
as
<
TypeRelationNode
>
())
{
// create a new relation node.
// create a new relation node.
RelationNode
*
rnode
=
make
<
RelationNode
>
();
RelationNode
*
rnode
=
arena_
.
make
<
RelationNode
>
();
rnode
->
rel
=
GetRef
<
TypeRelation
>
(
op
);
rnode
->
rel
=
GetRef
<
TypeRelation
>
(
op
);
rel_nodes_
.
push_back
(
rnode
);
rel_nodes_
.
push_back
(
rnode
);
// populate the type information.
// populate the type information.
for
(
size_t
i
=
0
;
i
<
op
->
args
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
op
->
args
.
size
();
++
i
)
{
// insert link to the type list
// insert link to the type list
LinkNode
<
TypeNode
*>*
tlink
=
make
<
LinkNode
<
TypeNode
*>
>
();
LinkNode
<
TypeNode
*>*
tlink
=
arena_
.
make
<
LinkNode
<
TypeNode
*>
>
();
TypeNode
*
tnode
=
GetTypeNode
(
op
->
args
[
i
]);
TypeNode
*
tnode
=
GetTypeNode
(
op
->
args
[
i
]);
tlink
->
value
=
tnode
;
tlink
->
value
=
tnode
;
rnode
->
type_list
.
Push
(
tlink
);
rnode
->
type_list
.
Push
(
tlink
);
// insert type->relation node
// insert type->relation node
LinkNode
<
RelationNode
*>*
rlink
=
make
<
LinkNode
<
RelationNode
*>
>
();
LinkNode
<
RelationNode
*>*
rlink
=
arena_
.
make
<
LinkNode
<
RelationNode
*>
>
();
rlink
->
value
=
rnode
;
rlink
->
value
=
rnode
;
tnode
->
rel_list
.
Push
(
rlink
);
tnode
->
rel_list
.
Push
(
rlink
);
}
}
...
...
src/relay/pass/type_solver.h
View file @
4369b7f6
...
@@ -16,6 +16,8 @@
...
@@ -16,6 +16,8 @@
namespace
tvm
{
namespace
tvm
{
namespace
relay
{
namespace
relay
{
using
common
::
LinkNode
;
using
common
::
LinkedList
;
/*!
/*!
* \brief Interface of type solver used in type inference.
* \brief Interface of type solver used in type inference.
*
*
...
@@ -70,41 +72,6 @@ class TypeSolver {
...
@@ -70,41 +72,6 @@ class TypeSolver {
// All the object in the structure is managed by a arena allocator
// All the object in the structure is managed by a arena allocator
// which releases the memory upon distruction of the type solver.
// which releases the memory upon distruction of the type solver.
/*!
/*!
* \brief Link list node
* \tparam T the content data type
*/
template
<
typename
T
>
struct
LinkNode
{
/*! \brief The content value */
T
value
;
/*! \brief pointer to the next location */
LinkNode
<
T
>*
next
{
nullptr
};
};
/*!
* \brief LinkedList structure
* \tparam T the content data type
*/
template
<
typename
T
>
struct
LinkedList
{
/*! \brief Head pointer */
LinkNode
<
T
>*
head
{
nullptr
};
/*! \brief Tail pointer */
LinkNode
<
T
>*
tail
{
nullptr
};
/*!
* \brief Push a new node to the end of the linked list.
* \param node The node to be pushed.
*/
void
Push
(
LinkNode
<
T
>*
node
)
{
node
->
next
=
nullptr
;
if
(
this
->
tail
!=
nullptr
)
{
this
->
tail
->
next
=
node
;
this
->
tail
=
node
;
}
else
{
head
=
tail
=
node
;
}
}
};
/*!
* \brief type node struct
* \brief type node struct
* TypeNode implements a union-find data structure(via parent)
* TypeNode implements a union-find data structure(via parent)
* that can unifies the same types to the name resolved_type.
* that can unifies the same types to the name resolved_type.
...
@@ -165,18 +132,6 @@ class TypeSolver {
...
@@ -165,18 +132,6 @@ class TypeSolver {
/*! \brief Reporter that reports back to self */
/*! \brief Reporter that reports back to self */
TypeReporter
reporter_
;
TypeReporter
reporter_
;
/*!
/*!
* \brief Create function to create a new node ptr via arena
* \tparam The type parameter
* \return The node pointer.
*/
template
<
typename
T
>
T
*
make
()
{
T
*
ptr
=
arena_
.
Alloc
<
T
>
();
// call constructor
new
(
ptr
)
T
();
return
ptr
;
}
/*!
* \brief GetTypeNode that is corresponds to t.
* \brief GetTypeNode that is corresponds to t.
* if it do not exist, create a new one.
* if it do not exist, create a new one.
* \return The type node.
* \return The type node.
...
@@ -186,7 +141,7 @@ class TypeSolver {
...
@@ -186,7 +141,7 @@ class TypeSolver {
if
(
it
!=
tmap_
.
end
())
{
if
(
it
!=
tmap_
.
end
())
{
return
it
->
second
->
FindRoot
();
return
it
->
second
->
FindRoot
();
}
else
{
}
else
{
TypeNode
*
n
=
make
<
TypeNode
>
();
TypeNode
*
n
=
arena_
.
make
<
TypeNode
>
();
type_nodes_
.
push_back
(
n
);
type_nodes_
.
push_back
(
n
);
n
->
resolved_type
=
t
;
n
->
resolved_type
=
t
;
tmap_
[
t
]
=
n
;
tmap_
[
t
]
=
n
;
...
...
src/relay/pass/util.cc
View file @
4369b7f6
...
@@ -129,5 +129,23 @@ TVM_REGISTER_API("relay._ir_pass.free_type_vars")
...
@@ -129,5 +129,23 @@ TVM_REGISTER_API("relay._ir_pass.free_type_vars")
}
}
});
});
/*!
* \brief Get reference counter of each internal ExprNode in body.
* \param body The body expression.
* \return The reference count mapping.
*/
std
::
unordered_map
<
const
Node
*
,
size_t
>
GetExprRefCount
(
const
Expr
&
body
)
{
class
ExprRefCounter
:
private
ExprVisitor
{
public
:
std
::
unordered_map
<
const
Node
*
,
size_t
>
Get
(
const
Expr
&
body
)
{
this
->
VisitExpr
(
body
);
return
std
::
move
(
this
->
visit_counter_
);
}
};
return
ExprRefCounter
().
Get
(
body
);
}
}
// namespace relay
}
// namespace relay
}
// namespace tvm
}
// namespace tvm
tests/python/relay/test_ir_text_printer.py
View file @
4369b7f6
...
@@ -33,6 +33,7 @@ def test_env():
...
@@ -33,6 +33,7 @@ def test_env():
text
=
env
.
astext
()
text
=
env
.
astext
()
assert
"def @myf"
in
text
assert
"def @myf"
in
text
assert
"
%1
= add(
%0
,
%0
) # ty=float32"
in
text
assert
"
%1
= add(
%0
,
%0
) # ty=float32"
in
text
show
(
env
.
astext
(
annotate
=
lambda
x
:
str
(
x
.
checked_type
.
dtype
)))
show
(
text
)
show
(
text
)
...
...
tests/python/relay/test_pass_fold_scale_axis.py
View file @
4369b7f6
...
@@ -46,6 +46,8 @@ def test_fold_fwd_simple():
...
@@ -46,6 +46,8 @@ def test_fold_fwd_simple():
weight
=
relay
.
var
(
"weight"
,
type_dict
[
"weight"
])
weight
=
relay
.
var
(
"weight"
,
type_dict
[
"weight"
])
y1_folded
=
relay
.
ir_pass
.
forward_fold_scale_axis
(
y1
)
y1_folded
=
relay
.
ir_pass
.
forward_fold_scale_axis
(
y1
)
y1_expected
=
expected
(
x
,
weight
,
in_bias
,
in_scale
,
channels
)
y1_expected
=
expected
(
x
,
weight
,
in_bias
,
in_scale
,
channels
)
y1_folded
=
relay
.
ir_pass
.
infer_type
(
y1_folded
)
y1_expected
=
relay
.
ir_pass
.
infer_type
(
y1_expected
)
assert
relay
.
ir_pass
.
alpha_equal
(
y1_folded
,
y1_expected
)
assert
relay
.
ir_pass
.
alpha_equal
(
y1_folded
,
y1_expected
)
check
((
2
,
4
,
10
,
10
),
2
)
check
((
2
,
4
,
10
,
10
),
2
)
...
@@ -113,6 +115,8 @@ def test_fold_fwd_dual_path():
...
@@ -113,6 +115,8 @@ def test_fold_fwd_dual_path():
type_dict
=
{
x
.
name_hint
:
x
.
checked_type
for
x
in
y1
.
params
}
type_dict
=
{
x
.
name_hint
:
x
.
checked_type
for
x
in
y1
.
params
}
weight
=
relay
.
var
(
"weight"
,
type_dict
[
"weight"
])
weight
=
relay
.
var
(
"weight"
,
type_dict
[
"weight"
])
y1_expected
=
expected
(
x
,
weight
,
in_bias
,
in_scale
,
channels
)
y1_expected
=
expected
(
x
,
weight
,
in_bias
,
in_scale
,
channels
)
y1_folded
=
relay
.
ir_pass
.
infer_type
(
y1_folded
)
y1_expected
=
relay
.
ir_pass
.
infer_type
(
y1_expected
)
assert
relay
.
ir_pass
.
alpha_equal
(
y1_folded
,
y1_expected
)
assert
relay
.
ir_pass
.
alpha_equal
(
y1_folded
,
y1_expected
)
check
((
2
,
4
,
10
,
3
),
3
)
check
((
2
,
4
,
10
,
3
),
3
)
...
@@ -194,6 +198,8 @@ def test_fold_bwd_simple():
...
@@ -194,6 +198,8 @@ def test_fold_bwd_simple():
weight
=
relay
.
var
(
"weight"
,
type_dict
[
"weight"
])
weight
=
relay
.
var
(
"weight"
,
type_dict
[
"weight"
])
y1_folded
=
relay
.
ir_pass
.
backward_fold_scale_axis
(
y1
)
y1_folded
=
relay
.
ir_pass
.
backward_fold_scale_axis
(
y1
)
y1_expected
=
expected
(
x
,
weight
,
out_bias
,
out_scale
,
channels
)
y1_expected
=
expected
(
x
,
weight
,
out_bias
,
out_scale
,
channels
)
y1_folded
=
relay
.
ir_pass
.
infer_type
(
y1_folded
)
y1_expected
=
relay
.
ir_pass
.
infer_type
(
y1_expected
)
assert
relay
.
ir_pass
.
alpha_equal
(
y1_folded
,
y1_expected
)
assert
relay
.
ir_pass
.
alpha_equal
(
y1_folded
,
y1_expected
)
check
((
2
,
4
,
10
,
10
),
8
)
check
((
2
,
4
,
10
,
10
),
8
)
...
@@ -255,6 +261,8 @@ def test_fold_bwd_dual_path():
...
@@ -255,6 +261,8 @@ def test_fold_bwd_dual_path():
weight
=
relay
.
var
(
"weight"
,
type_dict
[
"weight"
])
weight
=
relay
.
var
(
"weight"
,
type_dict
[
"weight"
])
y1_folded
=
relay
.
ir_pass
.
backward_fold_scale_axis
(
y1
)
y1_folded
=
relay
.
ir_pass
.
backward_fold_scale_axis
(
y1
)
y1_expected
=
expected
(
x
,
weight
,
out_bias
,
out_scale
,
channels
)
y1_expected
=
expected
(
x
,
weight
,
out_bias
,
out_scale
,
channels
)
y1_folded
=
relay
.
ir_pass
.
infer_type
(
y1_folded
)
y1_expected
=
relay
.
ir_pass
.
infer_type
(
y1_expected
)
assert
relay
.
ir_pass
.
alpha_equal
(
y1_folded
,
y1_expected
)
assert
relay
.
ir_pass
.
alpha_equal
(
y1_folded
,
y1_expected
)
check
((
2
,
4
,
10
,
10
),
8
)
check
((
2
,
4
,
10
,
10
),
8
)
...
...
tests/python/relay/test_pass_fuse_ops.py
View file @
4369b7f6
...
@@ -3,15 +3,103 @@ from tvm import relay
...
@@ -3,15 +3,103 @@ from tvm import relay
def
test_fuse_simple
():
def
test_fuse_simple
():
"""Simple testcase."""
"""Simple testcase."""
def
before
():
x
=
relay
.
var
(
"x"
,
shape
=
(
10
,
20
))
x
=
relay
.
var
(
"x"
,
shape
=
(
10
,
20
))
y
=
relay
.
add
(
x
,
x
)
y
=
relay
.
add
(
x
,
relay
.
const
(
1
,
"float32"
)
)
z
=
relay
.
exp
(
y
)
z
=
relay
.
exp
(
y
)
return
relay
.
Function
([
x
],
z
)
def
expected
():
x
=
relay
.
var
(
"p"
,
shape
=
(
10
,
20
))
y
=
relay
.
add
(
x
,
relay
.
const
(
1
,
"float32"
))
z
=
relay
.
exp
(
y
)
f1
=
relay
.
Function
([
x
],
z
)
x
=
relay
.
var
(
"x"
,
shape
=
(
10
,
20
))
y
=
relay
.
Call
(
f1
,
[
x
])
return
relay
.
Function
([
x
],
y
)
z
=
before
()
z
=
relay
.
ir_pass
.
infer_type
(
z
)
z
=
relay
.
ir_pass
.
infer_type
(
z
)
zz
=
relay
.
ir_pass
.
fuse_ops
(
z
)
zz
=
relay
.
ir_pass
.
fuse_ops
(
z
,
opt_level
=
2
)
zz
=
relay
.
ir_pass
.
infer_type
(
zz
)
zz
=
relay
.
ir_pass
.
fuse_ops
(
zz
)
zz
=
relay
.
ir_pass
.
fuse_ops
(
zz
)
zz
=
relay
.
ir_pass
.
infer_type
(
zz
)
zz
=
relay
.
ir_pass
.
infer_type
(
zz
)
zz
.
astext
()
after
=
relay
.
ir_pass
.
infer_type
(
expected
())
assert
relay
.
ir_pass
.
alpha_equal
(
zz
,
after
)
def
test_conv2d_fuse
():
"""Test fusion case of conv2d"""
def
before
(
dshape
):
x
=
relay
.
var
(
"x"
,
shape
=
dshape
)
y
=
relay
.
nn
.
conv2d
(
x
,
relay
.
var
(
"w1"
),
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
),
channels
=
16
)
# this is the next dominator.
y1
=
relay
.
add
(
relay
.
const
(
1
,
"float32"
),
y
)
y
=
relay
.
add
(
y
,
y1
)
# second path
z2
=
relay
.
nn
.
conv2d
(
y
,
relay
.
var
(
"w2"
),
kernel_size
=
(
1
,
1
),
padding
=
(
0
,
0
),
channels
=
16
)
z3
=
relay
.
nn
.
conv2d
(
y
,
relay
.
var
(
"w3"
),
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
),
channels
=
16
)
# add can only be fused to z1
z
=
relay
.
add
(
z2
,
z3
)
return
relay
.
Function
(
relay
.
ir_pass
.
free_vars
(
z
),
z
)
def
expected
(
dshape
):
# segment 1
x
=
relay
.
var
(
"p0"
,
shape
=
dshape
)
w
=
relay
.
var
(
"p1"
)
y
=
relay
.
nn
.
conv2d
(
x
,
w
,
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
),
channels
=
16
)
y1
=
relay
.
add
(
relay
.
const
(
1
,
"float32"
),
y
)
y
=
relay
.
add
(
y
,
y1
)
f1
=
relay
.
Function
([
x
,
w
],
y
)
# segment 2
x
=
relay
.
var
(
"p0"
,
shape
=
dshape
)
w
=
relay
.
var
(
"p1"
)
z2
=
relay
.
nn
.
conv2d
(
x
,
w
,
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
),
channels
=
16
)
f2
=
relay
.
Function
([
x
,
w
],
z2
)
# segment 3
x
=
relay
.
var
(
"p0"
,
shape
=
dshape
)
w
=
relay
.
var
(
"p1"
)
offset
=
relay
.
var
(
"p2"
,
shape
=
dshape
)
z3
=
relay
.
nn
.
conv2d
(
x
,
w
,
kernel_size
=
(
1
,
1
),
padding
=
(
0
,
0
),
channels
=
16
)
z3
=
relay
.
add
(
z3
,
offset
)
f3
=
relay
.
Function
([
x
,
w
,
offset
],
z3
)
# compose
x
=
relay
.
var
(
"x"
,
shape
=
dshape
)
y
=
relay
.
Call
(
f1
,
[
x
,
relay
.
var
(
"w1"
)])
z2
=
relay
.
Call
(
f2
,
[
y
,
relay
.
var
(
"w3"
)])
z3
=
relay
.
Call
(
f3
,
[
y
,
relay
.
var
(
"w2"
),
z2
])
z
=
z3
return
relay
.
Function
(
relay
.
ir_pass
.
free_vars
(
z
),
z
)
dshape
=
(
1
,
16
,
64
,
64
)
z
=
before
(
dshape
)
z
=
relay
.
ir_pass
.
infer_type
(
z
)
zz
=
relay
.
ir_pass
.
fuse_ops
(
z
,
opt_level
=
2
)
zz
=
relay
.
ir_pass
.
infer_type
(
zz
)
after
=
relay
.
ir_pass
.
infer_type
(
expected
(
dshape
))
assert
relay
.
ir_pass
.
alpha_equal
(
zz
,
after
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_fuse_simple
()
test_fuse_simple
()
test_conv2d_fuse
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment