Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
9d002e8e
Commit
9d002e8e
authored
Apr 28, 2019
by
Yizhi Liu
Committed by
Tianqi Chen
Apr 28, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Lang] Fix undef BijectiveLayout and add scalar layout support (#3105)
parent
73f87ae0
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
36 additions
and
17 deletions
+36
-17
include/tvm/data_layout.h
+4
-2
src/lang/data_layout.cc
+8
-1
src/relay/pass/alter_op_layout.h
+8
-4
tests/python/relay/test_pass_alter_op_layout.py
+10
-10
tests/python/unittest/test_lang_data_layout.py
+6
-0
No files found.
include/tvm/data_layout.h
View file @
9d002e8e
...
@@ -94,12 +94,13 @@ class Layout;
...
@@ -94,12 +94,13 @@ class Layout;
// Internal node container Buffer
// Internal node container Buffer
class
LayoutNode
:
public
Node
{
class
LayoutNode
:
public
Node
{
public
:
public
:
/*! \brief string representation of layout */
/*! \brief string representation of layout
, "" for scalar.
*/
std
::
string
name
;
std
::
string
name
;
/*! \brief specify each axis of the layout,
/*! \brief specify each axis of the layout,
* in which the variable name is the name of the axis.
* in which the variable name is the name of the axis.
* The IterVar's extent indicates the size of the axis,
* The IterVar's extent indicates the size of the axis,
* it is a variable for a primal axis, but a constant for a subordinate axis.
* it is a variable for a primal axis, but a constant for a subordinate axis.
* Empty for scalar's layout.
*/
*/
Array
<
IterVar
>
axes
;
Array
<
IterVar
>
axes
;
...
@@ -122,6 +123,7 @@ class LayoutNode : public Node {
...
@@ -122,6 +123,7 @@ class LayoutNode : public Node {
* For example, NCHW16c can describe a 5-D tensor of
* For example, NCHW16c can describe a 5-D tensor of
* [batch_size, channel, height, width, channel_block].
* [batch_size, channel, height, width, channel_block].
* Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel).
* Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel).
* Layout for scalar is defined, while both its name and axes have size 0.
*/
*/
class
Layout
:
public
NodeRef
{
class
Layout
:
public
NodeRef
{
public
:
public
:
...
@@ -175,7 +177,7 @@ class Layout : public NodeRef {
...
@@ -175,7 +177,7 @@ class Layout : public NodeRef {
* that starts at dimension \p pos and spans \p len dimensions
* that starts at dimension \p pos and spans \p len dimensions
* (or until the end of the layout, whichever comes first).
* (or until the end of the layout, whichever comes first).
* \param pos The start position.
* \param pos The start position.
* \param len The length of the sub-layout.
* \param len The length of the sub-layout.
if 0, return layout of scalar
* \return A newly constructed Layout object.
* \return A newly constructed Layout object.
*/
*/
Layout
SubLayout
(
size_t
pos
,
size_t
len
)
const
;
Layout
SubLayout
(
size_t
pos
,
size_t
len
)
const
;
...
...
src/lang/data_layout.cc
View file @
9d002e8e
...
@@ -88,12 +88,14 @@ Layout::Layout(const Array<IterVar>& axes) {
...
@@ -88,12 +88,14 @@ Layout::Layout(const Array<IterVar>& axes) {
}
}
Layout
::
Layout
(
const
std
::
string
&
name
)
{
// NOLINT(*)
Layout
::
Layout
(
const
std
::
string
&
name
)
{
// NOLINT(*)
if
(
name
.
empty
()
||
name
==
"__undef__"
)
return
;
if
(
name
==
"__undef__"
)
return
;
node_
=
make_node
<
LayoutNode
>
();
node_
=
make_node
<
LayoutNode
>
();
LayoutNode
*
node
=
operator
->
();
LayoutNode
*
node
=
operator
->
();
node
->
name
=
name
;
node
->
name
=
name
;
if
(
name
.
empty
())
return
;
// scalar
// parse layout string
// parse layout string
int32_t
factor
=
0
;
int32_t
factor
=
0
;
for
(
char
c
:
name
)
{
for
(
char
c
:
name
)
{
...
@@ -146,6 +148,7 @@ Layout LayoutNode::make(const std::string& layout) {
...
@@ -146,6 +148,7 @@ Layout LayoutNode::make(const std::string& layout) {
Layout
Layout
::
SubLayout
(
size_t
pos
,
size_t
len
)
const
{
Layout
Layout
::
SubLayout
(
size_t
pos
,
size_t
len
)
const
{
if
(
!
defined
()
||
pos
>
ndim
())
return
Layout
::
Undef
();
if
(
!
defined
()
||
pos
>
ndim
())
return
Layout
::
Undef
();
if
(
len
==
0
)
return
Layout
(
Array
<
IterVar
>
());
if
(
pos
+
len
>
ndim
())
len
=
ndim
()
-
pos
;
if
(
pos
+
len
>
ndim
())
len
=
ndim
()
-
pos
;
Array
<
IterVar
>
new_layout
;
Array
<
IterVar
>
new_layout
;
const
auto
axes
=
operator
->
()
->
axes
;
const
auto
axes
=
operator
->
()
->
axes
;
...
@@ -195,6 +198,10 @@ int32_t Layout::FactorOf(const LayoutAxis& axis) const {
...
@@ -195,6 +198,10 @@ int32_t Layout::FactorOf(const LayoutAxis& axis) const {
inline
bool
GetStoreRule
(
Array
<
Expr
>*
rule
,
inline
bool
GetStoreRule
(
Array
<
Expr
>*
rule
,
const
Layout
&
src_layout
,
const
Layout
&
src_layout
,
const
Layout
&
dst_layout
)
{
const
Layout
&
dst_layout
)
{
if
(
!
src_layout
.
defined
()
||
src_layout
.
name
().
empty
()
||
!
dst_layout
.
defined
()
||
dst_layout
.
name
().
empty
())
{
return
false
;
}
for
(
size_t
i
=
0
;
i
<
dst_layout
.
ndim
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
dst_layout
.
ndim
();
++
i
)
{
const
auto
&
store_axis
=
dst_layout
[
i
];
const
auto
&
store_axis
=
dst_layout
[
i
];
const
IterVar
&
store_axis_impl
=
dst_layout
->
axes
[
i
];
const
IterVar
&
store_axis_impl
=
dst_layout
->
axes
[
i
];
...
...
src/relay/pass/alter_op_layout.h
View file @
9d002e8e
...
@@ -97,15 +97,19 @@ inline Array<Array<Layout> > BinaryBroadcastLayout(const Attrs& attrs,
...
@@ -97,15 +97,19 @@ inline Array<Array<Layout> > BinaryBroadcastLayout(const Attrs& attrs,
if
(
old_in_shapes
[
defined_idx
].
size
()
>=
old_in_shapes
[
undef_idx
].
size
())
{
if
(
old_in_shapes
[
defined_idx
].
size
()
>=
old_in_shapes
[
undef_idx
].
size
())
{
layouts
.
Set
(
undef_idx
,
layouts
.
Set
(
undef_idx
,
layouts
[
defined_idx
].
SubLayout
(
layouts
[
defined_idx
].
SubLayout
(
old_in_shapes
[
defined_idx
].
size
()
-
old_in_shapes
[
undef_idx
].
size
(),
old_in_shapes
[
defined_idx
].
size
()
-
old_in_shapes
[
undef_idx
].
size
(),
old_in_shapes
[
undef_idx
].
size
()));
old_in_shapes
[
undef_idx
].
size
()));
return
Array
<
Array
<
Layout
>
>
{
layouts
,
{
layouts
[
defined_idx
]}};
return
Array
<
Array
<
Layout
>
>
{
layouts
,
{
layouts
[
defined_idx
]}};
}
else
{
}
else
{
// only know the tensor with smaller dimensions,
// only know the tensor with smaller dimensions,
// so we cannot infer the final broadcasted output.
// so we cannot infer the final broadcasted output.
// fails in this case.
// fails in this case.
return
Array
<
Array
<
Layout
>
>
{{
Layout
::
Undef
()},
{
Layout
::
Undef
()}};
return
Array
<
Array
<
Layout
>
>
{{
Layout
::
Undef
()},
{
Layout
::
Undef
()}};
}
}
}
else
if
(
layouts
[
0
].
defined
()
&&
layouts
[
1
].
defined
()
&&
(
layouts
[
0
].
ndim
()
==
0
||
layouts
[
1
].
ndim
()
==
0
))
{
int
scalar
=
layouts
[
0
].
ndim
()
==
0
?
0
:
1
;
return
Array
<
Array
<
Layout
>
>
{
layouts
,
{
layouts
[
1
-
scalar
]}};
}
else
{
}
else
{
// try to broadcast the tensors to the larger dimension
// try to broadcast the tensors to the larger dimension
int
large_idx
=
layouts
[
0
].
ndim_primal
()
>=
layouts
[
1
].
ndim_primal
()
?
0
:
1
;
int
large_idx
=
layouts
[
0
].
ndim_primal
()
>=
layouts
[
1
].
ndim_primal
()
?
0
:
1
;
...
...
tests/python/relay/test_pass_alter_op_layout.py
View file @
9d002e8e
...
@@ -57,7 +57,7 @@ def test_alter_op():
...
@@ -57,7 +57,7 @@ def test_alter_op():
b
=
expected
()
b
=
expected
()
b
=
infer_type
(
b
)
b
=
infer_type
(
b
)
assert
(
alpha_equal
(
a
,
b
)
)
assert
alpha_equal
(
a
,
b
),
"Actual =
\n
"
+
str
(
a
)
def
test_alter_return_none
():
def
test_alter_return_none
():
...
@@ -81,7 +81,7 @@ def test_alter_return_none():
...
@@ -81,7 +81,7 @@ def test_alter_return_none():
b
=
before
()
b
=
before
()
b
=
infer_type
(
b
)
b
=
infer_type
(
b
)
assert
(
alpha_equal
(
a
,
b
)
)
assert
alpha_equal
(
a
,
b
),
"Actual =
\n
"
+
str
(
a
)
assert
(
called
[
0
])
assert
(
called
[
0
])
...
@@ -147,7 +147,7 @@ def test_alter_layout():
...
@@ -147,7 +147,7 @@ def test_alter_layout():
b
=
expected
()
b
=
expected
()
b
=
infer_type
(
b
)
b
=
infer_type
(
b
)
assert
(
alpha_equal
(
a
,
b
)
)
assert
alpha_equal
(
a
,
b
),
"Actual =
\n
"
+
str
(
a
)
def
test_alter_layout_dual_path
():
def
test_alter_layout_dual_path
():
...
@@ -213,7 +213,7 @@ def test_alter_layout_dual_path():
...
@@ -213,7 +213,7 @@ def test_alter_layout_dual_path():
b
=
expected
()
b
=
expected
()
b
=
infer_type
(
b
)
b
=
infer_type
(
b
)
assert
(
alpha_equal
(
a
,
b
)
)
assert
alpha_equal
(
a
,
b
),
"Actual =
\n
"
+
str
(
a
)
def
test_alter_layout_resnet
():
def
test_alter_layout_resnet
():
"""Test alternating the layout of a residual block
"""Test alternating the layout of a residual block
...
@@ -273,7 +273,7 @@ def test_alter_layout_resnet():
...
@@ -273,7 +273,7 @@ def test_alter_layout_resnet():
b
=
expected
()
b
=
expected
()
b
=
infer_type
(
b
)
b
=
infer_type
(
b
)
assert
(
alpha_equal
(
a
,
b
)
)
assert
alpha_equal
(
a
,
b
),
"Actual =
\n
"
+
str
(
a
)
def
test_alter_layout_broadcast_op
():
def
test_alter_layout_broadcast_op
():
...
@@ -323,7 +323,7 @@ def test_alter_layout_broadcast_op():
...
@@ -323,7 +323,7 @@ def test_alter_layout_broadcast_op():
b
=
expected
()
b
=
expected
()
b
=
infer_type
(
b
)
b
=
infer_type
(
b
)
assert
(
alpha_equal
(
a
,
b
)
)
assert
alpha_equal
(
a
,
b
),
"Actual =
\n
"
+
str
(
a
)
def
test_alter_layout_scalar
():
def
test_alter_layout_scalar
():
"""Test alternating the layout of a conv2d.
"""Test alternating the layout of a conv2d.
...
@@ -370,7 +370,7 @@ def test_alter_layout_scalar():
...
@@ -370,7 +370,7 @@ def test_alter_layout_scalar():
b
=
expected
()
b
=
expected
()
b
=
infer_type
(
b
)
b
=
infer_type
(
b
)
assert
(
alpha_equal
(
a
,
b
)
)
assert
alpha_equal
(
a
,
b
),
"Actual =
\n
"
+
str
(
a
)
def
test_alter_layout_concatenate
():
def
test_alter_layout_concatenate
():
""" """
""" """
...
@@ -425,7 +425,7 @@ def test_alter_layout_concatenate():
...
@@ -425,7 +425,7 @@ def test_alter_layout_concatenate():
b
=
expected
()
b
=
expected
()
b
=
infer_type
(
b
)
b
=
infer_type
(
b
)
assert
(
alpha_equal
(
a
,
b
)
)
assert
alpha_equal
(
a
,
b
),
"Actual =
\n
"
+
str
(
a
)
def
test_alter_layout_nchw_upsamping_op
():
def
test_alter_layout_nchw_upsamping_op
():
...
@@ -469,7 +469,7 @@ def test_alter_layout_nchw_upsamping_op():
...
@@ -469,7 +469,7 @@ def test_alter_layout_nchw_upsamping_op():
b
=
expected
()
b
=
expected
()
b
=
infer_type
(
b
)
b
=
infer_type
(
b
)
assert
(
alpha_equal
(
a
,
b
)
)
assert
alpha_equal
(
a
,
b
),
"Actual =
\n
"
+
str
(
a
)
def
test_alter_layout_strided_slice
():
def
test_alter_layout_strided_slice
():
...
@@ -511,7 +511,7 @@ def test_alter_layout_strided_slice():
...
@@ -511,7 +511,7 @@ def test_alter_layout_strided_slice():
b
=
expected
()
b
=
expected
()
b
=
infer_type
(
b
)
b
=
infer_type
(
b
)
assert
(
alpha_equal
(
a
,
b
)
)
assert
alpha_equal
(
a
,
b
),
"Actual =
\n
"
+
str
(
a
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
tests/python/unittest/test_lang_data_layout.py
View file @
9d002e8e
...
@@ -52,6 +52,12 @@ def test_layout():
...
@@ -52,6 +52,12 @@ def test_layout():
def
test_bilayout_convertible
():
def
test_bilayout_convertible
():
# not convertible
# not convertible
assert
tvm
.
bijective_layout
(
"NCHW"
,
"ABCD"
)
is
None
assert
tvm
.
bijective_layout
(
"NCHW"
,
"ABCD"
)
is
None
assert
tvm
.
bijective_layout
(
"__undef__"
,
"NCHW"
)
is
None
assert
tvm
.
bijective_layout
(
"NCHW"
,
"__undef__"
)
is
None
assert
tvm
.
bijective_layout
(
"__undef__"
,
"__undef__"
)
is
None
assert
tvm
.
bijective_layout
(
""
,
"NCHW"
)
is
None
assert
tvm
.
bijective_layout
(
"NCHW"
,
""
)
is
None
assert
tvm
.
bijective_layout
(
""
,
""
)
is
None
# convertible
# convertible
assert
tvm
.
bijective_layout
(
"NCHW"
,
"NCHW16c"
)
is
not
None
assert
tvm
.
bijective_layout
(
"NCHW"
,
"NCHW16c"
)
is
not
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment