Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
6514849f
Commit
6514849f
authored
Oct 18, 2018
by
Siva
Committed by
Tianqi Chen
Oct 17, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[NNVM/TOPI][OP] Split : default axis to 0 and allow negative values - nump… (#1883)
parent
7631873b
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
33 additions
and
116 deletions
+33
-116
nnvm/include/nnvm/top/tensor.h
+1
-1
nnvm/src/top/tensor/transform.cc
+18
-9
nnvm/tests/python/unittest/test_infer_shape.py
+4
-0
topi/python/topi/transform.py
+9
-105
topi/tests/python/test_topi_transform.py
+1
-1
No files found.
nnvm/include/nnvm/top/tensor.h
View file @
6514849f
...
@@ -43,7 +43,7 @@ struct SplitParam : public dmlc::Parameter<SplitParam> {
...
@@ -43,7 +43,7 @@ struct SplitParam : public dmlc::Parameter<SplitParam> {
DMLC_DECLARE_PARAMETER
(
SplitParam
)
{
DMLC_DECLARE_PARAMETER
(
SplitParam
)
{
DMLC_DECLARE_FIELD
(
indices_or_sections
)
DMLC_DECLARE_FIELD
(
indices_or_sections
)
.
describe
(
"Number of outputs to be splitted"
);
.
describe
(
"Number of outputs to be splitted"
);
DMLC_DECLARE_FIELD
(
axis
).
set_
lower_bound
(
0
).
set_
default
(
1
)
DMLC_DECLARE_FIELD
(
axis
).
set_default
(
1
)
.
describe
(
"the axis to be splitted."
);
.
describe
(
"the axis to be splitted."
);
}
}
};
};
...
...
nnvm/src/top/tensor/transform.cc
View file @
6514849f
...
@@ -344,14 +344,23 @@ inline bool SplitInferShape(const NodeAttrs& attrs,
...
@@ -344,14 +344,23 @@ inline bool SplitInferShape(const NodeAttrs& attrs,
const
TShape
&
dshape
=
(
*
in_shape
)[
0
];
const
TShape
&
dshape
=
(
*
in_shape
)[
0
];
if
(
dshape
.
ndim
()
==
0
)
return
false
;
if
(
dshape
.
ndim
()
==
0
)
return
false
;
auto
axis
=
param
.
axis
;
if
(
axis
<
0
)
{
axis
+=
dshape
.
ndim
();
}
CHECK_LT
(
axis
,
dshape
.
ndim
())
<<
"axis should be within input dimension range but got "
<<
axis
;
CHECK_GT
(
axis
,
-
1
)
<<
"axis should be within input dimension range but got "
<<
axis
;
if
(
param
.
equal_split
)
{
if
(
param
.
equal_split
)
{
int
num_outputs
=
param
.
indices_or_sections
[
0
];
int
num_outputs
=
param
.
indices_or_sections
[
0
];
CHECK_EQ
(
out_shape
->
size
(),
static_cast
<
size_t
>
(
num_outputs
));
CHECK_EQ
(
out_shape
->
size
(),
static_cast
<
size_t
>
(
num_outputs
));
CHECK_LT
(
param
.
axis
,
dshape
.
ndim
());
TShape
oshape
=
dshape
;
TShape
oshape
=
dshape
;
CHECK_EQ
(
oshape
[
param
.
axis
]
%
num_outputs
,
0
)
CHECK_EQ
(
oshape
[
axis
]
%
num_outputs
,
0
)
<<
"indices_or_sections need to be able to divide input.shape[axis]"
;
<<
"indices_or_sections need to be able to divide input.shape[axis] got sections "
oshape
[
param
.
axis
]
/=
num_outputs
;
<<
num_outputs
<<
" and dimension "
<<
oshape
[
axis
];
oshape
[
axis
]
/=
num_outputs
;
for
(
size_t
i
=
0
;
i
<
out_shape
->
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
out_shape
->
size
();
++
i
)
{
NNVM_ASSIGN_OUTPUT_SHAPE
(
attrs
,
*
out_shape
,
i
,
oshape
);
NNVM_ASSIGN_OUTPUT_SHAPE
(
attrs
,
*
out_shape
,
i
,
oshape
);
...
@@ -359,19 +368,19 @@ inline bool SplitInferShape(const NodeAttrs& attrs,
...
@@ -359,19 +368,19 @@ inline bool SplitInferShape(const NodeAttrs& attrs,
}
else
{
}
else
{
dim_t
num_outputs
=
param
.
indices_or_sections
.
ndim
()
+
1
;
dim_t
num_outputs
=
param
.
indices_or_sections
.
ndim
()
+
1
;
CHECK_EQ
(
out_shape
->
size
(),
static_cast
<
size_t
>
(
num_outputs
));
CHECK_EQ
(
out_shape
->
size
(),
static_cast
<
size_t
>
(
num_outputs
));
CHECK_LT
(
param
.
axis
,
dshape
.
ndim
());
TShape
oshape
=
dshape
;
TShape
oshape
=
dshape
;
dim_t
begin
=
0
;
dim_t
begin
=
0
;
for
(
dim_t
i
=
0
;
i
<
num_outputs
-
1
;
++
i
)
{
for
(
dim_t
i
=
0
;
i
<
num_outputs
-
1
;
++
i
)
{
CHECK_GT
(
param
.
indices_or_sections
[
i
],
begin
)
CHECK_GT
(
param
.
indices_or_sections
[
i
],
begin
)
<<
"indices_or_sections need to be a sorted ascending list"
;
<<
"indices_or_sections need to be a sorted ascending list got "
oshape
[
param
.
axis
]
=
param
.
indices_or_sections
[
i
]
-
begin
;
<<
param
.
indices_or_sections
;
oshape
[
axis
]
=
param
.
indices_or_sections
[
i
]
-
begin
;
begin
=
param
.
indices_or_sections
[
i
];
begin
=
param
.
indices_or_sections
[
i
];
NNVM_ASSIGN_OUTPUT_SHAPE
(
attrs
,
*
out_shape
,
i
,
oshape
);
NNVM_ASSIGN_OUTPUT_SHAPE
(
attrs
,
*
out_shape
,
i
,
oshape
);
}
}
CHECK_LT
(
begin
,
dshape
[
param
.
axis
])
CHECK_LT
(
begin
,
dshape
[
axis
])
<<
"The sum of sections must match the input.shape[axis]"
;
<<
"The sum of sections must match the input.shape[axis]"
;
oshape
[
param
.
axis
]
=
dshape
[
param
.
axis
]
-
begin
;
oshape
[
axis
]
=
dshape
[
axis
]
-
begin
;
NNVM_ASSIGN_OUTPUT_SHAPE
(
attrs
,
*
out_shape
,
num_outputs
-
1
,
oshape
);
NNVM_ASSIGN_OUTPUT_SHAPE
(
attrs
,
*
out_shape
,
num_outputs
-
1
,
oshape
);
}
}
return
true
;
return
true
;
...
...
nnvm/tests/python/unittest/test_infer_shape.py
View file @
6514849f
...
@@ -84,6 +84,10 @@ def test_split():
...
@@ -84,6 +84,10 @@ def test_split():
sdict
=
infer_shape
(
z
)
sdict
=
infer_shape
(
z
)
assert
(
sdict
[
"y"
][
0
]
==
[
10
,
10
])
assert
(
sdict
[
"y"
][
0
]
==
[
10
,
10
])
assert
(
sdict
[
"y"
][
1
]
==
[
10
,
10
])
assert
(
sdict
[
"y"
][
1
]
==
[
10
,
10
])
z
=
sym
.
split
(
x1
,
indices_or_sections
=
[
6
],
axis
=-
1
,
name
=
"y"
)
sdict
=
infer_shape
(
z
)
assert
(
sdict
[
"y"
][
0
]
==
[
10
,
6
])
assert
(
sdict
[
"y"
][
1
]
==
[
10
,
14
])
def
test_batchnorm
():
def
test_batchnorm
():
...
...
topi/python/topi/transform.py
View file @
6514849f
...
@@ -4,7 +4,6 @@ from __future__ import absolute_import as _abs
...
@@ -4,7 +4,6 @@ from __future__ import absolute_import as _abs
import
tvm
import
tvm
import
topi
import
topi
from
.
import
tag
from
.
import
tag
from
.util
import
ravel_index
,
unravel_index
,
get_const_int
,
get_const_tuple
from
.
import
cpp
from
.
import
cpp
@tvm.tag_scope
(
tag
=
tag
.
BROADCAST
)
@tvm.tag_scope
(
tag
=
tag
.
BROADCAST
)
...
@@ -23,12 +22,7 @@ def expand_dims(a, axis, num_newaxis=1):
...
@@ -23,12 +22,7 @@ def expand_dims(a, axis, num_newaxis=1):
-------
-------
ret : tvm.Tensor
ret : tvm.Tensor
"""
"""
axis
=
len
(
a
.
shape
)
+
axis
+
1
if
axis
<
0
else
axis
return
cpp
.
expand_dims
(
a
,
axis
,
num_newaxis
)
new_shape
=
a
.
shape
[:
axis
]
+
([
1
]
*
num_newaxis
)
+
a
.
shape
[
axis
:]
def
_compute
(
*
indices
):
idx
=
indices
[:
axis
]
+
indices
[
axis
+
num_newaxis
:]
return
a
(
*
idx
)
return
tvm
.
compute
(
new_shape
,
_compute
)
@tvm.tag_scope
(
tag
=
tag
.
BROADCAST
)
@tvm.tag_scope
(
tag
=
tag
.
BROADCAST
)
...
@@ -101,15 +95,8 @@ def transpose(a, axes=None):
...
@@ -101,15 +95,8 @@ def transpose(a, axes=None):
-------
-------
ret : tvm.Tensor
ret : tvm.Tensor
"""
"""
ndim
=
len
(
a
.
shape
)
return
cpp
.
transpose
(
a
,
axes
)
axes
=
axes
if
axes
else
tuple
(
reversed
(
range
(
ndim
)))
new_shape
=
[
a
.
shape
[
x
]
for
x
in
axes
]
def
_compute
(
*
indices
):
idx
=
[
1
]
*
len
(
axes
)
for
i
,
k
in
enumerate
(
axes
):
idx
[
k
]
=
indices
[
i
]
return
a
(
*
idx
)
return
tvm
.
compute
(
new_shape
,
_compute
)
def
flip
(
a
,
axis
=
0
):
def
flip
(
a
,
axis
=
0
):
"""Flip/reverse elements of an array in a particular axis.
"""Flip/reverse elements of an array in a particular axis.
...
@@ -153,6 +140,7 @@ def strided_slice(a, begin, end, strides=None):
...
@@ -153,6 +140,7 @@ def strided_slice(a, begin, end, strides=None):
"""
"""
return
cpp
.
strided_slice
(
a
,
begin
,
end
,
strides
)
return
cpp
.
strided_slice
(
a
,
begin
,
end
,
strides
)
@tvm.tag_scope
(
tag
=
tag
.
INJECTIVE
)
@tvm.tag_scope
(
tag
=
tag
.
INJECTIVE
)
def
reshape
(
a
,
newshape
):
def
reshape
(
a
,
newshape
):
"""Reshape the array
"""Reshape the array
...
@@ -168,10 +156,7 @@ def reshape(a, newshape):
...
@@ -168,10 +156,7 @@ def reshape(a, newshape):
-------
-------
ret : tvm.Tensor
ret : tvm.Tensor
"""
"""
ndim
=
len
(
a
.
shape
)
return
cpp
.
reshape
(
a
,
newshape
)
a_shape
=
[
a
.
shape
[
i
]
for
i
in
range
(
ndim
)]
return
tvm
.
compute
(
newshape
,
lambda
*
indices
:
a
(
*
unravel_index
(
ravel_index
(
indices
,
newshape
),
a_shape
)))
@tvm.tag_scope
(
tag
=
tag
.
INJECTIVE
)
@tvm.tag_scope
(
tag
=
tag
.
INJECTIVE
)
...
@@ -190,41 +175,7 @@ def squeeze(a, axis=None):
...
@@ -190,41 +175,7 @@ def squeeze(a, axis=None):
-------
-------
squeezed : tvm.Tensor
squeezed : tvm.Tensor
"""
"""
a_ndim
=
len
(
a
.
shape
)
return
cpp
.
squeeze
(
a
,
axis
)
a_shape
=
get_const_tuple
(
a
.
shape
)
if
axis
is
None
:
axis
=
[]
for
i
,
ele
in
enumerate
(
a_shape
):
if
ele
==
1
:
axis
.
append
(
i
)
else
:
if
isinstance
(
axis
,
int
):
axis
=
axis
+
a_ndim
if
axis
<
0
else
axis
assert
a_shape
[
axis
]
==
1
axis
=
[
axis
]
else
:
axis
=
[
ele
+
a_ndim
if
ele
<
0
else
ele
for
ele
in
axis
]
for
ele
in
axis
:
assert
a_shape
[
ele
]
==
1
out_shape
=
[]
search_axis
=
set
(
axis
)
for
i
,
a_dim
in
enumerate
(
a_shape
):
if
i
not
in
search_axis
:
out_shape
.
append
(
a_dim
)
if
not
out_shape
:
out_shape
.
append
(
1
)
def
_compute
(
*
indices
):
real_indices
=
[]
flag
=
0
for
i
in
range
(
a_ndim
):
if
i
not
in
search_axis
:
real_indices
.
append
(
indices
[
i
-
flag
])
else
:
real_indices
.
append
(
0
)
flag
+=
1
return
a
(
*
real_indices
)
return
tvm
.
compute
(
out_shape
,
_compute
)
@tvm.tag_scope
(
tag
=
tag
.
INJECTIVE
)
@tvm.tag_scope
(
tag
=
tag
.
INJECTIVE
)
...
@@ -243,25 +194,7 @@ def concatenate(a_tuple, axis=0):
...
@@ -243,25 +194,7 @@ def concatenate(a_tuple, axis=0):
-------
-------
ret : tvm.Tensor
ret : tvm.Tensor
"""
"""
assert
isinstance
(
a_tuple
,
(
list
,
tuple
))
return
cpp
.
concatenate
(
a_tuple
,
axis
)
if
axis
<
0
:
axis
+=
len
(
a_tuple
[
0
]
.
shape
)
assert
axis
<
len
(
a_tuple
[
0
]
.
shape
)
axis_sizes
=
[
a_tuple
[
i
]
.
shape
[
axis
]
for
i
in
range
(
len
(
a_tuple
))]
out_shape
=
[
a_tuple
[
0
]
.
shape
[
i
]
for
i
in
range
(
0
,
axis
)]
+
[
sum
(
axis_sizes
)]
\
+
[
a_tuple
[
0
]
.
shape
[
i
]
for
i
in
range
(
axis
+
1
,
len
(
a_tuple
[
0
]
.
shape
))]
out_shape
[
axis
]
=
tvm
.
ir_pass
.
Simplify
(
out_shape
[
axis
])
def
_compute
(
*
indices
):
ret
=
a_tuple
[
0
](
*
indices
)
ind
=
indices
[
axis
]
for
i
in
range
(
len
(
a_tuple
)
-
1
):
ind
-=
axis_sizes
[
i
]
ret
=
tvm
.
select
(
ind
>=
0
,
a_tuple
[
i
+
1
](
*
(
indices
[
0
:
axis
]
+
(
ind
,)
+
indices
[
axis
+
1
:])),
ret
)
return
ret
return
tvm
.
compute
(
out_shape
,
_compute
)
@tvm.tag_scope
(
tag
=
tag
.
INJECTIVE
)
@tvm.tag_scope
(
tag
=
tag
.
INJECTIVE
)
...
@@ -280,37 +213,7 @@ def split(ary, indices_or_sections, axis=0):
...
@@ -280,37 +213,7 @@ def split(ary, indices_or_sections, axis=0):
-------
-------
ret : tuple of tvm.Tensor
ret : tuple of tvm.Tensor
"""
"""
def
_compute
(
begin
,
*
indices
):
return
cpp
.
split
(
ary
,
indices_or_sections
,
axis
)
real_indices
=
indices
[:
axis
]
+
(
indices
[
axis
]
+
begin
,
)
+
indices
[
axis
+
1
:]
return
ary
(
*
real_indices
)
if
axis
<
0
:
axis
+=
len
(
ary
.
shape
)
src_axis_size
=
get_const_int
(
ary
.
shape
[
axis
])
if
isinstance
(
indices_or_sections
,
int
):
assert
indices_or_sections
>
0
assert
src_axis_size
%
indices_or_sections
==
0
seg_size
=
src_axis_size
//
indices_or_sections
begin_ids
=
[
seg_size
*
i
for
i
in
range
(
indices_or_sections
)]
elif
isinstance
(
indices_or_sections
,
(
tuple
,
list
)):
assert
tuple
(
indices_or_sections
)
==
tuple
(
sorted
(
indices_or_sections
)),
\
"Should be sorted, recieved
%
s"
%
str
(
indices_or_sections
)
begin_ids
=
[
0
]
+
list
(
indices_or_sections
)
else
:
raise
NotImplementedError
()
out_shapes
=
[]
for
i
in
range
(
len
(
begin_ids
)):
if
i
==
len
(
begin_ids
)
-
1
:
out_axis_size
=
src_axis_size
-
begin_ids
[
i
]
else
:
out_axis_size
=
begin_ids
[
i
+
1
]
-
begin_ids
[
i
]
out_shapes
.
append
([
ary
.
shape
[
i
]
for
i
in
range
(
axis
)]
+
[
out_axis_size
]
+
\
[
ary
.
shape
[
i
]
for
i
in
range
(
axis
+
1
,
len
(
ary
.
shape
))])
# pylint: disable=cell-var-from-loop
return
[
tvm
.
compute
(
out_shape
,
lambda
*
indices
:
_compute
(
begin_id
,
*
indices
),
name
=
"s
%
d"
%
i
)
for
i
,
(
out_shape
,
begin_id
)
in
enumerate
(
zip
(
out_shapes
,
begin_ids
))]
# pylint: enable=cell-var-from-loop
def
take
(
a
,
indices
,
axis
=
None
):
def
take
(
a
,
indices
,
axis
=
None
):
...
@@ -336,6 +239,7 @@ def take(a, indices, axis=None):
...
@@ -336,6 +239,7 @@ def take(a, indices, axis=None):
return
cpp
.
take
(
a
,
indices
)
return
cpp
.
take
(
a
,
indices
)
return
cpp
.
take
(
a
,
indices
,
int
(
axis
))
return
cpp
.
take
(
a
,
indices
,
int
(
axis
))
def
matmul
(
a
,
b
,
transp_a
=
False
,
transp_b
=
False
):
def
matmul
(
a
,
b
,
transp_a
=
False
,
transp_b
=
False
):
"""
"""
Creates an operation that calculates a matrix multiplication (row-major notation):
Creates an operation that calculates a matrix multiplication (row-major notation):
...
...
topi/tests/python/test_topi_transform.py
View file @
6514849f
...
@@ -139,7 +139,7 @@ def verify_split(src_shape, indices_or_sections, axis):
...
@@ -139,7 +139,7 @@ def verify_split(src_shape, indices_or_sections, axis):
with
tvm
.
target
.
create
(
device
):
with
tvm
.
target
.
create
(
device
):
s
=
topi
.
generic
.
schedule_injective
(
tensor_l
)
s
=
topi
.
generic
.
schedule_injective
(
tensor_l
)
foo
=
tvm
.
build
(
s
,
[
A
]
+
tensor_l
,
device
,
name
=
"split"
)
foo
=
tvm
.
build
(
s
,
[
A
]
+
list
(
tensor_l
)
,
device
,
name
=
"split"
)
data_npy
=
np
.
random
.
normal
(
size
=
src_shape
)
.
astype
(
A
.
dtype
)
data_npy
=
np
.
random
.
normal
(
size
=
src_shape
)
.
astype
(
A
.
dtype
)
out_npys
=
np
.
split
(
data_npy
,
indices_or_sections
,
axis
=
axis
)
out_npys
=
np
.
split
(
data_npy
,
indices_or_sections
,
axis
=
axis
)
data_nd
=
tvm
.
nd
.
array
(
data_npy
,
ctx
)
data_nd
=
tvm
.
nd
.
array
(
data_npy
,
ctx
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment