Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
1071e242
Commit
1071e242
authored
Dec 26, 2019
by
Animesh Jain
Committed by
Yizhi Liu
Dec 26, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay][AlterLayout] Broadcast with scalar shape (#4577)
parent
73dda6be
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
92 additions
and
0 deletions
+92
-0
src/relay/pass/pattern_util.h
+22
-0
src/relay/pass/transform_layout.h
+5
-0
tests/python/relay/test_pass_alter_op_layout.py
+65
-0
No files found.
src/relay/pass/pattern_util.h
View file @
1071e242
...
@@ -200,6 +200,28 @@ inline int64_t GetConv2DSuperChannelsDim(const CallNode* call) {
...
@@ -200,6 +200,28 @@ inline int64_t GetConv2DSuperChannelsDim(const CallNode* call) {
}
}
/*!
/*!
* \brief Is single value tensor (scalar).
* \param expr The expr.
* \return True if single value tensor.
*/
inline
bool
IsScalar
(
const
Expr
&
expr
)
{
if
(
auto
tensor_type
=
expr
->
checked_type
().
as
<
TensorTypeNode
>
())
{
for
(
auto
dim_index_expr
:
tensor_type
->
shape
)
{
if
(
auto
dim_index
=
dim_index_expr
.
as
<
IntImm
>
())
{
if
(
dim_index
->
value
!=
1
)
{
return
false
;
}
}
else
{
return
false
;
}
}
}
else
{
return
false
;
}
return
true
;
}
/*!
* \brief Create a Constant with a scalar
* \brief Create a Constant with a scalar
*
*
* \param dtype The data type.
* \param dtype The data type.
...
...
src/relay/pass/transform_layout.h
View file @
1071e242
...
@@ -119,6 +119,11 @@ class TransformMemorizer : public NodeRef {
...
@@ -119,6 +119,11 @@ class TransformMemorizer : public NodeRef {
Expr
input_expr
=
raw
;
Expr
input_expr
=
raw
;
Layout
new_src_layout
=
src_layout
;
Layout
new_src_layout
=
src_layout
;
if
(
src_layout
.
ndim_primal
()
<
dst_layout
.
ndim_primal
())
{
if
(
src_layout
.
ndim_primal
()
<
dst_layout
.
ndim_primal
())
{
// If scalar, then no need of layout transformation as scalar can be broadcasted easily even
// if the other operand has a transformed layout.
if
(
IsScalar
(
input_expr
))
{
return
raw
;
}
int
num_new_axis
=
dst_layout
.
ndim_primal
()
-
src_layout
.
ndim_primal
();
int
num_new_axis
=
dst_layout
.
ndim_primal
()
-
src_layout
.
ndim_primal
();
new_src_layout
=
src_layout
.
ExpandPrimal
(
dst_layout
);
new_src_layout
=
src_layout
.
ExpandPrimal
(
dst_layout
);
input_expr
=
MakeExpandDims
(
input_expr
,
0
,
num_new_axis
);
input_expr
=
MakeExpandDims
(
input_expr
,
0
,
num_new_axis
);
...
...
tests/python/relay/test_pass_alter_op_layout.py
View file @
1071e242
...
@@ -318,6 +318,70 @@ def test_alter_layout_broadcast_op():
...
@@ -318,6 +318,70 @@ def test_alter_layout_broadcast_op():
assert
analysis
.
alpha_equal
(
a
,
b
),
"Actual =
\n
"
+
str
(
a
)
assert
analysis
.
alpha_equal
(
a
,
b
),
"Actual =
\n
"
+
str
(
a
)
def
test_alter_layout_broadcast_scalar_op
():
"""Test alternating the layout of a conv2d.
The layout of broadcast operators and the weight should be changed accordingly.
"""
def
before
():
x
=
relay
.
var
(
"x"
,
shape
=
(
1
,
500
,
500
,
64
))
kernel
=
relay
.
var
(
'kernel'
,
shape
=
(
3
,
3
,
64
,
64
),
dtype
=
'float32'
)
bias
=
relay
.
var
(
"bias"
,
shape
=
(
64
,))
multiplier1
=
relay
.
var
(
'multiplier1'
,
shape
=
(
1
,
),
dtype
=
'float32'
)
multiplier2
=
relay
.
var
(
'multiplier2'
,
shape
=
(
1
,
1
),
dtype
=
'float32'
)
y
=
relay
.
nn
.
conv2d
(
x
,
kernel
,
data_layout
=
'NHWC'
,
kernel_layout
=
"HWIO"
,
kernel_size
=
(
3
,
3
))
y
=
relay
.
add
(
bias
,
y
)
y
=
relay
.
nn
.
relu
(
y
)
y
=
relay
.
multiply
(
multiplier1
,
y
)
y
=
relay
.
multiply
(
y
,
multiplier2
)
y
=
relay
.
Function
(
analysis
.
free_vars
(
y
),
y
)
return
y
def
alter_conv2d
(
attrs
,
inputs
,
tinfos
):
data
,
weight
=
inputs
new_attrs
=
dict
(
attrs
)
new_attrs
[
'data_layout'
]
=
'NCHW16c'
return
relay
.
nn
.
conv2d
(
data
,
weight
,
**
new_attrs
)
def
expected
():
x
=
relay
.
var
(
"x"
,
shape
=
(
1
,
500
,
500
,
64
))
kernel
=
relay
.
var
(
'kernel'
,
shape
=
(
3
,
3
,
64
,
64
),
dtype
=
'float32'
)
bias
=
relay
.
var
(
"bias"
,
shape
=
(
64
,))
multiplier1
=
relay
.
var
(
'multiplier1'
,
shape
=
(
1
,
),
dtype
=
'float32'
)
multiplier2
=
relay
.
var
(
'multiplier2'
,
shape
=
(
1
,
1
),
dtype
=
'float32'
)
b
=
relay
.
expand_dims
(
bias
,
axis
=
0
,
num_newaxis
=
3
)
b
=
relay
.
layout_transform
(
b
,
"NHWC"
,
"NCHW16c"
)
y
=
relay
.
layout_transform
(
x
,
"NHWC"
,
"NCHW16c"
)
y
=
relay
.
nn
.
conv2d
(
y
,
kernel
,
data_layout
=
'NCHW16c'
,
kernel_layout
=
"HWIO"
,
kernel_size
=
(
3
,
3
))
y
=
relay
.
add
(
b
,
y
)
y
=
relay
.
nn
.
relu
(
y
)
y
=
relay
.
multiply
(
multiplier1
,
y
)
y
=
relay
.
multiply
(
y
,
multiplier2
)
y
=
relay
.
layout_transform
(
y
,
"NCHW16c"
,
"NHWC"
)
y
=
relay
.
Function
(
analysis
.
free_vars
(
y
),
y
)
return
y
with
TempOpAttr
(
"nn.conv2d"
,
"FTVMAlterOpLayout"
,
alter_conv2d
):
a
=
before
()
a
=
run_opt_pass
(
a
,
[
transform
.
CanonicalizeOps
(),
transform
.
AlterOpLayout
()])
b
=
run_opt_pass
(
expected
(),
transform
.
InferType
())
assert
analysis
.
alpha_equal
(
a
,
b
),
"Actual =
\n
"
+
str
(
a
)
def
test_alter_layout_scalar
():
def
test_alter_layout_scalar
():
"""Test alternating the layout of a conv2d.
"""Test alternating the layout of a conv2d.
The layout of broadcast operators and the weight should be changed accordingly.
The layout of broadcast operators and the weight should be changed accordingly.
...
@@ -980,6 +1044,7 @@ if __name__ == "__main__":
...
@@ -980,6 +1044,7 @@ if __name__ == "__main__":
test_alter_layout_dual_path
()
test_alter_layout_dual_path
()
test_alter_layout_resnet
()
test_alter_layout_resnet
()
test_alter_layout_broadcast_op
()
test_alter_layout_broadcast_op
()
test_alter_layout_broadcast_scalar_op
()
test_alter_layout_scalar
()
test_alter_layout_scalar
()
test_alter_layout_concatenate
()
test_alter_layout_concatenate
()
test_alter_layout_nchw_upsamping_op
()
test_alter_layout_nchw_upsamping_op
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment