Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
40ac2064
Commit
40ac2064
authored
Jul 04, 2018
by
Dayananda V
Committed by
Tianqi Chen
Jul 04, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add take frontend (#1307)
parent
4503f77b
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
169 additions
and
0 deletions
+169
-0
nnvm/include/nnvm/top/tensor.h
+10
-0
nnvm/python/nnvm/top/transform.py
+4
-0
nnvm/src/top/tensor/transform.cc
+120
-0
nnvm/tests/python/compiler/test_top_level1.py
+35
-0
No files found.
nnvm/include/nnvm/top/tensor.h
View file @
40ac2064
...
...
@@ -48,6 +48,16 @@ struct SplitParam : public dmlc::Parameter<SplitParam> {
}
};
struct
TakeParam
:
public
dmlc
::
Parameter
<
TakeParam
>
{
dmlc
::
optional
<
int
>
axis
;
DMLC_DECLARE_PARAMETER
(
TakeParam
)
{
DMLC_DECLARE_FIELD
(
axis
).
set_default
(
dmlc
::
optional
<
int
>
())
.
describe
(
"the axis over which to select values."
);
}
};
struct
StridedSliceParam
:
public
dmlc
::
Parameter
<
StridedSliceParam
>
{
// numpy convention, only support indices, not support list.
Tuple
<
int64_t
>
begin
;
...
...
nnvm/python/nnvm/top/transform.py
View file @
40ac2064
...
...
@@ -61,6 +61,10 @@ reg.register_schedule("concatenate", _fschedule_injective)
reg
.
register_pattern
(
"split"
,
OpPattern
.
INJECTIVE
)
reg
.
register_schedule
(
"split"
,
_fschedule_injective
)
# take
reg
.
register_pattern
(
"take"
,
OpPattern
.
INJECTIVE
)
reg
.
register_schedule
(
"take"
,
_fschedule_injective
)
# strided_slice
reg
.
register_pattern
(
"strided_slice"
,
OpPattern
.
INJECTIVE
)
reg
.
register_schedule
(
"strided_slice"
,
_fschedule_injective
)
...
...
nnvm/src/top/tensor/transform.cc
View file @
40ac2064
...
...
@@ -1001,6 +1001,126 @@ Examples::
return
Array
<
Tensor
>
{
topi
::
flip
(
inputs
[
0
],
param
.
axis
)
};
});
// take
DMLC_REGISTER_PARAMETER
(
TakeParam
);
inline
bool
TakeInferShape
(
const
NodeAttrs
&
attrs
,
std
::
vector
<
TShape
>*
in_shape
,
std
::
vector
<
TShape
>*
out_shape
)
{
CHECK_EQ
(
in_shape
->
size
(),
2U
);
CHECK_EQ
(
out_shape
->
size
(),
1U
);
const
TShape
&
dshape
=
(
*
in_shape
)[
0
];
const
TShape
&
indicesshape
=
(
*
in_shape
)[
1
];
if
(
dshape
.
ndim
()
==
0
)
return
false
;
if
(
indicesshape
.
ndim
()
==
0
)
return
false
;
const
TakeParam
&
param
=
nnvm
::
get
<
TakeParam
>
(
attrs
.
parsed
);
TShape
oshape
((
!
param
.
axis
?
0
:
dshape
.
ndim
()
-
1
)
+
indicesshape
.
ndim
());
if
(
!
param
.
axis
)
{
for
(
size_t
j
=
0
;
j
<
indicesshape
.
ndim
();
++
j
)
{
oshape
[
j
]
=
indicesshape
[
j
];
}
}
else
{
int
axis
=
param
.
axis
.
value
();
if
(
axis
<
0
)
{
axis
+=
dshape
.
ndim
();
}
CHECK_LT
(
axis
,
dshape
.
ndim
());
size_t
posi
=
0
;
for
(
size_t
i
=
0
;
i
<
dshape
.
ndim
();
++
i
)
{
if
(
static_cast
<
int
>
(
i
)
==
axis
)
{
for
(
size_t
j
=
0
;
j
<
indicesshape
.
ndim
();
++
j
)
{
oshape
[
posi
++
]
=
indicesshape
[
j
];
}
}
else
{
oshape
[
posi
++
]
=
dshape
[
i
];
}
}
}
NNVM_ASSIGN_INPUT_SHAPE
(
attrs
,
*
in_shape
,
0
,
dshape
);
NNVM_ASSIGN_INPUT_SHAPE
(
attrs
,
*
in_shape
,
1
,
indicesshape
);
NNVM_ASSIGN_OUTPUT_SHAPE
(
attrs
,
*
out_shape
,
0
,
oshape
);
return
dshape
.
Size
()
!=
0
;
}
inline
bool
TakeInferType
(
const
NodeAttrs
&
attrs
,
std
::
vector
<
int
>*
in_attrs
,
std
::
vector
<
int
>*
out_attrs
)
{
CHECK_EQ
(
in_attrs
->
size
(),
2U
);
CHECK_EQ
(
out_attrs
->
size
(),
1U
);
CHECK_EQ
((
*
in_attrs
)[
1
],
kInt32
);
NNVM_ASSIGN_INPUT_TYPE
(
attrs
,
*
in_attrs
,
0
,
(
*
in_attrs
)[
0
]);
NNVM_ASSIGN_INPUT_TYPE
(
attrs
,
*
in_attrs
,
1
,
static_cast
<
int
>
(
kInt32
));
NNVM_ASSIGN_OUTPUT_TYPE
(
attrs
,
*
out_attrs
,
0
,
(
*
in_attrs
)[
0
]);
return
true
;
}
inline
bool
TakeCorrectLayout
(
const
NodeAttrs
&
attrs
,
std
::
vector
<
Layout
>
*
ilayouts
,
const
std
::
vector
<
Layout
>
*
last_ilayouts
,
std
::
vector
<
Layout
>
*
olayouts
)
{
CHECK_EQ
(
ilayouts
->
size
(),
last_ilayouts
->
size
());
CHECK_EQ
(
olayouts
->
size
(),
1U
);
for
(
size_t
i
=
0
;
i
<
ilayouts
->
size
();
++
i
)
{
const
Layout
&
input
=
last_ilayouts
->
at
(
i
).
defined
()
?
last_ilayouts
->
at
(
i
)
:
ilayouts
->
at
(
i
);
NNVM_ASSIGN_LAYOUT
(
*
ilayouts
,
i
,
input
);
}
return
true
;
}
NNVM_REGISTER_OP
(
take
)
.
describe
(
R"code(Take elements from an array along an axis.
When axis is not None, this function does the same thing as 'fancy' indexing
(indexing arrays using arrays); however, it can be easier to use if you need
elements along a given axis.
**Note** that when axis is none the flattened input array is used.
Examples::
a = [[ 1, 2],
[ 3, 4]]
indices = [3, 0, 2]
take(a, indices) = [ 4, 1, 3]
a = [[ 1., 2.],
[ 3., 4.]]
indices = [1, 0]
take(a, indices, axis=1) = [[ 2., 1.],
[ 4., 3.]]
)code"
NNVM_ADD_FILELINE
)
.
add_argument
(
"data"
,
"Tensor"
,
"Array to be indexed"
)
.
add_argument
(
"indices"
,
"Tensor"
,
"The indices of the values to extract"
)
.
add_arguments
(
TakeParam
::
__FIELDS__
())
.
set_attr_parser
(
ParamParser
<
TakeParam
>
)
.
set_attr
<
FInferShape
>
(
"FInferShape"
,
TakeInferShape
)
.
set_attr
<
FInferType
>
(
"FInferType"
,
TakeInferType
)
.
set_attr
<
FCorrectLayout
>
(
"FCorrectLayout"
,
TakeCorrectLayout
)
.
set_num_inputs
(
2
)
.
set_num_outputs
(
1
)
.
set_support_level
(
1
)
.
set_attr
<
FTVMCompute
>
(
"FTVMCompute"
,
[](
const
NodeAttrs
&
attrs
,
const
Array
<
Tensor
>&
inputs
,
const
Array
<
Tensor
>&
out_info
)
{
const
TakeParam
&
param
=
nnvm
::
get
<
TakeParam
>
(
attrs
.
parsed
);
if
(
!
param
.
axis
)
{
return
Array
<
Tensor
>
{
topi
::
take
(
inputs
[
0
],
inputs
[
1
])
};
}
else
{
return
Array
<
Tensor
>
{
topi
::
take
(
inputs
[
0
],
inputs
[
1
],
param
.
axis
.
value
())
};
}
});
// SliceLike
DMLC_REGISTER_PARAMETER
(
SliceLikeParam
);
...
...
nnvm/tests/python/compiler/test_top_level1.py
View file @
40ac2064
...
...
@@ -365,6 +365,40 @@ def test_strided_slice():
verify_strided_slice
((
3
,
4
,
3
),
[
1
,
1
,
0
],
[
4
,
4
])
verify_strided_slice
((
3
,
4
,
3
),
[
1
,
1
],
[
4
,
4
,
3
])
def
verify_take
(
src_shape
,
indices_src
,
axis
=
None
):
src_dtype
=
"float32"
indices_dtype
=
"int32"
indices_src
=
np
.
array
(
indices_src
,
dtype
=
indices_dtype
)
a
=
sym
.
Variable
(
"a"
)
indices
=
sym
.
Variable
(
"indices"
)
y
=
sym
.
take
(
a
,
indices
,
axis
=
axis
)
for
target
,
ctx
in
ctx_list
():
# set input
shape_dict
=
{
"a"
:
src_shape
,
"indices"
:
indices_src
.
shape
}
type_dict
=
{
"a"
:
src_dtype
,
"indices"
:
indices_dtype
}
graph
,
lib
,
_
=
nnvm
.
compiler
.
build
(
y
,
target
,
shape
=
shape_dict
,
dtype
=
type_dict
)
m
=
graph_runtime
.
create
(
graph
,
lib
,
ctx
)
shape_size
=
1
for
i
in
range
(
len
(
src_shape
)):
shape_size
=
shape_size
*
src_shape
[
i
]
a_src
=
np
.
arange
(
shape_size
,
dtype
=
src_dtype
)
.
reshape
((
src_shape
))
out_np
=
np
.
take
(
a_src
,
indices_src
,
axis
=
axis
)
m
.
run
(
a
=
a_src
,
indices
=
indices_src
)
out
=
m
.
get_output
(
0
,
tvm
.
nd
.
empty
(
out_np
.
shape
,
dtype
=
src_dtype
))
np
.
testing
.
assert_allclose
(
out
.
asnumpy
(),
out_np
,
atol
=
1e-5
,
rtol
=
1e-5
)
def
test_take
():
verify_take
((
4
,),
[
1
])
verify_take
((
4
,),
[[
0
,
1
,
2
,
3
]])
verify_take
((
3
,
3
,
3
),
[[
11
,
25
]])
verify_take
((
4
,),
[[
0
,
1
],[
2
,
3
]])
verify_take
((
4
,),
[
1
],
0
)
verify_take
((
2
,
2
),
[[[
1
,
0
],[
0
,
1
]]],
0
)
verify_take
((
2
,
2
),
[[[
1
,
0
],[
0
,
1
]]],
1
)
verify_take
((
4
,
3
,
5
,
6
),
[[
2
,
1
,
0
,
0
]],
-
2
)
def
verify_squeeze
(
dshape
,
axis
):
x
=
sym
.
Variable
(
"x"
)
if
axis
:
...
...
@@ -481,6 +515,7 @@ if __name__ == "__main__":
test_softmax
()
test_squeeze
()
test_pad
()
test_take
()
test_lrn
()
test_l2_normalize
()
test_strided_slice
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment