Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
6d88c987
Commit
6d88c987
authored
Dec 03, 2019
by
abergeron
Committed by
Tianqi Chen
Dec 03, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[TOPI][Relay][OP] Add a strided_set operation. (#4303)
parent
e3eff20d
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
386 additions
and
2 deletions
+386
-2
python/tvm/relay/op/_transform.py
+6
-0
python/tvm/relay/op/transform.py
+30
-0
src/relay/op/tensor/transform.cc
+48
-0
tests/python/relay/test_op_level4.py
+40
-0
topi/python/topi/testing/__init__.py
+1
-1
topi/python/topi/testing/strided_slice_python.py
+39
-1
topi/python/topi/transform.py
+93
-0
topi/python/topi/util.py
+72
-0
topi/tests/python/test_topi_transform.py
+57
-0
No files found.
python/tvm/relay/op/_transform.py
View file @
6d88c987
...
...
@@ -48,6 +48,7 @@ _reg.register_schedule("cast", schedule_injective)
_reg
.
register_schedule
(
"cast_like"
,
schedule_injective
)
_reg
.
register_schedule
(
"reinterpret"
,
schedule_injective
)
_reg
.
register_schedule
(
"strided_slice"
,
schedule_injective
)
_reg
.
register_schedule
(
"strided_set"
,
schedule_injective
)
_reg
.
register_schedule
(
"slice_like"
,
schedule_injective
)
_reg
.
register_schedule
(
"split"
,
schedule_injective
)
_reg
.
register_schedule
(
"take"
,
schedule_injective
)
...
...
@@ -304,6 +305,11 @@ def compute_argwhere(attrs, inputs, output_type, _):
new_output_type
=
tvm
.
relay
.
ty
.
TensorType
(
output_shape
,
"int32"
)
return
[
topi
.
argwhere
(
new_output_type
,
inputs
[
0
])]
@_reg.register_compute
(
"strided_set"
)
def
compute_strided_set
(
attrs
,
inputs
,
output_type
,
_
):
"""Compute definition of strided_set"""
return
[
topi
.
strided_set
(
inputs
[
0
],
inputs
[
1
],
inputs
[
2
],
inputs
[
3
],
inputs
[
4
])]
@script
def
_layout_transform_shape_func
(
data_shape
,
out_layout_len
,
...
...
python/tvm/relay/op/transform.py
View file @
6d88c987
...
...
@@ -631,6 +631,36 @@ def strided_slice(data, begin, end, strides=None):
return
_make
.
strided_slice
(
data
,
list
(
begin
),
list
(
end
),
list
(
strides
))
def
strided_set
(
data
,
v
,
begin
,
end
,
strides
=
None
):
"""Strided set of an array.
Parameters
----------
data : relay.Expr
The source array to be sliced.
v : relay.Expr
The data to be set.
begin: relay.Expr
The indices to begin with in the slicing.
end: relay.Expr
Indices indicating end of the slice.
strides: relay.Expr, optional
Specifies the stride values, it can be negative in that case,
the input tensor will be reversed in that particular axis.
Returns
-------
ret : relay.Expr
The computed result.
"""
strides
=
strides
or
const
([
1
],
dtype
=
"int32"
)
return
_make
.
strided_set
(
data
,
v
,
begin
,
end
,
strides
)
def
slice_like
(
data
,
shape_like
,
axes
=
None
):
"""Slice the first input with respect to the second input.
...
...
src/relay/op/tensor/transform.cc
View file @
6d88c987
...
...
@@ -2049,6 +2049,54 @@ Examples::
.
set_attr
<
TOpPattern
>
(
"TOpPattern"
,
kInjective
)
.
set_attr
<
FInferCorrectLayout
>
(
"FInferCorrectLayout"
,
StridedSliceInferCorrectLayout
);
// strided_set
bool
StridedSetRel
(
const
Array
<
Type
>&
types
,
int
num_inputs
,
const
Attrs
&
attrs
,
const
TypeReporter
&
reporter
)
{
CHECK_EQ
(
types
.
size
(),
6
);
reporter
->
Assign
(
types
[
5
],
types
[
0
]);
return
true
;
}
Expr
MakeStridedSet
(
Expr
data
,
Expr
v
,
Expr
begin
,
Expr
end
,
Expr
strides
)
{
static
const
Op
&
op
=
Op
::
Get
(
"strided_set"
);
return
CallNode
::
make
(
op
,
{
data
,
v
,
begin
,
end
,
strides
},
{});
}
TVM_REGISTER_API
(
"relay.op._make.strided_set"
)
.
set_body_typed
(
MakeStridedSet
);
RELAY_REGISTER_OP
(
"strided_set"
)
.
describe
(
R"code(Strided set of an array.
Example::
x = [[ 1., 4., 7., 10.],
[ 2., 5., 8., 11.],
[ 3., 6., 9., 12.]]
v = [[ 11., 22., 33.]
[ 44., 55., 66.]]
strided_set(x, v, begin=[0, 1], end=[2, 4], stride=[1, 1]) = \
[[ 1., 11., 22., 33.],
[ 2., 44., 55., 66.],
[ 3., 6., 9., 12.]]
)code"
TVM_ADD_FILELINE
)
.
set_num_inputs
(
5
)
.
add_argument
(
"data"
,
"Tensor"
,
"The input tensor."
)
.
add_argument
(
"v"
,
"Tensor"
,
"The data to set."
)
.
add_argument
(
"begin"
,
"Tensor"
,
"Indices for the start of the slice."
)
.
add_argument
(
"end"
,
"Tensor"
,
"Indices indicating the end of the slice."
)
.
add_argument
(
"strides"
,
"Tensor"
,
"The strides values."
)
.
set_support_level
(
4
)
.
set_attr
<
TOpPattern
>
(
"TOpPattern"
,
kInjective
)
.
add_type_rel
(
"StridedSet"
,
StridedSetRel
);
// relay.split
TVM_REGISTER_NODE_TYPE
(
SplitAttrs
);
...
...
tests/python/relay/test_op_level4.py
View file @
6d88c987
...
...
@@ -300,8 +300,48 @@ def test_strided_slice():
verify
((
3
,
4
,
3
),
[
1
,
1
],
[
4
,
4
,
3
],
None
,
(
2
,
3
,
3
))
def
test_strided_set
():
def
verify
(
dshape
,
begin
,
end
,
strides
,
vshape
,
test_ref
=
True
):
x
=
relay
.
var
(
"x"
,
relay
.
TensorType
(
dshape
,
"float32"
))
v
=
relay
.
var
(
"v"
,
relay
.
TensorType
(
vshape
,
"float32"
))
begin_c
=
relay
.
const
(
begin
,
dtype
=
"int32"
)
end_c
=
relay
.
const
(
end
,
dtype
=
"int32"
)
if
strides
:
strides_c
=
relay
.
const
(
strides
,
dtype
=
"int32"
)
z
=
relay
.
strided_set
(
x
,
v
,
begin
=
begin_c
,
end
=
end_c
,
strides
=
strides_c
)
else
:
z
=
relay
.
strided_set
(
x
,
v
,
begin
=
begin_c
,
end
=
end_c
)
func
=
relay
.
Function
([
x
,
v
],
z
)
func
=
run_infer_type
(
func
)
text
=
func
.
astext
()
assert
"strided_set"
in
text
print
(
text
)
assert
func
.
body
.
checked_type
==
relay
.
ty
.
TensorType
(
dshape
,
"float32"
)
if
not
test_ref
:
return
x_data
=
np
.
random
.
uniform
(
size
=
dshape
)
.
astype
(
"float32"
)
v_data
=
np
.
random
.
uniform
(
size
=
vshape
)
.
astype
(
"float32"
)
ref_res
=
topi
.
testing
.
strided_set_python
(
x_data
,
v_data
,
begin
,
end
,
strides
)
for
target
,
ctx
in
ctx_list
():
intrp
=
relay
.
create_executor
(
"graph"
,
ctx
=
ctx
,
target
=
target
)
op_res
=
intrp
.
evaluate
(
func
)(
x_data
,
v_data
)
tvm
.
testing
.
assert_allclose
(
op_res
.
asnumpy
(),
ref_res
)
verify
((
3
,
4
,
3
),
[
0
,
0
,
0
],
[
4
,
-
5
,
4
],
[
1
,
-
1
,
2
],
(
3
,
1
,
2
))
verify
((
3
,
4
,
3
),
[
1
,
1
,
0
],
[
4
,
4
,
3
],
[
2
,
1
,
1
],
(
1
,
3
,
3
))
verify
((
3
,
4
,
3
),
[
1
,
-
1
,
0
],
[
4
,
-
5
,
3
],
[
2
,
-
1
,
1
],
(
1
,
4
,
3
))
verify
((
3
,
4
,
3
),
[
1
,
0
,
0
],
[
2
,
2
,
3
],
[
1
,
1
,
2
],
(
1
,
2
,
2
))
verify
((
3
,
4
,
3
),
[
1
,
-
1
,
0
],
[
2
,
-
3
,
3
],
[
1
,
-
1
,
1
],
(
1
,
2
,
3
))
verify
((
3
,
4
,
3
),
[
1
,
1
,
0
],
[
4
,
4
,
3
],
None
,
(
2
,
3
,
3
))
verify
((
3
,
4
,
3
),
[
1
,
1
,
0
],
[
4
,
1000
,
3
],
None
,
(
2
,
3
,
3
))
verify
((
3
,
4
,
3
),
[
1
,
1
,
0
],
[
4
,
4
],
None
,
(
2
,
3
,
3
))
verify
((
3
,
4
,
3
),
[
1
,
1
],
[
4
,
4
,
3
],
None
,
(
2
,
3
,
3
))
if
__name__
==
"__main__"
:
test_strided_slice
()
test_strided_set
()
test_binary_op
()
test_cmp_type
()
test_binary_int_broadcast
()
...
...
topi/python/topi/testing/__init__.py
View file @
6d88c987
...
...
@@ -37,7 +37,7 @@ from .roi_pool_python import roi_pool_nchw_python
from
.lrn_python
import
lrn_python
from
.l2_normalize_python
import
l2_normalize_python
from
.gather_nd_python
import
gather_nd_python
from
.strided_slice_python
import
strided_slice_python
from
.strided_slice_python
import
strided_slice_python
,
strided_set_python
from
.batch_matmul
import
batch_matmul
from
.slice_axis_python
import
slice_axis_python
from
.sequence_mask_python
import
sequence_mask
...
...
topi/python/topi/testing/strided_slice_python.py
View file @
6d88c987
...
...
@@ -14,7 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""gather_nd in python"""
"""strided_slice/set in python"""
def
strided_slice_python
(
data
,
begin
,
end
,
strides
):
"""Python version of strided slice operator.
...
...
@@ -46,3 +47,40 @@ def strided_slice_python(data, begin, end, strides):
end
[
i
]
if
i
<
len
(
end
)
else
None
,
strides
[
i
]
if
i
<
len
(
strides
)
else
None
))
return
data
[
tuple
(
slices
)]
def
strided_set_python
(
data
,
v
,
begin
,
end
,
strides
):
"""Python version of strided slice operator.
Parameters
----------
data : numpy.ndarray
Input data
v : numpy.ndarray
Value data
begin : list
Begining of the slices.
end : list
End of the slices.
strides : list
The stride of each slice.
Returns
-------
result : numpy.ndarray
The updated result.
"""
strides
=
[]
if
strides
is
None
else
strides
slices
=
[]
res
=
data
.
copy
()
for
i
in
range
(
len
(
data
.
shape
)):
slices
.
append
(
slice
(
begin
[
i
]
if
i
<
len
(
begin
)
else
None
,
end
[
i
]
if
i
<
len
(
end
)
else
None
,
strides
[
i
]
if
i
<
len
(
strides
)
else
None
))
res
[
tuple
(
slices
)]
=
v
return
res
topi/python/topi/transform.py
View file @
6d88c987
...
...
@@ -20,6 +20,8 @@ from __future__ import absolute_import as _abs
import
tvm
import
topi
from
.
import
cpp
from
.
import
tag
from
.util
import
within_index
,
make_idx
def
expand_dims
(
a
,
axis
,
num_newaxis
=
1
):
...
...
@@ -155,6 +157,97 @@ def strided_slice(a, begin, end, strides=None):
strides
=
[]
return
cpp
.
strided_slice
(
a
,
begin
,
end
,
strides
)
@tvm.tag_scope
(
tag
=
tag
.
INJECTIVE
+
",strided_set"
)
def
strided_set
(
a
,
v
,
begin
,
end
,
strides
=
None
):
"""Set slice of an array.
Parameters
----------
a : tvm.Tensor
The tensor to be sliced.
v : tvm.Tensor
The values to set
begin: tvm.Tensor
The indices to begin with in the slicing.
end: tvm.Tensor
Indicies indicating end of the slice.
strides: tvm.Tensor, optional
Specifies the stride values, it can be negative
in that case, the input tensor will be reversed
in that particular axis.
Returns
-------
ret : tvm.Tensor
"""
n
=
len
(
a
.
shape
)
if
len
(
begin
.
shape
)
!=
1
:
raise
ValueError
(
"begin should be a vector"
)
if
not
begin
.
dtype
==
'int32'
:
raise
TypeError
(
"begin should be int32"
)
if
len
(
end
.
shape
)
!=
1
:
raise
ValueError
(
"end should be a vector"
)
if
not
end
.
dtype
==
'int32'
:
raise
TypeError
(
"end should be int32"
)
if
strides
is
not
None
:
if
len
(
strides
.
shape
)
!=
1
:
raise
ValueError
(
"strides should be a vector"
)
if
not
strides
.
dtype
==
'int32'
:
raise
TypeError
(
"strides should be int32"
)
def
_max
(
a
,
b
):
return
tvm
.
expr
.
Select
(
a
>
b
,
a
,
b
)
if
strides
is
None
:
strides
=
[
tvm
.
const
(
1
,
'int32'
)]
*
n
else
:
strides
=
[
tvm
.
if_then_else
(
strides
.
shape
[
0
]
>
i
,
strides
[
i
],
tvm
.
const
(
1
,
'int32'
))
for
i
in
range
(
n
)]
begin
=
[
tvm
.
if_then_else
(
begin
.
shape
[
0
]
>
i
,
begin
[
i
],
tvm
.
expr
.
Select
(
strides
[
i
]
>
0
,
tvm
.
const
(
0
,
'int32'
),
a
.
shape
[
i
]))
for
i
in
range
(
n
)]
end
=
[
tvm
.
if_then_else
(
end
.
shape
[
0
]
>
i
,
end
[
i
],
tvm
.
expr
.
Select
(
strides
[
i
]
>
0
,
a
.
shape
[
i
]
+
1
,
-
(
a
.
shape
[
i
]
+
1
)))
for
i
in
range
(
n
)]
# Convert negative indexes
for
i
in
range
(
n
):
begin
[
i
]
=
tvm
.
if_then_else
(
begin
[
i
]
<
0
,
begin
[
i
]
+
a
.
shape
[
i
],
begin
[
i
])
end
[
i
]
=
tvm
.
if_then_else
(
end
[
i
]
<
0
,
end
[
i
]
+
a
.
shape
[
i
],
end
[
i
])
def
_select
(
*
indices
):
from_val
=
[]
index_tuple
=
[]
for
i
in
range
(
n
):
from_val
.
append
(
within_index
(
begin
[
i
],
end
[
i
],
strides
[
i
],
indices
[
i
]))
index_tuple
.
append
(
make_idx
(
begin
[
i
],
end
[
i
],
strides
[
i
],
a
.
shape
[
i
],
indices
[
i
]))
return
tvm
.
if_then_else
(
tvm
.
all
(
*
from_val
),
v
(
*
index_tuple
),
a
(
*
indices
))
return
tvm
.
compute
(
a
.
shape
,
_select
,
name
=
"strided_set"
)
def
reshape
(
a
,
newshape
):
"""Reshape the array
...
...
topi/python/topi/util.py
View file @
6d88c987
...
...
@@ -345,3 +345,75 @@ def get_shape(src_shape, src_layout, dst_layout):
tvm
.
convert
([
i
for
i
in
range
(
len
(
src_layout
))]))
return
get_const_tuple
(
tuple
([
src_shape
[
i
.
value
]
for
i
in
dst_indices
]))
def
within_index
(
b
,
e
,
s
,
i
):
"""Return a boolean value that indicates if i is within the given index.
Parameter
---------
b : Expr
beginning of the index
e : Expr
end of the index
s : Expr
strides of index
i : Expr
array position
Returns
-------
selected: Expr
bool expression that is True is the array position would be selected
by the index and False otherwise
"""
bc
=
tvm
.
expr
.
Select
(
s
<
0
,
i
<=
e
,
i
<
b
)
ec
=
tvm
.
expr
.
Select
(
s
<
0
,
i
>
b
,
i
>=
e
)
ss
=
tvm
.
if_then_else
(
s
<
0
,
((
i
-
e
)
+
(
e
%
tvm
.
abs
(
s
))
+
1
)
%
tvm
.
abs
(
s
),
(
i
-
b
)
%
s
)
return
tvm
.
expr
.
Select
(
tvm
.
expr
.
Or
(
bc
,
ec
),
tvm
.
const
(
False
),
ss
.
equal
(
0
))
def
make_idx
(
b
,
e
,
s
,
z
,
i
):
"""Return the array position in the selection that corresponds to an
array position in the full array.
The returned value is only meaningful if within_index() returns True
for the same set of parameters.
Parameter
---------
b : Expr
beginning of the index
e : Expr
end of the index
s : Expr
strides of index
z : Expr
size of the indexed dimension
i : Expr
array position
Returns
-------
postion: Expr
int expression that corresponds to an array position in the selection.
"""
bc
=
tvm
.
expr
.
Select
(
s
<
0
,
i
<=
e
,
i
<
b
)
ec
=
tvm
.
expr
.
Select
(
s
<
0
,
i
>
b
,
i
>=
e
)
# Clamp to array size
b
=
tvm
.
expr
.
Select
(
z
<
b
,
z
-
1
,
b
)
ss
=
tvm
.
if_then_else
(
s
<
0
,
(
b
-
i
)
//
tvm
.
abs
(
s
),
(
i
-
b
)
//
s
)
return
tvm
.
if_then_else
(
tvm
.
expr
.
Or
(
bc
,
ec
),
88
,
ss
)
topi/tests/python/test_topi_transform.py
View file @
6d88c987
...
...
@@ -342,6 +342,52 @@ def verify_strided_slice(in_shape, begin, end, strides=None):
for
device
in
[
"llvm"
,
"opencl"
,
"sdaccel"
,
"aocl_sw_emu"
]:
check_device
(
device
)
def
verify_strided_set
(
in_shape
,
v_shape
,
begin
,
end
,
strides
=
None
):
A
=
tvm
.
placeholder
(
shape
=
in_shape
,
name
=
"A"
)
V
=
tvm
.
placeholder
(
shape
=
v_shape
,
name
=
"V"
)
b
=
tvm
.
placeholder
(
shape
=
(
len
(
begin
),),
name
=
"b"
,
dtype
=
'int32'
)
e
=
tvm
.
placeholder
(
shape
=
(
len
(
end
),),
name
=
"e"
,
dtype
=
'int32'
)
if
strides
is
not
None
:
st
=
tvm
.
placeholder
(
shape
=
(
len
(
strides
),),
name
=
"st"
,
dtype
=
'int32'
)
B
=
topi
.
strided_set
(
A
,
V
,
b
,
e
,
st
)
+
1
else
:
B
=
topi
.
strided_set
(
A
,
V
,
b
,
e
)
+
1
def
check_device
(
device
):
ctx
=
tvm
.
context
(
device
,
0
)
if
not
ctx
.
exist
:
print
(
"Skip because
%
s is not enabled"
%
device
)
return
print
(
"Running on target:
%
s"
%
device
)
with
tvm
.
target
.
create
(
device
):
s
=
topi
.
generic
.
schedule_injective
(
B
)
if
strides
is
not
None
:
foo
=
tvm
.
build
(
s
,
[
A
,
V
,
b
,
e
,
st
,
B
],
device
,
name
=
"stride_set"
)
s_np
=
np
.
asarray
(
strides
)
.
astype
(
'int32'
)
s_nd
=
tvm
.
nd
.
array
(
s_np
,
ctx
)
else
:
foo
=
tvm
.
build
(
s
,
[
A
,
V
,
b
,
e
,
B
],
device
,
name
=
"stride_set"
)
x_np
=
np
.
random
.
uniform
(
size
=
in_shape
)
.
astype
(
A
.
dtype
)
v_np
=
np
.
random
.
uniform
(
size
=
v_shape
)
.
astype
(
V
.
dtype
)
b_np
=
np
.
asarray
(
begin
)
.
astype
(
'int32'
)
e_np
=
np
.
asarray
(
end
)
.
astype
(
'int32'
)
out_npy
=
topi
.
testing
.
strided_set_python
(
x_np
,
v_np
,
begin
,
end
,
strides
)
+
1
data_nd
=
tvm
.
nd
.
array
(
x_np
,
ctx
)
v_nd
=
tvm
.
nd
.
array
(
v_np
,
ctx
)
b_nd
=
tvm
.
nd
.
array
(
b_np
,
ctx
)
e_nd
=
tvm
.
nd
.
array
(
e_np
,
ctx
)
out_nd
=
tvm
.
nd
.
empty
(
out_npy
.
shape
,
ctx
=
ctx
,
dtype
=
A
.
dtype
)
if
strides
is
not
None
:
foo
(
data_nd
,
v_nd
,
b_nd
,
e_nd
,
s_nd
,
out_nd
)
else
:
foo
(
data_nd
,
v_nd
,
b_nd
,
e_nd
,
out_nd
)
tvm
.
testing
.
assert_allclose
(
out_nd
.
asnumpy
(),
out_npy
)
for
device
in
[
"llvm"
,
"opencl"
,
"sdaccel"
,
"aocl_sw_emu"
]:
check_device
(
device
)
def
verify_gather_nd
(
src_shape
,
indices_src
,
indices_dtype
):
src_dtype
=
"float32"
indices_src
=
np
.
array
(
indices_src
,
dtype
=
indices_dtype
)
...
...
@@ -510,6 +556,17 @@ def test_strided_slice():
verify_strided_slice
((
3
,
4
,
3
),
[
1
,
-
1
,
0
],
[
2
,
-
3
,
3
],
[
1
,
-
1
,
1
])
verify_strided_slice
((
3
,
4
,
3
),
[
1
,
1
,
0
],
[
4
,
4
,
3
])
def
test_strided_set
():
verify_strided_set
((
3
,
4
,
3
),
(
3
,
2
,
2
),
[
0
,
3
,
0
],
[
4
,
1
,
4
],
[
1
,
-
1
,
2
])
verify_strided_set
((
3
,
4
,
3
),
(
3
,
1
,
2
),
[
0
,
0
,
0
],
[
4
,
-
5
,
4
],
[
1
,
-
1
,
2
])
verify_strided_set
((
3
,
4
,
3
),
(
1
,
3
,
3
),
[
1
,
1
,
0
],
[
4
,
4
,
3
],
[
2
,
1
,
1
])
verify_strided_set
((
3
,
4
,
3
),
(
1
,
4
,
3
),
[
1
,
-
1
,
0
],
[
4
,
-
5
,
3
],
[
2
,
-
1
,
1
])
verify_strided_set
((
3
,
4
,
3
),
(
1
,
2
,
2
),
[
1
,
0
,
0
],
[
2
,
2
,
3
],
[
1
,
1
,
2
])
verify_strided_set
((
3
,
4
,
3
),
(
1
,
2
,
3
),
[
1
,
-
1
,
0
],
[
2
,
-
3
,
3
],
[
1
,
-
1
,
1
])
verify_strided_set
((
3
,
4
,
3
),
(
1
,
2
,
3
),
[
1
,
1
,
0
],
[
2
,
3
,
3
],
[
1
])
verify_strided_set
((
3
,
4
,
3
),
(
2
,
3
,
3
),
[
1
,
1
,
0
],
[
4
,
4
,
3
])
verify_strided_set
((
3
,
4
,
3
),
(
2
,
3
,
3
),
[
1
,
1
],
[
4
,
4
,
3
])
def
test_expand_dims
():
verify_expand_dims
((
3
,
10
),
(
3
,
10
,
1
,
1
),
2
,
2
)
verify_expand_dims
((
3
,
10
),
(
1
,
3
,
10
),
-
3
,
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment