Commit e286e637 by Siju Committed by Tianqi Chen

[RELAY]prelu op support (#2016)

parent 2fb1cc6e
......@@ -74,6 +74,7 @@ This level enables additional math and transform operators.
tvm.relay.zeros
tvm.relay.nn.leaky_relu
tvm.relay.nn.prelu
tvm.relay.zeros_like
tvm.relay.ones
tvm.relay.ones_like
......@@ -183,6 +184,7 @@ Level 2 Definitions
Level 3 Definitions
-------------------
.. autofunction:: tvm.relay.nn.leaky_relu
.. autofunction:: tvm.relay.nn.prelu
.. autofunction:: tvm.relay.floor
.. autofunction:: tvm.relay.ceil
.. autofunction:: tvm.relay.trunc
......
......@@ -278,6 +278,17 @@ struct LeakyReluAttrs : public tvm::AttrsNode<LeakyReluAttrs> {
};
/*! \brief Attributes for prelu operator */
struct PReluAttrs : public tvm::AttrsNode<PReluAttrs> {
int axis;
TVM_DECLARE_ATTRS(PReluAttrs, "relay.attrs.PReluAttrs") {
TVM_ATTR_FIELD(axis).set_default(1)
.describe("Specify which shape axis the channel is specified.");
}
};
/*! \brief Attributes used in dropout operator */
struct DropoutAttrs : public tvm::AttrsNode<DropoutAttrs> {
double rate;
......
......@@ -280,6 +280,7 @@ class TypeReporterNode : public Node {
TVM_DLL virtual void Assign(const Type& dst, const Type& src) = 0;
/*!
* \brief assert shape expression comparison.
* \note Use assert only if any of the condition input is symbolic.
* \param cond The condition of operation.
* \return false if assertation can be proven to have failed
* true if solver can still proceed.
......
......@@ -528,6 +528,33 @@ def leaky_relu(data, alpha):
return _make.leaky_relu(data, alpha)
def prelu(data, alpha, axis=1):
"""This operator takes data as input and does Leaky version
of a Rectified Linear Unit.
.. math::
`y = x > 0 ? x : alpha * x`
Parameters
----------
data : tvm.relay.Expr
The input data to the operator.
alpha : tvm.relay.Expr
Slope coefficient for the negative half axis.
axis : int, optional
Specify which shape axis the channel is specified.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
return _make.prelu(data, alpha, axis)
def pad(data,
pad_width,
pad_value=0.0):
......
......@@ -171,6 +171,62 @@ RELAY_REGISTER_OP("nn.leaky_relu")
.add_type_rel("Identity", IdentityRel);
TVM_REGISTER_NODE_TYPE(PReluAttrs);
bool PReluRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;
const PReluAttrs* param = attrs.as<PReluAttrs>();
CHECK(param != nullptr);
CHECK(param->axis < static_cast<int>(data->shape.size()))
<< "Wrong axis (" << param->axis << ")value.";
// assign alpha type
Array<IndexExpr> alpha_shape({data->shape[param->axis]});
reporter->Assign(types[1], TensorTypeNode::make(alpha_shape, data->dtype));
// assign output type
reporter->Assign(types[2], TensorTypeNode::make(data->shape, data->dtype));
return true;
}
// Positional relay function to create prelu operator used by frontend FFI.
Expr MakePRelu(Expr data,
Expr alpha,
int axis) {
auto attrs = make_node<PReluAttrs>();
attrs->axis = axis;
static const Op& op = Op::Get("nn.prelu");
return CallNode::make(op, {data, alpha}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op.nn._make.prelu")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 3>(MakePRelu, args, rv);
});
RELAY_REGISTER_OP("nn.prelu")
.describe(R"code(Parametric version of a Rectified Linear Unit.
It accepts two arguments: an input ``x`` and a channelwise slope ``alpha``
and computes the output as :math:`PReLU(x) y = x > 0 ? x : alpha * x`,
where :math:`*` is an channelwise multiplication for each sample in the batch.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.PReluAttrs")
.set_num_inputs(2)
.add_argument("data", "Tensor", "Input data.")
.add_argument("alpha", "Tensor", "Input channelwise alpha.")
.set_support_level(3)
.add_type_rel("PRelu", PReluRel);
TVM_REGISTER_API("relay.op.nn._make.softmax")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
auto make_func = [](Expr data, int axis) {
......
......@@ -188,13 +188,39 @@ def test_full_like():
assert yy.checked_type == relay.TensorType((n, c, h, w), "float32")
def test_infer_type_leaky_relu():
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
x = relay.var("x", relay.TensorType((n, c, h, w), "float32"))
y = relay.nn.leaky_relu(x, alpha=0.1)
"alpha=0.1" in y.astext()
yy = relay.ir_pass.infer_type(y)
assert yy.checked_type == relay.TensorType((n, c, h, w), "float32")
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
x = relay.var("x", relay.TensorType((n, c, h, w), "float32"))
y = relay.nn.leaky_relu(x, alpha=0.1)
"alpha=0.1" in y.astext()
yy = relay.ir_pass.infer_type(y)
assert yy.checked_type == relay.TensorType((n, c, h, w), "float32")
def verify_infer_type_prelu(data, alpha, axis, output, dtype="float32"):
x = relay.var("data", relay.TensorType(data, dtype))
if alpha:
y = relay.var("alpha", relay.TensorType(alpha, dtype))
else:
y = relay.var("alpha", relay.IncompleteType())
z = relay.nn.prelu(x, y, axis=axis)
zz = relay.ir_pass.infer_type(z)
if axis != 1:
assert "axis" in z.astext()
assert zz.checked_type == relay.ty.TensorType(output, dtype)
if not alpha:
axis = axis if axis else 1
alpha_shape = (data[axis],)
assert zz.args[1].checked_type == relay.TensorType(alpha_shape, "float32")
def test_infer_type_prelu():
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
verify_infer_type_prelu((n, c, h, w), (c,), 1, (n, c, h, w))
verify_infer_type_prelu((n, h, w, c), (c,), 3, (n, h, w, c))
verify_infer_type_prelu((n, c, h, w), None, 1, (n, c, h, w))
verify_infer_type_prelu((n, h, w, c), None, 3, (n, h, w, c))
verify_infer_type_prelu((1, 3, 2, 2), (3,), 1, (1, 3, 2, 2))
verify_infer_type_prelu((1, 2, 2, 3), (3,), 3, (1, 2, 2, 3))
verify_infer_type_prelu((1, 3, 2, 2), None, 1, (1, 3, 2, 2))
verify_infer_type_prelu((1, 2, 2, 3), None, 3, (1, 2, 2, 3))
if __name__ == "__main__":
test_cast()
......@@ -208,6 +234,7 @@ if __name__ == "__main__":
test_full()
test_full_like()
test_infer_type_leaky_relu()
test_infer_type_prelu()
test_squeeze_infer_type()
test_squeeze_bad_axes_infer_type()
test_split_infer_type()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment