Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
71f88611
Commit
71f88611
authored
Jun 26, 2018
by
Yao Wang
Committed by
Tianqi Chen
Jun 26, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Improve schedule load, add slice_like (#1299)
parent
84bd230c
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
269 additions
and
25 deletions
+269
-25
nnvm/include/nnvm/top/tensor.h
+10
-0
nnvm/python/nnvm/frontend/mxnet.py
+18
-1
nnvm/python/nnvm/top/nn.py
+8
-2
nnvm/python/nnvm/top/transform.py
+4
-0
nnvm/src/top/elemwise_op_common.h
+1
-1
nnvm/src/top/tensor/transform.cc
+101
-0
nnvm/tests/python/compiler/test_top_level4.py
+55
-0
topi/python/topi/generic/nn.py
+24
-13
topi/python/topi/nn/conv2d.py
+30
-1
topi/python/topi/x86/conv2d.py
+17
-7
topi/tests/python_cpp/test_topi_transform.py
+1
-0
No files found.
nnvm/include/nnvm/top/tensor.h
View file @
71f88611
...
...
@@ -259,6 +259,16 @@ struct ClipParam : public dmlc::Parameter<ClipParam> {
}
};
struct
SliceLikeParam
:
public
dmlc
::
Parameter
<
SliceLikeParam
>
{
Tuple
<
int
>
axis
;
DMLC_DECLARE_PARAMETER
(
SliceLikeParam
)
{
DMLC_DECLARE_FIELD
(
axis
).
set_default
(
Tuple
<
int
>
())
.
describe
(
"List of axes on which input data will be sliced according to the "
"corresponding size of the second input. By default will slice "
"on all axes. Negative axes are supported."
);
}
};
}
// namespace top
}
// namespace nnvm
...
...
nnvm/python/nnvm/frontend/mxnet.py
View file @
71f88611
...
...
@@ -240,6 +240,21 @@ def _elemwise_sum(inputs, _):
new_attrs
=
{
'num_args'
:
len
(
inputs
)}
return
_get_nnvm_op
(
'elemwise_sum'
)(
*
inputs
,
**
new_attrs
)
def
_crop_like
(
inputs
,
attrs
):
new_attrs
=
{}
offsets
=
\
tuple
([
float
(
x
.
strip
())
for
x
in
attrs
.
get
(
'offsets'
)
.
strip
(
'()'
)
.
split
(
','
)])
\
if
attrs
.
get
(
'offsets'
)
is
not
None
else
(
0
,
0
)
if
offsets
!=
(
0
,
0
):
raise
RuntimeError
(
"Currently only supports offsets to be zero."
)
center_crop
=
_parse_bool_str
(
attrs
,
'center_crop'
,
default
=
"False"
)
if
center_crop
:
raise
RuntimeError
(
"center crop is not supported."
)
if
len
(
inputs
)
<
2
:
raise
RuntimeError
(
"Only support crop_like pattern."
)
new_attrs
[
"axis"
]
=
[
2
,
3
]
return
_get_nnvm_op
(
'slice_like'
)(
inputs
[
0
],
inputs
[
1
],
**
new_attrs
)
def
_expand_dims
(
inputs
,
attrs
):
op_name
,
new_attrs
=
"expand_dims"
,
{}
...
...
@@ -255,7 +270,8 @@ _identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__',
'broadcast_sub'
,
'broadcast_to'
,
'cast'
,
'elemwise_add'
,
'elemwise_div'
,
'elemwise_mul'
,
'elemwise_sub'
,
'exp'
,
'flatten'
,
'log'
,
'log_softmax'
,
'max'
,
'min'
,
'negative'
,
'relu'
,
'sigmoid'
,
'softmax'
,
'sum'
,
'tanh'
,
'transpose'
]
'relu'
,
'sigmoid'
,
'slice_like'
,
'softmax'
,
'sum'
,
'tanh'
,
'transpose'
]
_convert_map
=
{
'_copy'
:
_rename
(
'copy'
),
...
...
@@ -274,6 +290,7 @@ _convert_map = {
'Concat'
:
_concat
,
'Convolution'
:
_conv2d
,
'Convolution_v1'
:
_conv2d
,
'Crop'
:
_crop_like
,
'Deconvolution'
:
_conv2d_transpose
,
'Dropout'
:
_dropout
,
'Flatten'
:
_rename
(
'flatten'
),
...
...
nnvm/python/nnvm/top/nn.py
View file @
71f88611
...
...
@@ -155,10 +155,13 @@ def compute_contrib_conv2d_NCHWc(attrs, inputs, _):
kh
,
kw
=
attrs
.
get_int_tuple
(
'kernel_size'
)
groups
=
attrs
.
get_int
(
"groups"
)
channels
=
attrs
.
get_int
(
"channels"
)
layout
=
attrs
.
get_string
(
"layout"
)
out_layout
=
attrs
.
get_string
(
"out_layout"
)
assert
dilation
==
(
1
,
1
),
"not support dilate now"
if
groups
==
1
:
# pylint: disable=assignment-from-no-return
out
=
topi
.
nn
.
conv2d_NCHWc
(
inputs
[
0
],
inputs
[
1
],
channels
,
(
kh
,
kw
),
strides
,
padding
)
out
=
topi
.
nn
.
conv2d_NCHWc
(
inputs
[
0
],
inputs
[
1
],
channels
,
(
kh
,
kw
),
strides
,
padding
,
layout
,
out_layout
)
# pylint: enable=assignment-from-no-return
else
:
raise
ValueError
(
"not support arbitrary group number > 1 for now"
)
...
...
@@ -176,9 +179,12 @@ def schedule_contrib_conv2d_NCHWc(attrs, outs, target):
oc
=
attrs
.
get_int
(
"channels"
)
padding
=
attrs
.
get_int_tuple
(
"padding"
)
strides
=
attrs
.
get_int_tuple
(
"strides"
)
layout
=
attrs
.
get_string
(
"layout"
)
out_layout
=
attrs
.
get_string
(
"out_layout"
)
with
tvm
.
target
.
create
(
target
):
if
groups
==
1
:
return
topi
.
generic
.
schedule_conv2d_NCHWc
(
oc
,
(
kh
,
kw
),
strides
,
padding
,
outs
)
return
topi
.
generic
.
schedule_conv2d_NCHWc
(
oc
,
(
kh
,
kw
),
strides
,
padding
,
layout
,
out_layout
,
outs
)
else
:
raise
ValueError
(
"not support group number > 1 for now"
)
...
...
nnvm/python/nnvm/top/transform.py
View file @
71f88611
...
...
@@ -60,3 +60,7 @@ reg.register_schedule("concatenate", _fschedule_injective)
# split
reg
.
register_pattern
(
"split"
,
OpPattern
.
INJECTIVE
)
reg
.
register_schedule
(
"split"
,
_fschedule_injective
)
# slice_like
reg
.
register_pattern
(
"slice_like"
,
OpPattern
.
INJECTIVE
)
reg
.
register_schedule
(
"slice_like"
,
_fschedule_injective
)
nnvm/src/top/elemwise_op_common.h
View file @
71f88611
...
...
@@ -320,7 +320,7 @@ inline bool ElemwiseBinaryKeepLeftLayout(const NodeAttrs& attrs,
.set_attr<nnvm::FInferShape>("FInferShape", \
ElementWiseReduceShape) \
.set_attr<FCorrectLayout>("FCorrectLayout", \
ElemwiseFixedLayoutCopyToOut<1, 1>) \
ElemwiseFixedLayoutCopyToOut<
-
1, 1>) \
.set_attr<nnvm::FInferType>("FInferType", ElementWiseReduceType) \
.add_argument("args", "Symbol[]", "Positional input arguments")
...
...
nnvm/src/top/tensor/transform.cc
View file @
71f88611
...
...
@@ -15,6 +15,7 @@
#include "../elemwise_op_common.h"
#include "topi/nn/flatten.h"
#include "topi/transform.h"
#include "topi/detail/constant_utils.h"
namespace
nnvm
{
namespace
top
{
...
...
@@ -877,5 +878,105 @@ Examples::
return
Array
<
Tensor
>
{
topi
::
flip
(
inputs
[
0
],
param
.
axis
)
};
});
// SliceLike
DMLC_REGISTER_PARAMETER
(
SliceLikeParam
);
inline
bool
SliceLikeShape
(
const
nnvm
::
NodeAttrs
&
attrs
,
std
::
vector
<
TShape
>*
in_attrs
,
std
::
vector
<
TShape
>*
out_attrs
)
{
CHECK_EQ
(
in_attrs
->
size
(),
2U
);
CHECK_EQ
(
out_attrs
->
size
(),
1U
);
const
SliceLikeParam
&
param
=
nnvm
::
get
<
SliceLikeParam
>
(
attrs
.
parsed
);
const
TShape
&
src_shape
=
in_attrs
->
at
(
0
);
const
TShape
&
target_shape
=
in_attrs
->
at
(
1
);
Tuple
<
dim_t
>
end_idx
;
end_idx
=
Tuple
<
dim_t
>
(
src_shape
);
if
(
param
.
axis
.
ndim
()
==
0
)
{
for
(
size_t
i
=
0
;
i
<
src_shape
.
ndim
();
++
i
)
{
if
(
i
<
target_shape
.
ndim
())
{
end_idx
[
i
]
=
target_shape
[
i
];
CHECK_LE
(
end_idx
[
i
],
src_shape
[
i
])
<<
"End index of axis "
<<
i
<<
" exceeds input shape: "
<<
end_idx
[
i
]
<<
" vs "
<<
src_shape
[
i
];
}
}
}
else
{
for
(
auto
i
:
param
.
axis
)
{
if
(
i
<
0
)
{
i
=
src_shape
.
ndim
()
+
i
;
}
CHECK_LT
(
i
,
target_shape
.
ndim
())
<<
"Axis "
<<
i
<<
" exceeds dimension "
<<
target_shape
.
ndim
()
<<
" of target_shape."
;
end_idx
[
i
]
=
target_shape
[
i
];
CHECK_LE
(
end_idx
[
i
],
src_shape
[
i
])
<<
"End index of axis "
<<
i
<<
" exceeds input shape: "
<<
end_idx
[
i
]
<<
" vs "
<<
src_shape
[
i
];
}
}
TShape
out_shape
=
TShape
(
std
::
move
(
end_idx
));
NNVM_ASSIGN_OUTPUT_SHAPE
(
attrs
,
*
out_attrs
,
0
,
out_shape
);
return
true
;
}
NNVM_REGISTER_OP
(
slice_like
)
.
describe
(
R"code(Slice the first input respect to the second input.
)code"
NNVM_ADD_FILELINE
)
.
add_argument
(
"data"
,
"Tensor"
,
"Input data to be sliced."
)
.
add_argument
(
"slice_like"
,
"Tensor"
,
"Tensor with target shape"
)
.
set_num_inputs
(
2
)
.
set_num_outputs
(
1
)
.
add_arguments
(
SliceLikeParam
::
__FIELDS__
())
.
set_attr_parser
(
ParamParser
<
SliceLikeParam
>
)
.
set_attr
<
FGetAttrDict
>
(
"FGetAttrDict"
,
ParamGetAttrDict
<
SliceLikeParam
>
)
.
set_attr
<
FInferShape
>
(
"FInferShape"
,
SliceLikeShape
)
.
set_attr
<
FInferType
>
(
"FInferType"
,
ElemwiseType
<
2
,
1
>
)
.
set_attr
<
FCorrectLayout
>
(
"FCorrectLayout"
,
ElemwiseBinaryKeepLeftLayout
)
.
set_attr
<
FTVMCompute
>
(
"FTVMCompute"
,
[](
const
NodeAttrs
&
attrs
,
const
Array
<
Tensor
>&
inputs
,
const
Array
<
Tensor
>&
out_info
)
{
const
auto
&
param
=
nnvm
::
get
<
SliceLikeParam
>
(
attrs
.
parsed
);
Array
<
Expr
>
src_shape
=
inputs
[
0
]
->
shape
;
Array
<
Expr
>
target_shape
=
inputs
[
1
]
->
shape
;
Array
<
Expr
>
begin_idx
,
end_idx
,
strides
;
for
(
size_t
i
=
0
;
i
<
src_shape
.
size
();
++
i
)
{
begin_idx
.
push_back
(
make_const
(
tvm
::
Int
(
32
),
0
));
strides
.
push_back
(
make_const
(
tvm
::
Int
(
32
),
1
));
}
end_idx
=
Array
<
Expr
>
(
src_shape
);
if
(
param
.
axis
.
ndim
()
==
0
)
{
for
(
size_t
i
=
0
;
i
<
src_shape
.
size
();
++
i
)
{
if
(
i
<
target_shape
.
size
())
{
end_idx
.
Set
(
i
,
target_shape
[
i
]);
CHECK_LE
(
topi
::
GetConstInt
(
end_idx
[
i
]),
topi
::
GetConstInt
(
src_shape
[
i
]))
<<
"End index of axis "
<<
i
<<
" exceeds input shape: "
<<
topi
::
GetConstInt
(
end_idx
[
i
])
<<
" vs "
<<
topi
::
GetConstInt
(
src_shape
[
i
]);
}
}
}
else
{
for
(
int
axis
:
param
.
axis
)
{
if
(
axis
<
0
)
{
axis
=
static_cast
<
int
>
(
src_shape
.
size
())
+
axis
;
}
end_idx
.
Set
(
static_cast
<
size_t
>
(
axis
),
target_shape
[
axis
]);
CHECK_LE
(
topi
::
GetConstInt
(
end_idx
[
axis
]),
topi
::
GetConstInt
(
src_shape
[
axis
]))
<<
"End index of axis "
<<
axis
<<
" exceeds input shape: "
<<
topi
::
GetConstInt
(
end_idx
[
axis
])
<<
" vs "
<<
topi
::
GetConstInt
(
src_shape
[
axis
]);
}
}
return
Array
<
Tensor
>
{
topi
::
strided_slice
(
inputs
[
0
],
begin_idx
,
end_idx
,
strides
)
};
})
.
set_attr
<
FListInputNames
>
(
"FListInputNames"
,
[](
const
NodeAttrs
&
attrs
)
{
return
std
::
vector
<
std
::
string
>
{
"data"
,
"slice_like"
};
})
.
set_support_level
(
4
);
}
// namespace top
}
// namespace nnvm
nnvm/tests/python/compiler/test_top_level4.py
View file @
71f88611
...
...
@@ -541,6 +541,60 @@ def test_nms():
out
=
m
.
get_output
(
0
,
tvm
.
nd
.
empty
(
np_result
.
shape
,
"float32"
))
np
.
testing
.
assert_allclose
(
out
.
asnumpy
(),
np_result
,
atol
=
1e-5
,
rtol
=
1e-5
)
def
np_slice_like
(
np_data
,
np_shape_like
,
axis
=
[]):
begin_idx
=
[
0
for
_
in
np_data
.
shape
]
end_idx
=
list
(
np_data
.
shape
)
if
len
(
axis
)
>
0
:
for
i
in
axis
:
if
i
<
0
:
i
=
len
(
np_data
.
shape
)
+
i
end_idx
[
i
]
=
np_shape_like
.
shape
[
i
]
else
:
for
i
in
range
(
len
(
np_data
.
shape
)):
if
i
<
len
(
np_shape_like
.
shape
):
end_idx
[
i
]
=
np_shape_like
.
shape
[
i
]
slice_idx
=
[]
for
b
,
e
in
zip
(
begin_idx
,
end_idx
):
slice_idx
.
append
(
slice
(
b
,
e
))
np_result
=
np_data
[
slice_idx
]
return
np_result
def
verify_slice_like
(
np_data
,
np_shape_like
,
axis
=
[]):
dtype
=
"float32"
np_data
=
np_data
.
astype
(
dtype
)
np_shape_like
=
np_shape_like
.
astype
(
dtype
)
np_result
=
np_slice_like
(
np_data
,
np_shape_like
,
axis
)
data1
=
sym
.
Variable
(
"data1"
)
data2
=
sym
.
Variable
(
"data2"
)
net
=
sym
.
slice_like
(
data
=
data1
,
slice_like
=
data2
,
axis
=
axis
)
for
target
,
ctx
in
ctx_list
():
graph
,
lib
,
_
=
nnvm
.
compiler
.
build
(
net
,
target
,
{
"data1"
:
np_data
.
shape
,
"data2"
:
np_shape_like
.
shape
})
m
=
graph_runtime
.
create
(
graph
,
lib
,
ctx
)
m
.
set_input
(
**
{
"data1"
:
np_data
,
"data2"
:
np_shape_like
})
m
.
run
()
out
=
m
.
get_output
(
0
,
tvm
.
nd
.
empty
(
np_result
.
shape
,
dtype
))
np
.
testing
.
assert_allclose
(
out
.
asnumpy
(),
np_result
,
atol
=
1e-5
,
rtol
=
1e-5
)
def
test_slice_like
():
np_data
=
np
.
random
.
uniform
(
size
=
(
3
,
4
,
5
))
np_shape_like
=
np
.
random
.
uniform
(
size
=
(
1
,
2
,
3
))
verify_slice_like
(
np_data
,
np_shape_like
)
np_data
=
np
.
random
.
uniform
(
size
=
(
3
,
4
,
5
))
np_shape_like
=
np
.
random
.
uniform
(
size
=
(
1
,
2
))
verify_slice_like
(
np_data
,
np_shape_like
)
np_data
=
np
.
random
.
uniform
(
size
=
(
3
,
4
,
5
))
np_shape_like
=
np
.
random
.
uniform
(
size
=
(
1
,
2
,
3
))
axis
=
(
1
,
2
)
verify_slice_like
(
np_data
,
np_shape_like
,
axis
)
np_data
=
np
.
random
.
uniform
(
size
=
(
3
,
4
,
5
))
np_shape_like
=
np
.
random
.
uniform
(
size
=
(
1
,
2
,
3
))
axis
=
(
-
1
,
-
3
)
verify_slice_like
(
np_data
,
np_shape_like
,
axis
)
np_data
=
np
.
random
.
uniform
(
size
=
(
1
,
3
,
224
,
224
))
np_shape_like
=
np
.
random
.
uniform
(
size
=
(
1
,
3
,
112
,
112
))
axis
=
(
2
,
3
)
verify_slice_like
(
np_data
,
np_shape_like
,
axis
)
if
__name__
==
"__main__"
:
...
...
@@ -561,4 +615,5 @@ if __name__ == "__main__":
test_multibox_prior
()
test_multibox_transform_loc
()
test_nms
()
test_slice_like
()
print
(
nnvm
.
compiler
.
engine
.
dump
())
topi/python/topi/generic/nn.py
View file @
71f88611
...
...
@@ -55,26 +55,37 @@ def schedule_conv2d_nhwc(outs):
@tvm.target.generic_func
def
schedule_conv2d_NCHWc
(
num_filter
,
kernel_size
,
strides
,
padding
,
outs
):
def
schedule_conv2d_NCHWc
(
num_filter
,
kernel_size
,
strides
,
padding
,
layout
,
out_layout
,
outs
):
"""Schedule for conv2d_NCHW[x]c
Parameters
----------
num_filter: int
The number of filter, i.e., the output channel.
kernel_size: tuple of int
(kernel_height, kernel_width)
strides: tuple of int
(stride_of_height, stride_of_width)
padding: tuple of int
(pad_of_height, pad_of_width)
outs: Array of Tensor
The computation graph description of conv2d_NCHWc
in the format of an array of tensors.
num_filter : int
The number of filter, i.e., the output channel.
kernel_size : tuple of int
(kernel_height, kernel_width)
strides : tuple of int
(stride_of_height, stride_of_width)
padding : tuple of int
(pad_of_height, pad_of_width)
layout : str
Input data layout
out_layout : str
Output data layout
outs : Array of Tensor
The computation graph description of conv2d_NCHWc
in the format of an array of tensors.
Returns
-------
sch: Schedule
sch
: Schedule
The computation schedule for the op.
"""
return
_default_schedule
(
outs
,
False
)
...
...
topi/python/topi/nn/conv2d.py
View file @
71f88611
...
...
@@ -145,6 +145,17 @@ def _get_workload(data, kernel, stride, padding, out_dtype):
@tvm.target.generic_func
def
_get_alter_layout_schedule
(
wkl
):
# pylint: disable=unreachable
""" Get the platform specific schedule for conv2d_alter_layout. """
target
=
tvm
.
target
.
current_target
()
raise
RuntimeError
(
"No schedule for current target:{}"
.
format
(
target
))
# This return has no use, merely to supress pylint warning
return
wkl
@tvm.target.generic_func
def
_get_schedule
(
wkl
):
# pylint: disable=unreachable
""" Get the platform specific schedule. """
...
...
@@ -155,6 +166,17 @@ def _get_schedule(wkl):
return
wkl
@tvm.target.generic_func
def
_get_schedule_NCHWc
(
wkl
,
layout
,
out_layout
):
# pylint: disable=unreachable
""" Get the platform specific schedule. """
target
=
tvm
.
target
.
current_target
()
raise
RuntimeError
(
"No schedule for current target:{}"
.
format
(
target
))
# This return has no use, merely to supress pylint warning
return
wkl
def
_spatial_pack
(
data
,
kernel
,
stride
,
padding
,
out_dtype
=
None
):
""" Compute convolution with pack on spatial axes. """
if
out_dtype
is
None
:
...
...
@@ -443,7 +465,8 @@ def conv2d_nhwc(Input, Filter, stride, padding, out_dtype='float32'):
return
Output
@tvm.target.generic_func
def
conv2d_NCHWc
(
data
,
kernel
,
num_filter
,
kernel_size
,
stride
,
padding
,
out_dtype
=
'float32'
):
def
conv2d_NCHWc
(
data
,
kernel
,
num_filter
,
kernel_size
,
stride
,
padding
,
layout
,
out_layout
,
out_dtype
=
'float32'
):
"""Conv2D operator for nChw[x]c layout.
Parameters
...
...
@@ -468,6 +491,12 @@ def conv2d_NCHWc(data, kernel, num_filter, kernel_size, stride, padding, out_dty
padding : int or a list/tuple of two ints
padding size, or [pad_height, pad_width]
layout : str
Input data layout
out_layout : str
Output data layout
out_dtype : str
output data type
...
...
topi/python/topi/x86/conv2d.py
View file @
71f88611
# pylint: disable=invalid-name,unused-variable,invalid-name
# pylint: disable=invalid-name,unused-variable,invalid-name
,unused-argument
"""Conv2D schedule on x86"""
import
tvm
from
..
import
generic
,
tag
from
..
import
nn
from
..nn.util
import
infer_pad
,
infer_stride
from
..nn.conv2d
import
conv2d
,
conv2d_NCHWc
,
conv2d_alter_layout
,
\
_get_workload
,
_get_schedule
,
Workload
_get_workload
,
_get_schedule
,
_get_schedule_NCHWc
,
\
_get_alter_layout_schedule
,
Workload
from
.
import
conv2d_avx_1x1
,
conv2d_avx_common
from
.conv2d_avx_common
import
AVXConvCommonFwd
...
...
@@ -99,6 +100,13 @@ def _get_schedule_conv(wkl):
sch
=
_SCHEDULES_AVX
[
idx
]
return
sch
@_get_schedule_NCHWc.register
(
"cpu"
)
def
_get_schedule_NCHWc_x86
(
wkl
,
layout
,
out_layout
):
return
_get_schedule_conv
(
wkl
)
@_get_alter_layout_schedule.register
(
"cpu"
)
def
_get_alter_layout_schedule_x86
(
wkl
):
return
_get_schedule_conv
(
wkl
)
@conv2d.register
(
"cpu"
)
def
_declaration_conv
(
data
,
kernel
,
stride
,
padding
,
layout
,
out_dtype
):
...
...
@@ -139,7 +147,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos):
stride
=
ast
.
literal_eval
(
attrs
[
'strides'
])
wkl
=
_get_workload
(
data
,
kernel
,
stride
,
padding
,
data
.
dtype
)
sch
=
_get_
schedule_conv
(
wkl
)
sch
=
_get_
alter_layout_schedule
(
wkl
)
is_kernel_1x1
=
isinstance
(
sch
,
AVXConv1x1Fwd
)
ic_bn
,
oc_bn
=
sch
.
ic_bn
,
sch
.
oc_bn
...
...
@@ -157,7 +165,8 @@ def _alter_conv2d_layout(attrs, inputs, tinfos):
@conv2d_NCHWc.register
(
"cpu"
)
def
_declaration_conv_NCHWc
(
data
,
kernel
,
num_filter
,
kernel_size
,
stride
,
padding
,
out_dtype
):
def
_declaration_conv_NCHWc
(
data
,
kernel
,
num_filter
,
kernel_size
,
stride
,
padding
,
layout
,
out_layout
,
out_dtype
):
_AVX_SCH_TO_DECL_FUNC
=
{
AVXConvCommonFwd
:
conv2d_avx_common
.
_declaration_conv_NCHWc
,
AVXConv1x1Fwd
:
conv2d_avx_1x1
.
_declaration_conv_NCHWc
...
...
@@ -168,7 +177,7 @@ def _declaration_conv_NCHWc(data, kernel, num_filter, kernel_size, stride, paddi
wkl
=
_get_workload
(
tvm
.
placeholder
((
n
,
ic
,
h
,
w
),
dtype
=
out_dtype
),
tvm
.
placeholder
((
num_filter
,
ic
,
kh
,
kw
),
dtype
=
out_dtype
),
stride
,
padding
,
out_dtype
)
sch
=
_get_schedule
(
wkl
)
sch
=
_get_schedule
_NCHWc
(
wkl
,
layout
,
out_layout
)
return
_AVX_SCH_TO_DECL_FUNC
[
type
(
sch
)](
wkl
,
sch
,
data
,
kernel
)
...
...
@@ -311,7 +320,8 @@ def schedule_conv2d_nhwc(outs):
@generic.schedule_conv2d_NCHWc.register
([
"cpu"
])
def
schedule_conv2d_NCHWc
(
num_filter
,
kernel_size
,
stride
,
padding
,
outs
):
def
schedule_conv2d_NCHWc
(
num_filter
,
kernel_size
,
stride
,
padding
,
layout
,
out_layout
,
outs
):
"""Create schedule for tensors"""
_AVX_SCH_TO_SCH_FUNC
=
{
AVXConvCommonFwd
:
conv2d_avx_common
.
_schedule_conv_NCHWc
,
...
...
@@ -348,7 +358,7 @@ def schedule_conv2d_NCHWc(num_filter, kernel_size, stride, padding, outs):
original_kernel
=
tvm
.
placeholder
((
num_filter
,
ic
,
kh
,
kw
),
dtype
=
conv_out
.
dtype
)
wkl
=
_get_workload
(
original_data
,
original_kernel
,
stride
,
padding
,
conv_out
.
dtype
)
sch
=
_get_schedule
(
wkl
)
sch
=
_get_schedule
_NCHWc
(
wkl
,
layout
,
out_layout
)
_AVX_SCH_TO_SCH_FUNC
[
type
(
sch
)](
s
,
wkl
,
sch
,
data_vec
,
kernel
,
conv_out
,
outs
[
0
])
...
...
topi/tests/python_cpp/test_topi_transform.py
View file @
71f88611
...
...
@@ -271,6 +271,7 @@ def verify_concatenate_broadcast(shapes, axis, rhs_shape):
for
device
in
[
"llvm"
,
"cuda"
,
"opencl"
,
"metal"
,
"rocm"
]:
check_device
(
device
)
def
test_expand_dims
():
verify_expand_dims
((
3
,
10
),
(
3
,
10
,
1
,
1
),
2
,
2
)
verify_expand_dims
((
3
,
10
),
(
1
,
3
,
10
),
-
3
,
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment