Commit 71f88611 by Yao Wang Committed by Tianqi Chen

Improve schedule load, add slice_like (#1299)

parent 84bd230c
......@@ -259,6 +259,16 @@ struct ClipParam : public dmlc::Parameter<ClipParam> {
}
};
struct SliceLikeParam : public dmlc::Parameter<SliceLikeParam> {
Tuple<int> axis;
DMLC_DECLARE_PARAMETER(SliceLikeParam) {
DMLC_DECLARE_FIELD(axis).set_default(Tuple<int>())
.describe("List of axes on which input data will be sliced according to the "
"corresponding size of the second input. By default will slice "
"on all axes. Negative axes are supported.");
}
};
} // namespace top
} // namespace nnvm
......
......@@ -240,6 +240,21 @@ def _elemwise_sum(inputs, _):
new_attrs = {'num_args':len(inputs)}
return _get_nnvm_op('elemwise_sum')(*inputs, **new_attrs)
def _crop_like(inputs, attrs):
new_attrs = {}
offsets = \
tuple([float(x.strip()) for x in attrs.get('offsets').strip('()').split(',')]) \
if attrs.get('offsets') is not None else (0, 0)
if offsets != (0, 0):
raise RuntimeError("Currently only supports offsets to be zero.")
center_crop = _parse_bool_str(attrs, 'center_crop', default="False")
if center_crop:
raise RuntimeError("center crop is not supported.")
if len(inputs) < 2:
raise RuntimeError("Only support crop_like pattern.")
new_attrs["axis"] = [2, 3]
return _get_nnvm_op('slice_like')(inputs[0], inputs[1], **new_attrs)
def _expand_dims(inputs, attrs):
op_name, new_attrs = "expand_dims", {}
......@@ -255,7 +270,8 @@ _identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__',
'broadcast_sub', 'broadcast_to', 'cast', 'elemwise_add',
'elemwise_div', 'elemwise_mul', 'elemwise_sub', 'exp',
'flatten', 'log', 'log_softmax', 'max', 'min', 'negative',
'relu', 'sigmoid', 'softmax', 'sum', 'tanh', 'transpose']
'relu', 'sigmoid', 'slice_like', 'softmax', 'sum', 'tanh',
'transpose']
_convert_map = {
'_copy' : _rename('copy'),
......@@ -274,6 +290,7 @@ _convert_map = {
'Concat' : _concat,
'Convolution' : _conv2d,
'Convolution_v1': _conv2d,
'Crop' : _crop_like,
'Deconvolution' : _conv2d_transpose,
'Dropout' : _dropout,
'Flatten' : _rename('flatten'),
......
......@@ -155,10 +155,13 @@ def compute_contrib_conv2d_NCHWc(attrs, inputs, _):
kh, kw = attrs.get_int_tuple('kernel_size')
groups = attrs.get_int("groups")
channels = attrs.get_int("channels")
layout = attrs.get_string("layout")
out_layout = attrs.get_string("out_layout")
assert dilation == (1, 1), "not support dilate now"
if groups == 1:
# pylint: disable=assignment-from-no-return
out = topi.nn.conv2d_NCHWc(inputs[0], inputs[1], channels, (kh, kw), strides, padding)
out = topi.nn.conv2d_NCHWc(inputs[0], inputs[1], channels, (kh, kw),
strides, padding, layout, out_layout)
# pylint: enable=assignment-from-no-return
else:
raise ValueError("not support arbitrary group number > 1 for now")
......@@ -176,9 +179,12 @@ def schedule_contrib_conv2d_NCHWc(attrs, outs, target):
oc = attrs.get_int("channels")
padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides")
layout = attrs.get_string("layout")
out_layout = attrs.get_string("out_layout")
with tvm.target.create(target):
if groups == 1:
return topi.generic.schedule_conv2d_NCHWc(oc, (kh, kw), strides, padding, outs)
return topi.generic.schedule_conv2d_NCHWc(oc, (kh, kw), strides, padding,
layout, out_layout, outs)
else:
raise ValueError("not support group number > 1 for now")
......
......@@ -60,3 +60,7 @@ reg.register_schedule("concatenate", _fschedule_injective)
# split
reg.register_pattern("split", OpPattern.INJECTIVE)
reg.register_schedule("split", _fschedule_injective)
# slice_like
reg.register_pattern("slice_like", OpPattern.INJECTIVE)
reg.register_schedule("slice_like", _fschedule_injective)
......@@ -320,7 +320,7 @@ inline bool ElemwiseBinaryKeepLeftLayout(const NodeAttrs& attrs,
.set_attr<nnvm::FInferShape>("FInferShape", \
ElementWiseReduceShape) \
.set_attr<FCorrectLayout>("FCorrectLayout", \
ElemwiseFixedLayoutCopyToOut<1, 1>) \
ElemwiseFixedLayoutCopyToOut<-1, 1>) \
.set_attr<nnvm::FInferType>("FInferType", ElementWiseReduceType) \
.add_argument("args", "Symbol[]", "Positional input arguments")
......
......@@ -15,6 +15,7 @@
#include "../elemwise_op_common.h"
#include "topi/nn/flatten.h"
#include "topi/transform.h"
#include "topi/detail/constant_utils.h"
namespace nnvm {
namespace top {
......@@ -877,5 +878,105 @@ Examples::
return Array<Tensor>{ topi::flip(inputs[0], param.axis) };
});
// SliceLike
DMLC_REGISTER_PARAMETER(SliceLikeParam);
inline bool SliceLikeShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape>* in_attrs,
std::vector<TShape>* out_attrs) {
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U);
const SliceLikeParam& param = nnvm::get<SliceLikeParam>(attrs.parsed);
const TShape& src_shape = in_attrs->at(0);
const TShape& target_shape = in_attrs->at(1);
Tuple<dim_t> end_idx;
end_idx = Tuple<dim_t>(src_shape);
if (param.axis.ndim() == 0) {
for (size_t i = 0; i < src_shape.ndim(); ++i) {
if (i < target_shape.ndim()) {
end_idx[i] = target_shape[i];
CHECK_LE(end_idx[i], src_shape[i])
<< "End index of axis " << i << " exceeds input shape: "
<< end_idx[i] << " vs " << src_shape[i];
}
}
} else {
for (auto i : param.axis) {
if (i < 0) {
i = src_shape.ndim() + i;
}
CHECK_LT(i, target_shape.ndim())
<< "Axis " << i << " exceeds dimension "
<< target_shape.ndim()<< " of target_shape.";
end_idx[i] = target_shape[i];
CHECK_LE(end_idx[i], src_shape[i])
<< "End index of axis " << i << " exceeds input shape: "
<< end_idx[i] << " vs " << src_shape[i];
}
}
TShape out_shape = TShape(std::move(end_idx));
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, out_shape);
return true;
}
NNVM_REGISTER_OP(slice_like)
.describe(R"code(Slice the first input respect to the second input.
)code" NNVM_ADD_FILELINE)
.add_argument("data", "Tensor", "Input data to be sliced.")
.add_argument("slice_like", "Tensor", "Tensor with target shape")
.set_num_inputs(2)
.set_num_outputs(1)
.add_arguments(SliceLikeParam::__FIELDS__())
.set_attr_parser(ParamParser<SliceLikeParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<SliceLikeParam>)
.set_attr<FInferShape>("FInferShape", SliceLikeShape)
.set_attr<FInferType>("FInferType", ElemwiseType<2, 1>)
.set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseBinaryKeepLeftLayout)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const auto& param = nnvm::get<SliceLikeParam>(attrs.parsed);
Array<Expr> src_shape = inputs[0]->shape;
Array<Expr> target_shape = inputs[1]->shape;
Array<Expr> begin_idx, end_idx, strides;
for (size_t i = 0; i < src_shape.size(); ++i) {
begin_idx.push_back(make_const(tvm::Int(32), 0));
strides.push_back(make_const(tvm::Int(32), 1));
}
end_idx = Array<Expr>(src_shape);
if (param.axis.ndim() == 0) {
for (size_t i = 0; i < src_shape.size(); ++i) {
if (i < target_shape.size()) {
end_idx.Set(i, target_shape[i]);
CHECK_LE(topi::GetConstInt(end_idx[i]),
topi::GetConstInt(src_shape[i]))
<< "End index of axis " << i << " exceeds input shape: "
<< topi::GetConstInt(end_idx[i]) << " vs "
<< topi::GetConstInt(src_shape[i]);
}
}
} else {
for (int axis : param.axis) {
if (axis < 0) {
axis = static_cast<int>(src_shape.size()) + axis;
}
end_idx.Set(static_cast<size_t>(axis), target_shape[axis]);
CHECK_LE(topi::GetConstInt(end_idx[axis]),
topi::GetConstInt(src_shape[axis]))
<< "End index of axis " << axis << " exceeds input shape: "
<< topi::GetConstInt(end_idx[axis]) << " vs "
<< topi::GetConstInt(src_shape[axis]);
}
}
return Array<Tensor>{
topi::strided_slice(inputs[0], begin_idx, end_idx, strides)
};
})
.set_attr<FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
return std::vector<std::string>{"data", "slice_like"};
})
.set_support_level(4);
} // namespace top
} // namespace nnvm
......@@ -541,6 +541,60 @@ def test_nms():
out = m.get_output(0, tvm.nd.empty(np_result.shape, "float32"))
np.testing.assert_allclose(out.asnumpy(), np_result, atol=1e-5, rtol=1e-5)
def np_slice_like(np_data, np_shape_like, axis=[]):
begin_idx = [0 for _ in np_data.shape]
end_idx = list(np_data.shape)
if len(axis) > 0:
for i in axis:
if i < 0:
i = len(np_data.shape) + i
end_idx[i] = np_shape_like.shape[i]
else:
for i in range(len(np_data.shape)):
if i < len(np_shape_like.shape):
end_idx[i] = np_shape_like.shape[i]
slice_idx = []
for b, e in zip(begin_idx, end_idx):
slice_idx.append(slice(b, e))
np_result = np_data[slice_idx]
return np_result
def verify_slice_like(np_data, np_shape_like, axis=[]):
dtype = "float32"
np_data = np_data.astype(dtype)
np_shape_like = np_shape_like.astype(dtype)
np_result = np_slice_like(np_data, np_shape_like, axis)
data1 = sym.Variable("data1")
data2 = sym.Variable("data2")
net = sym.slice_like(data=data1, slice_like=data2, axis=axis)
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(net, target, {"data1": np_data.shape,
"data2": np_shape_like.shape})
m = graph_runtime.create(graph, lib, ctx)
m.set_input(**{"data1": np_data, "data2": np_shape_like})
m.run()
out = m.get_output(0, tvm.nd.empty(np_result.shape, dtype))
np.testing.assert_allclose(out.asnumpy(), np_result, atol=1e-5, rtol=1e-5)
def test_slice_like():
np_data = np.random.uniform(size=(3, 4, 5))
np_shape_like = np.random.uniform(size=(1, 2, 3))
verify_slice_like(np_data, np_shape_like)
np_data = np.random.uniform(size=(3, 4, 5))
np_shape_like = np.random.uniform(size=(1, 2))
verify_slice_like(np_data, np_shape_like)
np_data = np.random.uniform(size=(3, 4, 5))
np_shape_like = np.random.uniform(size=(1, 2, 3))
axis = (1, 2)
verify_slice_like(np_data, np_shape_like, axis)
np_data = np.random.uniform(size=(3, 4, 5))
np_shape_like = np.random.uniform(size=(1, 2, 3))
axis = (-1, -3)
verify_slice_like(np_data, np_shape_like, axis)
np_data = np.random.uniform(size=(1, 3, 224, 224))
np_shape_like = np.random.uniform(size=(1, 3, 112, 112))
axis = (2, 3)
verify_slice_like(np_data, np_shape_like, axis)
if __name__ == "__main__":
......@@ -561,4 +615,5 @@ if __name__ == "__main__":
test_multibox_prior()
test_multibox_transform_loc()
test_nms()
test_slice_like()
print(nnvm.compiler.engine.dump())
......@@ -55,26 +55,37 @@ def schedule_conv2d_nhwc(outs):
@tvm.target.generic_func
def schedule_conv2d_NCHWc(num_filter, kernel_size, strides, padding, outs):
def schedule_conv2d_NCHWc(num_filter, kernel_size, strides,
padding, layout, out_layout, outs):
"""Schedule for conv2d_NCHW[x]c
Parameters
----------
num_filter: int
The number of filter, i.e., the output channel.
kernel_size: tuple of int
(kernel_height, kernel_width)
strides: tuple of int
(stride_of_height, stride_of_width)
padding: tuple of int
(pad_of_height, pad_of_width)
outs: Array of Tensor
The computation graph description of conv2d_NCHWc
in the format of an array of tensors.
num_filter : int
The number of filter, i.e., the output channel.
kernel_size : tuple of int
(kernel_height, kernel_width)
strides : tuple of int
(stride_of_height, stride_of_width)
padding : tuple of int
(pad_of_height, pad_of_width)
layout : str
Input data layout
out_layout : str
Output data layout
outs : Array of Tensor
The computation graph description of conv2d_NCHWc
in the format of an array of tensors.
Returns
-------
sch: Schedule
sch : Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
......
......@@ -145,6 +145,17 @@ def _get_workload(data, kernel, stride, padding, out_dtype):
@tvm.target.generic_func
def _get_alter_layout_schedule(wkl):
# pylint: disable=unreachable
""" Get the platform specific schedule for conv2d_alter_layout. """
target = tvm.target.current_target()
raise RuntimeError(
"No schedule for current target:{}".format(target))
# This return has no use, merely to supress pylint warning
return wkl
@tvm.target.generic_func
def _get_schedule(wkl):
# pylint: disable=unreachable
""" Get the platform specific schedule. """
......@@ -155,6 +166,17 @@ def _get_schedule(wkl):
return wkl
@tvm.target.generic_func
def _get_schedule_NCHWc(wkl, layout, out_layout):
# pylint: disable=unreachable
""" Get the platform specific schedule. """
target = tvm.target.current_target()
raise RuntimeError(
"No schedule for current target:{}".format(target))
# This return has no use, merely to supress pylint warning
return wkl
def _spatial_pack(data, kernel, stride, padding, out_dtype=None):
""" Compute convolution with pack on spatial axes. """
if out_dtype is None:
......@@ -443,7 +465,8 @@ def conv2d_nhwc(Input, Filter, stride, padding, out_dtype='float32'):
return Output
@tvm.target.generic_func
def conv2d_NCHWc(data, kernel, num_filter, kernel_size, stride, padding, out_dtype='float32'):
def conv2d_NCHWc(data, kernel, num_filter, kernel_size, stride,
padding, layout, out_layout, out_dtype='float32'):
"""Conv2D operator for nChw[x]c layout.
Parameters
......@@ -468,6 +491,12 @@ def conv2d_NCHWc(data, kernel, num_filter, kernel_size, stride, padding, out_dty
padding : int or a list/tuple of two ints
padding size, or [pad_height, pad_width]
layout : str
Input data layout
out_layout : str
Output data layout
out_dtype : str
output data type
......
# pylint: disable=invalid-name,unused-variable,invalid-name
# pylint: disable=invalid-name,unused-variable,invalid-name,unused-argument
"""Conv2D schedule on x86"""
import tvm
from .. import generic, tag
from .. import nn
from ..nn.util import infer_pad, infer_stride
from ..nn.conv2d import conv2d, conv2d_NCHWc, conv2d_alter_layout, \
_get_workload, _get_schedule, Workload
_get_workload, _get_schedule, _get_schedule_NCHWc, \
_get_alter_layout_schedule, Workload
from . import conv2d_avx_1x1, conv2d_avx_common
from .conv2d_avx_common import AVXConvCommonFwd
......@@ -99,6 +100,13 @@ def _get_schedule_conv(wkl):
sch = _SCHEDULES_AVX[idx]
return sch
@_get_schedule_NCHWc.register("cpu")
def _get_schedule_NCHWc_x86(wkl, layout, out_layout):
return _get_schedule_conv(wkl)
@_get_alter_layout_schedule.register("cpu")
def _get_alter_layout_schedule_x86(wkl):
return _get_schedule_conv(wkl)
@conv2d.register("cpu")
def _declaration_conv(data, kernel, stride, padding, layout, out_dtype):
......@@ -139,7 +147,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos):
stride = ast.literal_eval(attrs['strides'])
wkl = _get_workload(data, kernel, stride, padding, data.dtype)
sch = _get_schedule_conv(wkl)
sch = _get_alter_layout_schedule(wkl)
is_kernel_1x1 = isinstance(sch, AVXConv1x1Fwd)
ic_bn, oc_bn = sch.ic_bn, sch.oc_bn
......@@ -157,7 +165,8 @@ def _alter_conv2d_layout(attrs, inputs, tinfos):
@conv2d_NCHWc.register("cpu")
def _declaration_conv_NCHWc(data, kernel, num_filter, kernel_size, stride, padding, out_dtype):
def _declaration_conv_NCHWc(data, kernel, num_filter, kernel_size, stride,
padding, layout, out_layout, out_dtype):
_AVX_SCH_TO_DECL_FUNC = {
AVXConvCommonFwd: conv2d_avx_common._declaration_conv_NCHWc,
AVXConv1x1Fwd: conv2d_avx_1x1._declaration_conv_NCHWc
......@@ -168,7 +177,7 @@ def _declaration_conv_NCHWc(data, kernel, num_filter, kernel_size, stride, paddi
wkl = _get_workload(tvm.placeholder((n, ic, h, w), dtype=out_dtype),
tvm.placeholder((num_filter, ic, kh, kw), dtype=out_dtype),
stride, padding, out_dtype)
sch = _get_schedule(wkl)
sch = _get_schedule_NCHWc(wkl, layout, out_layout)
return _AVX_SCH_TO_DECL_FUNC[type(sch)](wkl, sch, data, kernel)
......@@ -311,7 +320,8 @@ def schedule_conv2d_nhwc(outs):
@generic.schedule_conv2d_NCHWc.register(["cpu"])
def schedule_conv2d_NCHWc(num_filter, kernel_size, stride, padding, outs):
def schedule_conv2d_NCHWc(num_filter, kernel_size, stride, padding,
layout, out_layout, outs):
"""Create schedule for tensors"""
_AVX_SCH_TO_SCH_FUNC = {
AVXConvCommonFwd: conv2d_avx_common._schedule_conv_NCHWc,
......@@ -348,7 +358,7 @@ def schedule_conv2d_NCHWc(num_filter, kernel_size, stride, padding, outs):
original_kernel = tvm.placeholder((num_filter, ic, kh, kw), dtype=conv_out.dtype)
wkl = _get_workload(original_data, original_kernel, stride, padding, conv_out.dtype)
sch = _get_schedule(wkl)
sch = _get_schedule_NCHWc(wkl, layout, out_layout)
_AVX_SCH_TO_SCH_FUNC[type(sch)](s, wkl, sch, data_vec,
kernel, conv_out, outs[0])
......
......@@ -271,6 +271,7 @@ def verify_concatenate_broadcast(shapes, axis, rhs_shape):
for device in ["llvm", "cuda", "opencl", "metal", "rocm"]:
check_device(device)
def test_expand_dims():
verify_expand_dims((3, 10), (3, 10, 1, 1), 2, 2)
verify_expand_dims((3, 10), (1, 3, 10), -3, 1)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment