Commit 719d6d47 by Alex Gladkov Committed by Yao Wang

Add support for MXNet pad operator. (#3739)

MXNet pad is described at:
https://mxnet.incubator.apache.org/api/python/symbol/symbol.html#mxnet.symbol.pad

Add support for parameter 'None' in MXNet slice operator.

MXNet 'slice' is described at
https://mxnet.incubator.apache.org/api/python/symbol/symbol.html#mxnet.symbol.slice

Add support for MXNet cos, sin, arctan

MXNet 'cos' is described at
https://mxnet.incubator.apache.org/api/python/symbol/symbol.html#mxnet.symbol.cos

MXNet 'sin' is described at
https://mxnet.incubator.apache.org/api/python/symbol/symbol.html#mxnet.symbol.sin

MXNet arctan is descirbed at
https://mxnet.incubator.apache.org/api/python/symbol/symbol.html#mxnet.symbol.arctan

Add support for MXNet 1D Convolution and 1D Deconvolution

MXNet convolution is described at:
https://mxnet.incubator.apache.org/api/python/symbol/symbol.html#mxnet.symbol.Convolution

MXNet Deconvolution is described at:
https://mxnet.incubator.apache.org/api/python/symbol/symbol.html#mxnet.symbol.Deconvolution
parent 0840b064
......@@ -521,6 +521,7 @@ TVM_DECLARE_INTRIN_UNARY(log);
TVM_DECLARE_INTRIN_UNARY(popcount);
TVM_DECLARE_INTRIN_UNARY(cos);
TVM_DECLARE_INTRIN_UNARY(sin);
TVM_DECLARE_INTRIN_UNARY(atan);
// Implementation details after this
inline bool is_const(const Expr& x) {
......
......@@ -405,13 +405,18 @@ struct UpSamplingAttrs : public tvm::AttrsNode<UpSamplingAttrs> {
struct PadAttrs : public tvm::AttrsNode<PadAttrs> {
double pad_value;
Array<Array<IndexExpr> > pad_width;
std::string pad_mode;
TVM_DECLARE_ATTRS(PadAttrs, "relay.attrs.PadAttrs") {
TVM_ATTR_FIELD(pad_value).set_default(0.0)
.describe("Specifies the strides of the convolution.");
.describe("The value used for padding when mode is 'constant'.");
TVM_ATTR_FIELD(pad_width)
.describe("Number of values padded to the edges of each axis, "
"in the format of ((before_1, after_1), ..., (before_N, after_N))");
TVM_ATTR_FIELD(pad_mode).set_default("constant")
.describe("Padding type to use. \"constant\" pads with constant_value, "
"\"edge\" pads using the edge values of the input array, "
"\"reflect\" pads by reflecting values with respect to the edges.");
}
};
......
......@@ -304,6 +304,21 @@ def sin(x):
"""
return call_pure_intrin(x.dtype, "sin", x)
def atan(x):
"""Take atan of input x.
Parameters
----------
x : Expr
Input argument.
Returns
-------
y : Expr
The result.
"""
return call_pure_intrin(x.dtype, "atan", x)
def sqrt(x):
"""Take square root of input x.
......
......@@ -138,16 +138,8 @@ class StrAttrsDict(object):
"""
if key in self.attrs:
tshape = self.attrs[key]
ret = []
for x in tshape.strip('()[]').split(','):
x = x.strip()
if not x:
continue
if x == "None":
ret.append(None)
else:
ret.append(int(x))
return tuple(ret)
return tuple(int(x) if x.strip("- ").isdigit() else None
for x in tshape.strip('()[]').split(',') if x)
if isinstance(default, RequiredAttr):
raise AttributeError("Required attribute {} not found.".format(key))
return default
......
......@@ -112,11 +112,55 @@ def _mx_zeros(inputs, attrs):
return _op.zeros(shape=shape, dtype=dtype)
def _mx_conv(inputs, attrs):
kernel_size = attrs.get_int_tuple("kernel")
if len(kernel_size) == 2:
return _mx_conv2d(inputs, attrs)
elif len(kernel_size) == 1:
return _mx_conv1d(inputs, attrs)
else:
raise tvm.error.OpAttributeInvalid(
'1D or 2D kernels only are supported for operator Convolution')
def _mx_conv1d(inputs, attrs):
kernel_size = attrs.get_int_tuple("kernel")
if len(kernel_size) != 1:
raise tvm.error.OpAttributeInvalid(
'Non 1D or 2D kernels are not supported for operator Convolution')
data_layout = attrs.get_str("layout", "NCW")
# MXNet Conv1D only supports ‘NCW’ layout for now.
if data_layout != "NCW":
raise tvm.error.OpAttributeInvalid(
'Only "NCW" data layout is supported for 1D Convolution')
data_layout = "NCHW"
channel_axis = 1
kernel_layout = "OIHW"
new_attrs = {}
new_attrs["channels"] = attrs.get_int("num_filter")
new_attrs["kernel_size"] = (1,) + kernel_size
new_attrs["strides"] = (1,) + attrs.get_int_tuple("stride", (1,))
new_attrs["padding"] = (0,) + attrs.get_int_tuple("pad", (0,))
new_attrs["dilation"] = (1,) + attrs.get_int_tuple("dilate", (1,))
new_attrs["groups"] = attrs.get_int("num_group", 1)
new_attrs["data_layout"] = data_layout
new_attrs["kernel_layout"] = kernel_layout
use_bias = not attrs.get_bool("no_bias", False)
data = _op.expand_dims(inputs[0], axis=2)
kernel = _op.expand_dims(inputs[1], axis=2)
res = _op.nn.conv2d(data, kernel, **new_attrs)
if use_bias:
assert len(inputs) == 3
res = _op.nn.bias_add(res, inputs[2], axis=channel_axis)
res = _op.squeeze(res, axis=[2])
return res
def _mx_conv2d(inputs, attrs):
kernel_size = attrs.get_int_tuple("kernel")
if len(kernel_size) != 2:
raise tvm.error.OpAttributeInvalid(
'Non-2D kernels are not supported for operator Conv2D.')
'Non 1D or 2D kernels are not supported for operator Convolution')
data_layout = attrs.get_str("layout", "NCHW")
channel_axis = _get_channel_axis(data_layout, "conv2d")
......@@ -142,6 +186,51 @@ def _mx_conv2d(inputs, attrs):
return res
def _mx_conv_transpose(inputs, attrs):
kernel_size = attrs.get_int_tuple("kernel")
if len(kernel_size) == 2:
return _mx_conv2d_transpose(inputs, attrs)
elif len(kernel_size) == 1:
return _mx_conv1d_transpose(inputs, attrs)
else:
raise tvm.error.OpAttributeInvalid(
'1D or 2D kernels only are supported for operator Convolution')
def _mx_conv1d_transpose(inputs, attrs):
if "target_shape" in attrs.attrs:
raise tvm.error.OpAttributeUnImplemented(
'Attribute "target_shape" is not supported for operator Conv2D-transpose.')
data_layout = attrs.get_str("layout", "NCW")
if data_layout != "NCW":
raise tvm.error.OpAttributeInvalid(
'Only "NCW" data layout is supported for 1D Convolution')
data_layout = "NCHW"
channel_axis = 1
kernel_layout = "OIHW"
new_attrs = {}
new_attrs["channels"] = attrs.get_int("num_filter")
new_attrs["kernel_size"] = (1,) + attrs.get_int_tuple("kernel")
new_attrs["strides"] = (1,) + attrs.get_int_tuple("stride", (1,))
new_attrs["output_padding"] = (0,) + attrs.get_int_tuple("adj", (0,))
new_attrs["padding"] = (0,) + attrs.get_int_tuple("pad", (0,))
new_attrs["dilation"] = (1,) + attrs.get_int_tuple("dilate", (1,))
new_attrs["groups"] = attrs.get_int("num_group", 1)
new_attrs["data_layout"] = data_layout
new_attrs["kernel_layout"] = kernel_layout
use_bias = not attrs.get_bool("no_bias", True)
data = _op.expand_dims(inputs[0], axis=2)
kernel = _op.expand_dims(inputs[1], axis=2)
res = _op.nn.conv2d_transpose(data, kernel, **new_attrs)
if use_bias:
assert len(inputs) == 3
res = _op.nn.bias_add(res, inputs[2], axis=channel_axis)
res = _op.squeeze(res, axis=[2])
return res
def _mx_conv2d_transpose(inputs, attrs):
if "target_shape" in attrs.attrs:
raise tvm.error.OpAttributeUnImplemented(
......@@ -257,13 +346,7 @@ def _mx_slice(inputs, attrs):
if end is None:
raise tvm.error.OpAttributeRequired(
'Attribute "end" not found in operator Slice.')
if None in begin:
data_shape = _infer_type(inputs[0]).checked_type.shape
for i, beg in enumerate(begin):
if beg is None:
assert end[i] is None
begin[i] = 0
end[i] = data_shape[i]
begin = tuple(x if x is not None else 0 for x in begin)
new_attrs = {'begin': begin, 'end': end}
if stride is not None:
new_attrs['strides'] = stride
......@@ -373,6 +456,27 @@ def _mx_expand_dims(inputs, attrs):
axis = attrs.get_int("axis")
return _op.expand_dims(inputs[0], axis=axis)
def _mx_pad(inputs, attrs):
pad_mode = attrs.get_str('mode', None)
if pad_mode is None:
raise tvm.error.OpAttributeRequired(
'Attribute "mode" not found in operator pad.')
if pad_mode not in ['constant', 'edge', 'reflect']:
raise tvm.error.OpAttributeInvalid(
'Value ' + mode + ' in attribute "mode" is not valid')
pad_width = attrs.get_int_tuple('pad_width', None)
if pad_width is None:
raise tvm.error.OpAttributeRequired(
'Attribute "pad_width" not found in operator pad.')
if None in pad_width:
raise tvm.error.OpAttributeInvalid(
'Value None in attribute "pad_width" of operator Slice is not valid.')
constant_value = attrs.get_float('constant_value', 0.0)
padding = tuple(tuple((b, a)) for b, a in zip(pad_width[::2], pad_width[1::2]))
return _op.nn.pad(data=inputs[0],
pad_width=padding,
pad_value=constant_value,
pad_mode=pad_mode)
def _mx_leaky_relu(inputs, attrs):
act_type = attrs.get_str("act_type")
......@@ -931,6 +1035,8 @@ _identity_list = [
"ones_like",
"where",
"gather_nd",
"cos",
"sin"
]
_convert_map = {
......@@ -943,6 +1049,7 @@ _convert_map = {
"broadcast_mod" : _rename(_op.mod),
"broadcast_maximum" : _rename(_op.maximum),
"broadcast_minimum" : _rename(_op.minimum),
"arctan" : _rename(_op.atan),
"broadcast_equal" : _mx_compare(_op.equal, _rename),
"broadcast_not_equal" : _mx_compare(_op.not_equal, _rename),
"broadcast_greater" : _mx_compare(_op.greater, _rename),
......@@ -1018,9 +1125,9 @@ _convert_map = {
"_zeros" : _mx_zeros,
"FullyConnected": _mx_fully_connected,
"Activation" : _mx_activations,
"Convolution" : _mx_conv2d,
"Convolution" : _mx_conv,
"Convolution_v1": _mx_conv2d,
"Deconvolution" : _mx_conv2d_transpose,
"Deconvolution" : _mx_conv_transpose,
"Pooling" : _mx_pooling,
"Pooling_v1" : _mx_pooling,
"Dropout" : _mx_dropout,
......@@ -1044,6 +1151,8 @@ _convert_map = {
"_full" : _mx_full,
"repeat" : _mx_repeat,
"tile" : _mx_tile,
"pad" : _mx_pad,
"Pad" : _mx_pad,
"take" : _mx_take,
"reverse" : _mx_reverse,
"squeeze" : _mx_squeeze,
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
......
......@@ -29,6 +29,7 @@ register_schedule("log", schedule_broadcast)
register_schedule("log1p", schedule_broadcast)
register_schedule("cos", schedule_broadcast)
register_schedule("sin", schedule_broadcast)
register_schedule("atan", schedule_broadcast)
register_schedule("exp", schedule_broadcast)
register_schedule("erf", schedule_broadcast)
register_schedule("sqrt", schedule_broadcast)
......
......@@ -60,6 +60,12 @@ def sin_grad(orig, grad):
x = orig.args[0]
return [grad * cos(x)]
@register_gradient("atan")
def atan_grad(orig, grad):
"""Returns [grad * 1 / (1 + x ^ 2)]"""
x = orig.args[0]
a = const(2.0)
return [grad * ones_like(x) / (ones_like(x) + power(x, a))]
@register_gradient("exp")
def exp_grad(orig, grad):
......
......@@ -673,7 +673,8 @@ def prelu(data, alpha, axis=1):
def pad(data,
pad_width,
pad_value=0.0):
pad_value=0.0,
pad_mode='constant'):
r"""Padding
This operator takes in a tensor and pads each axis by the specified
......@@ -688,13 +689,16 @@ def pad(data,
of ((before_1, after_1), ..., (before_N, after_N))
pad_value: float, optional, default=0.0
The value used for padding
pad_mode: 'constant', 'edge', 'reflect'
'constant' pads with constant_value pad_value
'edge' pads using the edge values of the input array
'reflect' pads by reflecting values with respect to the edge
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
return _make.pad(data, pad_width, pad_value)
return _make.pad(data, pad_width, pad_value, pad_mode)
def mirror_pad(data,
......
......@@ -76,6 +76,21 @@ def sin(data):
"""
return _make.sin(data)
def atan(data):
"""Compute elementwise atan of data.
Parameters
----------
data : relay.Expr
The input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.atan(data)
def exp(data):
"""Compute elementwise exp of data.
......
......@@ -46,6 +46,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.cos")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sin")
.set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atan")
.set_body(DispatchExtern<FloatSuffix>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sqrt")
.set_body(DispatchExtern<FloatSuffix>);
......
......@@ -104,6 +104,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.cos")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sin")
.set_body(DispatchExtern<CUDAFastMath>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.atan")
.set_body(DispatchExtern<CUDAMath>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tanh")
.set_body(DispatchExtern<CUDAMath>);
......
......@@ -88,6 +88,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.cos")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sin")
.set_body(DispatchExternLibDevice);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.atan")
.set_body(DispatchExternLibDevice);
} // namespace llvm
} // namespace codegen
} // namespace tvm
......
......@@ -102,14 +102,21 @@ Array<Tensor> PadCompute(const Attrs& attrs,
}
const auto* out_ttype = out_type.as<TensorTypeNode>();
return Array<Tensor>{ topi::pad(inputs[0], pad_before, pad_after,
tvm::make_const(out_ttype->dtype, param->pad_value)) };
tvm::make_const(out_ttype->dtype, param->pad_value),
"T_pad",
topi::kElementWise,
param->pad_mode) };
}
// Handler to create a call to the padding op used by front-end FFI
Expr MakePad(Expr data, Array<Array<IndexExpr> > pad_width, double pad_value) {
Expr MakePad(Expr data,
Array<Array<IndexExpr> > pad_width,
double pad_value,
std::string pad_mode) {
auto attrs = make_node<PadAttrs>();
attrs->pad_value = pad_value;
attrs->pad_width = std::move(pad_width);
attrs->pad_mode = std::move(pad_mode);
static const Op& op = Op::Get("nn.pad");
return CallNode::make(op, {data}, Attrs(attrs), {});
}
......
......@@ -75,6 +75,17 @@ RELAY_REGISTER_UNARY_OP("sin")
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sin));
RELAY_REGISTER_UNARY_OP("atan")
.describe(R"code(Returns the atan of input array, computed element-wise.
.. math::
Y = atan(X)
)code" TVM_ADD_FILELINE)
.set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::atan));
RELAY_REGISTER_UNARY_OP("exp")
.describe(R"code(Returns the exp input array, computed element-wise.
......
......@@ -797,6 +797,87 @@ def test_forward_one_hot():
verify((3, 2, 4, 5), 6, 1, 0, "int32")
verify((3, 2, 4, 5), 6, 1.0, 0.0, "float32")
def test_forward_pad():
def verify(data_shape, out_shape, mode, pad_width, constant_value=0.0):
data = mx.sym.var('data')
mx_sym = mx.sym.pad(data, mode=mode, pad_width=pad_width, constant_value=constant_value)
verify_mxnet_frontend_impl(mx_sym, data_shape=data_shape, out_shape=out_shape)
verify(data_shape=(1,1,3,5), out_shape=(1,1,6,12), mode="constant",
pad_width=(0,0,0,0,1,2,3,4))
verify(data_shape=(1,1,3,5), out_shape=(1,1,6,12), mode="constant",
pad_width=(0,0,0,0,1,2,3,4), constant_value=3.0)
verify(data_shape=(1,1,3,5), out_shape=(1,1,6,12), mode="edge",
pad_width=(0,0,0,0,1,2,3,4))
verify(data_shape=(1,1,3,5), out_shape=(1,1,6,12), mode="reflect",
pad_width=(0,0,0,0,1,2,3,4))
verify(data_shape=(1,1,3,5,7), out_shape=(1,1,6,12,18), mode="constant",
pad_width=(0,0,0,0,1,2,3,4,5,6))
verify(data_shape=(1,1,3,5,7), out_shape=(1,1,6,12,18), mode="constant",
pad_width=(0,0,0,0,1,2,3,4,5,6), constant_value=3.0)
verify(data_shape=(1,1,3,5,7), out_shape=(1,1,6,12,18), mode="edge",
pad_width=(0,0,0,0,1,2,3,4,5,6))
verify(data_shape=(1,1,3,5,7), out_shape=(1,1,6,12,18), mode="reflect",
pad_width=(0,0,0,0,1,2,3,4,5,6))
def test_forward_slice():
def verify(data_shape, out_shape, begin, end):
data = mx.sym.var('data')
mx_sym = mx.sym.slice(data, begin=begin, end=end)
verify_mxnet_frontend_impl(mx_sym, data_shape=data_shape, out_shape=out_shape)
verify(data_shape=(1,1,10), out_shape=(1,1,8), begin=(0, 0, 2), end=(1, 1, 10))
verify(data_shape=(1,1,10), out_shape=(1,1,8), begin=(None, None, 2), end=(None, None, None))
def test_forward_convolution():
def verify(data_shape, kernel_size, stride, pad, num_filter):
weight_shape=(num_filter,1,) + kernel_size
x = np.random.uniform(size=data_shape).astype("float32")
weight = np.random.uniform(size=weight_shape).astype("float32")
bias = np.random.uniform(size=num_filter).astype("float32")
ref_res = mx.nd.Convolution(data=mx.nd.array(x), weight=mx.nd.array(weight),
bias=mx.nd.array(bias), kernel=kernel_size, stride=stride,
pad=pad, num_filter=num_filter)
mx_sym = mx.sym.Convolution(mx.sym.var("x"), mx.sym.var("weight"), mx.sym.var("bias"),
kernel=kernel_size, stride=stride,
pad=pad, num_filter=num_filter)
shape_dict = {"x": x.shape, "weight": weight.shape, "bias": bias.shape}
mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(x, weight, bias)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3)
verify(data_shape=(1,1,1024*16), kernel_size=(17,), stride=(2,), pad=(8,), num_filter=4)
verify(data_shape=(1, 1, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2)
def test_forward_deconvolution():
def verify(data_shape, kernel_size, stride, pad, num_filter):
weight_shape=(1, num_filter) + kernel_size
x = np.random.uniform(size=data_shape).astype("float32")
weight = np.random.uniform(size=weight_shape).astype("float32")
bias = np.random.uniform(size=num_filter).astype("float32")
ref_res = mx.nd.Deconvolution(data=mx.nd.array(x), weight=mx.nd.array(weight), bias=mx.nd.array(bias),
kernel=kernel_size, stride=stride,
pad=pad, num_filter=num_filter, no_bias=False)
mx_sym = mx.sym.Deconvolution(mx.sym.var("x"), mx.sym.var("weight"), mx.sym.var("bias"),
kernel=kernel_size, stride=stride,
pad=pad, num_filter=num_filter, no_bias=False)
shape_dict = {"x": x.shape, "weight": weight.shape, "bias": bias.shape}
mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(x, weight, bias)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3)
verify(data_shape=(1,1,1024*16), kernel_size=(17,), stride=(2,), pad=(8,), num_filter=4)
verify(data_shape=(1, 1, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2)
if __name__ == '__main__':
test_forward_mlp()
test_forward_vgg()
......@@ -810,6 +891,8 @@ if __name__ == '__main__':
test_forward_split()
test_forward_split_squeeze()
test_forward_expand_dims()
test_forward_pad()
test_forward_slice()
test_forward_pooling()
test_forward_adaptive_pooling()
test_forward_lrn()
......@@ -845,3 +928,5 @@ if __name__ == '__main__':
test_forward_batch_norm()
test_forward_layer_norm()
test_forward_one_hot()
test_forward_convolution()
test_forward_deconvolution()
......@@ -61,7 +61,8 @@ def test_unary_op():
(tvm.relay.abs, lambda x: np.where(x < 0, -np.ones_like(x), np.ones_like(x))),
(relay.nn.relu, lambda x: np.where(x < 0, np.zeros_like(x), np.ones_like(x))),
(tvm.relay.cos, lambda x: -1.0 * np.sin(x)),
(tvm.relay.sin, lambda x: np.cos(x))]:
(tvm.relay.sin, lambda x: np.cos(x)),
(tvm.relay.atan, lambda x: 1 / (1 + np.power(x, 2.0)))]:
check_single_op(opfunc, ref)
......
......@@ -75,7 +75,8 @@ def test_unary_op():
(tvm.relay.tanh, np.tanh),
(relay.nn.relu, relu),
(tvm.relay.cos, np.cos),
(tvm.relay.sin, np.sin)]:
(tvm.relay.sin, np.sin),
(tvm.relay.atan, np.arctan)]:
check_single_op(opfunc, ref)
......
......@@ -57,6 +57,7 @@ TOPI_DECLARE_UNARY_OP(trunc);
TOPI_DECLARE_UNARY_OP(abs);
TOPI_DECLARE_UNARY_OP(cos);
TOPI_DECLARE_UNARY_OP(sin);
TOPI_DECLARE_UNARY_OP(atan);
/*
* \brief Fast_tanh_float implementation from Eigen
......
......@@ -144,6 +144,10 @@ inline tvm::Tensor prelu(const tvm::Tensor &x,
* \param pad_after An Array of Expr describing the padding after the
* respective iterator
* \param pad_value The value to fill padding elements with
* \param pad_mode Padding type to use.
* "constant" pads with constant_value;
* "edge" pads using the edge values of the input array;
* "reflect" pads by reflecting values with respect to the edges.
* \param name The name of the operation
* \param tag The tag to mark the operation
*
......@@ -173,7 +177,8 @@ inline tvm::Tensor pad(const tvm::Tensor& t,
tvm::Array<tvm::Expr> pad_after = tvm::Array<tvm::Expr>(),
Expr pad_value = Expr(),
std::string name = "T_pad",
std::string tag = kElementWise) {
std::string tag = kElementWise,
std::string pad_mode = "constant") {
if (pad_after.size() < pad_before.size()) {
for (size_t i = pad_after.size(); i < pad_before.size(); ++i) {
pad_after.push_back(pad_before[i]);
......@@ -202,10 +207,10 @@ inline tvm::Tensor pad(const tvm::Tensor& t,
if (!pad_value.defined()) {
pad_value = tvm::make_const(t->dtype, 0);
}
auto l = [&](tvm::Array<tvm::Var> ovars) {
tvm::Array<tvm::Expr> indices;
tvm::Array<tvm::Expr> sel;
tvm::Array<tvm::Expr> pad_idx;
for (size_t i = 0; i < t->shape.size(); ++i) {
if (i >= pad_before_int32.size()) {
indices.push_back(ovars[i]);
......@@ -220,10 +225,30 @@ inline tvm::Tensor pad(const tvm::Tensor& t,
if (!topi::detail::EqualCheck(pad_after_int32[i], 0)) {
sel.push_back(tvm::ir::Simplify(ovars[i] < pad_before_int32[i] + t->shape[i]));
}
if (pad_mode == "edge") {
pad_idx.push_back(tvm::if_then_else(
ovars[i] < pad_before[i],
0,
tvm::if_then_else(ovars[i] >= pad_before[i] + t->shape[i],
t->shape[i] - 1,
ovars[i] - pad_before[i])));
} else if (pad_mode == "reflect") {
pad_idx.push_back(tvm::if_then_else(
ovars[i] < pad_before[i],
pad_before[i] - ovars[i],
tvm::if_then_else(ovars[i] >= pad_before[i] + t->shape[i],
t->shape[i] * 2 - ovars[i] + pad_before[i] - 2,
ovars[i] - pad_before[i])));
}
}
if (sel.size() != 0) {
return tvm::if_then_else(
detail::Map(sel, tvm::ir::And::make), t(indices), pad_value);
if (pad_mode == "constant") {
return tvm::if_then_else(
detail::Map(sel, tvm::ir::And::make), t(indices), pad_value);
} else if (pad_mode == "edge" || pad_mode == "reflect") {
return tvm::if_then_else(
detail::Map(sel, tvm::ir::And::make), t(indices), t(pad_idx));
}
}
return t(indices);
};
......
......@@ -144,6 +144,22 @@ def sin(x):
@tvm.tag_scope(tag=tag.ELEMWISE)
def atan(x):
"""Take atan of input x.
Parameters
----------
x : tvm.Tensor
Input argument.
Returns
-------
y : tvm.Tensor
The result.
"""
return tvm.compute(x.shape, lambda *i: tvm.atan(x(*i)))
@tvm.tag_scope(tag=tag.ELEMWISE)
def floor(x):
"""Take floor of input x.
......
......@@ -172,6 +172,11 @@ TVM_REGISTER_GLOBAL("topi.tanh")
*rv = tanh(args[0]);
});
TVM_REGISTER_GLOBAL("topi.atan")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = atan(args[0]);
});
TVM_REGISTER_GLOBAL("topi.sigmoid")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = sigmoid(args[0]);
......
......@@ -44,6 +44,7 @@ def test_ewise():
test_apply(topi.rsqrt, "rsqrt")
test_apply(topi.sin, "sin")
test_apply(topi.cos, "cos")
test_apply(topi.atan, "atan")
if __name__ == "__main__":
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment