Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
73dda6be
Commit
73dda6be
authored
Dec 26, 2019
by
Animesh Jain
Committed by
Yizhi Liu
Dec 26, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay] Convert Layout Pass. (#4335)
parent
641024f5
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1094 additions
and
323 deletions
+1094
-323
include/tvm/relay/op_attr_types.h
+17
-0
include/tvm/relay/transform.h
+20
-0
python/tvm/relay/op/nn/_nn.py
+41
-0
python/tvm/relay/op/op.py
+17
-0
python/tvm/relay/transform.py
+28
-0
src/relay/op/annotation/annotation.cc
+1
-1
src/relay/op/device_copy.cc
+1
-1
src/relay/op/memory/memory.cc
+1
-1
src/relay/op/nn/bitserial.cc
+1
-1
src/relay/op/nn/convolution.cc
+1
-1
src/relay/op/nn/nn.cc
+1
-1
src/relay/op/nn/pooling.cc
+1
-1
src/relay/op/nn/sparse.cc
+1
-1
src/relay/op/op_common.h
+1
-1
src/relay/op/tensor/transform.cc
+1
-1
src/relay/pass/alter_op_layout.cc
+53
-308
src/relay/pass/convert_layout.cc
+146
-0
src/relay/pass/infer_layout_util.h
+40
-5
src/relay/pass/transform_layout.h
+362
-0
tests/python/relay/test_pass_convert_op_layout.py
+360
-0
No files found.
include/tvm/relay/op_attr_types.h
View file @
73dda6be
...
...
@@ -29,6 +29,7 @@
#include <tvm/build_module.h>
#include <tvm/relay/type.h>
#include <tvm/relay/expr.h>
#include <string>
namespace
tvm
{
namespace
relay
{
...
...
@@ -133,6 +134,22 @@ using FTVMAlterOpLayout = runtime::TypedPackedFunc<
const
Array
<
Tensor
>&
tinfos
)
>
;
/*!
* \brief Convert the layout of operators or replace the
* operator with other expressions. This function will be invoked
* in ConvertLayout pass.
* \param attrs The attribute of the original node.
* \param inputs The input symbols of the original node.
* \param tinfos An array of placeholders, use for getting the inferred shape
* and dtype of the inputs.
* \param desired_layout The desired layout.
* \return new_expr The modified expression.
*/
using
FTVMConvertOpLayout
=
runtime
::
TypedPackedFunc
<
Expr
(
const
Attrs
&
attrs
,
const
Array
<
Expr
>&
args
,
const
Array
<
Tensor
>&
tinfos
,
const
std
::
string
&
desired_layout
)
>
;
/*!
* \brief Legalizes an expression with another expression. This function will be
* invoked in Legalize pass. It is a target-dependent pass.
* \param attrs The attribute of the original node.
...
...
include/tvm/relay/transform.h
View file @
73dda6be
...
...
@@ -533,6 +533,26 @@ TVM_DLL Pass CanonicalizeOps();
TVM_DLL
Pass
AlterOpLayout
();
/*!
* \brief Given a dest layout, this pass transforms the expr such that most of the ops input data
* layout is changed to the dest layout. In ideal situation, there are only 2 layout transforms, one
* at the start and one at the end.
*
* This pass is not a part of relay.build and is expected to be called between framework-relay
* parser and relay.build call. This is very helpful for hardware backends that support/prefer only
* type of data layout.
*
* RFC - https://discuss.tvm.ai/t/layout-conversion-pass/4009
*
* This pass uses most of the AlterOpLayout and InferCorrectLayout infrastructure. We can define new
* layouts for conv2d ops for now. Most of the other operators try to adapt to their input layout
* using the InferCorrectLayout infrastructure.
*
* \param desired_layout The desired layout.
* \return The pass.
*/
TVM_DLL
Pass
ConvertLayout
(
const
std
::
string
&
desired_layout
);
/*!
* \brief Legalizes an expr with another expression.
* \param legalize_map_attr_name The Op's attr name which corresponds to the legalize rule function.
* One can collect and isolate similar type of legalize transformations using this param. For
...
...
python/tvm/relay/op/nn/_nn.py
View file @
73dda6be
...
...
@@ -251,6 +251,47 @@ def legalize_conv2d(attrs, inputs, types):
"""
return
topi
.
nn
.
conv2d_legalize
(
attrs
,
inputs
,
types
)
@reg.register_convert_op_layout
(
"nn.conv2d"
)
def
convert_conv2d
(
attrs
,
inputs
,
tinfos
,
desired_layout
):
"""Convert Layout pass registration for conv2d op.
Parameters
----------
attrs : tvm.attrs.Attrs
Attributes of current convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
tinfos : list of types
List of input and output types
desired_layout : str
The desired layout
Returns
-------
result : tvm.relay.Expr
The transformed expr
"""
from
tvm
import
relay
data_layout
=
attrs
[
'data_layout'
]
kernel_layout
=
attrs
[
'kernel_layout'
]
data
,
weight
=
inputs
assert
desired_layout
==
'NCHW'
,
\
"Currently only transformation to NCHW layout is supported."
if
desired_layout
==
'NCHW'
:
new_attrs
=
dict
(
attrs
)
new_attrs
[
'data_layout'
]
=
desired_layout
new_attrs
[
'kernel_layout'
]
=
'OIHW'
if
data_layout
==
'NHWC'
and
kernel_layout
==
'HWIO'
:
# Convert (NHWC, HWIO) to (NCHW, OIHW)
return
relay
.
nn
.
conv2d
(
data
,
weight
,
**
new_attrs
)
if
data_layout
==
'NHWC'
and
kernel_layout
==
'HWOI'
:
# Convert (NHWC, HWOI) to (NCHW, OIHW). Depthwise conv2d.
return
relay
.
nn
.
conv2d
(
data
,
weight
,
**
new_attrs
)
return
None
reg
.
register_pattern
(
"nn.conv2d"
,
OpPattern
.
OUT_ELEMWISE_FUSABLE
)
...
...
python/tvm/relay/op/op.py
View file @
73dda6be
...
...
@@ -196,6 +196,23 @@ def register_alter_op_layout(op_name, alter_layout=None, level=10):
return
register
(
op_name
,
"FTVMAlterOpLayout"
,
alter_layout
,
level
)
def
register_convert_op_layout
(
op_name
,
convert_layout
=
None
,
level
=
10
):
"""Register convert op layout function for an op
Parameters
----------
op_name : str
The name of the operator
convert_layout: function (attrs: Attrs, inputs: List[Expr]) -> new_expr: Expr
The function for changing the layout or replacing the operator
level : int
The priority level
"""
return
register
(
op_name
,
"FTVMConvertOpLayout"
,
convert_layout
,
level
)
def
register_legalize
(
op_name
,
legal_op
=
None
,
level
=
10
):
"""Register legal transformation function for an op
...
...
python/tvm/relay/transform.py
View file @
73dda6be
...
...
@@ -460,6 +460,34 @@ def AlterOpLayout():
return
_transform
.
AlterOpLayout
()
def
ConvertLayout
(
desired_layout
):
""" Given a dest layout, this pass transforms the expr such that most of the ops input data
layout is changed to the dest layout. In ideal situation, there are only 2 layout transforms,
one at the start and one at the end.
This pass is not a part of relay.build and is expected to be called between framework-relay
parser and relay.build call. This is very helpful for hardware backends that support/prefer only
type of data layout.
RFC - https://discuss.tvm.ai/t/layout-conversion-pass/4009
This pass uses most of the AlterOpLayout and InferCorrectLayout infrastructure. We can define
new layouts for conv2d ops for now. Most of the other operators try to adapt to their input
layout using the InferCorrectLayout infrastructure.
Parameters
----------
desired_layout : str
The desired layout for the transformed expr.
Returns
-------
pass: FunctionPass
The pass.
"""
return
_transform
.
ConvertLayout
(
desired_layout
)
def
Legalize
(
legalize_map_attr_name
=
"FTVMLegalize"
):
"""Legalizes an expression with another expression.
This pass can be used to replace an expr with another expr for target
...
...
src/relay/op/annotation/annotation.cc
View file @
73dda6be
...
...
@@ -30,7 +30,7 @@
#include <tvm/relay/op_attr_types.h>
#include <topi/elemwise.h>
#include "../../pass/
alter_op_layout
.h"
#include "../../pass/
infer_layout_util
.h"
#include "../type_relations.h"
namespace
tvm
{
...
...
src/relay/op/device_copy.cc
View file @
73dda6be
...
...
@@ -33,7 +33,7 @@
#include <tvm/relay/op_attr_types.h>
#include "type_relations.h"
#include "../pass/
alter_op_layout
.h"
#include "../pass/
infer_layout_util
.h"
namespace
tvm
{
namespace
relay
{
...
...
src/relay/op/memory/memory.cc
View file @
73dda6be
...
...
@@ -29,7 +29,7 @@
#include <tvm/relay/attrs/memory.h>
#include "../op_common.h"
#include "../../pass/
alter_op_layout
.h"
#include "../../pass/
infer_layout_util
.h"
#include "../type_relations.h"
namespace
tvm
{
...
...
src/relay/op/nn/bitserial.cc
View file @
73dda6be
...
...
@@ -26,7 +26,7 @@
#include <tvm/relay/attrs/bitserial.h>
#include <tvm/relay/op.h>
#include "../../pass/
alter_op_layout
.h"
#include "../../pass/
infer_layout_util
.h"
namespace
tvm
{
namespace
relay
{
...
...
src/relay/op/nn/convolution.cc
View file @
73dda6be
...
...
@@ -27,7 +27,7 @@
#include <tvm/relay/attrs/nn.h>
#include <vector>
#include "../../pass/
alter_op_layout
.h"
#include "../../pass/
infer_layout_util
.h"
#include "../op_common.h"
#include "convolution.h"
...
...
src/relay/op/nn/nn.cc
View file @
73dda6be
...
...
@@ -33,7 +33,7 @@
#include <vector>
#include <string>
#include "../type_relations.h"
#include "../../pass/
alter_op_layout
.h"
#include "../../pass/
infer_layout_util
.h"
#include "../op_common.h"
#include "nn.h"
...
...
src/relay/op/nn/pooling.cc
View file @
73dda6be
...
...
@@ -27,7 +27,7 @@
#include <tvm/relay/attrs/nn.h>
#include <topi/nn/pooling.h>
#include <vector>
#include "../../pass/
alter_op_layout
.h"
#include "../../pass/
infer_layout_util
.h"
namespace
tvm
{
namespace
relay
{
...
...
src/relay/op/nn/sparse.cc
View file @
73dda6be
...
...
@@ -27,7 +27,7 @@
#include <tvm/relay/attrs/nn.h>
#include <vector>
#include "../../pass/
alter_op_layout
.h"
#include "../../pass/
infer_layout_util
.h"
namespace
tvm
{
namespace
relay
{
...
...
src/relay/op/op_common.h
View file @
73dda6be
...
...
@@ -32,7 +32,7 @@
#include <string>
#include <unordered_map>
#include "type_relations.h"
#include "../pass/
alter_op_layout
.h"
#include "../pass/
infer_layout_util
.h"
namespace
tvm
{
namespace
relay
{
...
...
src/relay/op/tensor/transform.cc
View file @
73dda6be
...
...
@@ -36,7 +36,7 @@
#include <vector>
#include "../op_common.h"
#include "../../../arithmetic/compute_expr.h"
#include "../../pass/
alter_op_layout
.h"
#include "../../pass/
infer_layout_util
.h"
#include "../../pass/pattern_util.h"
#include "transform.h"
...
...
src/relay/pass/alter_op_layout.cc
View file @
73dda6be
...
...
@@ -36,7 +36,7 @@
#include <utility>
#include <unordered_map>
#include "
alter_op
_layout.h"
#include "
transform
_layout.h"
#include "pattern_util.h"
namespace
tvm
{
...
...
@@ -44,328 +44,73 @@ namespace relay {
namespace
alter_op_layout
{
// Make a transform CallNode
/* Performs 2 operations
* 1) If src_layout ndim is smaller then dst_layout, expand_dim is inserted to match the dim size.
* For example, src_layout = C, dst_layout = NCHW16c. The src is expanded to NHWC.
* 2) Call layout transform with new src layout.
/*!
* \brief Container to instantiate a Node for alter op layouts.
*/
Expr
TransformLayout
(
Expr
raw
,
Layout
src_layout
,
Layout
dst_layout
)
{
if
(
src_layout
.
Equals
(
dst_layout
))
{
return
raw
;
}
// 1) Check if the shape lengths are different. If yes, expand dims.
Expr
input_expr
=
raw
;
Layout
new_src_layout
=
src_layout
;
if
(
src_layout
.
ndim_primal
()
<
dst_layout
.
ndim_primal
())
{
int
num_new_axis
=
dst_layout
.
ndim_primal
()
-
src_layout
.
ndim_primal
();
new_src_layout
=
src_layout
.
ExpandPrimal
(
dst_layout
);
input_expr
=
MakeExpandDims
(
input_expr
,
0
,
num_new_axis
);
if
(
new_src_layout
.
Equals
(
dst_layout
))
{
return
input_expr
;
}
}
// 2) Insert layout transform on the transformed src.
CHECK
(
new_src_layout
.
defined
()
&&
dst_layout
.
defined
())
<<
"Cannot insert layout transform because there are undefined layouts"
;
CHECK
(
BijectiveLayoutNode
::
make
(
new_src_layout
,
dst_layout
).
defined
())
<<
"Cannot insert layout transform because there are inconvertible layouts: "
<<
new_src_layout
<<
" v.s. "
<<
dst_layout
;
return
MakeLayoutTransform
(
input_expr
,
new_src_layout
.
name
(),
dst_layout
.
name
());
}
// Memorize layout transform so we can reuse internal transformed nodes
class
TransformMemorizerNode
:
public
Node
{
class
AlterTransformMemorizerNode
:
public
TransformMemorizerNode
{
public
:
// map from (Expr, src_layout, dst_layout) to transformed Expr
using
TransformKey
=
std
::
tuple
<
const
Node
*
,
std
::
string
,
std
::
string
>
;
struct
key_hash
:
public
std
::
function
<
std
::
size_t
(
TransformKey
)
>
{
std
::
size_t
operator
()(
const
TransformKey
&
k
)
const
{
return
dmlc
::
HashCombine
<
std
::
string
>
(
dmlc
::
HashCombine
<
std
::
string
>
(
std
::
hash
<
const
Node
*>
()(
std
::
get
<
0
>
(
k
)),
std
::
get
<
1
>
(
k
)),
(
std
::
get
<
2
>
(
k
)));
}
};
std
::
unordered_map
<
TransformKey
,
Expr
,
key_hash
>
memo
;
static
constexpr
const
char
*
_type_key
=
"relay.alter_op_layout.TransformMemorizerNode"
;
TVM_DECLARE_NODE_TYPE_INFO
(
TransformMemorizerNode
,
Node
);
};
class
TransformMemorizer
:
public
NodeRef
{
public
:
TransformMemorizer
()
{}
explicit
TransformMemorizer
(
ObjectPtr
<
Object
>
n
)
:
NodeRef
(
n
)
{}
TransformMemorizerNode
*
operator
->
()
{
return
static_cast
<
TransformMemorizerNode
*>
(
get_mutable
());
}
// Transform layout with memorizer
Expr
Transform
(
Expr
raw
,
const
Layout
&
src_layout
,
const
Layout
&
dst_layout
)
{
if
(
src_layout
.
Equals
(
dst_layout
))
{
return
raw
;
}
std
::
tuple
<
const
Node
*
,
std
::
string
,
std
::
string
>
key
=
std
::
make_tuple
<>
(
raw
.
get
(),
src_layout
.
name
(),
dst_layout
.
name
());
auto
&
memo
=
operator
->
()
->
memo
;
auto
iter
=
memo
.
find
(
key
);
if
(
iter
!=
memo
.
end
())
{
return
iter
->
second
;
}
else
{
Expr
transform
=
TransformLayout
(
raw
,
src_layout
,
dst_layout
);
memo
[
key
]
=
transform
;
return
transform
;
}
}
using
ContainerType
=
TransformMemorizerNode
;
static
constexpr
const
char
*
_type_key
=
"relay.alter_op_layout.AlterTransformMemorizerNode"
;
};
// TempExprNode during layout transform
// Instance of this expr will be Realized to normal expr ultimately
class
LayoutAlternatedExprNode
:
public
TempExprNode
{
/*!
* \brief Container that provides the transformation function for alter layout..
*/
class
AlterTransformMemorizer
:
public
TransformMemorizer
{
public
:
Expr
value
;
Layout
old_layout
;
Layout
new_layout
;
TransformMemorizer
memorizer
;
Expr
Realize
()
const
final
{
// NOTE: use a copy to discard the "const" qualifier
TransformMemorizer
tmp_memorizer
=
memorizer
;
// fallback to old layout
return
tmp_memorizer
.
Transform
(
value
,
new_layout
,
old_layout
);
}
void
VisitAttrs
(
AttrVisitor
*
v
)
{
v
->
Visit
(
"value"
,
&
value
);
v
->
Visit
(
"old_layout"
,
&
old_layout
);
v
->
Visit
(
"new_layout"
,
&
new_layout
);
}
static
constexpr
const
char
*
_type_key
=
"relay.alter_op_layout.LayoutAlternatedExprNode"
;
TVM_DECLARE_NODE_TYPE_INFO
(
LayoutAlternatedExprNode
,
TempExprNode
);
};
RELAY_DEFINE_NODE_REF
(
LayoutAlternatedExpr
,
LayoutAlternatedExprNode
,
TempExpr
);
// Call registered FInferCorrectLayout of an op.
// Parameters are the same as the parameters for FInferCorrectLayout
// Returns inferred_input_layout, inferred_output_layout, success
std
::
tuple
<
Array
<
Layout
>
,
Array
<
Layout
>
,
bool
>
CallInfer
(
const
Call
&
call
,
const
Array
<
Layout
>&
new_in_layouts
,
const
Array
<
Layout
>&
old_in_layouts
,
const
Array
<
Array
<
IndexExpr
>
>
&
old_in_shapes
)
{
static
auto
finfer_layout
=
Op
::
GetAttr
<
FInferCorrectLayout
>
(
"FInferCorrectLayout"
);
if
(
!
call
->
op
.
as
<
OpNode
>
())
{
return
std
::
make_tuple
<>
(
Array
<
Layout
>
(
nullptr
),
Array
<
Layout
>
(
nullptr
),
false
);
}
Op
op
=
Downcast
<
Op
>
(
call
->
op
);
if
(
finfer_layout
.
count
(
op
))
{
Array
<
Array
<
Layout
>
>
inferred_layouts
;
inferred_layouts
=
finfer_layout
[
op
](
call
->
attrs
,
new_in_layouts
,
old_in_layouts
,
old_in_shapes
);
CHECK_EQ
(
inferred_layouts
.
size
(),
2
)
<<
"FInferCorrectLayout should return an array with size of 2"
;
for
(
auto
x
:
inferred_layouts
)
{
for
(
auto
y
:
x
)
{
if
(
!
y
.
defined
())
{
// inference fails
return
std
::
make_tuple
<>
(
Array
<
Layout
>
(
nullptr
),
Array
<
Layout
>
(
nullptr
),
false
);
}
AlterTransformMemorizer
()
{}
explicit
AlterTransformMemorizer
(
ObjectPtr
<
Object
>
n
)
:
TransformMemorizer
(
n
)
{}
AlterTransformMemorizerNode
*
operator
->
()
{
return
static_cast
<
AlterTransformMemorizerNode
*>
(
get_mutable
());
}
/*!
* \brief Defines the call transformation for AlterOpLayout pass. The new layouts are defined by
* used for different targets using a packed func.
* \param ref_call The original call.
* \param new_args The traversed/recursed args to the call.
* \return The new Call after calling the packed func.
*/
Call
CallWithNewLayouts
(
const
Call
&
ref_call
,
const
std
::
vector
<
Expr
>&
new_args
)
override
{
static
auto
falter_layout
=
Op
::
GetAttr
<
FTVMAlterOpLayout
>
(
"FTVMAlterOpLayout"
);
Op
op
=
Downcast
<
Op
>
(
ref_call
->
op
);
Expr
new_e
;
bool
modified
=
false
;
if
(
falter_layout
.
count
(
op
))
{
tvm
::
Array
<
tvm
::
Tensor
>
tinfos
;
for
(
auto
expr
:
ref_call
->
args
)
{
auto
ttype
=
expr
->
type_as
<
TensorTypeNode
>
();
tinfos
.
push_back
(
tvm
::
placeholder
(
ttype
->
shape
,
ttype
->
dtype
));
}
}
return
std
::
make_tuple
<>
(
inferred_layouts
[
0
],
inferred_layouts
[
1
],
true
);
}
else
{
return
std
::
make_tuple
<>
(
Array
<
Layout
>
(
nullptr
),
Array
<
Layout
>
(
nullptr
),
false
);
}
}
// Call registered FTVMAlterOpLayout of an op
// Returns the altered expression
Call
CallAlter
(
const
Call
&
ref_call
,
const
std
::
vector
<
Expr
>&
new_args
)
{
static
auto
falter_layout
=
Op
::
GetAttr
<
FTVMAlterOpLayout
>
(
"FTVMAlterOpLayout"
);
Op
op
=
Downcast
<
Op
>
(
ref_call
->
op
);
Expr
new_e
;
bool
modified
=
false
;
if
(
falter_layout
.
count
(
op
))
{
tvm
::
Array
<
tvm
::
Tensor
>
tinfos
;
for
(
auto
expr
:
ref_call
->
args
)
{
auto
ttype
=
expr
->
type_as
<
TensorTypeNode
>
();
tinfos
.
push_back
(
tvm
::
placeholder
(
ttype
->
shape
,
ttype
->
dtype
));
}
Expr
altered_value
=
falter_layout
[
op
](
ref_call
->
attrs
,
new_args
,
tinfos
);
if
(
altered_value
.
defined
())
{
new_e
=
altered_value
;
modified
=
true
;
}
}
if
(
!
modified
)
{
new_e
=
CallNode
::
make
(
ref_call
->
op
,
new_args
,
ref_call
->
attrs
);
}
const
CallNode
*
new_call
=
new_e
.
as
<
CallNode
>
();
CHECK
(
new_call
)
<<
"Can only replace the original operator with another call node"
;
return
GetRef
<
Call
>
(
new_call
);
}
Expr
AlterOpLayoutRewrite
(
const
Call
&
ref_call
,
const
Array
<
Expr
>
&
new_args
,
const
NodeRef
&
ctx
)
{
std
::
vector
<
LayoutAlternatedExpr
>
inputs
;
std
::
vector
<
Expr
>
normal_new_args
;
Array
<
Array
<
IndexExpr
>
>
input_shapes
;
// NOTE: discard the "const" qualifier
TransformMemorizer
memorizer
=
Downcast
<
TransformMemorizer
>
(
ctx
);
// fill incomplete state and flatten tuple
auto
push_back_one_arg
=
[
&
inputs
,
memorizer
](
Expr
arg
)
{
// We always expect LayoutAlternatedExpr.
// This is used to convert the normal Expr to LayoutAlternatedExpr.
if
(
const
LayoutAlternatedExprNode
*
inp
=
arg
.
as
<
LayoutAlternatedExprNode
>
())
{
inputs
.
push_back
(
GetRef
<
LayoutAlternatedExpr
>
(
inp
));
return
inp
->
value
;
}
else
{
auto
inode
=
make_node
<
LayoutAlternatedExprNode
>
();
inode
->
value
=
arg
;
inode
->
memorizer
=
memorizer
;
inputs
.
push_back
(
LayoutAlternatedExpr
(
inode
));
return
arg
;
}
};
for
(
auto
new_arg
:
new_args
)
{
// NOTE: do not support nested tuple
if
(
new_arg
->
IsInstance
<
TupleNode
>
())
{
Tuple
tuple_new_arg
=
Downcast
<
Tuple
>
(
new_arg
);
std
::
vector
<
Expr
>
fields
;
for
(
auto
x
:
tuple_new_arg
->
fields
)
{
Expr
tmp
=
push_back_one_arg
(
x
);
fields
.
push_back
(
tmp
);
Expr
altered_value
=
falter_layout
[
op
](
ref_call
->
attrs
,
new_args
,
tinfos
);
if
(
altered_value
.
defined
())
{
new_e
=
altered_value
;
modified
=
true
;
}
normal_new_args
.
push_back
(
TupleNode
::
make
(
fields
));
}
else
{
Expr
tmp
=
push_back_one_arg
(
new_arg
);
normal_new_args
.
push_back
(
tmp
);
}
}
// old_in, new_in = state[inputs]
Array
<
Layout
>
old_in
,
old_out
,
new_in
,
new_out
,
new_in2
;
for
(
auto
inp
:
inputs
)
{
old_in
.
push_back
(
inp
->
old_layout
);
new_in
.
push_back
(
inp
->
new_layout
);
}
for
(
auto
arg
:
ref_call
->
args
)
{
if
(
arg
->
IsInstance
<
TupleNode
>
())
{
// flatten tuple
Tuple
tuple_arg
=
Downcast
<
Tuple
>
(
arg
);
for
(
auto
x
:
tuple_arg
->
fields
)
{
input_shapes
.
push_back
(
x
->
type_as
<
TensorTypeNode
>
()
->
shape
);
}
}
else
{
input_shapes
.
push_back
(
arg
->
type_as
<
TensorTypeNode
>
()
->
shape
);
if
(
!
modified
)
{
new_e
=
CallNode
::
make
(
ref_call
->
op
,
new_args
,
ref_call
->
attrs
);
}
}
// old_in, old_out = op.infer(old_in)
bool
success
=
false
;
std
::
tie
(
old_in
,
old_out
,
success
)
=
CallInfer
(
ref_call
,
Array
<
Layout
>
(
nullptr
),
old_in
,
input_shapes
);
if
(
!
success
)
{
return
Expr
(
nullptr
);
}
CHECK_EQ
(
old_in
.
size
(),
new_in
.
size
());
// if new_in == 'undef': new_in = old_in
for
(
size_t
i
=
0
;
i
<
new_in
.
size
();
++
i
)
{
if
(
!
new_in
[
i
].
defined
())
{
new_in
.
Set
(
i
,
old_in
[
i
]);
}
}
// new_op = alter(op)
Call
new_call
=
CallAlter
(
ref_call
,
normal_new_args
);
// new_in2, new_out = op.infer(new_in)
if
(
new_call
->
op
->
IsInstance
<
OpNode
>
())
{
success
=
false
;
std
::
tie
(
new_in2
,
new_out
,
success
)
=
CallInfer
(
new_call
,
new_in
,
old_in
,
input_shapes
);
if
(
!
success
)
{
return
Expr
(
nullptr
);
}
}
else
{
return
Expr
(
nullptr
);
}
CHECK_EQ
(
new_out
.
size
(),
old_out
.
size
())
<<
"The number of output nodes should keep the same during alter_op_layout"
;
CHECK_EQ
(
new_in
.
size
(),
new_in2
.
size
())
<<
"The number of input nodes should keep the same during alter_op_layout"
;
// if (new_in != new_in2): insert transform (new_in -> new_in2)
Array
<
Expr
>
transformed_args
;
size_t
pt
=
0
;
for
(
auto
arg
:
new_call
->
args
)
{
if
(
arg
->
IsInstance
<
TupleNode
>
())
{
// unflatten tuple
Tuple
tuple_arg
=
Downcast
<
Tuple
>
(
arg
);
std
::
vector
<
Expr
>
transformed_tuple_arg
;
for
(
auto
arg_item
:
tuple_arg
->
fields
)
{
transformed_tuple_arg
.
push_back
(
memorizer
.
Transform
(
arg_item
,
new_in
[
pt
],
new_in2
[
pt
]));
pt
++
;
}
transformed_args
.
push_back
(
TupleNode
::
make
(
transformed_tuple_arg
));
}
else
{
transformed_args
.
push_back
(
memorizer
.
Transform
(
arg
,
new_in
[
pt
],
new_in2
[
pt
]));
pt
++
;
}
const
CallNode
*
new_call
=
new_e
.
as
<
CallNode
>
();
CHECK
(
new_call
)
<<
"Can only replace the original operator with another call node"
;
return
GetRef
<
Call
>
(
new_call
);
}
CHECK_EQ
(
pt
,
inputs
.
size
());
// state[node] = (old_out, new_out)
// (handle tuple output)
if
(
ref_call
->
checked_type
()
->
IsInstance
<
TupleTypeNode
>
())
{
Expr
tuple_output
=
CallNode
::
make
(
new_call
->
op
,
transformed_args
,
new_call
->
attrs
);
Array
<
Expr
>
fields
;
for
(
size_t
i
=
0
;
i
<
new_out
.
size
();
++
i
)
{
auto
rnode
=
make_node
<
LayoutAlternatedExprNode
>
();
rnode
->
value
=
TupleGetItemNode
::
make
(
tuple_output
,
i
);
rnode
->
old_layout
=
old_out
[
i
];
rnode
->
new_layout
=
new_out
[
i
];
rnode
->
memorizer
=
memorizer
;
fields
.
push_back
(
Expr
(
rnode
));
}
return
TupleNode
::
make
(
fields
);
}
else
{
auto
rnode
=
make_node
<
LayoutAlternatedExprNode
>
();
CHECK_EQ
(
new_out
.
size
(),
1
);
rnode
->
value
=
CallNode
::
make
(
new_call
->
op
,
transformed_args
,
new_call
->
attrs
);
rnode
->
old_layout
=
old_out
[
0
];
rnode
->
new_layout
=
new_out
[
0
];
rnode
->
memorizer
=
memorizer
;
return
Expr
(
rnode
);
}
}
using
ContainerType
=
AlterTransformMemorizerNode
;
};
// Limiations:
// 1. the altered op should have the same number of arguments as the previous one
// 2. do not support nested tuple arguments
/*!
* Limitations:
* 1. The altered op should have the same number of arguments as the previous one.
* 2. Do not support nested tuple arguments.
*/
Expr
AlterOpLayout
(
const
Expr
&
expr
)
{
TransformMemorizer
transformMemorizer
(
make_node
<
TransformMemorizerNode
>
());
auto
fcontext
=
[
&
](
const
Call
&
call
)
->
NodeRef
{
return
transformMemorizer
;
};
AlterTransformMemorizer
alterMemorizer
(
make_node
<
AlterTransformMemorizerNode
>
());
auto
fcontext
=
[
&
](
const
Call
&
call
)
->
NodeRef
{
return
alterMemorizer
;
};
return
ForwardRewrite
(
expr
,
AlterOpLayoutRewrite
,
fcontext
);
return
ForwardRewrite
(
expr
,
LayoutRewriter
<
AlterTransformMemorizer
>
,
fcontext
);
}
}
// namespace alter_op_layout
...
...
src/relay/pass/convert_layout.cc
0 → 100644
View file @
73dda6be
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file convert_op_layout.cc
* \brief Alternate the layouts of operators or replace primitive operators with
other expressions. This pass can be used for computing convolution in
custom layouts or other general weight pre-transformation.
*/
#include <tvm/relay/analysis.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/transform.h>
#include <tvm/operation.h>
#include <tuple>
#include <vector>
#include <functional>
#include <string>
#include <utility>
#include <unordered_map>
#include "transform_layout.h"
#include "pattern_util.h"
namespace
tvm
{
namespace
relay
{
namespace
convert_op_layout
{
/*!
* \brief Container for the transformations for ConvertLayout.
*/
class
ConvertTransformMemorizerNode
:
public
TransformMemorizerNode
{
public
:
/*!
* \brief Initializes the desired_layout.
* \param desired_layout The desired layout.
*/
explicit
ConvertTransformMemorizerNode
(
const
std
::
string
&
desired_layout
)
:
desired_layout_
(
desired_layout
)
{}
/*! \brief The desired layout for the Convert Layout pass */
std
::
string
desired_layout_
;
};
/*!
* \brief Container that provides the transformation function for convert layout.
*/
class
ConvertTransformMemorizer
:
public
TransformMemorizer
{
public
:
ConvertTransformMemorizer
()
{}
explicit
ConvertTransformMemorizer
(
ObjectPtr
<
Object
>
n
)
:
TransformMemorizer
(
n
)
{}
ConvertTransformMemorizerNode
*
operator
->
()
{
return
static_cast
<
ConvertTransformMemorizerNode
*>
(
get_mutable
());
}
/*!
* \brief Defines the call transformation for ConvertLayout pass. The new layouts should be the
* desired layout as specified by the user.
* \param ref_call The original call.
* \param new_args The traversed/recursed args to the call.
* \return The new Call after calling the packed func.
*/
Call
CallWithNewLayouts
(
const
Call
&
ref_call
,
const
std
::
vector
<
Expr
>&
new_args
)
override
{
static
auto
fconvert_layout
=
Op
::
GetAttr
<
FTVMConvertOpLayout
>
(
"FTVMConvertOpLayout"
);
Op
op
=
Downcast
<
Op
>
(
ref_call
->
op
);
Expr
new_e
;
bool
modified
=
false
;
if
(
fconvert_layout
.
count
(
op
))
{
tvm
::
Array
<
tvm
::
Tensor
>
tinfos
;
for
(
auto
expr
:
ref_call
->
args
)
{
auto
ttype
=
expr
->
type_as
<
TensorTypeNode
>
();
tinfos
.
push_back
(
tvm
::
placeholder
(
ttype
->
shape
,
ttype
->
dtype
));
}
Expr
altered_value
=
fconvert_layout
[
op
](
ref_call
->
attrs
,
new_args
,
tinfos
,
operator
->
()
->
desired_layout_
);
if
(
altered_value
.
defined
())
{
new_e
=
altered_value
;
modified
=
true
;
}
}
if
(
!
modified
)
{
new_e
=
CallNode
::
make
(
ref_call
->
op
,
new_args
,
ref_call
->
attrs
);
}
const
CallNode
*
new_call
=
new_e
.
as
<
CallNode
>
();
CHECK
(
new_call
)
<<
"Can only replace the original operator with another call node"
;
return
GetRef
<
Call
>
(
new_call
);
}
using
ContainerType
=
ConvertTransformMemorizerNode
;
};
/*!
* Limitations:
* 1. The altered op should have the same number of arguments as the previous one.
* 2. Do not support nested tuple arguments.
*/
Expr
ConvertLayout
(
const
Expr
&
expr
,
const
std
::
string
&
desired_layout
)
{
ConvertTransformMemorizer
transformMemorizer
(
make_node
<
ConvertTransformMemorizerNode
>
(
desired_layout
));
auto
fcontext
=
[
&
](
const
Call
&
call
)
->
NodeRef
{
return
transformMemorizer
;
};
return
ForwardRewrite
(
expr
,
LayoutRewriter
<
ConvertTransformMemorizer
>
,
fcontext
);
}
}
// namespace convert_op_layout
namespace
transform
{
Pass
ConvertLayout
(
const
std
::
string
&
desired_layout
)
{
runtime
::
TypedPackedFunc
<
Function
(
Function
,
Module
,
PassContext
)
>
pass_func
=
[
=
](
Function
f
,
Module
m
,
PassContext
pc
)
{
return
Downcast
<
Function
>
(
relay
::
convert_op_layout
::
ConvertLayout
(
f
,
desired_layout
));
};
return
CreateFunctionPass
(
pass_func
,
3
,
"ConvertLayout"
,
{
ir
::
StringImm
::
make
(
"InferType"
),
ir
::
StringImm
::
make
(
"SimplifyInference"
),
ir
::
StringImm
::
make
(
"CanonicalizeOps"
)});
}
TVM_REGISTER_API
(
"relay._transform.ConvertLayout"
).
set_body_typed
(
ConvertLayout
);
}
// namespace transform
}
// namespace relay
}
// namespace tvm
src/relay/pass/
alter_op_layout
.h
→
src/relay/pass/
infer_layout_util
.h
View file @
73dda6be
...
...
@@ -18,18 +18,20 @@
*/
/*!
* \file
alter_op_layout
.h
* \brief
Alternate
the layouts of operators or replace primitive operators with
* \file
infer_layout_util
.h
* \brief
Utility functions to alter
the layouts of operators or replace primitive operators with
other expressions. This pass can be used for computing convolution in
custom layouts or other general weight pre-transformation.
*/
#ifndef TVM_RELAY_PASS_
ALTER_OP_LAYOUT
_H_
#define TVM_RELAY_PASS_
ALTER_OP_LAYOUT
_H_
#ifndef TVM_RELAY_PASS_
INFER_LAYOUT_UTIL
_H_
#define TVM_RELAY_PASS_
INFER_LAYOUT_UTIL
_H_
#include <tvm/data_layout.h>
#include <tvm/relay/expr.h>
#include <string>
#include <tuple>
#include "pattern_util.h"
namespace
tvm
{
namespace
relay
{
...
...
@@ -193,7 +195,40 @@ inline Array<Array<Layout> > BinaryBroadcastLayout(const Attrs& attrs,
}
}
/*!
* Call registered FInferCorrectLayout of an op.
* Parameters are the same as the parameters for FInferCorrectLayout
* Returns inferred_input_layout, inferred_output_layout, success
*/
static
inline
std
::
tuple
<
Array
<
Layout
>
,
Array
<
Layout
>
,
bool
>
InferCorrectLayouts
(
const
Call
&
call
,
const
Array
<
Layout
>&
new_in_layouts
,
const
Array
<
Layout
>&
old_in_layouts
,
const
Array
<
Array
<
IndexExpr
>>&
old_in_shapes
)
{
static
auto
finfer_layout
=
Op
::
GetAttr
<
FInferCorrectLayout
>
(
"FInferCorrectLayout"
);
if
(
!
call
->
op
.
as
<
OpNode
>
())
{
return
std
::
make_tuple
<>
(
Array
<
Layout
>
(
nullptr
),
Array
<
Layout
>
(
nullptr
),
false
);
}
Op
op
=
Downcast
<
Op
>
(
call
->
op
);
if
(
finfer_layout
.
count
(
op
))
{
Array
<
Array
<
Layout
>>
inferred_layouts
;
inferred_layouts
=
finfer_layout
[
op
](
call
->
attrs
,
new_in_layouts
,
old_in_layouts
,
old_in_shapes
);
CHECK_EQ
(
inferred_layouts
.
size
(),
2
)
<<
"FInferCorrectLayout should return an array with size of 2"
;
for
(
auto
x
:
inferred_layouts
)
{
for
(
auto
y
:
x
)
{
if
(
!
y
.
defined
())
{
// inference fails
return
std
::
make_tuple
<>
(
Array
<
Layout
>
(
nullptr
),
Array
<
Layout
>
(
nullptr
),
false
);
}
}
}
return
std
::
make_tuple
<>
(
inferred_layouts
[
0
],
inferred_layouts
[
1
],
true
);
}
else
{
return
std
::
make_tuple
<>
(
Array
<
Layout
>
(
nullptr
),
Array
<
Layout
>
(
nullptr
),
false
);
}
}
}
// namespace relay
}
// namespace tvm
#endif // TVM_RELAY_PASS_
ALTER_OP_LAYOUT
_H_
#endif // TVM_RELAY_PASS_
INFER_LAYOUT_UTIL
_H_
src/relay/pass/transform_layout.h
0 → 100644
View file @
73dda6be
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
*
* \file transform_layout.h
* \brief Common infrastructure for transforming the layouts. This is used for AlterOpLayout and
* ConvertLayout pass. */
#ifndef TVM_RELAY_PASS_TRANSFORM_LAYOUT_H_
#define TVM_RELAY_PASS_TRANSFORM_LAYOUT_H_
#include <tvm/data_layout.h>
#include <tvm/relay/expr.h>
#include <string>
#include <unordered_map>
#include <tuple>
#include <vector>
#include "pattern_util.h"
#include "infer_layout_util.h"
namespace
tvm
{
namespace
relay
{
/*!
* \brief Memorizes layout transformations to reuse.
*/
class
TransformMemorizerNode
:
public
Node
{
public
:
/*! \brief The key for the memorizer map is (Expr, src_layout, dst_layout). */
using
TransformKey
=
std
::
tuple
<
const
Node
*
,
std
::
string
,
std
::
string
>
;
struct
key_hash
:
public
std
::
function
<
std
::
size_t
(
TransformKey
)
>
{
std
::
size_t
operator
()(
const
TransformKey
&
k
)
const
{
return
dmlc
::
HashCombine
<
std
::
string
>
(
dmlc
::
HashCombine
<
std
::
string
>
(
std
::
hash
<
const
Node
*>
()(
std
::
get
<
0
>
(
k
)),
std
::
get
<
1
>
(
k
)),
(
std
::
get
<
2
>
(
k
)));
}
};
/*! \brief The memorizer map. */
std
::
unordered_map
<
TransformKey
,
Expr
,
key_hash
>
memo
;
static
constexpr
const
char
*
_type_key
=
"relay.alter_op_layout.TransformMemorizerNode"
;
TVM_DECLARE_NODE_TYPE_INFO
(
TransformMemorizerNode
,
Node
);
};
/*!
* \brief Container that transforms the layouts and memorizes them.
*/
class
TransformMemorizer
:
public
NodeRef
{
public
:
TransformMemorizer
()
{}
explicit
TransformMemorizer
(
ObjectPtr
<
Object
>
n
)
:
NodeRef
(
n
)
{}
TransformMemorizerNode
*
operator
->
()
{
return
static_cast
<
TransformMemorizerNode
*>
(
get_mutable
());
}
/*
* \brief Memorizes and transforms the layout.
* \param expr The initial expr.
* \param src_layout The source layout.
* \param dst_layout The dest layout.
* \return The new expr with the dst layout.
*/
Expr
Transform
(
Expr
raw
,
const
Layout
&
src_layout
,
const
Layout
&
dst_layout
)
{
if
(
src_layout
.
Equals
(
dst_layout
))
{
return
raw
;
}
std
::
tuple
<
const
Node
*
,
std
::
string
,
std
::
string
>
key
=
std
::
make_tuple
<>
(
raw
.
get
(),
src_layout
.
name
(),
dst_layout
.
name
());
auto
&
memo
=
operator
->
()
->
memo
;
auto
iter
=
memo
.
find
(
key
);
if
(
iter
!=
memo
.
end
())
{
return
iter
->
second
;
}
else
{
Expr
transform
=
TransformHelper
(
raw
,
src_layout
,
dst_layout
);
memo
[
key
]
=
transform
;
return
transform
;
}
}
/*
* \brief Helper to transform the layouts.
* \param expr The initial expr.
* \param src_layout The source layout.
* \param dst_layout The dest layout.
* \return The new expr with the dst layout.
* \note It performs following 2 operations
* 1) If src_layout ndim is smaller then dst_layout, expand_dim is inserted to match the dim
* size. For example, src_layout = C, dst_layout = NCHW16c. The src is expanded to NHWC.
* 2) Call layout transform with new src layout.
*/
Expr
TransformHelper
(
Expr
raw
,
Layout
src_layout
,
Layout
dst_layout
)
{
if
(
src_layout
.
Equals
(
dst_layout
))
{
return
raw
;
}
// 1) Check if the shape lengths are different. If yes, expand dims.
Expr
input_expr
=
raw
;
Layout
new_src_layout
=
src_layout
;
if
(
src_layout
.
ndim_primal
()
<
dst_layout
.
ndim_primal
())
{
int
num_new_axis
=
dst_layout
.
ndim_primal
()
-
src_layout
.
ndim_primal
();
new_src_layout
=
src_layout
.
ExpandPrimal
(
dst_layout
);
input_expr
=
MakeExpandDims
(
input_expr
,
0
,
num_new_axis
);
if
(
new_src_layout
.
Equals
(
dst_layout
))
{
return
input_expr
;
}
}
// 2) Insert layout transform on the transformed src.
CHECK
(
new_src_layout
.
defined
()
&&
dst_layout
.
defined
())
<<
"Cannot insert layout transform because there are undefined layouts"
;
CHECK
(
BijectiveLayoutNode
::
make
(
new_src_layout
,
dst_layout
).
defined
())
<<
"Cannot insert layout transform because there are inconvertible layouts: "
<<
new_src_layout
<<
" v.s. "
<<
dst_layout
;
return
MakeLayoutTransform
(
input_expr
,
new_src_layout
.
name
(),
dst_layout
.
name
());
}
/*!
* \brief Defines the call transformation for derived passes. The new layouts are defined by
* used for different targets using a packed func.
* \param ref_call The original call.
* \param new_args The traversed/recursed args to the call.
* \return The new Call after calling the packed func.
*/
virtual
Call
CallWithNewLayouts
(
const
Call
&
ref_call
,
const
std
::
vector
<
Expr
>&
new_args
)
=
0
;
using
ContainerType
=
TransformMemorizerNode
;
};
/*
* \brief TempExprNode during layout transform. Instance of this expr will be Realized to normal
* expr ultimately.
* \tparam TransformMemorizerT The derived TransformMemorizer type.
*/
template
<
class
TransformMemorizerT
>
class
LayoutAlternatedExprNode
:
public
TempExprNode
{
public
:
Expr
value
;
Layout
old_layout
;
Layout
new_layout
;
TransformMemorizerT
memorizer
;
Expr
Realize
()
const
final
{
// NOTE: use a copy to discard the "const" qualifier
TransformMemorizerT
tmp_memorizer
=
memorizer
;
// fallback to old layout
return
tmp_memorizer
.
Transform
(
value
,
new_layout
,
old_layout
);
}
void
VisitAttrs
(
AttrVisitor
*
v
)
{
v
->
Visit
(
"value"
,
&
value
);
v
->
Visit
(
"old_layout"
,
&
old_layout
);
v
->
Visit
(
"new_layout"
,
&
new_layout
);
}
static
constexpr
const
char
*
_type_key
=
"relay.alter_op_layout.LayoutAlternatedExprNode"
;
TVM_DECLARE_NODE_TYPE_INFO
(
LayoutAlternatedExprNode
,
TempExprNode
);
};
/*!
* \brief Container for the layout alternated expr.
* \tparam TransformMemorizerT The derived TransformMemorizer type.
*/
template
<
class
TransformMemorizerT
>
class
LayoutAlternatedExpr
:
public
NodeRef
{
public
:
LayoutAlternatedExpr
()
{}
explicit
LayoutAlternatedExpr
(
ObjectPtr
<
Object
>
n
)
:
NodeRef
(
n
)
{}
LayoutAlternatedExprNode
<
TransformMemorizerT
>*
operator
->
()
{
return
static_cast
<
LayoutAlternatedExprNode
<
TransformMemorizerT
>*>
(
get_mutable
());
}
using
ContainerType
=
LayoutAlternatedExprNode
<
TransformMemorizerT
>
;
};
/*
* \brief Used with ForwardRewrite to transform the expr. The input args are same as
* FForwardRewrite.
* \param ref_call The reference old call type to be rewritten.
* We can make use of the op and type information.
* \param new_args The new arguments (some of them could be TempExpr).
* \param ctx Optional context information about ref_call.
* \tparam TransformMemorizerT The derived TransformMemorizer type.
* \return The rewriten result call, can also return nullptr,
* which indicate the rewriter should use the default fallback
* rule that realizes all its input and compose the call.
*
* \note The ctx can be used to provide extra information during transformation. The ctx is
* templated to reuse across AlterOpLayout and ConvertLayout pass. The steps are
* - Extract the original layouts.
* - Use ctx transformation to get a Call with new layouts - CallWithNewLayouts.
* - Extract the new layouts from the returned Call.
* - Transform the original call to reuse the new layouts using TransformMemorizer.
*/
template
<
class
TransformMemorizerT
>
Expr
LayoutRewriter
(
const
Call
&
ref_call
,
const
Array
<
Expr
>&
new_args
,
const
NodeRef
&
ctx
)
{
std
::
vector
<
LayoutAlternatedExpr
<
TransformMemorizerT
>>
inputs
;
std
::
vector
<
Expr
>
normal_new_args
;
Array
<
Array
<
IndexExpr
>>
input_shapes
;
// NOTE: discard the "const" qualifier
// TransformMemorizer memorizer = Downcast<TransformMemorizer>(ctx);
// TransformMemorizerT* ctx_transformer =
// static_cast<TransformMemorizerT*>(memorizer.operator->());
TransformMemorizerT
memorizer
=
Downcast
<
TransformMemorizerT
>
(
ctx
);
// fill incomplete state and flatten tuple
auto
push_back_one_arg
=
[
&
inputs
,
memorizer
](
Expr
arg
)
{
// We always expect LayoutAlternatedExpr<TransformMemorizerT>.
// This is used to convert the normal Expr to LayoutAlternatedExpr<TransformMemorizerT>.
if
(
const
LayoutAlternatedExprNode
<
TransformMemorizerT
>*
inp
=
arg
.
as
<
LayoutAlternatedExprNode
<
TransformMemorizerT
>>
())
{
inputs
.
push_back
(
GetRef
<
LayoutAlternatedExpr
<
TransformMemorizerT
>>
(
inp
));
return
inp
->
value
;
}
else
{
auto
inode
=
make_node
<
LayoutAlternatedExprNode
<
TransformMemorizerT
>>
();
inode
->
value
=
arg
;
inode
->
memorizer
=
memorizer
;
inputs
.
push_back
(
LayoutAlternatedExpr
<
TransformMemorizerT
>
(
inode
));
return
arg
;
}
};
for
(
auto
new_arg
:
new_args
)
{
// NOTE: do not support nested tuple
if
(
new_arg
->
IsInstance
<
TupleNode
>
())
{
Tuple
tuple_new_arg
=
Downcast
<
Tuple
>
(
new_arg
);
std
::
vector
<
Expr
>
fields
;
for
(
auto
x
:
tuple_new_arg
->
fields
)
{
Expr
tmp
=
push_back_one_arg
(
x
);
fields
.
push_back
(
tmp
);
}
normal_new_args
.
push_back
(
TupleNode
::
make
(
fields
));
}
else
{
Expr
tmp
=
push_back_one_arg
(
new_arg
);
normal_new_args
.
push_back
(
tmp
);
}
}
// old_in, new_in = state[inputs]
Array
<
Layout
>
old_in
,
old_out
,
new_in
,
new_out
,
new_in2
;
for
(
auto
inp
:
inputs
)
{
old_in
.
push_back
(
inp
->
old_layout
);
new_in
.
push_back
(
inp
->
new_layout
);
}
for
(
auto
arg
:
ref_call
->
args
)
{
if
(
arg
->
IsInstance
<
TupleNode
>
())
{
// flatten tuple
Tuple
tuple_arg
=
Downcast
<
Tuple
>
(
arg
);
for
(
auto
x
:
tuple_arg
->
fields
)
{
input_shapes
.
push_back
(
x
->
type_as
<
TensorTypeNode
>
()
->
shape
);
}
}
else
{
input_shapes
.
push_back
(
arg
->
type_as
<
TensorTypeNode
>
()
->
shape
);
}
}
// old_in, old_out = op.infer(old_in)
bool
success
=
false
;
std
::
tie
(
old_in
,
old_out
,
success
)
=
InferCorrectLayouts
(
ref_call
,
Array
<
Layout
>
(
nullptr
),
old_in
,
input_shapes
);
if
(
!
success
)
{
return
Expr
(
nullptr
);
}
CHECK_EQ
(
old_in
.
size
(),
new_in
.
size
());
// if new_in == 'undef': new_in = old_in
for
(
size_t
i
=
0
;
i
<
new_in
.
size
();
++
i
)
{
if
(
!
new_in
[
i
].
defined
())
{
new_in
.
Set
(
i
,
old_in
[
i
]);
}
}
// new_op = alter(op)
Call
new_call
=
memorizer
.
CallWithNewLayouts
(
ref_call
,
normal_new_args
);
// new_in2, new_out = op.infer(new_in)
if
(
new_call
->
op
->
IsInstance
<
OpNode
>
())
{
success
=
false
;
std
::
tie
(
new_in2
,
new_out
,
success
)
=
InferCorrectLayouts
(
new_call
,
new_in
,
old_in
,
input_shapes
);
if
(
!
success
)
{
return
Expr
(
nullptr
);
}
}
else
{
return
Expr
(
nullptr
);
}
CHECK_EQ
(
new_out
.
size
(),
old_out
.
size
())
<<
"The number of output nodes should keep the same during alter_op_layout"
;
CHECK_EQ
(
new_in
.
size
(),
new_in2
.
size
())
<<
"The number of input nodes should keep the same during alter_op_layout"
;
// if (new_in != new_in2): insert transform (new_in -> new_in2)
Array
<
Expr
>
transformed_args
;
size_t
pt
=
0
;
for
(
auto
arg
:
new_call
->
args
)
{
if
(
arg
->
IsInstance
<
TupleNode
>
())
{
// unflatten tuple
Tuple
tuple_arg
=
Downcast
<
Tuple
>
(
arg
);
std
::
vector
<
Expr
>
transformed_tuple_arg
;
for
(
auto
arg_item
:
tuple_arg
->
fields
)
{
transformed_tuple_arg
.
push_back
(
memorizer
.
Transform
(
arg_item
,
new_in
[
pt
],
new_in2
[
pt
]));
pt
++
;
}
transformed_args
.
push_back
(
TupleNode
::
make
(
transformed_tuple_arg
));
}
else
{
transformed_args
.
push_back
(
memorizer
.
Transform
(
arg
,
new_in
[
pt
],
new_in2
[
pt
]));
pt
++
;
}
}
CHECK_EQ
(
pt
,
inputs
.
size
());
// state[node] = (old_out, new_out)
// (handle tuple output)
if
(
ref_call
->
checked_type
()
->
IsInstance
<
TupleTypeNode
>
())
{
Expr
tuple_output
=
CallNode
::
make
(
new_call
->
op
,
transformed_args
,
new_call
->
attrs
);
Array
<
Expr
>
fields
;
for
(
size_t
i
=
0
;
i
<
new_out
.
size
();
++
i
)
{
auto
rnode
=
make_node
<
LayoutAlternatedExprNode
<
TransformMemorizerT
>>
();
rnode
->
value
=
TupleGetItemNode
::
make
(
tuple_output
,
i
);
rnode
->
old_layout
=
old_out
[
i
];
rnode
->
new_layout
=
new_out
[
i
];
rnode
->
memorizer
=
memorizer
;
fields
.
push_back
(
Expr
(
rnode
));
}
return
TupleNode
::
make
(
fields
);
}
else
{
auto
rnode
=
make_node
<
LayoutAlternatedExprNode
<
TransformMemorizerT
>>
();
CHECK_EQ
(
new_out
.
size
(),
1
);
rnode
->
value
=
CallNode
::
make
(
new_call
->
op
,
transformed_args
,
new_call
->
attrs
);
rnode
->
old_layout
=
old_out
[
0
];
rnode
->
new_layout
=
new_out
[
0
];
rnode
->
memorizer
=
memorizer
;
return
Expr
(
rnode
);
}
}
}
// namespace relay
}
// namespace tvm
#endif // TVM_RELAY_PASS_TRANSFORM_LAYOUT_H_
tests/python/relay/test_pass_convert_op_layout.py
0 → 100644
View file @
73dda6be
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test alter op layout pass"""
import
tvm
from
tvm
import
relay
from
tvm.relay.op
import
register_alter_op_layout
from
tvm.relay
import
transform
,
analysis
def
run_opt_pass
(
expr
,
passes
):
passes
=
passes
if
isinstance
(
passes
,
list
)
else
[
passes
]
mod
=
relay
.
Module
.
from_expr
(
expr
)
seq
=
transform
.
Sequential
(
passes
)
with
transform
.
PassContext
(
opt_level
=
3
):
mod
=
seq
(
mod
)
entry
=
mod
[
"main"
]
return
entry
if
isinstance
(
expr
,
relay
.
Function
)
else
entry
.
body
def
test_no_convert_layout
():
def
before
():
x
=
relay
.
var
(
"x"
,
shape
=
(
1
,
64
,
56
,
56
))
weight
=
relay
.
var
(
'weight'
,
shape
=
(
64
,
64
,
3
,
3
))
y
=
relay
.
nn
.
conv2d
(
x
,
weight
,
channels
=
64
,
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
))
y
=
relay
.
nn
.
relu
(
y
)
y
=
relay
.
Function
([
x
,
weight
],
y
)
return
y
def
expected
():
return
before
()
a
=
before
()
a
=
run_opt_pass
(
a
,
transform
.
ConvertLayout
(
'NCHW'
))
b
=
run_opt_pass
(
expected
(),
transform
.
InferType
())
assert
analysis
.
alpha_equal
(
a
,
b
),
"Actual =
\n
"
+
str
(
a
)
def
test_conv_convert_layout
():
def
before
():
x
=
relay
.
var
(
"x"
,
shape
=
(
1
,
56
,
56
,
64
))
weight
=
relay
.
var
(
'weight'
,
shape
=
(
3
,
3
,
64
,
64
))
y
=
relay
.
nn
.
conv2d
(
x
,
weight
,
channels
=
64
,
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
),
data_layout
=
'NHWC'
,
kernel_layout
=
'HWIO'
)
y
=
relay
.
nn
.
relu
(
y
)
y
=
relay
.
Function
([
x
,
weight
],
y
)
return
y
def
expected
():
x
=
relay
.
var
(
"x"
,
shape
=
(
1
,
56
,
56
,
64
))
weight
=
relay
.
var
(
'weight'
,
shape
=
(
3
,
3
,
64
,
64
))
x
=
relay
.
layout_transform
(
x
,
'NHWC'
,
'NCHW'
)
weight
=
relay
.
layout_transform
(
weight
,
'HWIO'
,
'OIHW'
)
y
=
relay
.
nn
.
conv2d
(
x
,
weight
,
channels
=
64
,
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
))
y
=
relay
.
nn
.
relu
(
y
)
y
=
relay
.
layout_transform
(
y
,
'NCHW'
,
'NHWC'
)
y
=
relay
.
Function
(
relay
.
analysis
.
free_vars
(
y
),
y
)
return
y
a
=
before
()
a
=
run_opt_pass
(
a
,
transform
.
ConvertLayout
(
'NCHW'
))
b
=
run_opt_pass
(
expected
(),
transform
.
InferType
())
assert
analysis
.
alpha_equal
(
a
,
b
),
"Actual =
\n
"
+
str
(
a
)
def
test_conv_bias_pool_convert_layout
():
def
before
():
x
=
relay
.
var
(
"x"
,
shape
=
(
1
,
56
,
56
,
64
))
bias
=
relay
.
var
(
"bias"
,
shape
=
(
64
,))
weight
=
relay
.
var
(
"weight"
,
shape
=
(
3
,
3
,
64
,
64
))
y
=
relay
.
nn
.
conv2d
(
x
,
weight
,
channels
=
64
,
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
),
data_layout
=
'NHWC'
,
kernel_layout
=
'HWIO'
)
y
=
relay
.
nn
.
bias_add
(
y
,
bias
,
axis
=
3
)
# a useless tuple, which will be eliminated
y
=
relay
.
Tuple
([
y
])[
0
]
y
=
relay
.
nn
.
relu
(
y
)
y
=
relay
.
nn
.
max_pool2d
(
y
,
pool_size
=
(
2
,
2
),
layout
=
'NHWC'
)
y
=
relay
.
cast
(
y
,
'int32'
)
y
=
relay
.
nn
.
batch_flatten
(
y
)
y
=
relay
.
Function
(
analysis
.
free_vars
(
y
),
y
)
return
y
def
expected
():
x
=
relay
.
var
(
"x"
,
shape
=
(
1
,
56
,
56
,
64
))
bias
=
relay
.
var
(
"bias"
,
shape
=
(
64
,))
weight
=
relay
.
var
(
"weight"
,
shape
=
(
3
,
3
,
64
,
64
))
x
=
relay
.
layout_transform
(
x
,
'NHWC'
,
'NCHW'
)
weight
=
relay
.
layout_transform
(
weight
,
'HWIO'
,
'OIHW'
)
y
=
relay
.
nn
.
conv2d
(
x
,
weight
,
channels
=
64
,
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
))
bias
=
relay
.
expand_dims
(
bias
,
axis
=
0
,
num_newaxis
=
3
)
bias
=
relay
.
layout_transform
(
bias
,
'NHWC'
,
'NCHW'
)
y
=
relay
.
add
(
y
,
bias
)
# a useless tuple, which will be eliminated
y
=
relay
.
Tuple
([
y
])[
0
]
y
=
relay
.
nn
.
relu
(
y
)
y
=
relay
.
nn
.
max_pool2d
(
y
,
pool_size
=
(
2
,
2
))
y
=
relay
.
cast
(
y
,
'int32'
)
y
=
relay
.
layout_transform
(
y
,
'NCHW'
,
'NHWC'
)
y
=
relay
.
nn
.
batch_flatten
(
y
)
y
=
relay
.
Function
(
analysis
.
free_vars
(
y
),
y
)
return
y
a
=
before
()
a
=
run_opt_pass
(
a
,
transform
.
ConvertLayout
(
'NCHW'
))
b
=
run_opt_pass
(
expected
(),
transform
.
InferType
())
assert
analysis
.
alpha_equal
(
a
,
b
),
"Actual =
\n
"
+
str
(
a
)
def
test_conv_concat_convert_layout
():
def
before
():
x
=
relay
.
var
(
"x"
,
shape
=
(
1
,
56
,
56
,
64
))
weight1
=
relay
.
var
(
'weight1'
,
shape
=
(
3
,
3
,
64
,
64
))
weight2
=
relay
.
var
(
'weight2'
,
shape
=
(
3
,
3
,
64
,
64
))
y
=
relay
.
nn
.
conv2d
(
x
,
weight1
,
channels
=
64
,
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
),
data_layout
=
'NHWC'
,
kernel_layout
=
'HWIO'
)
y1
=
relay
.
nn
.
conv2d
(
y
,
weight2
,
channels
=
64
,
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
),
data_layout
=
'NHWC'
,
kernel_layout
=
'HWIO'
)
ret
=
relay
.
concatenate
([
y
,
y1
],
axis
=
3
)
y
=
relay
.
Function
(
analysis
.
free_vars
(
ret
),
ret
)
return
y
def
expected
():
x
=
relay
.
var
(
"x"
,
shape
=
(
1
,
56
,
56
,
64
))
weight1
=
relay
.
var
(
'weight1'
,
shape
=
(
3
,
3
,
64
,
64
))
weight2
=
relay
.
var
(
'weight2'
,
shape
=
(
3
,
3
,
64
,
64
))
weight1
=
relay
.
layout_transform
(
weight1
,
'HWIO'
,
'OIHW'
)
weight2
=
relay
.
layout_transform
(
weight2
,
'HWIO'
,
'OIHW'
)
y
=
relay
.
layout_transform
(
x
,
"NHWC"
,
"NCHW"
)
y
=
relay
.
nn
.
conv2d
(
y
,
weight1
,
channels
=
64
,
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
))
y1
=
relay
.
nn
.
conv2d
(
y
,
weight2
,
channels
=
64
,
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
))
ret
=
relay
.
concatenate
([
y
,
y1
],
axis
=
1
)
ret
=
relay
.
layout_transform
(
ret
,
"NCHW"
,
"NHWC"
)
y
=
relay
.
Function
(
analysis
.
free_vars
(
ret
),
ret
)
return
y
a
=
before
()
a
=
run_opt_pass
(
a
,
transform
.
ConvertLayout
(
'NCHW'
))
b
=
run_opt_pass
(
expected
(),
transform
.
InferType
())
assert
analysis
.
alpha_equal
(
a
,
b
),
"Actual =
\n
"
+
str
(
a
)
def
test_dual_path_convert_layout
():
def
before
():
x
=
relay
.
var
(
"x"
,
shape
=
(
1
,
56
,
56
,
64
))
weight1
=
relay
.
var
(
'weight1'
,
shape
=
(
3
,
3
,
64
,
32
))
weight2
=
relay
.
var
(
'weight2'
,
shape
=
(
3
,
3
,
32
,
32
))
y
=
relay
.
nn
.
conv2d
(
x
,
weight1
,
channels
=
32
,
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
),
data_layout
=
'NHWC'
,
kernel_layout
=
'HWIO'
)
y
=
relay
.
nn
.
relu
(
y
)
y1
=
relay
.
nn
.
conv2d
(
y
,
weight2
,
channels
=
32
,
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
),
data_layout
=
'NHWC'
,
kernel_layout
=
'HWIO'
)
y1
=
relay
.
nn
.
relu
(
y1
)
y2
=
relay
.
nn
.
batch_flatten
(
y
)
ret
=
relay
.
Tuple
([
y1
,
y2
])
y
=
relay
.
Function
(
analysis
.
free_vars
(
ret
),
ret
)
return
y
def
expected
():
x
=
relay
.
var
(
"x"
,
shape
=
(
1
,
56
,
56
,
64
))
weight1
=
relay
.
var
(
'weight1'
,
shape
=
(
3
,
3
,
64
,
32
))
weight2
=
relay
.
var
(
'weight2'
,
shape
=
(
3
,
3
,
32
,
32
))
weight1
=
relay
.
layout_transform
(
weight1
,
'HWIO'
,
'OIHW'
)
weight2
=
relay
.
layout_transform
(
weight2
,
'HWIO'
,
'OIHW'
)
y
=
relay
.
layout_transform
(
x
,
"NHWC"
,
"NCHW"
)
y
=
relay
.
nn
.
conv2d
(
y
,
weight1
,
channels
=
32
,
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
))
y
=
relay
.
nn
.
relu
(
y
)
y1
=
relay
.
nn
.
conv2d
(
y
,
weight2
,
channels
=
32
,
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
))
y1
=
relay
.
nn
.
relu
(
y1
)
y1
=
relay
.
layout_transform
(
y1
,
"NCHW"
,
"NHWC"
)
y2
=
relay
.
layout_transform
(
y
,
"NCHW"
,
"NHWC"
)
y2
=
relay
.
nn
.
batch_flatten
(
y2
)
ret
=
relay
.
Tuple
([
y1
,
y2
])
y
=
relay
.
Function
(
analysis
.
free_vars
(
ret
),
ret
)
return
y
a
=
before
()
a
=
run_opt_pass
(
a
,
transform
.
ConvertLayout
(
'NCHW'
))
b
=
run_opt_pass
(
expected
(),
transform
.
InferType
())
assert
analysis
.
alpha_equal
(
a
,
b
),
"Actual =
\n
"
+
str
(
a
)
def
test_bn_convert_layout
():
def
before
():
x
=
relay
.
var
(
"x"
,
shape
=
(
1
,
56
,
56
,
64
))
weight1
=
relay
.
var
(
'weight1'
,
shape
=
(
3
,
3
,
64
,
32
))
y
=
relay
.
nn
.
conv2d
(
x
,
weight1
,
channels
=
32
,
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
),
data_layout
=
'NHWC'
,
kernel_layout
=
'HWIO'
)
gamma
=
relay
.
var
(
"gamma"
)
beta
=
relay
.
var
(
"beta"
)
mean
=
relay
.
var
(
"mean"
)
variance
=
relay
.
var
(
"variance"
)
y
,
_
,
_
=
relay
.
nn
.
batch_norm
(
y
,
gamma
,
beta
,
mean
,
variance
,
axis
=
3
)
return
relay
.
Function
(
analysis
.
free_vars
(
y
),
y
)
a
=
before
()
a
=
run_opt_pass
(
a
,
transform
.
ConvertLayout
(
'NCHW'
))
# Check that there is only 1 NHWC to NCHW transform.
has_lt
=
list
()
find_op
=
lambda
x
:
\
has_lt
.
append
(
isinstance
(
x
,
tvm
.
relay
.
expr
.
Call
)
and
x
.
op
.
name
==
"layout_transform"
\
and
x
.
attrs
.
src_layout
==
'NCHW'
and
x
.
attrs
.
dst_layout
==
'NHWC'
)
relay
.
analysis
.
post_order_visit
(
a
,
find_op
)
has_lt
=
list
(
filter
(
lambda
x
:
x
,
has_lt
))
assert
len
(
has_lt
)
==
1
def
test_resnet_convert_layout
():
def
before
():
x
=
relay
.
var
(
"x"
,
shape
=
(
1
,
56
,
56
,
64
))
weight1
=
relay
.
var
(
'weight1'
,
shape
=
(
3
,
3
,
64
,
32
))
weight2
=
relay
.
var
(
'weight2'
,
shape
=
(
1
,
1
,
64
,
32
))
y
=
relay
.
nn
.
conv2d
(
x
,
weight1
,
channels
=
32
,
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
),
data_layout
=
'NHWC'
,
kernel_layout
=
'HWIO'
)
y
=
relay
.
nn
.
relu
(
y
)
y2
=
relay
.
nn
.
conv2d
(
x
,
weight2
,
channels
=
32
,
kernel_size
=
(
1
,
1
),
data_layout
=
'NHWC'
,
kernel_layout
=
'HWIO'
)
y2
=
relay
.
nn
.
relu
(
y2
)
y
=
y
+
y2
y
=
relay
.
nn
.
global_max_pool2d
(
y
,
layout
=
'NHWC'
)
return
relay
.
Function
(
analysis
.
free_vars
(
y
),
y
)
def
expected
():
x
=
relay
.
var
(
"x"
,
shape
=
(
1
,
56
,
56
,
64
))
weight1
=
relay
.
var
(
'weight1'
,
shape
=
(
3
,
3
,
64
,
32
))
weight2
=
relay
.
var
(
'weight2'
,
shape
=
(
1
,
1
,
64
,
32
))
weight1
=
relay
.
layout_transform
(
weight1
,
'HWIO'
,
'OIHW'
)
weight2
=
relay
.
layout_transform
(
weight2
,
'HWIO'
,
'OIHW'
)
x
=
relay
.
layout_transform
(
x
,
"NHWC"
,
"NCHW"
)
y
=
relay
.
nn
.
conv2d
(
x
,
weight1
,
channels
=
32
,
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
))
y
=
relay
.
nn
.
relu
(
y
)
y2
=
relay
.
nn
.
conv2d
(
x
,
weight2
,
channels
=
32
,
kernel_size
=
(
1
,
1
))
y2
=
relay
.
nn
.
relu
(
y2
)
y
=
y
+
y2
y
=
relay
.
nn
.
global_max_pool2d
(
y
)
y
=
relay
.
layout_transform
(
y
,
"NCHW"
,
"NHWC"
)
return
relay
.
Function
(
analysis
.
free_vars
(
y
),
y
)
a
=
before
()
a
=
run_opt_pass
(
a
,
transform
.
ConvertLayout
(
'NCHW'
))
b
=
run_opt_pass
(
expected
(),
transform
.
InferType
())
assert
analysis
.
alpha_equal
(
a
,
b
),
"Actual =
\n
"
+
str
(
a
)
def
test_scalar_convert_layout
():
def
before
():
x
=
relay
.
var
(
"x"
,
shape
=
(
1
,
56
,
56
,
64
))
weight
=
relay
.
var
(
"weight"
,
shape
=
(
3
,
3
,
64
,
64
))
y
=
relay
.
nn
.
conv2d
(
x
,
weight
,
channels
=
64
,
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
),
data_layout
=
'NHWC'
,
kernel_layout
=
'HWIO'
)
y
=
relay
.
add
(
y
,
relay
.
const
(
1
,
"float32"
))
y
=
relay
.
Function
(
analysis
.
free_vars
(
y
),
y
)
return
y
def
expected
():
x
=
relay
.
var
(
"x"
,
shape
=
(
1
,
56
,
56
,
64
))
w
=
relay
.
var
(
"weight"
,
shape
=
(
3
,
3
,
64
,
64
))
x
=
relay
.
layout_transform
(
x
,
'NHWC'
,
'NCHW'
)
w
=
relay
.
layout_transform
(
w
,
'HWIO'
,
'OIHW'
)
y
=
relay
.
nn
.
conv2d
(
x
,
w
,
channels
=
64
,
kernel_size
=
(
3
,
3
),
padding
=
(
1
,
1
))
y
=
relay
.
add
(
y
,
relay
.
const
(
1.0
,
"float32"
))
y
=
relay
.
layout_transform
(
y
,
"NCHW"
,
"NHWC"
)
y
=
relay
.
Function
(
analysis
.
free_vars
(
y
),
y
)
return
y
a
=
before
()
a
=
run_opt_pass
(
a
,
transform
.
ConvertLayout
(
'NCHW'
))
b
=
run_opt_pass
(
expected
(),
transform
.
InferType
())
assert
analysis
.
alpha_equal
(
a
,
b
),
"Actual =
\n
"
+
str
(
a
)
if
__name__
==
"__main__"
:
test_no_convert_layout
()
test_conv_convert_layout
()
test_conv_bias_pool_convert_layout
()
test_conv_concat_convert_layout
()
test_dual_path_convert_layout
()
test_bn_convert_layout
()
test_resnet_convert_layout
()
test_scalar_convert_layout
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment