Commit 71f88611 by Yao Wang Committed by Tianqi Chen

Improve schedule load, add slice_like (#1299)

parent 84bd230c
...@@ -259,6 +259,16 @@ struct ClipParam : public dmlc::Parameter<ClipParam> { ...@@ -259,6 +259,16 @@ struct ClipParam : public dmlc::Parameter<ClipParam> {
} }
}; };
struct SliceLikeParam : public dmlc::Parameter<SliceLikeParam> {
Tuple<int> axis;
DMLC_DECLARE_PARAMETER(SliceLikeParam) {
DMLC_DECLARE_FIELD(axis).set_default(Tuple<int>())
.describe("List of axes on which input data will be sliced according to the "
"corresponding size of the second input. By default will slice "
"on all axes. Negative axes are supported.");
}
};
} // namespace top } // namespace top
} // namespace nnvm } // namespace nnvm
......
...@@ -240,6 +240,21 @@ def _elemwise_sum(inputs, _): ...@@ -240,6 +240,21 @@ def _elemwise_sum(inputs, _):
new_attrs = {'num_args':len(inputs)} new_attrs = {'num_args':len(inputs)}
return _get_nnvm_op('elemwise_sum')(*inputs, **new_attrs) return _get_nnvm_op('elemwise_sum')(*inputs, **new_attrs)
def _crop_like(inputs, attrs):
new_attrs = {}
offsets = \
tuple([float(x.strip()) for x in attrs.get('offsets').strip('()').split(',')]) \
if attrs.get('offsets') is not None else (0, 0)
if offsets != (0, 0):
raise RuntimeError("Currently only supports offsets to be zero.")
center_crop = _parse_bool_str(attrs, 'center_crop', default="False")
if center_crop:
raise RuntimeError("center crop is not supported.")
if len(inputs) < 2:
raise RuntimeError("Only support crop_like pattern.")
new_attrs["axis"] = [2, 3]
return _get_nnvm_op('slice_like')(inputs[0], inputs[1], **new_attrs)
def _expand_dims(inputs, attrs): def _expand_dims(inputs, attrs):
op_name, new_attrs = "expand_dims", {} op_name, new_attrs = "expand_dims", {}
...@@ -255,7 +270,8 @@ _identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__', ...@@ -255,7 +270,8 @@ _identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__',
'broadcast_sub', 'broadcast_to', 'cast', 'elemwise_add', 'broadcast_sub', 'broadcast_to', 'cast', 'elemwise_add',
'elemwise_div', 'elemwise_mul', 'elemwise_sub', 'exp', 'elemwise_div', 'elemwise_mul', 'elemwise_sub', 'exp',
'flatten', 'log', 'log_softmax', 'max', 'min', 'negative', 'flatten', 'log', 'log_softmax', 'max', 'min', 'negative',
'relu', 'sigmoid', 'softmax', 'sum', 'tanh', 'transpose'] 'relu', 'sigmoid', 'slice_like', 'softmax', 'sum', 'tanh',
'transpose']
_convert_map = { _convert_map = {
'_copy' : _rename('copy'), '_copy' : _rename('copy'),
...@@ -274,6 +290,7 @@ _convert_map = { ...@@ -274,6 +290,7 @@ _convert_map = {
'Concat' : _concat, 'Concat' : _concat,
'Convolution' : _conv2d, 'Convolution' : _conv2d,
'Convolution_v1': _conv2d, 'Convolution_v1': _conv2d,
'Crop' : _crop_like,
'Deconvolution' : _conv2d_transpose, 'Deconvolution' : _conv2d_transpose,
'Dropout' : _dropout, 'Dropout' : _dropout,
'Flatten' : _rename('flatten'), 'Flatten' : _rename('flatten'),
......
...@@ -155,10 +155,13 @@ def compute_contrib_conv2d_NCHWc(attrs, inputs, _): ...@@ -155,10 +155,13 @@ def compute_contrib_conv2d_NCHWc(attrs, inputs, _):
kh, kw = attrs.get_int_tuple('kernel_size') kh, kw = attrs.get_int_tuple('kernel_size')
groups = attrs.get_int("groups") groups = attrs.get_int("groups")
channels = attrs.get_int("channels") channels = attrs.get_int("channels")
layout = attrs.get_string("layout")
out_layout = attrs.get_string("out_layout")
assert dilation == (1, 1), "not support dilate now" assert dilation == (1, 1), "not support dilate now"
if groups == 1: if groups == 1:
# pylint: disable=assignment-from-no-return # pylint: disable=assignment-from-no-return
out = topi.nn.conv2d_NCHWc(inputs[0], inputs[1], channels, (kh, kw), strides, padding) out = topi.nn.conv2d_NCHWc(inputs[0], inputs[1], channels, (kh, kw),
strides, padding, layout, out_layout)
# pylint: enable=assignment-from-no-return # pylint: enable=assignment-from-no-return
else: else:
raise ValueError("not support arbitrary group number > 1 for now") raise ValueError("not support arbitrary group number > 1 for now")
...@@ -176,9 +179,12 @@ def schedule_contrib_conv2d_NCHWc(attrs, outs, target): ...@@ -176,9 +179,12 @@ def schedule_contrib_conv2d_NCHWc(attrs, outs, target):
oc = attrs.get_int("channels") oc = attrs.get_int("channels")
padding = attrs.get_int_tuple("padding") padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides") strides = attrs.get_int_tuple("strides")
layout = attrs.get_string("layout")
out_layout = attrs.get_string("out_layout")
with tvm.target.create(target): with tvm.target.create(target):
if groups == 1: if groups == 1:
return topi.generic.schedule_conv2d_NCHWc(oc, (kh, kw), strides, padding, outs) return topi.generic.schedule_conv2d_NCHWc(oc, (kh, kw), strides, padding,
layout, out_layout, outs)
else: else:
raise ValueError("not support group number > 1 for now") raise ValueError("not support group number > 1 for now")
......
...@@ -60,3 +60,7 @@ reg.register_schedule("concatenate", _fschedule_injective) ...@@ -60,3 +60,7 @@ reg.register_schedule("concatenate", _fschedule_injective)
# split # split
reg.register_pattern("split", OpPattern.INJECTIVE) reg.register_pattern("split", OpPattern.INJECTIVE)
reg.register_schedule("split", _fschedule_injective) reg.register_schedule("split", _fschedule_injective)
# slice_like
reg.register_pattern("slice_like", OpPattern.INJECTIVE)
reg.register_schedule("slice_like", _fschedule_injective)
...@@ -320,7 +320,7 @@ inline bool ElemwiseBinaryKeepLeftLayout(const NodeAttrs& attrs, ...@@ -320,7 +320,7 @@ inline bool ElemwiseBinaryKeepLeftLayout(const NodeAttrs& attrs,
.set_attr<nnvm::FInferShape>("FInferShape", \ .set_attr<nnvm::FInferShape>("FInferShape", \
ElementWiseReduceShape) \ ElementWiseReduceShape) \
.set_attr<FCorrectLayout>("FCorrectLayout", \ .set_attr<FCorrectLayout>("FCorrectLayout", \
ElemwiseFixedLayoutCopyToOut<1, 1>) \ ElemwiseFixedLayoutCopyToOut<-1, 1>) \
.set_attr<nnvm::FInferType>("FInferType", ElementWiseReduceType) \ .set_attr<nnvm::FInferType>("FInferType", ElementWiseReduceType) \
.add_argument("args", "Symbol[]", "Positional input arguments") .add_argument("args", "Symbol[]", "Positional input arguments")
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "../elemwise_op_common.h" #include "../elemwise_op_common.h"
#include "topi/nn/flatten.h" #include "topi/nn/flatten.h"
#include "topi/transform.h" #include "topi/transform.h"
#include "topi/detail/constant_utils.h"
namespace nnvm { namespace nnvm {
namespace top { namespace top {
...@@ -877,5 +878,105 @@ Examples:: ...@@ -877,5 +878,105 @@ Examples::
return Array<Tensor>{ topi::flip(inputs[0], param.axis) }; return Array<Tensor>{ topi::flip(inputs[0], param.axis) };
}); });
// SliceLike
DMLC_REGISTER_PARAMETER(SliceLikeParam);
inline bool SliceLikeShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape>* in_attrs,
std::vector<TShape>* out_attrs) {
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U);
const SliceLikeParam& param = nnvm::get<SliceLikeParam>(attrs.parsed);
const TShape& src_shape = in_attrs->at(0);
const TShape& target_shape = in_attrs->at(1);
Tuple<dim_t> end_idx;
end_idx = Tuple<dim_t>(src_shape);
if (param.axis.ndim() == 0) {
for (size_t i = 0; i < src_shape.ndim(); ++i) {
if (i < target_shape.ndim()) {
end_idx[i] = target_shape[i];
CHECK_LE(end_idx[i], src_shape[i])
<< "End index of axis " << i << " exceeds input shape: "
<< end_idx[i] << " vs " << src_shape[i];
}
}
} else {
for (auto i : param.axis) {
if (i < 0) {
i = src_shape.ndim() + i;
}
CHECK_LT(i, target_shape.ndim())
<< "Axis " << i << " exceeds dimension "
<< target_shape.ndim()<< " of target_shape.";
end_idx[i] = target_shape[i];
CHECK_LE(end_idx[i], src_shape[i])
<< "End index of axis " << i << " exceeds input shape: "
<< end_idx[i] << " vs " << src_shape[i];
}
}
TShape out_shape = TShape(std::move(end_idx));
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, out_shape);
return true;
}
NNVM_REGISTER_OP(slice_like)
.describe(R"code(Slice the first input respect to the second input.
)code" NNVM_ADD_FILELINE)
.add_argument("data", "Tensor", "Input data to be sliced.")
.add_argument("slice_like", "Tensor", "Tensor with target shape")
.set_num_inputs(2)
.set_num_outputs(1)
.add_arguments(SliceLikeParam::__FIELDS__())
.set_attr_parser(ParamParser<SliceLikeParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<SliceLikeParam>)
.set_attr<FInferShape>("FInferShape", SliceLikeShape)
.set_attr<FInferType>("FInferType", ElemwiseType<2, 1>)
.set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseBinaryKeepLeftLayout)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const auto& param = nnvm::get<SliceLikeParam>(attrs.parsed);
Array<Expr> src_shape = inputs[0]->shape;
Array<Expr> target_shape = inputs[1]->shape;
Array<Expr> begin_idx, end_idx, strides;
for (size_t i = 0; i < src_shape.size(); ++i) {
begin_idx.push_back(make_const(tvm::Int(32), 0));
strides.push_back(make_const(tvm::Int(32), 1));
}
end_idx = Array<Expr>(src_shape);
if (param.axis.ndim() == 0) {
for (size_t i = 0; i < src_shape.size(); ++i) {
if (i < target_shape.size()) {
end_idx.Set(i, target_shape[i]);
CHECK_LE(topi::GetConstInt(end_idx[i]),
topi::GetConstInt(src_shape[i]))
<< "End index of axis " << i << " exceeds input shape: "
<< topi::GetConstInt(end_idx[i]) << " vs "
<< topi::GetConstInt(src_shape[i]);
}
}
} else {
for (int axis : param.axis) {
if (axis < 0) {
axis = static_cast<int>(src_shape.size()) + axis;
}
end_idx.Set(static_cast<size_t>(axis), target_shape[axis]);
CHECK_LE(topi::GetConstInt(end_idx[axis]),
topi::GetConstInt(src_shape[axis]))
<< "End index of axis " << axis << " exceeds input shape: "
<< topi::GetConstInt(end_idx[axis]) << " vs "
<< topi::GetConstInt(src_shape[axis]);
}
}
return Array<Tensor>{
topi::strided_slice(inputs[0], begin_idx, end_idx, strides)
};
})
.set_attr<FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
return std::vector<std::string>{"data", "slice_like"};
})
.set_support_level(4);
} // namespace top } // namespace top
} // namespace nnvm } // namespace nnvm
...@@ -541,6 +541,60 @@ def test_nms(): ...@@ -541,6 +541,60 @@ def test_nms():
out = m.get_output(0, tvm.nd.empty(np_result.shape, "float32")) out = m.get_output(0, tvm.nd.empty(np_result.shape, "float32"))
np.testing.assert_allclose(out.asnumpy(), np_result, atol=1e-5, rtol=1e-5) np.testing.assert_allclose(out.asnumpy(), np_result, atol=1e-5, rtol=1e-5)
def np_slice_like(np_data, np_shape_like, axis=[]):
begin_idx = [0 for _ in np_data.shape]
end_idx = list(np_data.shape)
if len(axis) > 0:
for i in axis:
if i < 0:
i = len(np_data.shape) + i
end_idx[i] = np_shape_like.shape[i]
else:
for i in range(len(np_data.shape)):
if i < len(np_shape_like.shape):
end_idx[i] = np_shape_like.shape[i]
slice_idx = []
for b, e in zip(begin_idx, end_idx):
slice_idx.append(slice(b, e))
np_result = np_data[slice_idx]
return np_result
def verify_slice_like(np_data, np_shape_like, axis=[]):
dtype = "float32"
np_data = np_data.astype(dtype)
np_shape_like = np_shape_like.astype(dtype)
np_result = np_slice_like(np_data, np_shape_like, axis)
data1 = sym.Variable("data1")
data2 = sym.Variable("data2")
net = sym.slice_like(data=data1, slice_like=data2, axis=axis)
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(net, target, {"data1": np_data.shape,
"data2": np_shape_like.shape})
m = graph_runtime.create(graph, lib, ctx)
m.set_input(**{"data1": np_data, "data2": np_shape_like})
m.run()
out = m.get_output(0, tvm.nd.empty(np_result.shape, dtype))
np.testing.assert_allclose(out.asnumpy(), np_result, atol=1e-5, rtol=1e-5)
def test_slice_like():
np_data = np.random.uniform(size=(3, 4, 5))
np_shape_like = np.random.uniform(size=(1, 2, 3))
verify_slice_like(np_data, np_shape_like)
np_data = np.random.uniform(size=(3, 4, 5))
np_shape_like = np.random.uniform(size=(1, 2))
verify_slice_like(np_data, np_shape_like)
np_data = np.random.uniform(size=(3, 4, 5))
np_shape_like = np.random.uniform(size=(1, 2, 3))
axis = (1, 2)
verify_slice_like(np_data, np_shape_like, axis)
np_data = np.random.uniform(size=(3, 4, 5))
np_shape_like = np.random.uniform(size=(1, 2, 3))
axis = (-1, -3)
verify_slice_like(np_data, np_shape_like, axis)
np_data = np.random.uniform(size=(1, 3, 224, 224))
np_shape_like = np.random.uniform(size=(1, 3, 112, 112))
axis = (2, 3)
verify_slice_like(np_data, np_shape_like, axis)
if __name__ == "__main__": if __name__ == "__main__":
...@@ -561,4 +615,5 @@ if __name__ == "__main__": ...@@ -561,4 +615,5 @@ if __name__ == "__main__":
test_multibox_prior() test_multibox_prior()
test_multibox_transform_loc() test_multibox_transform_loc()
test_nms() test_nms()
test_slice_like()
print(nnvm.compiler.engine.dump()) print(nnvm.compiler.engine.dump())
...@@ -55,26 +55,37 @@ def schedule_conv2d_nhwc(outs): ...@@ -55,26 +55,37 @@ def schedule_conv2d_nhwc(outs):
@tvm.target.generic_func @tvm.target.generic_func
def schedule_conv2d_NCHWc(num_filter, kernel_size, strides, padding, outs): def schedule_conv2d_NCHWc(num_filter, kernel_size, strides,
padding, layout, out_layout, outs):
"""Schedule for conv2d_NCHW[x]c """Schedule for conv2d_NCHW[x]c
Parameters Parameters
---------- ----------
num_filter: int num_filter : int
The number of filter, i.e., the output channel. The number of filter, i.e., the output channel.
kernel_size: tuple of int
kernel_size : tuple of int
(kernel_height, kernel_width) (kernel_height, kernel_width)
strides: tuple of int
strides : tuple of int
(stride_of_height, stride_of_width) (stride_of_height, stride_of_width)
padding: tuple of int
padding : tuple of int
(pad_of_height, pad_of_width) (pad_of_height, pad_of_width)
outs: Array of Tensor
layout : str
Input data layout
out_layout : str
Output data layout
outs : Array of Tensor
The computation graph description of conv2d_NCHWc The computation graph description of conv2d_NCHWc
in the format of an array of tensors. in the format of an array of tensors.
Returns Returns
------- -------
sch: Schedule sch : Schedule
The computation schedule for the op. The computation schedule for the op.
""" """
return _default_schedule(outs, False) return _default_schedule(outs, False)
......
...@@ -145,6 +145,17 @@ def _get_workload(data, kernel, stride, padding, out_dtype): ...@@ -145,6 +145,17 @@ def _get_workload(data, kernel, stride, padding, out_dtype):
@tvm.target.generic_func @tvm.target.generic_func
def _get_alter_layout_schedule(wkl):
# pylint: disable=unreachable
""" Get the platform specific schedule for conv2d_alter_layout. """
target = tvm.target.current_target()
raise RuntimeError(
"No schedule for current target:{}".format(target))
# This return has no use, merely to supress pylint warning
return wkl
@tvm.target.generic_func
def _get_schedule(wkl): def _get_schedule(wkl):
# pylint: disable=unreachable # pylint: disable=unreachable
""" Get the platform specific schedule. """ """ Get the platform specific schedule. """
...@@ -155,6 +166,17 @@ def _get_schedule(wkl): ...@@ -155,6 +166,17 @@ def _get_schedule(wkl):
return wkl return wkl
@tvm.target.generic_func
def _get_schedule_NCHWc(wkl, layout, out_layout):
# pylint: disable=unreachable
""" Get the platform specific schedule. """
target = tvm.target.current_target()
raise RuntimeError(
"No schedule for current target:{}".format(target))
# This return has no use, merely to supress pylint warning
return wkl
def _spatial_pack(data, kernel, stride, padding, out_dtype=None): def _spatial_pack(data, kernel, stride, padding, out_dtype=None):
""" Compute convolution with pack on spatial axes. """ """ Compute convolution with pack on spatial axes. """
if out_dtype is None: if out_dtype is None:
...@@ -443,7 +465,8 @@ def conv2d_nhwc(Input, Filter, stride, padding, out_dtype='float32'): ...@@ -443,7 +465,8 @@ def conv2d_nhwc(Input, Filter, stride, padding, out_dtype='float32'):
return Output return Output
@tvm.target.generic_func @tvm.target.generic_func
def conv2d_NCHWc(data, kernel, num_filter, kernel_size, stride, padding, out_dtype='float32'): def conv2d_NCHWc(data, kernel, num_filter, kernel_size, stride,
padding, layout, out_layout, out_dtype='float32'):
"""Conv2D operator for nChw[x]c layout. """Conv2D operator for nChw[x]c layout.
Parameters Parameters
...@@ -468,6 +491,12 @@ def conv2d_NCHWc(data, kernel, num_filter, kernel_size, stride, padding, out_dty ...@@ -468,6 +491,12 @@ def conv2d_NCHWc(data, kernel, num_filter, kernel_size, stride, padding, out_dty
padding : int or a list/tuple of two ints padding : int or a list/tuple of two ints
padding size, or [pad_height, pad_width] padding size, or [pad_height, pad_width]
layout : str
Input data layout
out_layout : str
Output data layout
out_dtype : str out_dtype : str
output data type output data type
......
# pylint: disable=invalid-name,unused-variable,invalid-name # pylint: disable=invalid-name,unused-variable,invalid-name,unused-argument
"""Conv2D schedule on x86""" """Conv2D schedule on x86"""
import tvm import tvm
from .. import generic, tag from .. import generic, tag
from .. import nn from .. import nn
from ..nn.util import infer_pad, infer_stride from ..nn.util import infer_pad, infer_stride
from ..nn.conv2d import conv2d, conv2d_NCHWc, conv2d_alter_layout, \ from ..nn.conv2d import conv2d, conv2d_NCHWc, conv2d_alter_layout, \
_get_workload, _get_schedule, Workload _get_workload, _get_schedule, _get_schedule_NCHWc, \
_get_alter_layout_schedule, Workload
from . import conv2d_avx_1x1, conv2d_avx_common from . import conv2d_avx_1x1, conv2d_avx_common
from .conv2d_avx_common import AVXConvCommonFwd from .conv2d_avx_common import AVXConvCommonFwd
...@@ -99,6 +100,13 @@ def _get_schedule_conv(wkl): ...@@ -99,6 +100,13 @@ def _get_schedule_conv(wkl):
sch = _SCHEDULES_AVX[idx] sch = _SCHEDULES_AVX[idx]
return sch return sch
@_get_schedule_NCHWc.register("cpu")
def _get_schedule_NCHWc_x86(wkl, layout, out_layout):
return _get_schedule_conv(wkl)
@_get_alter_layout_schedule.register("cpu")
def _get_alter_layout_schedule_x86(wkl):
return _get_schedule_conv(wkl)
@conv2d.register("cpu") @conv2d.register("cpu")
def _declaration_conv(data, kernel, stride, padding, layout, out_dtype): def _declaration_conv(data, kernel, stride, padding, layout, out_dtype):
...@@ -139,7 +147,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos): ...@@ -139,7 +147,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos):
stride = ast.literal_eval(attrs['strides']) stride = ast.literal_eval(attrs['strides'])
wkl = _get_workload(data, kernel, stride, padding, data.dtype) wkl = _get_workload(data, kernel, stride, padding, data.dtype)
sch = _get_schedule_conv(wkl) sch = _get_alter_layout_schedule(wkl)
is_kernel_1x1 = isinstance(sch, AVXConv1x1Fwd) is_kernel_1x1 = isinstance(sch, AVXConv1x1Fwd)
ic_bn, oc_bn = sch.ic_bn, sch.oc_bn ic_bn, oc_bn = sch.ic_bn, sch.oc_bn
...@@ -157,7 +165,8 @@ def _alter_conv2d_layout(attrs, inputs, tinfos): ...@@ -157,7 +165,8 @@ def _alter_conv2d_layout(attrs, inputs, tinfos):
@conv2d_NCHWc.register("cpu") @conv2d_NCHWc.register("cpu")
def _declaration_conv_NCHWc(data, kernel, num_filter, kernel_size, stride, padding, out_dtype): def _declaration_conv_NCHWc(data, kernel, num_filter, kernel_size, stride,
padding, layout, out_layout, out_dtype):
_AVX_SCH_TO_DECL_FUNC = { _AVX_SCH_TO_DECL_FUNC = {
AVXConvCommonFwd: conv2d_avx_common._declaration_conv_NCHWc, AVXConvCommonFwd: conv2d_avx_common._declaration_conv_NCHWc,
AVXConv1x1Fwd: conv2d_avx_1x1._declaration_conv_NCHWc AVXConv1x1Fwd: conv2d_avx_1x1._declaration_conv_NCHWc
...@@ -168,7 +177,7 @@ def _declaration_conv_NCHWc(data, kernel, num_filter, kernel_size, stride, paddi ...@@ -168,7 +177,7 @@ def _declaration_conv_NCHWc(data, kernel, num_filter, kernel_size, stride, paddi
wkl = _get_workload(tvm.placeholder((n, ic, h, w), dtype=out_dtype), wkl = _get_workload(tvm.placeholder((n, ic, h, w), dtype=out_dtype),
tvm.placeholder((num_filter, ic, kh, kw), dtype=out_dtype), tvm.placeholder((num_filter, ic, kh, kw), dtype=out_dtype),
stride, padding, out_dtype) stride, padding, out_dtype)
sch = _get_schedule(wkl) sch = _get_schedule_NCHWc(wkl, layout, out_layout)
return _AVX_SCH_TO_DECL_FUNC[type(sch)](wkl, sch, data, kernel) return _AVX_SCH_TO_DECL_FUNC[type(sch)](wkl, sch, data, kernel)
...@@ -311,7 +320,8 @@ def schedule_conv2d_nhwc(outs): ...@@ -311,7 +320,8 @@ def schedule_conv2d_nhwc(outs):
@generic.schedule_conv2d_NCHWc.register(["cpu"]) @generic.schedule_conv2d_NCHWc.register(["cpu"])
def schedule_conv2d_NCHWc(num_filter, kernel_size, stride, padding, outs): def schedule_conv2d_NCHWc(num_filter, kernel_size, stride, padding,
layout, out_layout, outs):
"""Create schedule for tensors""" """Create schedule for tensors"""
_AVX_SCH_TO_SCH_FUNC = { _AVX_SCH_TO_SCH_FUNC = {
AVXConvCommonFwd: conv2d_avx_common._schedule_conv_NCHWc, AVXConvCommonFwd: conv2d_avx_common._schedule_conv_NCHWc,
...@@ -348,7 +358,7 @@ def schedule_conv2d_NCHWc(num_filter, kernel_size, stride, padding, outs): ...@@ -348,7 +358,7 @@ def schedule_conv2d_NCHWc(num_filter, kernel_size, stride, padding, outs):
original_kernel = tvm.placeholder((num_filter, ic, kh, kw), dtype=conv_out.dtype) original_kernel = tvm.placeholder((num_filter, ic, kh, kw), dtype=conv_out.dtype)
wkl = _get_workload(original_data, original_kernel, stride, padding, conv_out.dtype) wkl = _get_workload(original_data, original_kernel, stride, padding, conv_out.dtype)
sch = _get_schedule(wkl) sch = _get_schedule_NCHWc(wkl, layout, out_layout)
_AVX_SCH_TO_SCH_FUNC[type(sch)](s, wkl, sch, data_vec, _AVX_SCH_TO_SCH_FUNC[type(sch)](s, wkl, sch, data_vec,
kernel, conv_out, outs[0]) kernel, conv_out, outs[0])
......
...@@ -271,6 +271,7 @@ def verify_concatenate_broadcast(shapes, axis, rhs_shape): ...@@ -271,6 +271,7 @@ def verify_concatenate_broadcast(shapes, axis, rhs_shape):
for device in ["llvm", "cuda", "opencl", "metal", "rocm"]: for device in ["llvm", "cuda", "opencl", "metal", "rocm"]:
check_device(device) check_device(device)
def test_expand_dims(): def test_expand_dims():
verify_expand_dims((3, 10), (3, 10, 1, 1), 2, 2) verify_expand_dims((3, 10), (3, 10, 1, 1), 2, 2)
verify_expand_dims((3, 10), (1, 3, 10), -3, 1) verify_expand_dims((3, 10), (1, 3, 10), -3, 1)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment