Commit 4b765c51 by Pariksheet Pinjari Committed by Tianqi Chen

[OP] PReLU Support (#394)

parent 079599c9
...@@ -101,6 +101,13 @@ struct LeakyReLUParam : public dmlc::Parameter<LeakyReLUParam> { ...@@ -101,6 +101,13 @@ struct LeakyReLUParam : public dmlc::Parameter<LeakyReLUParam> {
} }
}; };
struct PReLUParam : public dmlc::Parameter<PReLUParam> {
int axis;
DMLC_DECLARE_PARAMETER(PReLUParam) {
DMLC_DECLARE_FIELD(axis).set_default(1)
.describe("Specify which shape axis the channel is specified.");
}
};
struct PadParam : public dmlc::Parameter<PadParam> { struct PadParam : public dmlc::Parameter<PadParam> {
float pad_value; float pad_value;
......
...@@ -18,6 +18,9 @@ reg.register_pattern("relu", OpPattern.ELEMWISE) ...@@ -18,6 +18,9 @@ reg.register_pattern("relu", OpPattern.ELEMWISE)
reg.register_schedule("leaky_relu", _fschedule_broadcast) reg.register_schedule("leaky_relu", _fschedule_broadcast)
reg.register_pattern("leaky_relu", OpPattern.ELEMWISE) reg.register_pattern("leaky_relu", OpPattern.ELEMWISE)
# prelu
reg.register_schedule("prelu", _fschedule_broadcast)
reg.register_pattern("prelu", OpPattern.BROADCAST)
# flatten # flatten
reg.register_schedule("flatten", _fschedule_broadcast) reg.register_schedule("flatten", _fschedule_broadcast)
......
...@@ -417,6 +417,53 @@ NNVM_REGISTER_OP(leaky_relu) ...@@ -417,6 +417,53 @@ NNVM_REGISTER_OP(leaky_relu)
}) })
.set_support_level(1); .set_support_level(1);
// prelu
DMLC_REGISTER_PARAMETER(PReLUParam);
inline bool PReluInferShape(const nnvm::NodeAttrs &attrs,
std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape) {
const PReLUParam &param = nnvm::get<PReLUParam>(attrs.parsed);
TShape dshape = in_shape->at(0);
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, 0, dshape);
// The case of parametric relu
CHECK_EQ(dshape.ndim(), 4) << "Input data should be 4D, but got " << dshape.ndim();
CHECK(size_t(param.axis) < dshape.Size())
<< "Wrong axis (" << param.axis << ")value.";
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, 1, TShape({dshape[param.axis]}));
TShape oshape(dshape);
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape);
return true;
}
NNVM_REGISTER_OP(prelu)
.describe(R"code(Parametric version of a Rectified Linear Unit.
It accepts two arguments: an input ``x`` and a channelwise slope ``alpha``
and computes the output as :math:`PReLU(x) y = x > 0 ? x : alpha * x`,
where :math:`*` is an channelwise multiplication for each sample in the
)code" NNVM_ADD_FILELINE)
.add_argument("data", "Tensor", "Input data.")
.add_argument("alpha", "Tensor", "Input channelwise alpha.")
.add_arguments(PReLUParam::__FIELDS__())
.set_attr_parser(ParamParser<PReLUParam>)
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr<FInferShape>("FInferShape", PReluInferShape)
.set_attr<FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
return std::vector<std::string>{"data", "alpha"};
})
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const PReLUParam& param = nnvm::get<PReLUParam>(attrs.parsed);
return Array<Tensor>{ topi::prelu<float>(inputs[0], inputs[1], param.axis)};
})
.set_support_level(4);
DMLC_REGISTER_PARAMETER(PadParam); DMLC_REGISTER_PARAMETER(PadParam);
......
...@@ -64,6 +64,43 @@ def test_relu(): ...@@ -64,6 +64,43 @@ def test_relu():
inputs = [('x', dshape, x)] inputs = [('x', dshape, x)]
helper(y, inputs, dtype, forward, backward) helper(y, inputs, dtype, forward, backward)
def test_prelu_nchw():
x = sym.Variable("x")
a = sym.Variable("a")
y = sym.prelu(data=x, alpha=a)
def forward(x, a):
return (x < 0) * (x * a.reshape(3, 1, 1)) + (x>=0) * x
dtype = "float32"
dshape_x = (1, 3, 32, 32)
dshape_w = (3,)
inputs = [
('x', dshape_x, x),
('a', dshape_w, a)
]
helper(y, inputs, dtype, forward)
def test_prelu_nhwc():
x = sym.Variable("x")
a = sym.Variable("a")
y = sym.prelu(data=x, alpha=a, axis=3)
def forward(x, a):
return (x < 0) * (x * a.reshape(1, 1, 3)) + (x>=0) * x
dtype = "float32"
dshape_x = (1, 32, 32, 3)
dshape_w = (3,)
inputs = [
('x', dshape_x, x),
('a', dshape_w, a)
]
helper(y, inputs, dtype, forward)
def test_sym_scalar_pow(): def test_sym_scalar_pow():
scalar = 3 scalar = 3
...@@ -336,6 +373,8 @@ if __name__ == "__main__": ...@@ -336,6 +373,8 @@ if __name__ == "__main__":
test_batchnorm() test_batchnorm()
test_dense() test_dense()
test_relu() test_relu()
test_prelu_nchw()
test_prelu_nhwc()
test_sym_scalar_pow() test_sym_scalar_pow()
test_scalar_sym_pow() test_scalar_sym_pow()
test_exp() test_exp()
......
...@@ -250,6 +250,17 @@ def test_reshape(): ...@@ -250,6 +250,17 @@ def test_reshape():
check((2, 3, 4), (2, -4, -1, 3, -2), (2, 1, 3, 4)) check((2, 3, 4), (2, -4, -1, 3, -2), (2, 1, 3, 4))
def test_prelu():
def check(in_shape, axis, out_shape):
x = sym.Variable("x", shape=in_shape)
w = sym.Variable("w")
y = sym.prelu(x, w, axis=axis, name="y")
sdict = infer_shape(y)
assert(tuple(sdict["y"][0]) == tuple(out_shape))
check((1, 3, 2, 2), 1, (1, 3, 2, 2))
check((1, 2, 2, 3), 3, (1, 2, 2, 3))
# Level 4 # Level 4
def test_transpose(): def test_transpose():
def check(in_shape, out_shape, **kwargs): def check(in_shape, out_shape, **kwargs):
...@@ -319,3 +330,4 @@ if __name__ == "__main__": ...@@ -319,3 +330,4 @@ if __name__ == "__main__":
test_broadcast_binary() test_broadcast_binary()
test_reduce() test_reduce()
test_transpose() test_transpose()
test_prelu()
...@@ -16,8 +16,15 @@ def test_leaky_relu(): ...@@ -16,8 +16,15 @@ def test_leaky_relu():
y = sym.leaky_relu(x, alpha=0.1) y = sym.leaky_relu(x, alpha=0.1)
assert(y.list_input_names() == ["x"]) assert(y.list_input_names() == ["x"])
def test_prelu():
x = sym.Variable("x")
w = sym.Variable("w")
y = sym.prelu(x, w)
assert(y.list_input_names()[0] == 'x')
assert(y.list_input_names()[1] == 'w')
if __name__ == "__main__": if __name__ == "__main__":
test_scalar_op() test_scalar_op()
test_reshape() test_reshape()
test_leaky_relu() test_leaky_relu()
test_prelu()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment