Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
6514849f
Commit
6514849f
authored
Oct 18, 2018
by
Siva
Committed by
Tianqi Chen
Oct 17, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[NNVM/TOPI][OP] Split : default axis to 0 and allow negative values - nump… (#1883)
parent
7631873b
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
33 additions
and
116 deletions
+33
-116
nnvm/include/nnvm/top/tensor.h
+1
-1
nnvm/src/top/tensor/transform.cc
+18
-9
nnvm/tests/python/unittest/test_infer_shape.py
+4
-0
topi/python/topi/transform.py
+9
-105
topi/tests/python/test_topi_transform.py
+1
-1
No files found.
nnvm/include/nnvm/top/tensor.h
View file @
6514849f
...
...
@@ -43,7 +43,7 @@ struct SplitParam : public dmlc::Parameter<SplitParam> {
DMLC_DECLARE_PARAMETER
(
SplitParam
)
{
DMLC_DECLARE_FIELD
(
indices_or_sections
)
.
describe
(
"Number of outputs to be splitted"
);
DMLC_DECLARE_FIELD
(
axis
).
set_
lower_bound
(
0
).
set_
default
(
1
)
DMLC_DECLARE_FIELD
(
axis
).
set_default
(
1
)
.
describe
(
"the axis to be splitted."
);
}
};
...
...
nnvm/src/top/tensor/transform.cc
View file @
6514849f
...
...
@@ -344,14 +344,23 @@ inline bool SplitInferShape(const NodeAttrs& attrs,
const
TShape
&
dshape
=
(
*
in_shape
)[
0
];
if
(
dshape
.
ndim
()
==
0
)
return
false
;
auto
axis
=
param
.
axis
;
if
(
axis
<
0
)
{
axis
+=
dshape
.
ndim
();
}
CHECK_LT
(
axis
,
dshape
.
ndim
())
<<
"axis should be within input dimension range but got "
<<
axis
;
CHECK_GT
(
axis
,
-
1
)
<<
"axis should be within input dimension range but got "
<<
axis
;
if
(
param
.
equal_split
)
{
int
num_outputs
=
param
.
indices_or_sections
[
0
];
CHECK_EQ
(
out_shape
->
size
(),
static_cast
<
size_t
>
(
num_outputs
));
CHECK_LT
(
param
.
axis
,
dshape
.
ndim
());
TShape
oshape
=
dshape
;
CHECK_EQ
(
oshape
[
param
.
axis
]
%
num_outputs
,
0
)
<<
"indices_or_sections need to be able to divide input.shape[axis]"
;
oshape
[
param
.
axis
]
/=
num_outputs
;
CHECK_EQ
(
oshape
[
axis
]
%
num_outputs
,
0
)
<<
"indices_or_sections need to be able to divide input.shape[axis] got sections "
<<
num_outputs
<<
" and dimension "
<<
oshape
[
axis
];
oshape
[
axis
]
/=
num_outputs
;
for
(
size_t
i
=
0
;
i
<
out_shape
->
size
();
++
i
)
{
NNVM_ASSIGN_OUTPUT_SHAPE
(
attrs
,
*
out_shape
,
i
,
oshape
);
...
...
@@ -359,19 +368,19 @@ inline bool SplitInferShape(const NodeAttrs& attrs,
}
else
{
dim_t
num_outputs
=
param
.
indices_or_sections
.
ndim
()
+
1
;
CHECK_EQ
(
out_shape
->
size
(),
static_cast
<
size_t
>
(
num_outputs
));
CHECK_LT
(
param
.
axis
,
dshape
.
ndim
());
TShape
oshape
=
dshape
;
dim_t
begin
=
0
;
for
(
dim_t
i
=
0
;
i
<
num_outputs
-
1
;
++
i
)
{
CHECK_GT
(
param
.
indices_or_sections
[
i
],
begin
)
<<
"indices_or_sections need to be a sorted ascending list"
;
oshape
[
param
.
axis
]
=
param
.
indices_or_sections
[
i
]
-
begin
;
<<
"indices_or_sections need to be a sorted ascending list got "
<<
param
.
indices_or_sections
;
oshape
[
axis
]
=
param
.
indices_or_sections
[
i
]
-
begin
;
begin
=
param
.
indices_or_sections
[
i
];
NNVM_ASSIGN_OUTPUT_SHAPE
(
attrs
,
*
out_shape
,
i
,
oshape
);
}
CHECK_LT
(
begin
,
dshape
[
param
.
axis
])
CHECK_LT
(
begin
,
dshape
[
axis
])
<<
"The sum of sections must match the input.shape[axis]"
;
oshape
[
param
.
axis
]
=
dshape
[
param
.
axis
]
-
begin
;
oshape
[
axis
]
=
dshape
[
axis
]
-
begin
;
NNVM_ASSIGN_OUTPUT_SHAPE
(
attrs
,
*
out_shape
,
num_outputs
-
1
,
oshape
);
}
return
true
;
...
...
nnvm/tests/python/unittest/test_infer_shape.py
View file @
6514849f
...
...
@@ -84,6 +84,10 @@ def test_split():
sdict
=
infer_shape
(
z
)
assert
(
sdict
[
"y"
][
0
]
==
[
10
,
10
])
assert
(
sdict
[
"y"
][
1
]
==
[
10
,
10
])
z
=
sym
.
split
(
x1
,
indices_or_sections
=
[
6
],
axis
=-
1
,
name
=
"y"
)
sdict
=
infer_shape
(
z
)
assert
(
sdict
[
"y"
][
0
]
==
[
10
,
6
])
assert
(
sdict
[
"y"
][
1
]
==
[
10
,
14
])
def
test_batchnorm
():
...
...
topi/python/topi/transform.py
View file @
6514849f
...
...
@@ -4,7 +4,6 @@ from __future__ import absolute_import as _abs
import
tvm
import
topi
from
.
import
tag
from
.util
import
ravel_index
,
unravel_index
,
get_const_int
,
get_const_tuple
from
.
import
cpp
@tvm.tag_scope
(
tag
=
tag
.
BROADCAST
)
...
...
@@ -23,12 +22,7 @@ def expand_dims(a, axis, num_newaxis=1):
-------
ret : tvm.Tensor
"""
axis
=
len
(
a
.
shape
)
+
axis
+
1
if
axis
<
0
else
axis
new_shape
=
a
.
shape
[:
axis
]
+
([
1
]
*
num_newaxis
)
+
a
.
shape
[
axis
:]
def
_compute
(
*
indices
):
idx
=
indices
[:
axis
]
+
indices
[
axis
+
num_newaxis
:]
return
a
(
*
idx
)
return
tvm
.
compute
(
new_shape
,
_compute
)
return
cpp
.
expand_dims
(
a
,
axis
,
num_newaxis
)
@tvm.tag_scope
(
tag
=
tag
.
BROADCAST
)
...
...
@@ -101,15 +95,8 @@ def transpose(a, axes=None):
-------
ret : tvm.Tensor
"""
ndim
=
len
(
a
.
shape
)
axes
=
axes
if
axes
else
tuple
(
reversed
(
range
(
ndim
)))
new_shape
=
[
a
.
shape
[
x
]
for
x
in
axes
]
def
_compute
(
*
indices
):
idx
=
[
1
]
*
len
(
axes
)
for
i
,
k
in
enumerate
(
axes
):
idx
[
k
]
=
indices
[
i
]
return
a
(
*
idx
)
return
tvm
.
compute
(
new_shape
,
_compute
)
return
cpp
.
transpose
(
a
,
axes
)
def
flip
(
a
,
axis
=
0
):
"""Flip/reverse elements of an array in a particular axis.
...
...
@@ -153,6 +140,7 @@ def strided_slice(a, begin, end, strides=None):
"""
return
cpp
.
strided_slice
(
a
,
begin
,
end
,
strides
)
@tvm.tag_scope
(
tag
=
tag
.
INJECTIVE
)
def
reshape
(
a
,
newshape
):
"""Reshape the array
...
...
@@ -168,10 +156,7 @@ def reshape(a, newshape):
-------
ret : tvm.Tensor
"""
ndim
=
len
(
a
.
shape
)
a_shape
=
[
a
.
shape
[
i
]
for
i
in
range
(
ndim
)]
return
tvm
.
compute
(
newshape
,
lambda
*
indices
:
a
(
*
unravel_index
(
ravel_index
(
indices
,
newshape
),
a_shape
)))
return
cpp
.
reshape
(
a
,
newshape
)
@tvm.tag_scope
(
tag
=
tag
.
INJECTIVE
)
...
...
@@ -190,41 +175,7 @@ def squeeze(a, axis=None):
-------
squeezed : tvm.Tensor
"""
a_ndim
=
len
(
a
.
shape
)
a_shape
=
get_const_tuple
(
a
.
shape
)
if
axis
is
None
:
axis
=
[]
for
i
,
ele
in
enumerate
(
a_shape
):
if
ele
==
1
:
axis
.
append
(
i
)
else
:
if
isinstance
(
axis
,
int
):
axis
=
axis
+
a_ndim
if
axis
<
0
else
axis
assert
a_shape
[
axis
]
==
1
axis
=
[
axis
]
else
:
axis
=
[
ele
+
a_ndim
if
ele
<
0
else
ele
for
ele
in
axis
]
for
ele
in
axis
:
assert
a_shape
[
ele
]
==
1
out_shape
=
[]
search_axis
=
set
(
axis
)
for
i
,
a_dim
in
enumerate
(
a_shape
):
if
i
not
in
search_axis
:
out_shape
.
append
(
a_dim
)
if
not
out_shape
:
out_shape
.
append
(
1
)
def
_compute
(
*
indices
):
real_indices
=
[]
flag
=
0
for
i
in
range
(
a_ndim
):
if
i
not
in
search_axis
:
real_indices
.
append
(
indices
[
i
-
flag
])
else
:
real_indices
.
append
(
0
)
flag
+=
1
return
a
(
*
real_indices
)
return
tvm
.
compute
(
out_shape
,
_compute
)
return
cpp
.
squeeze
(
a
,
axis
)
@tvm.tag_scope
(
tag
=
tag
.
INJECTIVE
)
...
...
@@ -243,25 +194,7 @@ def concatenate(a_tuple, axis=0):
-------
ret : tvm.Tensor
"""
assert
isinstance
(
a_tuple
,
(
list
,
tuple
))
if
axis
<
0
:
axis
+=
len
(
a_tuple
[
0
]
.
shape
)
assert
axis
<
len
(
a_tuple
[
0
]
.
shape
)
axis_sizes
=
[
a_tuple
[
i
]
.
shape
[
axis
]
for
i
in
range
(
len
(
a_tuple
))]
out_shape
=
[
a_tuple
[
0
]
.
shape
[
i
]
for
i
in
range
(
0
,
axis
)]
+
[
sum
(
axis_sizes
)]
\
+
[
a_tuple
[
0
]
.
shape
[
i
]
for
i
in
range
(
axis
+
1
,
len
(
a_tuple
[
0
]
.
shape
))]
out_shape
[
axis
]
=
tvm
.
ir_pass
.
Simplify
(
out_shape
[
axis
])
def
_compute
(
*
indices
):
ret
=
a_tuple
[
0
](
*
indices
)
ind
=
indices
[
axis
]
for
i
in
range
(
len
(
a_tuple
)
-
1
):
ind
-=
axis_sizes
[
i
]
ret
=
tvm
.
select
(
ind
>=
0
,
a_tuple
[
i
+
1
](
*
(
indices
[
0
:
axis
]
+
(
ind
,)
+
indices
[
axis
+
1
:])),
ret
)
return
ret
return
tvm
.
compute
(
out_shape
,
_compute
)
return
cpp
.
concatenate
(
a_tuple
,
axis
)
@tvm.tag_scope
(
tag
=
tag
.
INJECTIVE
)
...
...
@@ -280,37 +213,7 @@ def split(ary, indices_or_sections, axis=0):
-------
ret : tuple of tvm.Tensor
"""
def
_compute
(
begin
,
*
indices
):
real_indices
=
indices
[:
axis
]
+
(
indices
[
axis
]
+
begin
,
)
+
indices
[
axis
+
1
:]
return
ary
(
*
real_indices
)
if
axis
<
0
:
axis
+=
len
(
ary
.
shape
)
src_axis_size
=
get_const_int
(
ary
.
shape
[
axis
])
if
isinstance
(
indices_or_sections
,
int
):
assert
indices_or_sections
>
0
assert
src_axis_size
%
indices_or_sections
==
0
seg_size
=
src_axis_size
//
indices_or_sections
begin_ids
=
[
seg_size
*
i
for
i
in
range
(
indices_or_sections
)]
elif
isinstance
(
indices_or_sections
,
(
tuple
,
list
)):
assert
tuple
(
indices_or_sections
)
==
tuple
(
sorted
(
indices_or_sections
)),
\
"Should be sorted, recieved
%
s"
%
str
(
indices_or_sections
)
begin_ids
=
[
0
]
+
list
(
indices_or_sections
)
else
:
raise
NotImplementedError
()
out_shapes
=
[]
for
i
in
range
(
len
(
begin_ids
)):
if
i
==
len
(
begin_ids
)
-
1
:
out_axis_size
=
src_axis_size
-
begin_ids
[
i
]
else
:
out_axis_size
=
begin_ids
[
i
+
1
]
-
begin_ids
[
i
]
out_shapes
.
append
([
ary
.
shape
[
i
]
for
i
in
range
(
axis
)]
+
[
out_axis_size
]
+
\
[
ary
.
shape
[
i
]
for
i
in
range
(
axis
+
1
,
len
(
ary
.
shape
))])
# pylint: disable=cell-var-from-loop
return
[
tvm
.
compute
(
out_shape
,
lambda
*
indices
:
_compute
(
begin_id
,
*
indices
),
name
=
"s
%
d"
%
i
)
for
i
,
(
out_shape
,
begin_id
)
in
enumerate
(
zip
(
out_shapes
,
begin_ids
))]
# pylint: enable=cell-var-from-loop
return
cpp
.
split
(
ary
,
indices_or_sections
,
axis
)
def
take
(
a
,
indices
,
axis
=
None
):
...
...
@@ -336,6 +239,7 @@ def take(a, indices, axis=None):
return
cpp
.
take
(
a
,
indices
)
return
cpp
.
take
(
a
,
indices
,
int
(
axis
))
def
matmul
(
a
,
b
,
transp_a
=
False
,
transp_b
=
False
):
"""
Creates an operation that calculates a matrix multiplication (row-major notation):
...
...
topi/tests/python/test_topi_transform.py
View file @
6514849f
...
...
@@ -139,7 +139,7 @@ def verify_split(src_shape, indices_or_sections, axis):
with
tvm
.
target
.
create
(
device
):
s
=
topi
.
generic
.
schedule_injective
(
tensor_l
)
foo
=
tvm
.
build
(
s
,
[
A
]
+
tensor_l
,
device
,
name
=
"split"
)
foo
=
tvm
.
build
(
s
,
[
A
]
+
list
(
tensor_l
)
,
device
,
name
=
"split"
)
data_npy
=
np
.
random
.
normal
(
size
=
src_shape
)
.
astype
(
A
.
dtype
)
out_npys
=
np
.
split
(
data_npy
,
indices_or_sections
,
axis
=
axis
)
data_nd
=
tvm
.
nd
.
array
(
data_npy
,
ctx
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment