Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
add1f90e
Commit
add1f90e
authored
Oct 31, 2018
by
Haichen Shen
Committed by
Tianqi Chen
Oct 31, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[NNVM/TOPI][OP] gather_nd (#2041)
parent
2005f852
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
333 additions
and
8 deletions
+333
-8
docs/api/python/topi.rst
+2
-0
docs/nnvm_top.rst
+2
-0
nnvm/python/nnvm/frontend/mxnet.py
+1
-1
nnvm/python/nnvm/top/transform.py
+4
-0
nnvm/src/top/tensor/transform.cc
+109
-7
nnvm/tests/python/compiler/test_top_level1.py
+31
-0
nnvm/tests/python/unittest/test_infer_shape.py
+21
-0
topi/include/topi/transform.h
+54
-0
topi/python/topi/testing/__init__.py
+1
-0
topi/python/topi/testing/gather_nd_python.py
+36
-0
topi/python/topi/transform.py
+18
-0
topi/src/topi.cc
+5
-0
topi/tests/python/test_topi_transform.py
+49
-0
No files found.
docs/api/python/topi.rst
View file @
add1f90e
...
...
@@ -30,6 +30,7 @@ List of operators
topi.concatenate
topi.split
topi.take
topi.gather_nd
topi.full
topi.full_like
topi.nn.relu
...
...
@@ -103,6 +104,7 @@ topi
.. autofunction:: topi.concatenate
.. autofunction:: topi.split
.. autofunction:: topi.take
.. autofunction:: topi.gather_nd
.. autofunction:: topi.full
.. autofunction:: topi.full_like
.. autofunction:: topi.max
...
...
docs/nnvm_top.rst
View file @
add1f90e
...
...
@@ -61,6 +61,7 @@ This level enables fully connected multi-layer perceptron.
nnvm.symbol.flip
nnvm.symbol.lrn
nnvm.symbol.where
nnvm.symbol.gather_nd
**Level 2: Convolutions**
...
...
@@ -197,6 +198,7 @@ Detailed Definitions
.. autofunction:: nnvm.symbol.flip
.. autofunction:: nnvm.symbol.lrn
.. autofunction:: nnvm.symbol.where
.. autofunction:: nnvm.symbol.gather_nd
.. autofunction:: nnvm.symbol.conv2d
.. autofunction:: nnvm.symbol.conv2d_transpose
...
...
nnvm/python/nnvm/frontend/mxnet.py
View file @
add1f90e
...
...
@@ -290,7 +290,7 @@ _identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__',
'elemwise_div'
,
'elemwise_mul'
,
'elemwise_sub'
,
'exp'
,
'flatten'
,
'log'
,
'log_softmax'
,
'max'
,
'min'
,
'negative'
,
'ones_like'
,
'relu'
,
'sigmoid'
,
'slice_like'
,
'softmax'
,
'sum'
,
'tanh'
,
'transpose'
,
'zeros_like'
]
'sum'
,
'tanh'
,
'transpose'
,
'zeros_like'
,
'gather_nd'
]
_convert_map
=
{
'_copy'
:
_rename
(
'copy'
),
...
...
nnvm/python/nnvm/top/transform.py
View file @
add1f90e
...
...
@@ -86,3 +86,7 @@ reg.register_schedule("slice_like", _fschedule_injective)
# where
reg
.
register_pattern
(
"where"
,
OpPattern
.
INJECTIVE
)
reg
.
register_schedule
(
"where"
,
_fschedule_injective
)
# gather_nd
reg
.
register_pattern
(
"gather_nd"
,
OpPattern
.
INJECTIVE
)
reg
.
register_schedule
(
"gather_nd"
,
_fschedule_injective
)
nnvm/src/top/tensor/transform.cc
View file @
add1f90e
...
...
@@ -1003,7 +1003,7 @@ Examples::
[ 3, 4]]
flip(x) = [[ 3., 4.],
[ 1., 2.]]
[ 1., 2.]]
x = [[[ 1., 2.],
[ 3., 4.]],
...
...
@@ -1012,16 +1012,16 @@ Examples::
[ 7., 8.]]]
flip(x) = [[[ 5., 6.],
[ 7., 8.]],
[ 7., 8.]],
[[ 1., 2.],
[ 3., 4.]]]
[[ 1., 2.],
[ 3., 4.]]]
flip(x, axis=1) = [[[ 3., 4.],
[ 1., 2.]],
[ 1., 2.]],
[[ 7., 8.],
[ 5., 6.]]]
[[ 7., 8.],
[ 5., 6.]]]
)code"
NNVM_ADD_FILELINE
)
.
add_argument
(
"data"
,
"Tensor"
,
"Source input"
)
.
add_arguments
(
FlipParam
::
__FIELDS__
())
...
...
@@ -1353,5 +1353,107 @@ Examples::
})
.
set_support_level
(
4
);
// gather_nd
inline
bool
GatherNDInferShape
(
const
nnvm
::
NodeAttrs
&
attrs
,
std
::
vector
<
TShape
>*
in_attrs
,
std
::
vector
<
TShape
>*
out_attrs
)
{
CHECK_EQ
(
in_attrs
->
size
(),
2U
);
CHECK_EQ
(
out_attrs
->
size
(),
1U
);
const
TShape
&
data_shape
=
in_attrs
->
at
(
0
);
const
TShape
&
indices_shape
=
in_attrs
->
at
(
1
);
CHECK_GT
(
indices_shape
.
ndim
(),
1
)
<<
"indices must have at least 2 dimensions"
;
CHECK_LE
(
indices_shape
[
0
],
data_shape
.
ndim
())
<<
"dim 0 of indices must be no more than rank of data"
;
std
::
vector
<
dim_t
>
oshape
;
for
(
size_t
i
=
1
;
i
<
indices_shape
.
ndim
();
++
i
)
{
oshape
.
push_back
(
indices_shape
[
i
]);
}
for
(
size_t
i
=
indices_shape
[
0
];
i
<
data_shape
.
ndim
();
++
i
)
{
oshape
.
push_back
(
data_shape
[
i
]);
}
if
(
oshape
.
size
()
==
0
)
{
oshape
.
push_back
(
1
);
}
NNVM_ASSIGN_OUTPUT_SHAPE
(
attrs
,
*
out_attrs
,
0
,
TShape
(
oshape
.
begin
(),
oshape
.
end
()));
return
true
;
}
inline
bool
GatherNDInferType
(
const
NodeAttrs
&
attrs
,
std
::
vector
<
int
>
*
in_attrs
,
std
::
vector
<
int
>
*
out_attrs
)
{
CHECK_EQ
(
in_attrs
->
size
(),
2U
);
CHECK_EQ
(
out_attrs
->
size
(),
1U
);
NNVM_ASSIGN_OUTPUT_TYPE
(
attrs
,
*
out_attrs
,
0
,
(
*
in_attrs
)[
0
]);
return
true
;
}
inline
bool
GatherNDCorrectLayout
(
const
NodeAttrs
&
attrs
,
std
::
vector
<
Layout
>
*
ilayouts
,
const
std
::
vector
<
Layout
>
*
last_ilayouts
,
std
::
vector
<
Layout
>
*
olayouts
)
{
CHECK_EQ
(
ilayouts
->
size
(),
last_ilayouts
->
size
());
CHECK_EQ
(
olayouts
->
size
(),
1U
);
for
(
size_t
i
=
0
;
i
<
ilayouts
->
size
();
++
i
)
{
const
Layout
&
input
=
last_ilayouts
->
at
(
i
).
defined
()
?
last_ilayouts
->
at
(
i
)
:
ilayouts
->
at
(
i
);
NNVM_ASSIGN_LAYOUT
(
*
ilayouts
,
i
,
input
);
}
return
true
;
}
NNVM_REGISTER_OP
(
gather_nd
)
.
describe
(
R"code(
Gather elements or slices from ``data`` into a tensor specified by ``indices``.
The shape of output tensor is inferred from ``indices``. Given ``data`` with
shape ``(X0, X1, ..., X_{N-1})`` and ``indices`` with shape ``(Y_0, ...,
Y_{M-1})``, the output will have shape ``(Y_1, ..., Y_{M-1}, X_{Y_0}, ...,
X_{N-1})`` when ``Y_0 < N``, or ``(Y_1, ..., Y_{M-1})`` when ``Y_0 == N``. The
operator is invalid when ``Y_0 > N``.
The element in output is defined as follows::
output[y_1, ..., y_{M-1}, x_{Y_0}, ..., x_{N-1}] = data[indices[0, y_1, ..., y_{M-1}],
...,
indices[Y_0-1, y_1, ..., y_{M-1}],
x_{Y_0}, ..., x_{N-1}]
Examples::
data = [[0, 1], [2, 3]]
indices = [[1], [0]]
gather_nd(data, indices) = [2]
data = [[0, 1], [2, 3]]
indices = [[1, 1, 0], [0, 1, 0]]
gather_nd(data, indices) = [2, 3, 0]
data = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]
indices = [[0, 1], [1, 0]]
gather_nd(data, indices) = [[3, 4], [5, 6]]
)code"
NNVM_ADD_FILELINE
)
.
add_argument
(
"data"
,
"Tensor"
,
"Input data."
)
.
add_argument
(
"indices"
,
"Tensor"
,
"Indices of data"
)
.
set_num_inputs
(
2
)
.
set_num_outputs
(
1
)
.
set_attr
<
FInferShape
>
(
"FInferShape"
,
GatherNDInferShape
)
.
set_attr
<
FInferType
>
(
"FInferType"
,
GatherNDInferType
)
.
set_attr
<
FCorrectLayout
>
(
"FCorrectLayout"
,
GatherNDCorrectLayout
)
.
set_attr
<
FTVMCompute
>
(
"FTVMCompute"
,
[](
const
NodeAttrs
&
attrs
,
const
Array
<
Tensor
>&
inputs
,
const
Array
<
Tensor
>&
out_info
)
{
return
Array
<
Tensor
>
{
topi
::
gather_nd
(
inputs
[
0
],
inputs
[
1
])
};
})
.
set_attr
<
FListInputNames
>
(
"FListInputNames"
,
[](
const
NodeAttrs
&
attrs
)
{
return
std
::
vector
<
std
::
string
>
{
"data"
,
"indices"
};
})
.
set_support_level
(
3
);
}
// namespace top
}
// namespace nnvm
nnvm/tests/python/compiler/test_top_level1.py
View file @
add1f90e
...
...
@@ -533,6 +533,36 @@ def test_l2_normalize():
verify_l2_normalize
((
1
,
3
,
20
,
20
),
0.001
,
(
1
,))
verify_l2_normalize
((
1
,
3
,
20
,
20
),
0.001
,
(
1
,
2
))
def
verify_gather_nd
(
src_shape
,
indices_src
):
src_dtype
=
"float32"
indices_dtype
=
"int32"
indices_src
=
np
.
array
(
indices_src
,
dtype
=
indices_dtype
)
a
=
sym
.
Variable
(
"a"
,
shape
=
src_shape
)
indices
=
sym
.
Variable
(
"indices"
,
shape
=
indices_src
.
shape
)
y
=
sym
.
gather_nd
(
a
,
indices
)
def
forward
(
a
,
indices
):
return
topi
.
testing
.
gather_nd_python
(
a
,
indices
)
a_src
=
np
.
arange
(
np
.
prod
(
src_shape
),
dtype
=
src_dtype
)
.
reshape
(
src_shape
)
check_function
(
y
,
forward
,
dtype
=
{
'a'
:
src_dtype
,
'indices'
:
indices_dtype
},
values
=
{
'a'
:
a_src
,
'indices'
:
indices_src
})
def
test_gather_nd
():
verify_gather_nd
((
4
,),
[[
1
]])
verify_gather_nd
((
4
,),
[[
1
,
3
,
2
]])
verify_gather_nd
((
2
,
3
),
[[
1
]])
verify_gather_nd
((
2
,
3
),
[[
1
],
[
0
]])
verify_gather_nd
((
2
,
3
),
[[
1
,
0
],
[
0
,
2
]])
verify_gather_nd
((
2
,
3
,
4
),
[[
1
,
0
],
[
0
,
2
]])
verify_gather_nd
((
2
,
3
,
4
),
[[
1
,
0
],
[
0
,
2
],
[
3
,
1
]])
verify_gather_nd
((
2
,
3
,
4
),
[[[
1
,
0
],
[
0
,
1
]],
[[
0
,
2
],
[
1
,
2
]],
[[
3
,
1
],
[
0
,
2
]]])
verify_gather_nd
((
2
,
3
,
4
,
5
),
[[
1
,
0
],
[
0
,
2
]])
verify_gather_nd
((
2
,
3
,
4
,
5
),
[[
1
,
0
],
[
2
,
1
],
[
3
,
2
],
[
4
,
2
]])
if
__name__
==
"__main__"
:
test_check_function
()
test_split
()
...
...
@@ -556,3 +586,4 @@ if __name__ == "__main__":
test_lrn
()
test_l2_normalize
()
test_strided_slice
()
test_gather_nd
()
nnvm/tests/python/unittest/test_infer_shape.py
View file @
add1f90e
...
...
@@ -356,6 +356,26 @@ def test_reduce():
check
((
4
,
5
,
10
),
(
1
,
5
,
1
),
axis
=
(
0
,
2
),
keepdims
=
True
)
def
test_gather_nd
():
def
check
(
data_shape
,
indices_shape
,
out_shape
):
x
=
sym
.
Variable
(
"x"
,
shape
=
data_shape
)
indices
=
sym
.
Variable
(
"indices"
,
shape
=
indices_shape
)
y
=
sym
.
gather_nd
(
x
,
indices
,
name
=
"y"
)
sdict
=
infer_shape
(
y
)
assert
(
tuple
(
sdict
[
"y"
][
0
])
==
tuple
(
out_shape
))
check
((
4
,),
(
1
,
1
),
(
1
,))
check
((
4
,),
(
1
,
3
),
(
3
,))
check
((
2
,
3
),
(
1
,
1
),
(
1
,
3
))
check
((
2
,
3
),
(
2
,
1
),
(
1
,))
check
((
2
,
3
),
(
2
,
5
,
6
),
(
5
,
6
))
check
((
2
,
3
,
4
),
(
1
,
1
),
(
1
,
3
,
4
))
check
((
2
,
3
,
4
),
(
2
,
1
),
(
1
,
4
))
check
((
2
,
3
,
4
),
(
2
,
5
),
(
5
,
4
))
check
((
2
,
3
,
4
),
(
2
,
5
,
6
),
(
5
,
6
,
4
))
check
((
2
,
3
,
4
,
5
),
(
2
,
6
,
7
),
(
6
,
7
,
4
,
5
))
if
__name__
==
"__main__"
:
test_conv2d_packed
()
test_expand_dims
()
...
...
@@ -376,3 +396,4 @@ if __name__ == "__main__":
test_transpose
()
test_prelu
()
test_squeeze
()
test_gather_nd
()
topi/include/topi/transform.h
View file @
add1f90e
...
...
@@ -640,6 +640,60 @@ inline Tensor where(const Tensor& condition,
}
/*!
* \brief Gather elements from a n-dimension array.
*
* \param data The source array.
* \param indices The indices of the values to extract.
* \param name The name of the operation.
* \param tag The tag to mark the operation.
*
* \return A Tensor whose op member is the gather_nd operation
*/
inline
Tensor
gather_nd
(
const
Tensor
&
data
,
const
Tensor
&
indices
,
std
::
string
name
=
"tensor"
,
std
::
string
tag
=
kInjective
)
{
size_t
ndim_d
=
data
->
shape
.
size
();
size_t
ndim_i
=
indices
->
shape
.
size
();
CHECK_GT
(
ndim_i
,
1
)
<<
"indices tensor must have at least 2 dimensions"
;
size_t
indices_dim0
=
static_cast
<
size_t
>
(
GetConstInt
(
indices
->
shape
[
0
]));
CHECK_LE
(
indices_dim0
,
ndim_d
)
<<
"dim 0 of indices tensor must be no more "
<<
"than dimensions of data tensor"
;
Array
<
Expr
>
out_shape
;
for
(
size_t
i
=
1
;
i
<
ndim_i
;
++
i
)
{
out_shape
.
push_back
(
indices
->
shape
[
i
]);
}
for
(
size_t
i
=
indices_dim0
;
i
<
ndim_d
;
++
i
)
{
out_shape
.
push_back
(
data
->
shape
[
i
]);
}
if
(
out_shape
.
size
()
==
0
)
{
out_shape
.
push_back
(
make_const
(
Int
(
32
),
1
));
}
return
compute
(
out_shape
,
[
&
](
const
Array
<
Var
>&
out_index
)
{
Array
<
Expr
>
indices_position
;
indices_position
.
push_back
(
0
);
for
(
size_t
i
=
0
;
i
<
ndim_i
-
1
;
++
i
)
{
indices_position
.
push_back
(
out_index
[
i
]);
}
Array
<
Expr
>
real_indices
;
for
(
size_t
i
=
0
;
i
<
indices_dim0
;
++
i
)
{
indices_position
.
Set
(
0
,
make_const
(
Int
(
32
),
i
));
if
(
indices
->
dtype
.
is_int
())
{
real_indices
.
push_back
(
indices
(
indices_position
));
}
else
{
real_indices
.
push_back
(
tvm
::
cast
(
tvm
::
Int
(
32
),
indices
(
indices_position
)));
}
}
for
(
size_t
i
=
ndim_i
-
1
;
i
<
out_index
.
size
();
++
i
)
{
real_indices
.
push_back
(
out_index
[
i
]);
}
return
data
(
real_indices
);
},
name
,
tag
);
}
/*!
* \brief Creates an operation that calculates a matrix multiplication
* (row-major notation):
* A(i, k) * B(k, j), if trans_a == trans_b
...
...
topi/python/topi/testing/__init__.py
View file @
add1f90e
...
...
@@ -18,3 +18,4 @@ from .region_python import region_python
from
.shortcut_python
import
shortcut_python
from
.lrn_python
import
lrn_python
from
.l2_normalize_python
import
l2_normalize_python
from
.gather_nd_python
import
gather_nd_python
topi/python/topi/testing/gather_nd_python.py
0 → 100644
View file @
add1f90e
# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals
"""gather_nd in python"""
import
numpy
as
np
def
gather_nd_python
(
a_np
,
indices_np
):
""" Python version of GatherND operator
Parameters
----------
a_np : numpy.ndarray
Numpy array
indices_np : numpy.ndarray
Numpy array
Returns
-------
b_np : numpy.ndarray
Numpy array
"""
a_shape
=
a_np
.
shape
indices_np
=
indices_np
.
astype
(
'int32'
)
indices_shape
=
indices_np
.
shape
assert
len
(
indices_shape
)
>
1
assert
indices_shape
[
0
]
<=
len
(
a_shape
)
b_shape
=
list
(
indices_shape
[
1
:])
for
i
in
range
(
indices_shape
[
0
],
len
(
a_shape
)):
b_shape
.
append
(
a_shape
[
i
])
b_np
=
np
.
zeros
(
b_shape
)
for
idx
in
np
.
ndindex
(
*
indices_shape
[
1
:]):
a_idx
=
[]
for
i
in
range
(
indices_shape
[
0
]):
indices_pos
=
tuple
([
i
]
+
list
(
idx
))
a_idx
.
append
(
indices_np
[
indices_pos
])
b_np
[
idx
]
=
a_np
[
tuple
(
a_idx
)]
return
b_np
topi/python/topi/transform.py
View file @
add1f90e
...
...
@@ -240,6 +240,24 @@ def take(a, indices, axis=None):
return
cpp
.
take
(
a
,
indices
,
int
(
axis
))
def
gather_nd
(
a
,
indices
):
"""Gather elements from a n-dimension array..
Parameters
----------
a : tvm.Tensor
The source array.
indices : tvm.Tensor
The indices of the values to extract.
Returns
-------
ret : tvm.Tensor
"""
return
cpp
.
gather_nd
(
a
,
indices
)
def
matmul
(
a
,
b
,
transp_a
=
False
,
transp_b
=
False
):
"""
Creates an operation that calculates a matrix multiplication (row-major notation):
...
...
topi/src/topi.cc
View file @
add1f90e
...
...
@@ -291,6 +291,11 @@ TVM_REGISTER_GLOBAL("topi.where")
*
rv
=
where
(
args
[
0
],
args
[
1
],
args
[
2
]);
});
TVM_REGISTER_GLOBAL
(
"topi.gather_nd"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
*
rv
=
gather_nd
(
args
[
0
],
args
[
1
]);
});
TVM_REGISTER_GLOBAL
(
"topi.matmul"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
switch
(
args
.
size
()
)
{
...
...
topi/tests/python/test_topi_transform.py
View file @
add1f90e
...
...
@@ -2,6 +2,7 @@
import
numpy
as
np
import
tvm
import
topi
import
topi.testing
from
common
import
get_all_backend
...
...
@@ -275,6 +276,38 @@ def verify_strided_slice(in_shape, begin, end, stride=None):
for
device
in
[
"llvm"
,
"opencl"
,
"sdaccel"
,
"aocl_sw_emu"
]:
check_device
(
device
)
def
verify_gather_nd
(
src_shape
,
indices_src
,
indices_dtype
):
src_dtype
=
"float32"
indices_src
=
np
.
array
(
indices_src
,
dtype
=
indices_dtype
)
A
=
tvm
.
placeholder
(
shape
=
src_shape
,
dtype
=
src_dtype
,
name
=
"A"
)
indices
=
tvm
.
placeholder
(
shape
=
indices_src
.
shape
,
dtype
=
indices_dtype
,
name
=
"indices"
)
out_tensor
=
topi
.
gather_nd
(
a
=
A
,
indices
=
indices
)
def
check_device
(
device
):
ctx
=
tvm
.
context
(
device
,
0
)
if
not
ctx
.
exist
:
print
(
"Skip because
%
s is not enabled"
%
device
)
return
print
(
"Running on target:
%
s"
%
device
)
with
tvm
.
target
.
create
(
device
):
s
=
topi
.
generic
.
schedule_injective
(
out_tensor
)
func
=
tvm
.
build
(
s
,
[
A
,
indices
,
out_tensor
]
,
device
,
name
=
"take"
)
shape_size
=
1
for
i
in
range
(
len
(
src_shape
)):
shape_size
=
shape_size
*
src_shape
[
i
]
data_npy
=
np
.
arange
(
shape_size
,
dtype
=
src_dtype
)
.
reshape
((
src_shape
))
out_npys
=
topi
.
testing
.
gather_nd_python
(
data_npy
,
indices_src
)
data_nd
=
tvm
.
nd
.
array
(
data_npy
,
ctx
)
indices_nd
=
tvm
.
nd
.
array
(
indices_src
,
ctx
)
out_nd
=
tvm
.
nd
.
empty
(
out_npys
.
shape
,
ctx
=
ctx
,
dtype
=
src_dtype
)
func
(
data_nd
,
indices_nd
,
out_nd
)
tvm
.
testing
.
assert_allclose
(
out_nd
.
asnumpy
(),
out_npys
)
for
device
in
get_all_backend
():
check_device
(
device
)
def
test_strided_slice
():
verify_strided_slice
((
3
,
4
,
3
),
[
0
,
0
,
0
],
[
4
,
-
5
,
4
],
[
1
,
-
1
,
2
])
verify_strided_slice
((
3
,
4
,
3
),
[
1
,
1
,
0
],
[
4
,
4
,
3
],
[
2
,
1
,
1
])
...
...
@@ -363,6 +396,21 @@ def test_take():
verify_take
((
2
,
2
),
[[[
1
,
0
],[
0
,
1
]]],
1
)
verify_take
((
4
,
3
,
5
,
6
),
[[
2
,
1
,
0
,
0
]],
-
2
)
def
test_gather_nd
():
for
indices_dtype
in
[
'int32'
,
'float32'
]:
verify_gather_nd
((
4
,),
[[
1.8
]],
indices_dtype
)
verify_gather_nd
((
4
,),
[[
1
,
3
,
2
]],
indices_dtype
)
verify_gather_nd
((
2
,
3
),
[[
1
]],
indices_dtype
)
verify_gather_nd
((
2
,
3
),
[[
1
],
[
0
]],
indices_dtype
)
verify_gather_nd
((
2
,
3
),
[[
1
,
0
],
[
0
,
2
]],
indices_dtype
)
verify_gather_nd
((
2
,
3
,
4
),
[[
1
,
0
],
[
0
,
2
]],
indices_dtype
)
verify_gather_nd
((
2
,
3
,
4
),
[[
1
,
0
],
[
0
,
2
],
[
3
,
1
]],
indices_dtype
)
verify_gather_nd
((
2
,
3
,
4
),
[[[
1
,
0
],
[
0
,
1
]],
[[
0
,
2
],
[
1
,
2
]],
[[
3
,
1
],
[
0
,
2
]]],
indices_dtype
)
verify_gather_nd
((
2
,
3
,
4
,
5
),
[[
1
,
0
],
[
0
,
2
]],
indices_dtype
)
verify_gather_nd
((
2
,
3
,
4
,
5
),
[[
1
,
0
],
[
2
,
1
],
[
3
,
2
],
[
4
,
2
]],
indices_dtype
)
if
__name__
==
"__main__"
:
test_concatenate
()
test_tranpose
()
...
...
@@ -374,3 +422,4 @@ if __name__ == "__main__":
test_expand_like
()
test_take
()
test_strided_slice
()
test_gather_nd
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment