Unverified Commit f49fc366 by Samuel Committed by GitHub

[RELAY][PYTORCH]GroupNorm op support added (#5358)

parent 4b5f324a
......@@ -959,6 +959,30 @@ struct LayerNormAttrs : public tvm::AttrsNode<LayerNormAttrs> {
}; // struct LayerNormAttrs
/*! \brief Attributes used in group_norm operator */
struct GroupNormAttrs : public tvm::AttrsNode<GroupNormAttrs> {
int num_groups;
int axis;
double epsilon;
bool center;
bool scale;
TVM_DECLARE_ATTRS(GroupNormAttrs, "relay.attrs.GroupNormAttrs") {
TVM_ATTR_FIELD(num_groups).set_default(0)
.describe("Specify number of groups to separate the channels into.");
TVM_ATTR_FIELD(axis).set_default(1)
.describe("Specify which shape axis denotes the channel.");
TVM_ATTR_FIELD(epsilon).set_default(1e-5)
.describe("Small float added to variance to avoid dividing by zero");
TVM_ATTR_FIELD(center).set_default(true)
.describe("If true, add offset of beta to normalized tensor; "
"otherwise, beta is ignored.");
TVM_ATTR_FIELD(scale).set_default(true)
.describe("If true, multiply by gamma; otherwise, gamma is ignored.");
}
}; // struct GroupNormAttrs
/*! \brief Attributes for LRN operator */
struct LRNAttrs : public tvm::AttrsNode<LRNAttrs> {
int size;
......
......@@ -831,6 +831,26 @@ def _layer_norm():
scale=True)
return _impl
def _group_norm():
def _impl(inputs, input_types):
data = inputs[0]
gamma = inputs[2]
beta = inputs[3]
num_groups = inputs[1]
epsilon = float(inputs[4])
return _op.nn.group_norm(data,
gamma=gamma,
beta=beta,
num_groups=num_groups,
axis=1,
epsilon=epsilon,
center=True,
scale=True)
return _impl
def _transpose(prelude):
def _impl(inputs, input_types):
data = inputs[0]
......@@ -1630,6 +1650,7 @@ def _get_convert_map(prelude):
"aten::batch_norm" : _batch_norm(),
"aten::instance_norm" : _instance_norm(),
"aten::layer_norm" : _layer_norm(),
"aten::group_norm" : _group_norm(),
"aten::transpose" : _transpose(prelude),
"aten::transpose_" : _transpose(prelude),
"aten::t" : _transpose(prelude),
......
......@@ -1708,6 +1708,75 @@ def layer_norm(data,
return _make.layer_norm(data, gamma, beta, axis, epsilon, center, scale)
def group_norm(data,
gamma,
beta,
num_groups,
axis=1,
epsilon=1e-5,
center=True,
scale=True):
r"""
Group normalization normalizes over group of channels for each training examples.
We can say that, Group Norm is in between Instance Norm and Layer Norm. When we put
all the channels into a single group, group normalization becomes Layer normalization.
And, when we put each channel into different groups it becomes Instance normalization
https://arxiv.org/pdf/1803.08494.pdf
Applies group normalization to the n-dimensional input array by seperating the input channels
into 'num_groups' groups, each containing 'num_channels / num_groups' channels.
The mean and standard-deviation are calculated separately over the each group. gamma and
beta are learnable per-channel affine transform parameter vectors of size num_channels.
.. math::
out = \frac{data - mean(data, axis)}{\sqrt{var(data, axis)+\epsilon}}
* gamma + beta
Unlike batch normalization, the mean and var are computed along a group of channels.
If the input has size k on axis 1, then both gamma and beta have shape (k,).
.. note::
This operator can be optimized away for inference.
Parameters
----------
data : tvm.relay.Expr
Input to which group_norm will be applied.
gamma : tvm.relay.Expr
The gamma scale factor.
beta : tvm.relay.Expr
The beta offset factor.
num_groups : int
The number of groups to separate the channels into.
axis : int, optional, default=1
The axis of the channels.
epsilon : double, optional, default=1e-5
Small float added to variance to avoid dividing by zero.
center : boolean, optional, default=True
If True, add offset of beta to normalized tensor, If False,
beta is ignored.
scale : boolean, optional, default=True
If True, multiply by gamma. If False, gamma is not used.
Returns
-------
result : tvm.relay.Expr
The normalized data.
"""
return _make.group_norm(data, gamma, beta, num_groups, axis, epsilon, center, scale)
def batch_matmul(x, y):
r"""
Computes batch matrix multiplication of `x` and `y` when `x` and `y` are data
......
......@@ -852,6 +852,80 @@ RELAY_REGISTER_OP("nn.layer_norm")
.set_support_level(1)
.add_type_rel("LayerNorm", LayerNormRel);
// group_norm
TVM_REGISTER_NODE_TYPE(GroupNormAttrs);
bool GroupNormRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 4);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;
const GroupNormAttrs* param = attrs.as<GroupNormAttrs>();
int axis = param->axis >= 0 ? param->axis : param->axis + data->shape.size();
CHECK(axis >= 0 && axis < (int)data->shape.size());
reporter->Assign(types[1], TensorType({data->shape[axis]}, data->dtype));
reporter->Assign(types[2], TensorType({data->shape[axis]}, data->dtype));
reporter->Assign(types[3], TensorType(data->shape, data->dtype));
return true;
}
Expr MakeGroupNorm(Expr data, Expr gamma, Expr beta, int num_groups,
int axis, double epsilon, bool center, bool scale) {
auto attrs = make_object<GroupNormAttrs>();
attrs->num_groups = num_groups;
attrs->axis = axis;
attrs->epsilon = epsilon;
attrs->center = center;
attrs->scale = scale;
static const Op& op = Op::Get("nn.group_norm");
return Call(op, {data, gamma, beta}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op.nn._make.group_norm")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 8>(MakeGroupNorm, args, rv);
});
RELAY_REGISTER_OP("nn.group_norm")
.describe(R"code(
Group normalization normalizes over group of channels for each training examples.
We can say that, Group Norm is in between Instance Norm and Layer Norm. When we put
all the channels into a single group, group normalization becomes Layer normalization.
And, when we put each channel into different groups it becomes Instance normalization
https://arxiv.org/pdf/1803.08494.pdf
Applies group normalization to the n-dimensional input array by seperating the input channels
into 'num_groups' groups, each containing 'num_channels / num_groups' channels.
The mean and standard-deviation are calculated separately over the each group. gamma and
beta are learnable per-channel affine transform parameter vectors of size num_channels.
.. math::
out = \frac{data - mean(data, axis)}{\sqrt{var(data, axis)+\epsilon}}
* gamma + beta
Unlike batch normalization, the mean and var are computed along a group of channels.
If the input has size k on axis 1, then both gamma and beta have shape (k,).
.. note::
This operator can be optimized away for inference.
)code" TVM_ADD_FILELINE)
.set_attrs_type<GroupNormAttrs>()
.set_num_inputs(3)
.add_argument("data", "Tensor", "Input to which group_norm will be applied.")
.add_argument("gamma", "Tensor", "The gamma scale factor.")
.add_argument("beta", "Tensor", "The beta offset factor.")
.set_support_level(1)
.add_type_rel("GroupNorm", GroupNormRel);
// relay.nn.batch_matmul
bool BatchMatmulRel(const Array<Type>& types,
int num_inputs,
......
......@@ -64,6 +64,66 @@ Expr BatchNormToInferUnpack(const Attrs attrs,
return out;
}
Expr GroupNormToInferUnpack(const Attrs attrs,
Expr data,
Expr gamma,
Expr beta,
Type tdata) {
auto ttype = tdata.as<TensorTypeNode>();
CHECK(ttype);
const auto param = attrs.as<GroupNormAttrs>();
CHECK(param);
int ndim = ttype->shape.size();
int axis = (param->axis < 0) ? param->axis + ndim : param->axis;
Array<Integer> reduced_axes;
Array<Integer> new_shape;
Array<Integer> old_shape;
int num_groups = param->num_groups;
int channel = ttype->shape[axis].as<IntImmNode>()->value;
// old_shape = N, C, H, W
// new shape = N, num_groups, C/num_groups, H, W
// reduce_axes = axis of (C/num_groups, H, W)
for (int i = 0; i < ndim; ++i) {
auto val = ttype->shape[i].as<IntImmNode>()->value;
// Save the old shape to reshape later
old_shape.push_back(val);
if (i == axis) {
new_shape.push_back(num_groups);
new_shape.push_back(channel / num_groups);
reduced_axes.push_back(i + 1);
continue;
}
if (i >= axis) {
reduced_axes.push_back(i + 1);
}
new_shape.push_back(val);
}
data = Reshape(data, new_shape);
Expr epsilon = MakeConstantScalar(ttype->dtype, static_cast<float>(param->epsilon));
Expr mean = Mean(data, {reduced_axes}, true, false);
Expr var = Variance(data, mean, {reduced_axes}, true, false);
Expr denom = Sqrt(Add(var, epsilon));
Expr out = Divide(Subtract(data, mean), denom);
out = Reshape(out, old_shape);
if (param->scale) {
out = Multiply(out, ExpandBiasToMatchAxis(gamma, ndim, {axis}));
}
if (param->center) {
out = Add(out, ExpandBiasToMatchAxis(beta, ndim, {axis}));
}
return out;
}
Expr LayerNormToInferUnpack(const Attrs attrs,
Expr data,
Expr gamma,
......@@ -143,6 +203,7 @@ class InferenceSimplifier : public ExprMutator {
dropout_op_(Op::Get("nn.dropout")),
instance_norm_op_(Op::Get("nn.instance_norm")),
layer_norm_op_(Op::Get("nn.layer_norm")),
group_norm_op_(Op::Get("nn.group_norm")),
l2_norm_op_(Op::Get("nn.l2_normalize")) {}
Expr VisitExpr_(const TupleGetItemNode* n) final {
......@@ -170,6 +231,10 @@ class InferenceSimplifier : public ExprMutator {
const auto* call = new_n.as<CallNode>();
return LayerNormToInferUnpack(call->attrs, call->args[0], call->args[1], call->args[2],
n->args[0]->checked_type());
} else if (n->op == group_norm_op_) {
const auto* call = new_n.as<CallNode>();
return GroupNormToInferUnpack(call->attrs, call->args[0], call->args[1], call->args[2],
n->args[0]->checked_type());
} else if (n->op == instance_norm_op_) {
const auto* call = new_n.as<CallNode>();
return InstanceNormToInferUnpack(call->attrs, call->args[0], call->args[1], call->args[2],
......@@ -189,6 +254,7 @@ class InferenceSimplifier : public ExprMutator {
const Op& dropout_op_;
const Op& instance_norm_op_;
const Op& layer_norm_op_;
const Op& group_norm_op_;
const Op& l2_norm_op_;
std::unordered_map<Expr, Type, ObjectHash, ObjectEqual> ty_map_;
};
......
......@@ -717,6 +717,28 @@ def test_forward_layernorm():
init_weight(ln.eval())
verify_model(ln.eval(), input_data=inp)
def test_forward_groupnorm():
input_shape = [10, 6, 5, 5]
input_data = torch.rand(input_shape).float()
# Separate 6 channels into 3 groups
verify_model(torch.nn.GroupNorm(3, 6).eval(), input_data=input_data)
# Put all 6 channels into a single group (equivalent with LayerNorm)
verify_model(torch.nn.GroupNorm(1, 6).eval(), input_data=input_data)
# Separate 6 channels into 6 groups (equivalent with InstanceNorm)
verify_model(torch.nn.GroupNorm(6, 6).eval(), input_data=input_data)
input_shape = [1, 10, 4, 7]
input_data = torch.rand(input_shape).float()
verify_model(torch.nn.GroupNorm(1, 10).eval(), input_data=input_data)
verify_model(torch.nn.GroupNorm(2, 10).eval(), input_data=input_data)
verify_model(torch.nn.GroupNorm(5, 10).eval(), input_data=input_data)
verify_model(torch.nn.GroupNorm(10, 10).eval(), input_data=input_data)
def test_forward_reshape():
torch.set_grad_enabled(False)
input_shape = [2, 1, 10, 1, 10]
......@@ -1865,6 +1887,7 @@ if __name__ == "__main__":
test_forward_batchnorm()
test_forward_instancenorm()
test_forward_layernorm()
test_forward_groupnorm()
test_forward_transpose()
test_forward_size()
test_forward_view()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment