Commit 2e260938 by Zhao Wu Committed by Yao Wang

Fix PRelu layout in Relay (#3013)

* Fix PRelu layout in Relay

* Fix cpplint

* Add PRelu test case
parent 5af25722
...@@ -525,7 +525,6 @@ inline bool PReluInferShape(const nnvm::NodeAttrs &attrs, ...@@ -525,7 +525,6 @@ inline bool PReluInferShape(const nnvm::NodeAttrs &attrs,
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, 0, dshape); NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, 0, dshape);
// The case of parametric relu // The case of parametric relu
CHECK_EQ(dshape.ndim(), 4) << "Input data should be 4D, but got " << dshape.ndim();
CHECK(size_t(param.axis) < dshape.Size()) CHECK(size_t(param.axis) < dshape.Size())
<< "Wrong axis (" << param.axis << ")value."; << "Wrong axis (" << param.axis << ")value.";
......
...@@ -238,6 +238,23 @@ bool PReluRel(const Array<Type>& types, ...@@ -238,6 +238,23 @@ bool PReluRel(const Array<Type>& types,
return true; return true;
} }
template<typename T>
Array<Array<Layout> > PReluInferCorrectLayout(
const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr>> &old_in_shapes) {
CHECK_EQ(old_in_layouts.size(), 2U);
CHECK_EQ(old_in_shapes.size(), 2U);
Layout data_layout = old_in_layouts[0];
if (new_in_layouts.defined()) {
CHECK_EQ(new_in_layouts.size(), 2U);
}
return Array<Array<Layout> >{{data_layout, Layout("C")},
{data_layout}};
}
// Positional relay function to create prelu operator used by frontend FFI. // Positional relay function to create prelu operator used by frontend FFI.
Expr MakePRelu(Expr data, Expr MakePRelu(Expr data,
Expr alpha, Expr alpha,
...@@ -265,7 +282,7 @@ where :math:`*` is an channelwise multiplication for each sample in the batch. ...@@ -265,7 +282,7 @@ where :math:`*` is an channelwise multiplication for each sample in the batch.
.add_argument("alpha", "Tensor", "Input channelwise alpha.") .add_argument("alpha", "Tensor", "Input channelwise alpha.")
.set_support_level(3) .set_support_level(3)
.add_type_rel("PRelu", PReluRel) .add_type_rel("PRelu", PReluRel)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout) .set_attr<FInferCorrectLayout>("FInferCorrectLayout", PReluInferCorrectLayout<PReluAttrs>)
.set_attr<FTVMCompute>( .set_attr<FTVMCompute>(
"FTVMCompute", [](const Attrs& attrs, "FTVMCompute", [](const Attrs& attrs,
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
......
...@@ -513,6 +513,52 @@ def test_alter_layout_strided_slice(): ...@@ -513,6 +513,52 @@ def test_alter_layout_strided_slice():
assert alpha_equal(a, b), "Actual = \n" + str(a) assert alpha_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_prelu():
"""Test PRelu operator"""
def before():
x = relay.var("x", shape=(1, 64, 56, 56))
weight = relay.var("weight")
alpha = relay.var("alpha", relay.IncompleteType())
y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1))
y = relay.nn.prelu(y, alpha)
y = relay.Function(free_vars(y), y)
return y
@register_alter_op_layout("nn.conv2d", level=110)
def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs
new_attrs = dict(attrs)
new_attrs['data_layout'] = 'NCHW16c'
return relay.nn.conv2d(data, weight, **new_attrs)
def expected():
x = relay.var("x", shape=(1, 64, 56, 56))
w = relay.var("weight")
alpha = relay.var("alpha", relay.IncompleteType())
y = relay.layout_transform(x, "NCHW", "NCHW16c")
y = relay.nn.conv2d(y, w,
channels=64,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NCHW16c")
y = relay.layout_transform(y, "NCHW16c", "NCHW")
y = relay.nn.prelu(y, alpha)
y = relay.Function(free_vars(y), y)
return y
a = before()
a = infer_type(a)
a = canonicalize_ops(a)
a = infer_type(a)
a = alter_op_layout(a)
a = infer_type(a)
b = expected()
b = infer_type(b)
assert(alpha_equal(a, b))
if __name__ == "__main__": if __name__ == "__main__":
test_alter_op() test_alter_op()
...@@ -525,3 +571,4 @@ if __name__ == "__main__": ...@@ -525,3 +571,4 @@ if __name__ == "__main__":
test_alter_layout_concatenate() test_alter_layout_concatenate()
test_alter_layout_nchw_upsamping_op() test_alter_layout_nchw_upsamping_op()
test_alter_layout_strided_slice() test_alter_layout_strided_slice()
test_alter_layout_prelu()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment