Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
c1c32758
Commit
c1c32758
authored
Jun 06, 2018
by
Pariksheet Pinjari
Committed by
Tianqi Chen
Jun 05, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[TOPI] Slice operator (#1165)
parent
a9313787
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
171 additions
and
0 deletions
+171
-0
docs/api/python/topi.rst
+6
-0
topi/include/topi/detail/constant_utils.h
+18
-0
topi/include/topi/transform.h
+81
-0
topi/python/topi/transform.py
+26
-0
topi/src/topi.cc
+5
-0
topi/tests/python/test_topi_transform.py
+35
-0
No files found.
docs/api/python/topi.rst
View file @
c1c32758
...
...
@@ -17,11 +17,14 @@ List of operators
topi.clip
topi.cast
topi.transpose
topi.flip
topi.strided_slice
topi.expand_dims
topi.reshape
topi.squeeze
topi.concatenate
topi.split
topi.take
topi.full
topi.full_like
topi.greater
...
...
@@ -72,11 +75,14 @@ topi
.. autofunction:: topi.clip
.. autofunction:: topi.cast
.. autofunction:: topi.transpose
.. autofunction:: topi.flip
.. autofunction:: topi.strided_slice
.. autofunction:: topi.expand_dims
.. autofunction:: topi.reshape
.. autofunction:: topi.squeeze
.. autofunction:: topi.concatenate
.. autofunction:: topi.split
.. autofunction:: topi.take
.. autofunction:: topi.full
.. autofunction:: topi.full_like
.. autofunction:: topi.greater
...
...
topi/include/topi/detail/constant_utils.h
View file @
c1c32758
...
...
@@ -67,6 +67,24 @@ inline std::vector<int> GetConstIntValues(Array<Expr> exprs, const std::string&
}
/*!
* \brief Get the value of all the constant integer expressions in the given array
*
* \param exprs The array of expressions to get the values of
* \param var_name The name to be used when logging an error in the event that any
* of the expressions are not constant integers.
*
* \return A vector of the int64_t values
*/
inline
std
::
vector
<
int64_t
>
GetConstInt64Values
(
Array
<
Expr
>
exprs
,
const
std
::
string
&
var_name
)
{
std
::
vector
<
int64_t
>
result
;
for
(
auto
expr
:
exprs
)
{
CHECK
(
IsConstInt
(
expr
))
<<
"All elements of "
<<
var_name
<<
" must be constant integers"
;
result
.
push_back
(
GetConstInt
(
expr
));
}
return
result
;
}
/*!
* \brief Check weather the two expressions are equal or not, if not simplify the expressions and check again
* \note This is stronger equality check than tvm::ir::Equal
*
...
...
topi/include/topi/transform.h
View file @
c1c32758
...
...
@@ -366,6 +366,87 @@ inline Array<Tensor> split(const Tensor& x,
}
/*!
* \brief strided_slice of a tensor
*
* \param x The input tensor
* \param begin The indices to begin with in the slicing
* \param end Indicies indicating end of the slice
* \param strides Specifies the stride values, it can be negative
* in that case, the input tensor will be reversed in that particular axis
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the split operation
*/
inline
Tensor
strided_slice
(
const
Tensor
&
x
,
const
Array
<
Expr
>&
begin
,
const
Array
<
Expr
>&
end
,
const
Array
<
Expr
>&
strides
,
std
::
string
name
=
"tensor"
,
std
::
string
tag
=
kInjective
)
{
size_t
src_tensor_dim
=
static_cast
<
size_t
>
(
x
->
shape
.
size
());
std
::
vector
<
int64_t
>
begin_vec
=
GetConstInt64Values
(
begin
,
"begin"
);
std
::
vector
<
int64_t
>
end_vec
=
GetConstInt64Values
(
end
,
"end"
);
std
::
vector
<
int64_t
>
stride_vec
=
GetConstInt64Values
(
strides
,
"strides"
);
// in case user has not provided begin indices for all the axes,
// then inflate it with default value = 0
for
(
size_t
i
=
begin_vec
.
size
();
i
<
src_tensor_dim
;
++
i
)
{
begin_vec
.
push_back
(
0
);
}
// in case user has not provided end indices for all the axes,
// then inflate it with default value = input_tensor.shape[axis]
for
(
size_t
i
=
end_vec
.
size
();
i
<
src_tensor_dim
;
++
i
)
{
end_vec
.
push_back
(
GetConstInt
(
x
->
shape
[
i
]));
}
// in case user has not provided stride values,
// then inflate it with default value = 1
for
(
size_t
i
=
stride_vec
.
size
();
i
<
src_tensor_dim
;
++
i
)
{
stride_vec
.
push_back
(
1
);
}
Array
<
Expr
>
out_shape
;
Array
<
Expr
>
begin_expr
;
Array
<
Expr
>
strides_expr
;
for
(
size_t
i
=
0
;
i
<
src_tensor_dim
;
++
i
)
{
int64_t
begin_range
=
stride_vec
[
i
]
<
0
?
-
1
:
0
;
int64_t
dim_i
=
GetConstInt
(
x
->
shape
[
i
]);
int64_t
end_range
=
stride_vec
[
i
]
<
0
?
dim_i
-
1
:
dim_i
;
// transform negative indices to positive value, clips on the correct range
auto
index_canonicalization
=
[
dim_i
,
begin_range
,
end_range
](
int64_t
index
)
{
if
(
index
<
0
)
{
index
+=
dim_i
;
}
return
std
::
min
(
std
::
max
(
index
,
begin_range
),
end_range
);
};
int64_t
begin_i
=
index_canonicalization
(
begin_vec
[
i
]);
int64_t
end_i
=
index_canonicalization
(
end_vec
[
i
]);
int
interval
=
std
::
abs
(
end_i
-
begin_i
);
int
slice_size
=
static_cast
<
int
>
((
interval
+
std
::
abs
(
stride_vec
[
i
])
-
1
)
/
std
::
abs
(
stride_vec
[
i
]));
CHECK
(
stride_vec
[
i
]
<
0
?
(
end_i
<
begin_i
)
:
(
begin_i
<
end_i
))
<<
": Input [Begin="
<<
begin_vec
[
i
]
<<
", End="
<<
end_vec
[
i
]
<<
"] is invalid for axis="
<<
i
;
begin_expr
.
push_back
(
make_const
(
begin
[
0
].
type
(),
begin_i
));
strides_expr
.
push_back
(
make_const
((
strides
.
size
()
!=
0
?
strides
[
0
].
type
()
:
begin
[
0
].
type
()),
stride_vec
[
i
]));
out_shape
.
push_back
(
slice_size
);
}
return
compute
(
out_shape
,
[
&
](
const
Array
<
Var
>&
indices
)
{
Array
<
Expr
>
real_indices
;
for
(
size_t
i
=
0
;
i
<
src_tensor_dim
;
++
i
)
{
real_indices
.
push_back
(
indices
[
i
]
*
strides_expr
[
i
]
+
begin_expr
[
i
]);
}
return
x
(
real_indices
);
},
name
,
tag
);
}
/*!
* \brief Split a tensor into a number of sub-tensors
*
* \param x The input tensor
...
...
topi/python/topi/transform.py
View file @
c1c32758
...
...
@@ -130,6 +130,32 @@ def flip(a, axis=0):
return
cpp
.
flip
(
a
,
axis
)
@tvm.tag_scope
(
tag
=
tag
.
INJECTIVE
)
def
strided_slice
(
a
,
begin
,
end
,
strides
=
None
):
"""Slice of an array.
Parameters
----------
a : tvm.Tensor
The tensor to be sliced.
begin: list of int
The indices to begin with in the slicing.
end: list of int
Indicies indicating end of the slice.
strides: list of int, optional
Specifies the stride values, it can be negative
in that case, the input tensor will be reversed
in that particular axis.
Returns
-------
ret : tvm.Tensor
"""
return
cpp
.
strided_slice
(
a
,
begin
,
end
,
strides
)
@tvm.tag_scope
(
tag
=
tag
.
INJECTIVE
)
def
reshape
(
a
,
newshape
):
"""Reshape the array
...
...
topi/src/topi.cc
View file @
c1c32758
...
...
@@ -280,6 +280,11 @@ TVM_REGISTER_GLOBAL("topi.take")
}
});
TVM_REGISTER_GLOBAL
(
"topi.strided_slice"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
*
rv
=
strided_slice
(
args
[
0
],
args
[
1
],
args
[
2
],
args
[
3
]);
});
/* Ops from nn/batch_norm.h */
TVM_REGISTER_GLOBAL
(
"topi.nn.batch_norm_inference"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
...
...
topi/tests/python/test_topi_transform.py
View file @
c1c32758
...
...
@@ -246,6 +246,40 @@ def verify_take(src_shape, indices_src, axis=None):
for
device
in
[
"llvm"
,
"opencl"
]:
check_device
(
device
)
def
verify_strided_slice
(
in_shape
,
begin
,
end
,
stride
=
None
):
stride
=
stride
if
stride
else
[
1
,
1
,
1
]
A
=
tvm
.
placeholder
(
shape
=
in_shape
,
name
=
"A"
)
B
=
topi
.
strided_slice
(
A
,
begin
,
end
,
stride
)
+
1
def
test_forward
(
x
,
begin
,
end
,
stride
):
return
x
[
begin
[
0
]:
end
[
0
]:
stride
[
0
],
begin
[
1
]:
end
[
1
]:
stride
[
1
],
begin
[
2
]:
end
[
2
]:
stride
[
2
]]
+
1
def
check_device
(
device
):
ctx
=
tvm
.
context
(
device
,
0
)
if
not
ctx
.
exist
:
print
(
"Skip because
%
s is not enabled"
%
device
)
return
print
(
"Running on target:
%
s"
%
device
)
with
tvm
.
target
.
create
(
device
):
s
=
topi
.
generic
.
schedule_injective
(
B
)
foo
=
tvm
.
build
(
s
,
[
A
,
B
],
device
,
name
=
"stride_slice"
)
x_np
=
np
.
random
.
uniform
(
size
=
in_shape
)
.
astype
(
A
.
dtype
)
out_npy
=
test_forward
(
x_np
,
begin
,
end
,
stride
)
data_nd
=
tvm
.
nd
.
array
(
x_np
,
ctx
)
out_nd
=
tvm
.
nd
.
empty
(
out_npy
.
shape
,
ctx
=
ctx
,
dtype
=
A
.
dtype
)
foo
(
data_nd
,
out_nd
)
np
.
testing
.
assert_allclose
(
out_nd
.
asnumpy
(),
out_npy
)
for
device
in
[
"llvm"
,
"opencl"
]:
check_device
(
device
)
def
test_strided_slice
():
verify_strided_slice
((
3
,
4
,
3
),
[
0
,
0
,
0
],
[
4
,
-
5
,
4
],
[
1
,
-
1
,
2
])
verify_strided_slice
((
3
,
4
,
3
),
[
1
,
1
,
0
],
[
4
,
4
,
3
],
[
2
,
1
,
1
])
verify_strided_slice
((
3
,
4
,
3
),
[
1
,
-
1
,
0
],
[
4
,
-
5
,
3
],
[
2
,
-
1
,
1
])
verify_strided_slice
((
3
,
4
,
3
),
[
1
,
0
,
0
],
[
2
,
2
,
3
],
[
1
,
1
,
2
])
verify_strided_slice
((
3
,
4
,
3
),
[
1
,
-
1
,
0
],
[
2
,
-
3
,
3
],
[
1
,
-
1
,
1
])
verify_strided_slice
((
3
,
4
,
3
),
[
1
,
1
,
0
],
[
4
,
4
,
3
])
def
test_expand_dims
():
verify_expand_dims
((
3
,
10
),
(
3
,
10
,
1
,
1
),
2
,
2
)
...
...
@@ -322,3 +356,4 @@ if __name__ == "__main__":
test_flip
()
test_expand_like
()
test_take
()
test_strided_slice
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment