Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
1f2c8156
Unverified
Commit
1f2c8156
authored
Nov 13, 2018
by
Tianqi Chen
Committed by
GitHub
Nov 13, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RELAY][OP] strided_slice (#2094)
parent
4369b7f6
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
371 additions
and
37 deletions
+371
-37
docs/langref/relay_op.rst
+2
-0
include/tvm/relay/attrs/transform.h
+15
-0
nnvm/src/top/tensor/transform.cc
+22
-8
python/tvm/_ffi/node_generic.py
+2
-0
python/tvm/relay/op/__init__.py
+1
-0
python/tvm/relay/op/_transform.py
+8
-0
python/tvm/relay/op/transform.py
+27
-0
src/api/api_lang.cc
+5
-1
src/relay/ir/text_printer.cc
+5
-1
src/relay/op/tensor/transform.cc
+168
-0
tests/python/relay/test_op_level4.py
+37
-1
topi/include/topi/transform.h
+38
-17
topi/python/topi/testing/__init__.py
+1
-0
topi/python/topi/testing/strided_slice_python.py
+32
-0
topi/tests/python/test_topi_transform.py
+8
-9
No files found.
docs/langref/relay_op.rst
View file @
1f2c8156
...
...
@@ -123,6 +123,7 @@ This level enables additional math and transform operators.
tvm.relay.min
tvm.relay.mean
tvm.relay.prod
tvm.relay.strided_slice
**Level 5: Vision/Image Operators**
...
...
@@ -227,6 +228,7 @@ Level 4 Definitions
.. autofunction:: tvm.relay.min
.. autofunction:: tvm.relay.mean
.. autofunction:: tvm.relay.prod
.. autofunction:: tvm.relay.strided_slice
...
...
include/tvm/relay/attrs/transform.h
View file @
1f2c8156
...
...
@@ -123,6 +123,21 @@ struct SplitAttrs : public tvm::AttrsNode<SplitAttrs> {
}
};
/*! \brief Attributes for StridedSlice operator */
struct
StridedSliceAttrs
:
public
tvm
::
AttrsNode
<
StridedSliceAttrs
>
{
Array
<
Integer
>
begin
;
Array
<
Integer
>
end
;
Array
<
Integer
>
strides
;
TVM_DECLARE_ATTRS
(
StridedSliceAttrs
,
"relay.attrs.StridedSliceAttrs"
)
{
TVM_ATTR_FIELD
(
begin
)
.
describe
(
"Indices for begin of slice, begin index is also inclusive"
);
TVM_ATTR_FIELD
(
end
)
.
describe
(
"Indices for end of slice, end index is also inclusive"
);
TVM_ATTR_FIELD
(
strides
).
set_default
(
Array
<
Integer
>
({}))
.
describe
(
"Stride values of the slice"
);
}
};
}
// namespace relay
}
// namespace tvm
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_
nnvm/src/top/tensor/transform.cc
View file @
1f2c8156
...
...
@@ -980,23 +980,25 @@ Examples::
const
Array
<
Tensor
>&
inputs
,
const
Array
<
Tensor
>&
out_info
)
{
const
StridedSliceParam
&
param
=
nnvm
::
get
<
StridedSliceParam
>
(
attrs
.
parsed
);
Array
<
Exp
r
>
begin
;
Array
<
Exp
r
>
end
;
Array
<
Exp
r
>
stride
;
Array
<
Intege
r
>
begin
;
Array
<
Intege
r
>
end
;
Array
<
Intege
r
>
stride
;
for
(
int64_t
i
:
param
.
begin
)
{
begin
.
push_back
(
tvm
::
make_const
(
tvm
::
Int
(
32
),
i
));
begin
.
push_back
(
static_cast
<
int
>
(
i
));
}
for
(
int64_t
i
:
param
.
end
)
{
end
.
push_back
(
tvm
::
make_const
(
tvm
::
Int
(
32
),
i
));
end
.
push_back
(
static_cast
<
int
>
(
i
));
}
for
(
int64_t
i
:
param
.
stride
)
{
stride
.
push_back
(
tvm
::
make_const
(
tvm
::
Int
(
32
),
i
));
stride
.
push_back
(
static_cast
<
int
>
(
i
));
}
return
Array
<
Tensor
>
{
topi
::
strided_slice
(
inputs
[
0
],
begin
,
end
,
stride
)
};
return
Array
<
Tensor
>
{
topi
::
strided_slice
(
inputs
[
0
],
begin
,
end
,
stride
)
};
})
.
set_support_level
(
1
);
...
...
@@ -1210,6 +1212,15 @@ inline bool SliceLikeShape(const nnvm::NodeAttrs& attrs,
return
true
;
}
// Adapter function to make int array.
Array
<
Integer
>
GetIntArray
(
Array
<
Expr
>
arr
)
{
for
(
size_t
i
=
0
;
i
<
arr
.
size
();
++
i
)
{
CHECK
(
!
arr
[
i
].
defined
()
||
arr
[
i
].
as
<
IntImm
>
())
<<
"Expect an int array"
;
}
return
Array
<
Integer
>
(
arr
.
node_
);
}
NNVM_REGISTER_OP
(
slice_like
)
.
describe
(
R"code(Slice the first input respect to the second input.
)code"
NNVM_ADD_FILELINE
)
...
...
@@ -1261,7 +1272,10 @@ NNVM_REGISTER_OP(slice_like)
}
}
return
Array
<
Tensor
>
{
topi
::
strided_slice
(
inputs
[
0
],
begin_idx
,
end_idx
,
strides
)
topi
::
strided_slice
(
inputs
[
0
],
GetIntArray
(
begin_idx
),
GetIntArray
(
end_idx
),
GetIntArray
(
strides
))
};
})
.
set_attr
<
FListInputNames
>
(
"FListInputNames"
,
[](
const
NodeAttrs
&
attrs
)
{
...
...
python/tvm/_ffi/node_generic.py
View file @
1f2c8156
...
...
@@ -56,6 +56,8 @@ def convert_to_node(value):
return
_api_internal
.
_Map
(
*
vlist
)
elif
isinstance
(
value
,
NodeGeneric
):
return
value
.
asnode
()
elif
value
is
None
:
return
None
else
:
raise
ValueError
(
"don't know how to convert type
%
s to node"
%
type
(
value
))
...
...
python/tvm/relay/op/__init__.py
View file @
1f2c8156
...
...
@@ -13,6 +13,7 @@ from . import vision
# operator registry
from
.
import
_tensor
from
.
import
_transform
from
..expr
import
Expr
from
..base
import
register_relay_node
...
...
python/tvm/relay/op/_transform.py
0 → 100644
View file @
1f2c8156
#pylint: disable=invalid-name, unused-argument
"""Backend compiler related feature registration"""
from
__future__
import
absolute_import
from
.
import
op
as
_reg
from
.op
import
schedule_injective
# strided_slice
_reg
.
register_schedule
(
"strided_slice"
,
schedule_injective
)
python/tvm/relay/op/transform.py
View file @
1f2c8156
...
...
@@ -334,3 +334,30 @@ def split(data, indices_or_sections, axis=0):
else
:
ret_size
=
len
(
indices_or_sections
)
+
1
return
TupleWrapper
(
_make
.
split
(
data
,
indices_or_sections
,
axis
),
ret_size
)
def
strided_slice
(
data
,
begin
,
end
,
strides
=
None
):
"""Strided slice of an array..
Parameters
----------
data : relay.Expr
The source array to be sliced.
begin: list of int
The indices to begin with in the slicing.
end: list of int
Indicies indicating end of the slice.
strides: list of int, optional
Specifies the stride values, it can be negative in that case,
the input tensor will be reversed in that particular axis.
Returns
-------
ret : relay.Expr
The computed result.
"""
strides
=
strides
or
[]
return
_make
.
strided_slice
(
data
,
list
(
begin
),
list
(
end
),
list
(
strides
))
src/api/api_lang.cc
View file @
1f2c8156
...
...
@@ -47,7 +47,11 @@ TVM_REGISTER_API("_Array")
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
std
::
vector
<
NodePtr
<
Node
>
>
data
;
for
(
int
i
=
0
;
i
<
args
.
size
();
++
i
)
{
data
.
push_back
(
args
[
i
].
node_sptr
());
if
(
args
[
i
].
type_code
()
!=
kNull
)
{
data
.
push_back
(
args
[
i
].
node_sptr
());
}
else
{
data
.
push_back
(
NodePtr
<
Node
>
(
nullptr
));
}
}
auto
node
=
make_node
<
ArrayNode
>
();
node
->
data
=
std
::
move
(
data
);
...
...
src/relay/ir/text_printer.cc
View file @
1f2c8156
...
...
@@ -403,7 +403,11 @@ class TextPrinter :
* \param os The output type.
*/
void
PrintAttr
(
const
NodeRef
&
value
,
std
::
ostream
&
os
)
{
// NOLINT(*)
this
->
VisitAttr
(
value
,
os
);
if
(
value
.
defined
())
{
this
->
VisitAttr
(
value
,
os
);
}
else
{
os
<<
"None"
;
}
}
//------------------------------------
// Overload of Attr printing functions
...
...
src/relay/op/tensor/transform.cc
View file @
1f2c8156
...
...
@@ -7,6 +7,7 @@
#include <tvm/relay/attrs/transform.h>
#include <tvm/ir_operator.h>
#include <tvm/ir.h>
#include <topi/transform.h>
#include <vector>
#include "../op_common.h"
...
...
@@ -890,6 +891,173 @@ RELAY_REGISTER_OP("broadcast_to_like")
.
set_support_level
(
10
)
.
add_type_rel
(
"BroadCastToLike"
,
BroadCastToLikeRel
);
// strided_slice
TVM_REGISTER_NODE_TYPE
(
StridedSliceAttrs
);
bool
StridedSliceRel
(
const
Array
<
Type
>&
types
,
int
num_inputs
,
const
Attrs
&
attrs
,
const
TypeReporter
&
reporter
)
{
CHECK_EQ
(
types
.
size
(),
2
);
const
auto
*
data
=
types
[
0
].
as
<
TensorTypeNode
>
();
if
(
data
==
nullptr
)
return
false
;
const
StridedSliceAttrs
*
param
=
attrs
.
as
<
StridedSliceAttrs
>
();
CHECK
(
param
!=
nullptr
);
auto
dshape
=
data
->
shape
;
auto
num_axis
=
dshape
.
size
();
std
::
vector
<
int64_t
>
stride_vec
;
for
(
Integer
i
:
param
->
strides
)
{
CHECK
(
i
.
defined
());
stride_vec
.
push_back
(
i
->
value
);
}
for
(
size_t
i
=
stride_vec
.
size
();
i
<
num_axis
;
++
i
)
{
stride_vec
.
push_back
(
1
);
}
const
int64_t
max_range
=
std
::
numeric_limits
<
int64_t
>::
max
();
std
::
vector
<
int64_t
>
begin_vec
;
for
(
size_t
i
=
0
;
i
<
param
->
begin
.
size
();
++
i
)
{
if
(
!
param
->
begin
[
i
].
defined
())
{
// value=None
begin_vec
.
push_back
(
stride_vec
[
i
]
>
0
?
0
:
max_range
);
}
else
{
begin_vec
.
push_back
(
param
->
begin
[
i
]
->
value
);
}
}
for
(
size_t
i
=
begin_vec
.
size
();
i
<
num_axis
;
++
i
)
{
begin_vec
.
push_back
(
stride_vec
[
i
]
>
0
?
0
:
max_range
);
}
std
::
vector
<
int64_t
>
end_vec
;
for
(
size_t
i
=
0
;
i
<
param
->
end
.
size
();
++
i
)
{
// allow end to be None
if
(
!
param
->
end
[
i
].
defined
())
{
end_vec
.
push_back
(
stride_vec
[
i
]
<
0
?
0
:
max_range
);
}
else
{
end_vec
.
push_back
(
param
->
end
[
i
]
->
value
);
}
}
for
(
size_t
i
=
end_vec
.
size
();
i
<
num_axis
;
++
i
)
{
end_vec
.
push_back
(
stride_vec
[
i
]
<
0
?
0
:
max_range
);
}
std
::
vector
<
IndexExpr
>
oshape
(
dshape
.
size
());
for
(
size_t
i
=
0
;
i
<
num_axis
;
++
i
)
{
int64_t
stride_v
=
stride_vec
[
i
];
int64_t
begin_v
=
begin_vec
[
i
];
int64_t
end_v
=
end_vec
[
i
];
if
((
stride_v
==
1
&&
begin_v
==
0
&&
end_v
==
max_range
)
||
(
stride_v
==
-
1
&&
begin_v
==
max_range
&&
end_v
==
0
))
{
// Quick path, do not slice this dimension.
oshape
[
i
]
=
dshape
[
i
];
continue
;
}
// Normal path, require the shape to be concrete integer.
// Require concrete integer as symbolic inference of min/max
// can get complicated and not very helpful.
const
int64_t
*
p_dim_size
=
as_const_int
(
dshape
[
i
]);
CHECK
(
p_dim_size
)
<<
"strided_slice requires sliced dimension to be concrete int"
;
int64_t
dim_size
=
p_dim_size
[
0
];
begin_v
=
(
begin_v
<
0
)
?
dim_size
+
begin_v
:
begin_v
;
end_v
=
(
end_v
<
0
)
?
dim_size
+
end_v
:
end_v
;
int64_t
slice_range
,
step
;
if
(
stride_v
<
0
)
{
if
(
end_v
<
-
1
)
end_v
=
-
1
;
CHECK_LT
(
end_v
,
begin_v
)
<<
"strided_slice get empty slice at axis "
<<
i
;
begin_v
=
std
::
min
(
dim_size
-
1
,
begin_v
);
slice_range
=
begin_v
-
end_v
;
step
=
-
stride_v
;
}
else
{
if
(
begin_v
<
0
)
begin_v
=
0
;
CHECK_GE
(
stride_v
,
0
);
CHECK_LT
(
begin_v
,
end_v
)
<<
"strided_slice get empty slice at axis "
<<
i
;
end_v
=
std
::
min
(
dim_size
,
end_v
);
slice_range
=
end_v
-
begin_v
;
step
=
stride_v
;
}
oshape
[
i
]
=
make_const
(
dshape
[
i
].
type
(),
(
slice_range
+
step
-
1
)
/
step
);
}
reporter
->
Assign
(
types
[
1
],
TensorTypeNode
::
make
(
oshape
,
data
->
dtype
));
return
true
;
}
// Positional relay function to create StridedSlice operator used by frontend FFI.
Expr
MakeStridedSlice
(
Expr
data
,
Array
<
Integer
>
begin
,
Array
<
Integer
>
end
,
Array
<
Integer
>
strides
)
{
auto
attrs
=
make_node
<
StridedSliceAttrs
>
();
attrs
->
begin
=
std
::
move
(
begin
);
attrs
->
end
=
std
::
move
(
end
);
attrs
->
strides
=
std
::
move
(
strides
);
static
const
Op
&
op
=
Op
::
Get
(
"strided_slice"
);
return
CallNode
::
make
(
op
,
{
data
},
Attrs
(
attrs
),
{});
}
Array
<
Tensor
>
StridedSliceCompute
(
const
Attrs
&
attrs
,
const
Array
<
Tensor
>&
inputs
,
const
Type
&
out_type
,
const
Target
&
target
)
{
const
StridedSliceAttrs
*
param
=
attrs
.
as
<
StridedSliceAttrs
>
();
CHECK
(
param
!=
nullptr
);
return
Array
<
Tensor
>
{
topi
::
strided_slice
(
inputs
[
0
],
param
->
begin
,
param
->
end
,
param
->
strides
)
};
}
TVM_REGISTER_API
(
"relay.op._make.strided_slice"
)
.
set_body
([](
const
TVMArgs
&
args
,
TVMRetValue
*
rv
)
{
runtime
::
detail
::
unpack_call
<
Expr
,
4
>
(
MakeStridedSlice
,
args
,
rv
);
});
RELAY_REGISTER_OP
(
"strided_slice"
)
.
describe
(
R"code(Strided slice of an array.
Examples::
x = [[ 1., 4., 7., 10.],
[ 2., 5., 8., 11.],
[ 3., 6., 9., 12.]]
strided_slice(x, begin=[0, 1], end=[2, 4], stride=[1, 1]) = [[ 4., 7., 10.],
[ 5., 8., 11.]]
x = [[[ 1., 2.],
[ 3., 4.]],
[[ 5., 6.],
[ 7., 8.]]]
strided_slice(x, begin=[0, 0], end=[2, 2]) = [[[ 1., 2.],
[ 3., 4.]],
[[ 5., 6.],
[ 7., 8.]]]
)code"
TVM_ADD_FILELINE
)
.
set_num_inputs
(
1
)
.
add_argument
(
"data"
,
"Tensor"
,
"The input tensor."
)
.
set_support_level
(
4
)
.
set_attrs_type_key
(
"relay.attrs.StridedSliceAttrs"
)
.
add_type_rel
(
"StridedSlice"
,
StridedSliceRel
)
.
set_attr
<
FTVMCompute
>
(
"FTVMCompute"
,
StridedSliceCompute
)
.
set_attr
<
TOpPattern
>
(
"TOpPattern"
,
kInjective
);
// Split
TVM_REGISTER_NODE_TYPE
(
SplitAttrs
);
...
...
tests/python/relay/test_op_level4.py
View file @
1f2c8156
...
...
@@ -2,7 +2,7 @@ import tvm
import
numpy
as
np
from
tvm
import
relay
from
tvm.relay.testing
import
ctx_list
import
topi.testing
def
test_binary_op
():
def
check_binary_op
(
opfunc
,
ref
):
...
...
@@ -142,7 +142,43 @@ def test_reduce_functions():
verify_reduce
(
func
,
(
128
,
24
,
128
),
(
0
,
1
),
True
,
False
,
(
1
,
1
,
128
))
verify_reduce
(
func
,
(
128
,
24
,
128
),
(
0
,
2
),
True
,
False
,
(
1
,
24
,
1
))
def
test_strided_slice
():
def
verify
(
dshape
,
begin
,
end
,
strides
,
output
,
test_ref
=
True
):
x
=
relay
.
var
(
"x"
,
relay
.
TensorType
(
dshape
,
"float32"
))
z
=
relay
.
strided_slice
(
x
,
begin
=
begin
,
end
=
end
,
strides
=
strides
)
func
=
relay
.
Function
([
x
],
z
)
func
=
relay
.
ir_pass
.
infer_type
(
func
)
text
=
func
.
astext
()
assert
"begin="
in
text
assert
"end="
in
text
if
output
:
assert
func
.
body
.
checked_type
==
relay
.
ty
.
TensorType
(
output
,
"float32"
)
if
not
test_ref
:
return
x_data
=
np
.
random
.
uniform
(
size
=
dshape
)
.
astype
(
"float32"
)
ref_res
=
topi
.
testing
.
strided_slice_python
(
x_data
,
begin
,
end
,
strides
)
for
target
,
ctx
in
ctx_list
():
intrp
=
relay
.
create_executor
(
"graph"
,
ctx
=
ctx
,
target
=
target
)
op_res
=
intrp
.
evaluate
(
func
)(
x_data
)
tvm
.
testing
.
assert_allclose
(
op_res
.
asnumpy
(),
ref_res
)
d1
,
d2
,
d3
,
d4
=
tvm
.
var
(
"d1"
),
tvm
.
var
(
"d2"
),
tvm
.
var
(
"d3"
),
tvm
.
var
(
"d4"
)
verify
((
d1
,
d2
,
3
),
[
None
,
None
,
1
],
[
None
,
None
,
2
],
None
,
(
d1
,
d2
,
1
),
False
)
verify
((
3
,
4
,
3
),
[
0
,
0
,
0
],
[
4
,
-
5
,
4
],
[
1
,
-
1
,
2
],
(
3
,
1
,
2
))
verify
((
3
,
4
,
3
),
[
1
,
1
,
0
],
[
4
,
4
,
3
],
[
2
,
1
,
1
],
(
1
,
3
,
3
))
verify
((
3
,
4
,
3
),
[
1
,
-
1
,
0
],
[
4
,
-
5
,
3
],
[
2
,
-
1
,
1
],
(
1
,
4
,
3
))
verify
((
3
,
4
,
3
),
[
1
,
0
,
0
],
[
2
,
2
,
3
],
[
1
,
1
,
2
],
(
1
,
2
,
2
))
verify
((
3
,
4
,
3
),
[
1
,
-
1
,
0
],
[
2
,
-
3
,
3
],
[
1
,
-
1
,
1
],
(
1
,
2
,
3
))
verify
((
3
,
4
,
3
),
[
1
,
1
,
0
],
[
4
,
4
,
3
],
None
,
(
2
,
3
,
3
))
verify
((
3
,
4
,
3
),
[
1
,
1
,
0
],
[
4
,
1000
,
3
],
None
,
(
2
,
3
,
3
))
verify
((
3
,
4
,
3
),
[
1
,
1
,
0
],
[
4
,
4
],
None
,
(
2
,
3
,
3
))
verify
((
3
,
4
,
3
),
[
1
,
1
],
[
4
,
4
,
3
],
None
,
(
2
,
3
,
3
))
if
__name__
==
"__main__"
:
test_strided_slice
()
test_binary_op
()
test_cmp_type
()
test_binary_int_broadcast
()
...
...
topi/include/topi/transform.h
View file @
1f2c8156
...
...
@@ -10,6 +10,7 @@
#include <vector>
#include <iterator>
#include <algorithm>
#include <limits>
#include "topi/tags.h"
#include "topi/detail/ravel_unravel.h"
...
...
@@ -403,31 +404,51 @@ inline Array<Tensor> split(const Tensor& x,
* \return A Tensor whose op member is the split operation
*/
inline
Tensor
strided_slice
(
const
Tensor
&
x
,
const
Array
<
Exp
r
>&
begin
,
const
Array
<
Exp
r
>&
end
,
const
Array
<
Exp
r
>&
strides
,
const
Array
<
Intege
r
>&
begin
,
const
Array
<
Intege
r
>&
end
,
const
Array
<
Intege
r
>&
strides
,
std
::
string
name
=
"tensor"
,
std
::
string
tag
=
kInjective
)
{
size_t
src_tensor_dim
=
static_cast
<
size_t
>
(
x
->
shape
.
size
());
std
::
vector
<
int64_t
>
begin_vec
=
GetConstInt64Values
(
begin
,
"begin"
);
std
::
vector
<
int64_t
>
end_vec
=
GetConstInt64Values
(
end
,
"end"
);
std
::
vector
<
int64_t
>
stride_vec
=
GetConstInt64Values
(
strides
,
"strides"
);
// in case user has not provided begin indices for all the axes,
// then inflate it with default value = 0
for
(
size_t
i
=
begin_vec
.
size
();
i
<
src_tensor_dim
;
++
i
)
{
begin_vec
.
push_back
(
0
);
}
// in case user has not provided end indices for all the axes,
// then inflate it with default value = input_tensor.shape[axis]
for
(
size_t
i
=
end_vec
.
size
();
i
<
src_tensor_dim
;
++
i
)
{
end_vec
.
push_back
(
GetConstInt
(
x
->
shape
[
i
]));
// Setup the ranges.
// NOTE: this code duplicates the shape inference logic relay.op
// Consider to refactor in the future.
std
::
vector
<
int64_t
>
stride_vec
;
for
(
Integer
i
:
strides
)
{
CHECK
(
i
.
defined
());
stride_vec
.
push_back
(
i
->
value
);
}
// in case user has not provided stride values,
// then inflate it with default value = 1
for
(
size_t
i
=
stride_vec
.
size
();
i
<
src_tensor_dim
;
++
i
)
{
stride_vec
.
push_back
(
1
);
}
const
int64_t
max_range
=
std
::
numeric_limits
<
int64_t
>::
max
();
std
::
vector
<
int64_t
>
begin_vec
;
for
(
size_t
i
=
0
;
i
<
begin
.
size
();
++
i
)
{
if
(
!
begin
[
i
].
defined
())
{
// value=None
begin_vec
.
push_back
(
stride_vec
[
i
]
>
0
?
0
:
max_range
);
}
else
{
begin_vec
.
push_back
(
begin
[
i
]
->
value
);
}
}
for
(
size_t
i
=
begin_vec
.
size
();
i
<
src_tensor_dim
;
++
i
)
{
begin_vec
.
push_back
(
stride_vec
[
i
]
>
0
?
0
:
max_range
);
}
std
::
vector
<
int64_t
>
end_vec
;
for
(
size_t
i
=
0
;
i
<
end
.
size
();
++
i
)
{
// allow end to be None
if
(
!
end
[
i
].
defined
())
{
end_vec
.
push_back
(
stride_vec
[
i
]
<
0
?
0
:
max_range
);
}
else
{
end_vec
.
push_back
(
end
[
i
]
->
value
);
}
}
for
(
size_t
i
=
end_vec
.
size
();
i
<
src_tensor_dim
;
++
i
)
{
end_vec
.
push_back
(
stride_vec
[
i
]
<
0
?
0
:
max_range
);
}
// Compute
Array
<
Expr
>
out_shape
;
Array
<
Expr
>
begin_expr
;
Array
<
Expr
>
strides_expr
;
...
...
topi/python/topi/testing/__init__.py
View file @
1f2c8156
...
...
@@ -19,3 +19,4 @@ from .shortcut_python import shortcut_python
from
.lrn_python
import
lrn_python
from
.l2_normalize_python
import
l2_normalize_python
from
.gather_nd_python
import
gather_nd_python
from
.strided_slice_python
import
strided_slice_python
topi/python/topi/testing/strided_slice_python.py
0 → 100644
View file @
1f2c8156
"""gather_nd in python"""
def
strided_slice_python
(
data
,
begin
,
end
,
strides
):
"""Python version of strided slice operator.
Parameters
----------
data : numpy.ndarray
Input data
begin : list
Begining of the slices.
end : list
End of the slices.
strides : list
The stride of each slice.
Returns
-------
result : numpy.ndarray
The sliced result.
"""
strides
=
[]
if
strides
is
None
else
strides
slices
=
[]
for
i
in
range
(
len
(
data
.
shape
)):
slices
.
append
(
slice
(
begin
[
i
]
if
i
<
len
(
begin
)
else
None
,
end
[
i
]
if
i
<
len
(
end
)
else
None
,
strides
[
i
]
if
i
<
len
(
strides
)
else
None
))
return
data
[
tuple
(
slices
)]
topi/tests/python/test_topi_transform.py
View file @
1f2c8156
...
...
@@ -249,13 +249,11 @@ def verify_take(src_shape, indices_src, axis=None):
for
device
in
[
"llvm"
,
"opencl"
,
"sdaccel"
,
"aocl_sw_emu"
]:
check_device
(
device
)
def
verify_strided_slice
(
in_shape
,
begin
,
end
,
stride
=
None
):
stride
=
stride
if
stride
else
[
1
,
1
,
1
]
def
verify_strided_slice
(
in_shape
,
begin
,
end
,
strides
=
None
):
A
=
tvm
.
placeholder
(
shape
=
in_shape
,
name
=
"A"
)
B
=
topi
.
strided_slice
(
A
,
begin
,
end
,
stride
)
+
1
def
test_forward
(
x
,
begin
,
end
,
stride
):
return
x
[
begin
[
0
]:
end
[
0
]:
stride
[
0
],
begin
[
1
]:
end
[
1
]:
stride
[
1
],
begin
[
2
]:
end
[
2
]:
stride
[
2
]]
+
1
strides
=
[
1
,
1
,
1
]
if
strides
is
None
else
strides
B
=
topi
.
strided_slice
(
A
,
begin
,
end
,
strides
)
+
1
def
check_device
(
device
):
ctx
=
tvm
.
context
(
device
,
0
)
if
not
ctx
.
exist
:
...
...
@@ -267,7 +265,8 @@ def verify_strided_slice(in_shape, begin, end, stride=None):
foo
=
tvm
.
build
(
s
,
[
A
,
B
],
device
,
name
=
"stride_slice"
)
x_np
=
np
.
random
.
uniform
(
size
=
in_shape
)
.
astype
(
A
.
dtype
)
out_npy
=
test_forward
(
x_np
,
begin
,
end
,
stride
)
out_npy
=
topi
.
testing
.
strided_slice_python
(
x_np
,
begin
,
end
,
strides
)
+
1
data_nd
=
tvm
.
nd
.
array
(
x_np
,
ctx
)
out_nd
=
tvm
.
nd
.
empty
(
out_npy
.
shape
,
ctx
=
ctx
,
dtype
=
A
.
dtype
)
foo
(
data_nd
,
out_nd
)
...
...
@@ -298,7 +297,7 @@ def verify_gather_nd(src_shape, indices_src, indices_dtype):
shape_size
=
shape_size
*
src_shape
[
i
]
data_npy
=
np
.
arange
(
shape_size
,
dtype
=
src_dtype
)
.
reshape
((
src_shape
))
out_npys
=
topi
.
testing
.
gather_nd_python
(
data_npy
,
indices_src
)
data_nd
=
tvm
.
nd
.
array
(
data_npy
,
ctx
)
indices_nd
=
tvm
.
nd
.
array
(
indices_src
,
ctx
)
out_nd
=
tvm
.
nd
.
empty
(
out_npys
.
shape
,
ctx
=
ctx
,
dtype
=
src_dtype
)
...
...
@@ -412,6 +411,7 @@ def test_gather_nd():
indices_dtype
)
if
__name__
==
"__main__"
:
test_strided_slice
()
test_concatenate
()
test_tranpose
()
test_expand_dims
()
...
...
@@ -421,5 +421,4 @@ if __name__ == "__main__":
test_flip
()
test_expand_like
()
test_take
()
test_strided_slice
()
test_gather_nd
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment