Commit 7d911f46 by bindog Committed by Wuwei Lin

[Relay][Op] Add instance norm op (#4004)

* [Relay][Op] Add instance norm op

* mend

[Relay][Op] Add instance norm op
parent 36201fe9
...@@ -492,6 +492,29 @@ struct BatchNormAttrs : public tvm::AttrsNode<BatchNormAttrs> { ...@@ -492,6 +492,29 @@ struct BatchNormAttrs : public tvm::AttrsNode<BatchNormAttrs> {
}; // struct BatchNormAttrs }; // struct BatchNormAttrs
/*! \brief Attributes used in instance_norm operator */
struct InstanceNormAttrs : public tvm::AttrsNode<InstanceNormAttrs> {
int axis;
double epsilon;
bool center;
bool scale;
TVM_DECLARE_ATTRS(InstanceNormAttrs, "relay.attrs.InstanceNormAttrs") {
TVM_ATTR_FIELD(axis)
.describe("Specify which shape axis denotes the channel.")
.set_default(1);
TVM_ATTR_FIELD(epsilon)
.describe("Small float added to variance to avoid dividing by zero")
.set_default(1e-5);
TVM_ATTR_FIELD(center).set_default(true)
.describe("If true, add offset of beta to normalized tensor; "
"otherwise, beta is ignored.");
TVM_ATTR_FIELD(scale).set_default(true)
.describe("If true, multiply by gamma; otherwise, gamma is ignored.");
}
}; // struct InstanceNormAttrs
/*! \brief Attributes used in layer_norm operator */ /*! \brief Attributes used in layer_norm operator */
struct LayerNormAttrs : public tvm::AttrsNode<LayerNormAttrs> { struct LayerNormAttrs : public tvm::AttrsNode<LayerNormAttrs> {
int axis; int axis;
......
...@@ -324,6 +324,14 @@ def _mx_batch_norm(inputs, attrs): ...@@ -324,6 +324,14 @@ def _mx_batch_norm(inputs, attrs):
return _op.nn.batch_norm(*inputs, **new_attrs) return _op.nn.batch_norm(*inputs, **new_attrs)
def _mx_instance_norm(inputs, attrs):
assert len(inputs) == 3
new_attrs = {}
new_attrs["axis"] = attrs.get_int("axis", 1)
new_attrs["epsilon"] = attrs.get_float("eps", 1e-5)
return _op.nn.instance_norm(*inputs, **new_attrs)
def _mx_layer_norm(inputs, attrs): def _mx_layer_norm(inputs, attrs):
assert len(inputs) == 3 assert len(inputs) == 3
if attrs.get_bool("output_mean_var", False): if attrs.get_bool("output_mean_var", False):
...@@ -1133,6 +1141,7 @@ _convert_map = { ...@@ -1133,6 +1141,7 @@ _convert_map = {
"Dropout" : _mx_dropout, "Dropout" : _mx_dropout,
"BatchNorm" : _mx_batch_norm, "BatchNorm" : _mx_batch_norm,
"BatchNorm_v1" : _mx_batch_norm, "BatchNorm_v1" : _mx_batch_norm,
"InstanceNorm" : _mx_instance_norm,
"LayerNorm" : _mx_layer_norm, "LayerNorm" : _mx_layer_norm,
"LRN" : _mx_lrn, "LRN" : _mx_lrn,
"L2Normalization" : _mx_l2_normalize, "L2Normalization" : _mx_l2_normalize,
......
...@@ -176,6 +176,15 @@ class BatchNorm(OnnxOpConverter): ...@@ -176,6 +176,15 @@ class BatchNorm(OnnxOpConverter):
return out[0] return out[0]
class InstanceNorm(OnnxOpConverter):
""" Operator converter for BatchNorm.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
return AttrCvt(op_name='instance_norm')(inputs, attr, params)
class Conv(OnnxOpConverter): class Conv(OnnxOpConverter):
""" Operator converter for Conv. """ Operator converter for Conv.
""" """
...@@ -999,7 +1008,7 @@ def _get_convert_map(opset): ...@@ -999,7 +1008,7 @@ def _get_convert_map(opset):
'GlobalAveragePool': Renamer('global_avg_pool2d'), 'GlobalAveragePool': Renamer('global_avg_pool2d'),
'GlobalMaxPool': Renamer('global_max_pool2d'), 'GlobalMaxPool': Renamer('global_max_pool2d'),
'BatchNormalization': BatchNorm.get_converter(opset), 'BatchNormalization': BatchNorm.get_converter(opset),
# 'InstanceNormalization' 'InstanceNormalization': InstanceNorm.get_converter(opset),
# 'LpNormalization' # 'LpNormalization'
'Dropout': AttrCvt('dropout', {'ratio': 'rate'}, ignores=['is_test']), 'Dropout': AttrCvt('dropout', {'ratio': 'rate'}, ignores=['is_test']),
'Flatten': Flatten.get_converter(opset), 'Flatten': Flatten.get_converter(opset),
......
...@@ -935,6 +935,73 @@ def batch_norm(data, ...@@ -935,6 +935,73 @@ def batch_norm(data,
return TupleWrapper(result, 3) return TupleWrapper(result, 3)
def instance_norm(data,
gamma,
beta,
axis=1,
epsilon=1e-5,
center=True,
scale=True):
r"""
Instance Normalization (Ulyanov and et al., 2016)
Applies instance normalization to the n-dimensional input array.
.. math::
out = \frac{data - mean(data)}{\sqrt{var(data)+\epsilon}}
* gamma + beta
The instance normalization is similar to batch normalization, but unlike
batch normalization, the mean and var are calculated per-dimension
separately for each object(instance) in a mini-batch, not over a batch.
And the same normalization is applied both at test and train time.
Assume the input has size *k* on axis 1, then both ``gamma`` and ``beta``
have shape *(k,)*.
The parameter ``axis`` specifies which axis of the input shape denotes
the 'channel'. The default is 1. Specifying -1 sets the channel axis
to be the last item in the input shape.
.. note::
This operator can be optimized away for inference.
Parameters
----------
data : tvm.relay.Expr
Input to which instance_norm will be applied.
gamma : tvm.relay.Expr
The gamma scale factor.
beta : tvm.relay.Expr
The beta offset factor.
axis : int, optional, default=1
Specify along which shape axis the channel is specified.
epsilon : double, optional, default=1e-5
Small float added to variance to avoid dividing by zero.
center : boolean, optional, default=True
If True, add offset of beta to normalized tensor, If False,
beta is ignored.
scale : boolean, optional, default=True
If True, multiply by gamma. If False, gamma is not used.
Returns
-------
result : tvm.relay.Expr
The normalized data.
.. _`Instance Normalization: The Missing Ingredient for Fast Stylization`:
https://arxiv.org/abs/1607.08022
"""
return _make.instance_norm(data, gamma, beta, axis, epsilon, center, scale)
def layer_norm(data, def layer_norm(data,
gamma, gamma,
beta, beta,
...@@ -964,7 +1031,7 @@ def layer_norm(data, ...@@ -964,7 +1031,7 @@ def layer_norm(data,
Parameters Parameters
---------- ----------
data : tvm.relay.Expr data : tvm.relay.Expr
Input to which batch_norm will be applied. Input to which layer_norm will be applied.
gamma : tvm.relay.Expr gamma : tvm.relay.Expr
The gamma scale factor. The gamma scale factor.
......
...@@ -640,6 +640,76 @@ axis to be the last item in the input shape. ...@@ -640,6 +640,76 @@ axis to be the last item in the input shape.
.add_type_rel("BatchNorm", BatchNormRel); .add_type_rel("BatchNorm", BatchNormRel);
// instance_norm
TVM_REGISTER_NODE_TYPE(InstanceNormAttrs);
bool InstanceNormRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 4);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;
const InstanceNormAttrs* param = attrs.as<InstanceNormAttrs>();
int axis = param->axis >= 0 ? param->axis : param->axis + data->shape.size();
CHECK(axis >= 0 && axis < (int)data->shape.size());
reporter->Assign(types[1], TensorTypeNode::make({data->shape[axis]}, data->dtype));
reporter->Assign(types[2], TensorTypeNode::make({data->shape[axis]}, data->dtype));
reporter->Assign(types[3], TensorTypeNode::make(data->shape, data->dtype));
return true;
}
Expr MakeInstanceNorm(Expr data, Expr gamma, Expr beta, int axis, double epsilon,
bool center, bool scale) {
auto attrs = make_node<InstanceNormAttrs>();
attrs->axis = axis;
attrs->epsilon = epsilon;
attrs->center = center;
attrs->scale = scale;
static const Op& op = Op::Get("nn.instance_norm");
return CallNode::make(op, {data, gamma, beta}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op.nn._make.instance_norm")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 7>(MakeInstanceNorm, args, rv);
});
RELAY_REGISTER_OP("nn.instance_norm")
.describe(R"code(Instance Normalization (Ulyanov and et al., 2016)
Applies instance normalization to the n-dimensional input array.
.. math::
out = \frac{data - mean(data)}{\sqrt{var(data)+\epsilon}}
* gamma + beta
The instance normalization is similar to batch normalization, but unlike
batch normalization, the mean and var are calculated per-dimension
separately for each object(instance) in a mini-batch, not over a batch.
And the same normalization is applied both at test and train time.
Assume the input has size *k* on axis 1, then both ``gamma`` and ``beta``
have shape *(k,)*.
The parameter ``axis`` specifies which axis of the input shape denotes
the 'channel'. The default is 1. Specifying -1 sets the channel axis
to be the last item in the input shape.
.. note::
This operator can be optimized away for inference.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.InstanceNormAttrs")
.set_num_inputs(3)
.add_argument("data", "Tensor", "Input to which instance_norm will be applied.")
.add_argument("gamma", "Tensor", "The gamma scale factor.")
.add_argument("beta", "Tensor", "The beta offset factor.")
.set_support_level(1)
.add_type_rel("InstanceNorm", InstanceNormRel);
// layer_norm // layer_norm
TVM_REGISTER_NODE_TYPE(LayerNormAttrs); TVM_REGISTER_NODE_TYPE(LayerNormAttrs);
......
...@@ -92,6 +92,41 @@ Expr LayerNormToInferUnpack(const Attrs attrs, ...@@ -92,6 +92,41 @@ Expr LayerNormToInferUnpack(const Attrs attrs,
return out; return out;
} }
Expr InstanceNormToInferUnpack(const Attrs attrs,
Expr data,
Expr gamma,
Expr beta,
Type tdata) {
auto ttype = tdata.as<TensorTypeNode>();
CHECK(ttype);
const auto param = attrs.as<InstanceNormAttrs>();
CHECK(param);
int ndim = ttype->shape.size();
int axis = (param->axis < 0) ? param->axis + ndim : param->axis;
Array<Integer> reduced_axes;
for (int i = 1; i < ndim; ++i) {
if (i != axis)
reduced_axes.push_back(i);
}
Expr epsilon = MakeConstantScalar(Float(32), static_cast<float>(param->epsilon));
Expr mean = Mean(data, reduced_axes, true, false);
Expr var = Variance(data, mean, reduced_axes, true, false);
Expr denom = Sqrt(Add(var, epsilon));
Expr out = Divide(Subtract(data, mean), denom);
if (param->scale) {
out = Multiply(out, ExpandBiasToMatchAxis(gamma, ndim, {axis}));
}
if (param->center) {
out = Add(out, ExpandBiasToMatchAxis(beta, ndim, {axis}));
}
return out;
}
class InferenceSimplifier : public ExprMutator { class InferenceSimplifier : public ExprMutator {
public: public:
Expr VisitExpr_(const TupleGetItemNode* n) final { Expr VisitExpr_(const TupleGetItemNode* n) final {
...@@ -116,6 +151,7 @@ class InferenceSimplifier : public ExprMutator { ...@@ -116,6 +151,7 @@ class InferenceSimplifier : public ExprMutator {
Expr VisitExpr_(const CallNode* n) { Expr VisitExpr_(const CallNode* n) {
static const Op& batch_norm = Op::Get("nn.batch_norm"); static const Op& batch_norm = Op::Get("nn.batch_norm");
static const Op& instance_norm = Op::Get("nn.instance_norm");
static const Op& layer_norm = Op::Get("nn.layer_norm"); static const Op& layer_norm = Op::Get("nn.layer_norm");
auto new_n = ExprMutator::VisitExpr_(n); auto new_n = ExprMutator::VisitExpr_(n);
if (n->op.same_as(batch_norm)) { if (n->op.same_as(batch_norm)) {
...@@ -124,6 +160,10 @@ class InferenceSimplifier : public ExprMutator { ...@@ -124,6 +160,10 @@ class InferenceSimplifier : public ExprMutator {
const auto* call = new_n.as<CallNode>(); const auto* call = new_n.as<CallNode>();
return LayerNormToInferUnpack(call->attrs, call->args[0], call->args[1], return LayerNormToInferUnpack(call->attrs, call->args[0], call->args[1],
call->args[2], n->args[0]->checked_type()); call->args[2], n->args[0]->checked_type());
} else if (n->op.same_as(instance_norm)) {
const auto* call = new_n.as<CallNode>();
return InstanceNormToInferUnpack(call->attrs, call->args[0], call->args[1],
call->args[2], n->args[0]->checked_type());
} }
return new_n; return new_n;
} }
......
...@@ -758,6 +758,26 @@ def test_forward_batch_norm(): ...@@ -758,6 +758,26 @@ def test_forward_batch_norm():
verify((2, 3, 4, 5), fix_gamma=True) verify((2, 3, 4, 5), fix_gamma=True)
def test_forward_instance_norm():
def verify(shape, axis=1, epsilon=1e-5):
x = np.random.uniform(size=shape).astype("float32")
gamma = np.random.uniform(size=(shape[axis])).astype("float32")
beta = np.random.uniform(size=(shape[axis])).astype("float32")
ref_res = mx.nd.InstanceNorm(mx.nd.array(x), mx.nd.array(gamma), mx.nd.array(beta), epsilon)
mx_sym = mx.sym.InstanceNorm(mx.sym.var("x"), mx.sym.var("gamma"), mx.sym.var("beta"), epsilon)
shape_dict = {"x": x.shape, "gamma": gamma.shape, "beta": beta.shape}
mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(x, gamma, beta)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5, atol=1e-5)
verify((2, 3, 4, 5))
verify((32, 64, 80, 64))
verify((8, 6, 5))
verify((8, 7, 6, 5, 4))
def test_forward_layer_norm(): def test_forward_layer_norm():
def verify(shape, axis=-1): def verify(shape, axis=-1):
x = np.random.uniform(size=shape).astype("float32") x = np.random.uniform(size=shape).astype("float32")
...@@ -938,6 +958,7 @@ if __name__ == '__main__': ...@@ -938,6 +958,7 @@ if __name__ == '__main__':
test_forward_sequence_mask() test_forward_sequence_mask()
test_forward_contrib_div_sqrt_dim() test_forward_contrib_div_sqrt_dim()
test_forward_batch_norm() test_forward_batch_norm()
test_forward_instance_norm()
test_forward_layer_norm() test_forward_layer_norm()
test_forward_one_hot() test_forward_one_hot()
test_forward_convolution() test_forward_convolution()
......
...@@ -416,6 +416,50 @@ def test_lrn(): ...@@ -416,6 +416,50 @@ def test_lrn():
verify_lrn((5, 5, 5, 5), 3, 'float32') verify_lrn((5, 5, 5, 5), 3, 'float32')
verify_lrn((5, 5, 5, 5), 3, 'float32', alpha=0.0002, beta=0.5, bias=2.0) verify_lrn((5, 5, 5, 5), 3, 'float32', alpha=0.0002, beta=0.5, bias=2.0)
def verify_instance_norm(shape, axis=1):
def _get_python_instance_norm(x, gamma, beta, epsilon=1e-5):
dims_x = len(x.shape)
axis = tuple(range(2, dims_x))
mean = np.mean(x, axis=axis, keepdims=True)
var = np.var(x, axis=axis, keepdims=True)
dim_ones = (1,) * (dims_x - 2)
gamma = gamma.reshape(-1, *dim_ones)
beta = beta.reshape(-1, *dim_ones)
return gamma * (x - mean) / np.sqrt(var + epsilon) + beta
x = np.random.randn(*shape).astype(np.float32)
gamma = np.random.randn(shape[1]).astype(np.float32)
beta = np.random.randn(shape[1]).astype(np.float32)
epsilon = 1e-5
y = _get_python_instance_norm(x, gamma, beta, epsilon).astype(np.float32)
node = onnx.helper.make_node(
'InstanceNormalization',
inputs=['x', 'gamma', 'beta'],
outputs=['y'],
epsilon=epsilon,
)
graph = helper.make_graph([node],
"instance_norm_test",
inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(shape)),
helper.make_tensor_value_info("gamma", TensorProto.FLOAT, (shape[1],)),
helper.make_tensor_value_info("beta", TensorProto.FLOAT, (shape[1],))],
outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(shape))])
model = helper.make_model(graph, producer_name='instance_norm_test')
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, [x, gamma, beta], target, ctx, shape, 'float32')
tvm.testing.assert_allclose(y, tvm_out, rtol=1e-5, atol=1e-5)
def test_instance_norm():
verify_instance_norm((2, 3, 4, 5))
verify_instance_norm((32, 64, 80, 64))
verify_instance_norm((8, 6, 5))
verify_instance_norm((8, 7, 6, 5, 4))
def _test_upsample_nearest(): def _test_upsample_nearest():
scale = 2 scale = 2
in_shape = (1, 1, 3, 3) in_shape = (1, 1, 3, 3)
...@@ -1270,6 +1314,7 @@ if __name__ == '__main__': ...@@ -1270,6 +1314,7 @@ if __name__ == '__main__':
test_matmul() test_matmul()
test_gather() test_gather()
test_lrn() test_lrn()
test_instance_norm()
test_upsample() test_upsample()
test_forward_min() test_forward_min()
test_forward_max() test_forward_max()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment