Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
4fb58115
Commit
4fb58115
authored
Jun 28, 2018
by
Pariksheet Pinjari
Committed by
Tianqi Chen
Jun 28, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Strided_slice added in NNVM (#1318)
parent
2aa1f054
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
169 additions
and
0 deletions
+169
-0
nnvm/include/nnvm/top/tensor.h
+16
-0
nnvm/python/nnvm/top/transform.py
+4
-0
nnvm/src/top/tensor/transform.cc
+113
-0
nnvm/tests/python/compiler/test_top_level1.py
+36
-0
No files found.
nnvm/include/nnvm/top/tensor.h
View file @
4fb58115
...
...
@@ -48,6 +48,22 @@ struct SplitParam : public dmlc::Parameter<SplitParam> {
}
};
struct
StridedSliceParam
:
public
dmlc
::
Parameter
<
StridedSliceParam
>
{
// numpy convention, only support indices, not support list.
Tuple
<
int64_t
>
begin
;
Tuple
<
int64_t
>
end
;
Tuple
<
int64_t
>
stride
;
DMLC_DECLARE_PARAMETER
(
StridedSliceParam
)
{
DMLC_DECLARE_FIELD
(
begin
)
.
describe
(
"Indices for begin of slice"
);
DMLC_DECLARE_FIELD
(
end
)
.
describe
(
"Indices for end of the slice"
);
DMLC_DECLARE_FIELD
(
stride
).
set_default
(
Tuple
<
int64_t
>
())
.
describe
(
"Stride values of the slice"
);
}
};
enum
TypeFlag
{
kFloat32
=
0
,
kFloat64
=
1
,
...
...
nnvm/python/nnvm/top/transform.py
View file @
4fb58115
...
...
@@ -61,6 +61,10 @@ reg.register_schedule("concatenate", _fschedule_injective)
reg
.
register_pattern
(
"split"
,
OpPattern
.
INJECTIVE
)
reg
.
register_schedule
(
"split"
,
_fschedule_injective
)
# strided_slice
reg
.
register_pattern
(
"strided_slice"
,
OpPattern
.
INJECTIVE
)
reg
.
register_schedule
(
"strided_slice"
,
_fschedule_injective
)
# slice_like
reg
.
register_pattern
(
"slice_like"
,
OpPattern
.
INJECTIVE
)
reg
.
register_schedule
(
"slice_like"
,
_fschedule_injective
)
nnvm/src/top/tensor/transform.cc
View file @
4fb58115
...
...
@@ -829,6 +829,119 @@ Examples::
};
});
// strided_slice
DMLC_REGISTER_PARAMETER
(
StridedSliceParam
);
inline
void
StridedSliceParamParser
(
nnvm
::
NodeAttrs
*
attrs
)
{
StridedSliceParam
param
;
param
.
Init
(
attrs
->
dict
);
attrs
->
parsed
=
std
::
move
(
param
);
}
inline
bool
StridedSliceInferShape
(
const
NodeAttrs
&
attrs
,
std
::
vector
<
TShape
>*
in_shape
,
std
::
vector
<
TShape
>*
out_shape
)
{
const
StridedSliceParam
&
param
=
nnvm
::
get
<
StridedSliceParam
>
(
attrs
.
parsed
);
const
TShape
&
dshape
=
(
*
in_shape
)[
0
];
if
(
dshape
.
ndim
()
==
0
)
return
false
;
TShape
oshape
=
dshape
;
dim_t
num_axis
=
dshape
.
ndim
();
std
::
vector
<
int64_t
>
begin_vec
;
std
::
copy
(
param
.
begin
.
begin
(),
param
.
begin
.
end
(),
std
::
back_inserter
(
begin_vec
));
for
(
dim_t
i
=
begin_vec
.
size
();
i
<
num_axis
;
++
i
)
{
begin_vec
.
push_back
(
0
);
}
std
::
vector
<
int64_t
>
end_vec
;
std
::
copy
(
param
.
end
.
begin
(),
param
.
end
.
end
(),
std
::
back_inserter
(
end_vec
));
for
(
dim_t
i
=
end_vec
.
size
();
i
<
num_axis
;
++
i
)
{
end_vec
.
push_back
(
dshape
[
i
]);
}
std
::
vector
<
int64_t
>
stride_vec
;
std
::
copy
(
param
.
stride
.
begin
(),
param
.
stride
.
end
(),
std
::
back_inserter
(
stride_vec
));
for
(
dim_t
i
=
stride_vec
.
size
();
i
<
num_axis
;
++
i
)
{
stride_vec
.
push_back
(
1
);
}
for
(
dim_t
i
=
0
;
i
<
num_axis
;
++
i
)
{
int64_t
begin_range
=
stride_vec
[
i
]
<
0
?
-
1
:
0
;
int64_t
end_range
=
stride_vec
[
i
]
<
0
?
dshape
[
i
]
-
1
:
dshape
[
i
];
int64_t
begin
=
begin_vec
[
i
]
<
0
?
dshape
[
i
]
+
begin_vec
[
i
]
:
begin_vec
[
i
];
int64_t
end
=
end_vec
[
i
]
<
0
?
dshape
[
i
]
+
end_vec
[
i
]
:
end_vec
[
i
];
begin
=
std
::
min
(
std
::
max
(
begin
,
begin_range
),
end_range
);
end
=
std
::
min
(
std
::
max
(
end
,
begin_range
),
end_range
);
int
interval
=
std
::
abs
(
end
-
begin
);
int
slice_size
=
static_cast
<
int
>
((
interval
+
std
::
abs
(
stride_vec
[
i
])
-
1
)
/
std
::
abs
(
stride_vec
[
i
]));
CHECK
(
stride_vec
[
i
]
<
0
?
(
end
<
begin
)
:
(
begin
<
end
))
<<
": Input [Begin="
<<
begin_vec
[
i
]
<<
", End="
<<
end_vec
[
i
]
<<
"] is invalid for axis="
<<
i
;
oshape
[
i
]
=
slice_size
;
}
NNVM_ASSIGN_OUTPUT_SHAPE
(
attrs
,
*
out_shape
,
0
,
oshape
);
return
true
;
}
NNVM_REGISTER_OP
(
strided_slice
)
.
describe
(
R"code(Strided slice of an array.
Examples::
x = [[ 1., 4., 7., 10.],
[ 2., 5., 8., 11.],
[ 3., 6., 9., 12.]]
strided_slice(x, begin=[0, 1], end=[2, 4], stride=[1, 1]) = [[ 4., 7., 10.],
[ 5., 8., 11.]]
x = [[[ 1., 2.],
[ 3., 4.]],
[[ 5., 6.],
[ 7., 8.]]]
strided_slice(x, begin=[0, 0], end=[2, 2]) = [[[ 1., 2.],
[ 3., 4.]],
[[ 5., 6.],
[ 7., 8.]]]
)code"
NNVM_ADD_FILELINE
)
.
add_argument
(
"data"
,
"Tensor"
,
"Array to be sliced"
)
.
add_arguments
(
StridedSliceParam
::
__FIELDS__
())
.
set_attr_parser
(
StridedSliceParamParser
)
.
set_attr
<
FInferShape
>
(
"FInferShape"
,
StridedSliceInferShape
)
.
set_attr
<
FInferType
>
(
"FInferType"
,
ElemwiseType
<
1
,
1
>
)
.
set_attr
<
FCorrectLayout
>
(
"FCorrectLayout"
,
ElemwiseArbitraryLayout
<
1
,
1
>
)
.
set_num_inputs
(
1
)
.
set_num_outputs
(
1
)
.
set_attr
<
FTVMCompute
>
(
"FTVMCompute"
,
[](
const
NodeAttrs
&
attrs
,
const
Array
<
Tensor
>&
inputs
,
const
Array
<
Tensor
>&
out_info
)
{
const
StridedSliceParam
&
param
=
nnvm
::
get
<
StridedSliceParam
>
(
attrs
.
parsed
);
Array
<
Expr
>
begin
;
Array
<
Expr
>
end
;
Array
<
Expr
>
stride
;
for
(
int64_t
i
:
param
.
begin
)
{
begin
.
push_back
(
tvm
::
make_const
(
tvm
::
Int
(
32
),
i
));
}
for
(
int64_t
i
:
param
.
end
)
{
end
.
push_back
(
tvm
::
make_const
(
tvm
::
Int
(
32
),
i
));
}
for
(
int64_t
i
:
param
.
stride
)
{
stride
.
push_back
(
tvm
::
make_const
(
tvm
::
Int
(
32
),
i
));
}
return
Array
<
Tensor
>
{
topi
::
strided_slice
(
inputs
[
0
],
begin
,
end
,
stride
)
};
})
.
set_support_level
(
1
);
// Flip
DMLC_REGISTER_PARAMETER
(
FlipParam
);
...
...
nnvm/tests/python/compiler/test_top_level1.py
View file @
4fb58115
...
...
@@ -329,6 +329,41 @@ def test_split():
verify_split
((
5
,
3
),
[
3
],
axis
=
0
)
verify_split
((
5
,
9
,
3
),
[
3
,
4
],
axis
=
1
)
def
verify_strided_slice
(
ishape
,
begin
,
end
,
strideinp
=
None
):
stride
=
strideinp
if
strideinp
else
[
1
,
1
,
1
]
x
=
sym
.
Variable
(
"x"
)
if
strideinp
:
y
=
sym
.
strided_slice
(
x
,
begin
=
begin
,
end
=
end
,
stride
=
stride
)
+
1
else
:
y
=
sym
.
strided_slice
(
x
,
begin
=
begin
,
end
=
end
)
+
1
x_np
=
np
.
random
.
uniform
(
size
=
ishape
)
.
astype
(
"float32"
)
for
i
in
range
(
len
(
begin
),
3
):
begin
.
append
(
0
)
for
i
in
range
(
len
(
end
),
3
):
end
.
append
(
ishape
[
i
])
def
test_forward
(
x
,
begin
,
end
,
stride
):
return
x
[
begin
[
0
]:
end
[
0
]:
stride
[
0
],
begin
[
1
]:
end
[
1
]:
stride
[
1
],
begin
[
2
]:
end
[
2
]:
stride
[
2
]]
+
1
for
target
,
ctx
in
ctx_list
():
# set input
graph
,
lib
,
_
=
nnvm
.
compiler
.
build
(
y
,
target
,
{
"x"
:
ishape
})
m
=
graph_runtime
.
create
(
graph
,
lib
,
ctx
)
m
.
run
(
x
=
x_np
)
res
=
test_forward
(
x_np
,
begin
,
end
,
stride
)
out
=
m
.
get_output
(
0
,
tvm
.
nd
.
empty
(
res
.
shape
))
np
.
testing
.
assert_allclose
(
out
.
asnumpy
(),
res
,
atol
=
1e-5
,
rtol
=
1e-5
)
def
test_strided_slice
():
verify_strided_slice
((
3
,
4
,
3
),
[
0
,
0
,
0
],
[
4
,
-
5
,
4
],
[
1
,
-
1
,
2
])
verify_strided_slice
((
3
,
4
,
3
),
[
1
,
1
,
0
],
[
4
,
4
,
3
],
[
2
,
1
,
1
])
verify_strided_slice
((
3
,
4
,
3
),
[
1
,
-
1
,
0
],
[
4
,
-
5
,
3
],
[
2
,
-
1
,
1
])
verify_strided_slice
((
3
,
4
,
3
),
[
1
,
0
,
0
],
[
2
,
2
,
3
],
[
1
,
1
,
2
])
verify_strided_slice
((
3
,
4
,
3
),
[
1
,
-
1
,
0
],
[
2
,
-
3
,
3
],
[
1
,
-
1
,
1
])
verify_strided_slice
((
3
,
4
,
3
),
[
1
,
1
,
0
],
[
4
,
4
,
3
])
verify_strided_slice
((
3
,
4
,
3
),
[
1
,
1
,
0
],
[
4
,
1000
,
3
])
verify_strided_slice
((
3
,
4
,
3
),
[
1
,
1
,
0
],
[
4
,
4
])
verify_strided_slice
((
3
,
4
,
3
),
[
1
,
1
],
[
4
,
4
,
3
])
def
verify_squeeze
(
dshape
,
axis
):
x
=
sym
.
Variable
(
"x"
)
...
...
@@ -448,3 +483,4 @@ if __name__ == "__main__":
test_pad
()
test_lrn
()
test_l2_normalize
()
test_strided_slice
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment