Commit 6b6e3888 by Haichen Shen Committed by masahi

[Relay/TOPI][Op] Add variance and layer norm op (#3700)

* Add LayerNorm op

* update

* fix

* Add mean_std and mean_variance

* add std and update doc

* add license

* x

* lint

* x

* fix

* fix doc
parent f9e8c116
......@@ -150,6 +150,10 @@ This level enables additional math and transform operators.
tvm.relay.max
tvm.relay.min
tvm.relay.mean
tvm.relay.variance
tvm.relay.std
tvm.relay.mean_variance
tvm.relay.mean_std
tvm.relay.prod
tvm.relay.strided_slice
tvm.relay.broadcast_to
......@@ -297,6 +301,10 @@ Level 4 Definitions
.. autofunction:: tvm.relay.max
.. autofunction:: tvm.relay.min
.. autofunction:: tvm.relay.mean
.. autofunction:: tvm.relay.variance
.. autofunction:: tvm.relay.std
.. autofunction:: tvm.relay.mean_variance
.. autofunction:: tvm.relay.mean_std
.. autofunction:: tvm.relay.prod
.. autofunction:: tvm.relay.strided_slice
.. autofunction:: tvm.relay.broadcast_to
......
......@@ -470,6 +470,27 @@ struct BatchNormAttrs : public tvm::AttrsNode<BatchNormAttrs> {
}; // struct BatchNormAttrs
/*! \brief Attributes used in layer_norm operator */
struct LayerNormAttrs : public tvm::AttrsNode<LayerNormAttrs> {
int axis;
double epsilon;
bool center;
bool scale;
TVM_DECLARE_ATTRS(LayerNormAttrs, "relay.attrs.LayerNormAttrs") {
TVM_ATTR_FIELD(axis).set_default(-1)
.describe("Specify which shape axis denotes the channel.");
TVM_ATTR_FIELD(epsilon).set_default(1e-5)
.describe("Small float added to variance to avoid dividing by zero");
TVM_ATTR_FIELD(center).set_default(true)
.describe("If true, add offset of beta to normalized tensor; "
"otherwise, beta is ignored.");
TVM_ATTR_FIELD(scale).set_default(true)
.describe("If true, multiply by gamma; otherwise, gamma is ignored.");
}
}; // struct LayerNormAttrs
/*! \brief Attributes for LRN operator */
struct LRNAttrs : public tvm::AttrsNode<LRNAttrs> {
int size;
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file tvm/relay/attrs/reduce.h
* \brief Auxiliary attributes for reduce operators.
*/
#ifndef TVM_RELAY_ATTRS_REDUCE_H_
#define TVM_RELAY_ATTRS_REDUCE_H_
#include <tvm/attrs.h>
#include <string>
namespace tvm {
namespace relay {
/*! \brief Attributes for Reduce operators */
struct ReduceAttrs : public tvm::AttrsNode<ReduceAttrs> {
Array<Integer> axis;
bool keepdims;
bool exclude;
TVM_DECLARE_ATTRS(ReduceAttrs, "relay.attrs.ReduceAttrs") {
TVM_ATTR_FIELD(axis).set_default(NullValue<Array<Integer>>())
.describe(R"code(The axis or axes along which to perform the reduction.
The default, `axis=()`, will compute over all elements into a
scalar array with shape `(1,)`.
If `axis` is int, a reduction is performed on a particular axis.
If `axis` is a tuple of ints, a reduction is performed on all the axes
specified in the tuple.
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.)code");
TVM_ATTR_FIELD(keepdims).set_default(false)
.describe("If this is set to `True`, the reduced axes are left "
"in the result as dimension with size one.");
TVM_ATTR_FIELD(exclude).set_default(false)
.describe("Whether to perform reduction on axis that are NOT in axis instead.");
}
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_REDUCE_H_
......@@ -224,10 +224,21 @@ def _mx_batch_norm(inputs, attrs):
new_attrs["axis"] = attrs.get_int("axis", 1)
new_attrs["epsilon"] = attrs.get_float("eps", 0.001)
new_attrs["center"] = True
new_attrs["scale"] = not attrs.get_bool("fix_gamma", False)
new_attrs["scale"] = not attrs.get_bool("fix_gamma", True)
return _op.nn.batch_norm(*inputs, **new_attrs)
def _mx_layer_norm(inputs, attrs):
assert len(inputs) == 3
if attrs.get_bool("output_mean_var", False):
raise tvm.error.OpAttributeUnimplemented(
'Attribute "output_mean_var" is not supported for operator Layer Norm.')
new_attrs = {}
new_attrs["axis"] = attrs.get_int("axis", -1)
new_attrs["epsilon"] = attrs.get_float("eps", 1e-5)
return _op.nn.layer_norm(*inputs, **new_attrs)
def _mx_slice(inputs, attrs):
new_attrs = {}
begin = attrs.get_int_tuple('begin', None)
......@@ -997,6 +1008,7 @@ _convert_map = {
"Dropout" : _mx_dropout,
"BatchNorm" : _mx_batch_norm,
"BatchNorm_v1" : _mx_batch_norm,
"LayerNorm" : _mx_layer_norm,
"LRN" : _mx_lrn,
"L2Normalization" : _mx_l2_normalize,
"slice" : _mx_slice,
......
......@@ -35,3 +35,4 @@ _reg.register_schedule("max", _schedule_reduce)
_reg.register_schedule("min", _schedule_reduce)
_reg.register_schedule("prod", _schedule_reduce)
_reg.register_schedule("mean", _schedule_reduce)
_reg.register_schedule("variance", _schedule_reduce)
......@@ -867,7 +867,7 @@ def batch_norm(data,
Specify along which shape axis the channel is specified.
epsilon : double, optional, default=1e-5
Small float added to variance to avoid diving by zero.
Small float added to variance to avoid dividing by zero.
center : boolean, optional, default=True
If True, add offset of beta to normalized tensor, If False,
......@@ -897,6 +897,64 @@ def batch_norm(data,
return TupleWrapper(result, 3)
def layer_norm(data,
gamma,
beta,
axis=-1,
epsilon=1e-5,
center=True,
scale=True):
r"""
Layer normalization (Lei Ba and et al., 2016).
Applies layer normalization to the n-dimensional input array.
This operator takes an n-dimensional input array and normalizes
the input using the given axis:
.. math::
out = \frac{data - mean(data, axis)}{\sqrt{var(data, axis)+\epsilon}}
* gamma + beta
Unlike batch normalization, the mean and var are computed along the channel dimension.
Assume the input has size k on axis 1, then both gamma and beta have shape (k,).
.. note::
This operator can be optimized away for inference.
Parameters
----------
data : tvm.relay.Expr
Input to which batch_norm will be applied.
gamma : tvm.relay.Expr
The gamma scale factor.
beta : tvm.relay.Expr
The beta offset factor.
axis : int, optional, default=-1
The axis that should be normalized, typically the axis of the channels.
epsilon : double, optional, default=1e-5
Small float added to variance to avoid dividing by zero.
center : boolean, optional, default=True
If True, add offset of beta to normalized tensor, If False,
beta is ignored.
scale : boolean, optional, default=True
If True, multiply by gamma. If False, gamma is not used.
Returns
-------
result : tvm.relay.Expr
The normalized data.
"""
return _make.layer_norm(data, gamma, beta, axis, epsilon, center, scale)
def batch_matmul(x, y):
r"""
Computes batch matrix multiplication of `x` and `y` when `x` and `y` are data
......
......@@ -18,6 +18,9 @@
# pylint: disable=redefined-builtin
from . import _make
from .tensor import sqrt
from .transform import squeeze
from ..expr import Tuple, TupleWrapper
def argmax(data, axis=None, keepdims=False, exclude=False):
"""Returns the indices of the maximum values along an axis.
......@@ -236,8 +239,8 @@ def mean(data, axis=None, keepdims=False, exclude=False):
axis : None or int or tuple of int
Axis or axes along which a mean operation is performed.
The default, axis=None, will find the indices of minimum element all of the elements of
the input array. If axis is negative it counts from the last to the first axis.
The default, axis=None, will compute the mean of all elements in the input array.
If axis is negative it counts from the last to the first axis.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
......@@ -257,6 +260,140 @@ def mean(data, axis=None, keepdims=False, exclude=False):
return _make.mean(data, axis, keepdims, exclude)
def variance(data, axis=None, keepdims=False, exclude=False):
"""Computes the variance of data over given axes.
Parameters
----------
data : relay.Expr
The input data
axis : None or int or tuple of int
Axis or axes along which a variance operation is performed.
The default, axis=None, will compute the variance of all elements in the input array.
If axis is negative it counts from the last to the first axis.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.
With this option, the result will broadcast correctly against the input array.
exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.
Returns
-------
result : relay.Expr
The computed result.
"""
axis = [axis] if isinstance(axis, int) else axis
m = mean(data, axis, True, exclude)
return _make._variance(data, m, axis, keepdims, exclude)
def std(data, axis=None, keepdims=False, exclude=False):
"""Computes the standard deviation of data over given axes.
Parameters
----------
data : relay.Expr
The input data
axis : None or int or tuple of int
Axis or axes along which a standard deviation operation is performed.
The default, axis=None, will compute the standard deviation of all elements in the
input array. If axis is negative it counts from the last to the first axis.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.
With this option, the result will broadcast correctly against the input array.
exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.
Returns
-------
result : relay.Expr
The computed result.
"""
axis = [axis] if isinstance(axis, int) else axis
m = mean(data, axis, True, exclude)
return sqrt(_make._variance(data, m, axis, keepdims, exclude))
def mean_variance(data, axis=None, keepdims=False, exclude=False):
"""Computes the mean and variance of data over given axes.
Parameters
----------
data : relay.Expr
The input data
axis : None or int or tuple of int
Axis or axes along which a mean and variance operation is performed.
The default, axis=None, will compute the mean and variance of all elements in
the input array. If axis is negative it counts from the last to the first axis.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.
With this option, the result will broadcast correctly against the input array.
exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.
Returns
-------
result : relay.Expr
The computed result.
"""
axis = [axis] if isinstance(axis, int) else axis
m = mean(data, axis, True, exclude)
var = _make._variance(data, m, axis, keepdims, exclude)
if not keepdims:
m = squeeze(m)
return TupleWrapper(Tuple((m, var)), 2)
def mean_std(data, axis=None, keepdims=False, exclude=False):
"""Computes the mean and standard deviation of data over given axes.
Parameters
----------
data : relay.Expr
The input data
axis : None or int or tuple of int
Axis or axes along which a mean and standard deviation operation is performed.
The default, axis=None, will compute the mean and standard deviation of all elements in
the input array. If axis is negative it counts from the last to the first axis.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.
With this option, the result will broadcast correctly against the input array.
exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.
Returns
-------
result : relay.Expr
The computed result.
"""
axis = [axis] if isinstance(axis, int) else axis
m = mean(data, axis, True, exclude)
s = sqrt(_make._variance(data, m, axis, keepdims, exclude))
if not keepdims:
m = squeeze(m)
return TupleWrapper(Tuple((m, s)), 2)
def prod(data, axis=None, keepdims=False, exclude=False):
"""Computes the products of array elements over given axes.
......
......@@ -52,7 +52,7 @@ def deconv2d_bn_relu(data, prefix, **kwargs):
"""a block of deconv + batch norm + relu"""
eps = 1e-5 + 1e-12
net = deconv2d(data, name="%s_deconv" % prefix, **kwargs)
net = layers.batch_norm_infer(net, epsilon=eps, name="%s_batch_norm" % prefix)
net = layers.batch_norm_infer(net, epsilon=eps, scale=False, name="%s_batch_norm" % prefix)
net = relay.nn.relu(net)
return net
......
......@@ -38,7 +38,8 @@ def Conv(data, num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), name=None,
padding=pad,
name='%s%s_conv1' % (name, suffix))
bn = layers.batch_norm_infer(data=conv, epsilon=2e-5, name='%s%s_bn' % (name, suffix))
bn = layers.batch_norm_infer(data=conv, epsilon=2e-5, scale=False,
name='%s%s_bn' % (name, suffix))
act = relay.nn.relu(data=bn)
return act
......
......@@ -678,6 +678,53 @@ axis to be the last item in the input shape.
.add_type_rel("BatchNorm", BatchNormRel);
// layer_norm
TVM_REGISTER_NODE_TYPE(LayerNormAttrs);
bool LayerNormRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 4);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;
const LayerNormAttrs* param = attrs.as<LayerNormAttrs>();
int axis = param->axis >= 0 ? param->axis : param->axis + data->shape.size();
CHECK(axis >= 0 && axis < (int)data->shape.size());
reporter->Assign(types[1], TensorTypeNode::make({data->shape[axis]}, data->dtype));
reporter->Assign(types[2], TensorTypeNode::make({data->shape[axis]}, data->dtype));
reporter->Assign(types[3], TensorTypeNode::make(data->shape, data->dtype));
return true;
}
Expr MakeLayerNorm(Expr data, Expr gamma, Expr beta, int axis, double epsilon,
bool center, bool scale) {
auto attrs = make_node<LayerNormAttrs>();
attrs->axis = axis;
attrs->epsilon = epsilon;
attrs->center = center;
attrs->scale = scale;
static const Op& op = Op::Get("nn.layer_norm");
return CallNode::make(op, {data, gamma, beta}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op.nn._make.layer_norm")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 7>(MakeLayerNorm, args, rv);
});
RELAY_REGISTER_OP("nn.layer_norm")
.describe(R"code(
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.LayerNormAttrs")
.set_num_inputs(3)
.add_argument("data", "Tensor", "Input to which layer_norm will be applied.")
.add_argument("gamma", "Tensor", "The gamma scale factor.")
.add_argument("beta", "Tensor", "The beta offset factor.")
.set_support_level(1)
.add_type_rel("LayerNorm", LayerNormRel);
// relay.nn.batch_matmul
bool BatchMatmulRel(const Array<Type>& types,
int num_inputs,
......
......@@ -24,6 +24,7 @@
*/
#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/reduce.h>
#include <topi/elemwise.h>
#include <topi/reduction.h>
#include <numeric>
......@@ -34,34 +35,7 @@
namespace tvm {
namespace relay {
/*! \brief Attributes for Reduce operators */
struct ReduceAttrs : public tvm::AttrsNode<ReduceAttrs> {
Array<Integer> axis;
bool keepdims;
bool exclude;
TVM_DECLARE_ATTRS(ReduceAttrs, "relay.attrs.ReduceAttrs") {
TVM_ATTR_FIELD(axis).set_default(NullValue<Array<Integer>>())
.describe(R"code(The axis or axes along which to perform the reduction.
The default, `axis=()`, will compute over all elements into a
scalar array with shape `(1,)`.
If `axis` is int, a reduction is performed on a particular axis.
If `axis` is a tuple of ints, a reduction is performed on all the axes
specified in the tuple.
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.)code");
TVM_ATTR_FIELD(keepdims).set_default(false)
.describe("If this is set to `True`, the reduced axes are left "
"in the result as dimension with size one.");
TVM_ATTR_FIELD(exclude).set_default(false)
.describe("Whether to perform reduction on axis that are NOT in axis instead.");
}
};
TVM_REGISTER_NODE_TYPE(ReduceAttrs);
/*!
* \brief GetReduceAxes, get the new axis from indim and other arguments
......@@ -498,5 +472,84 @@ Example::
.add_type_rel("Reduce", ReduceRel)
.set_attr<FTVMCompute>("FTVMCompute", MeanCompute)
.set_attr<TOpPattern>("TOpPattern", kCommReduce);
bool VarianceRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;
CHECK(static_cast<int>(data->shape.size()) != 0);
const auto* mean = types[1].as<TensorTypeNode>();
if (mean == nullptr) return false;
std::vector<IndexExpr> in_shape(data->shape.begin(), data->shape.end());
std::vector<IndexExpr> mean_shape(mean->shape.begin(), mean->shape.end());
CHECK_EQ(in_shape.size(), mean_shape.size());
const ReduceAttrs* param = attrs.as<ReduceAttrs>();
CHECK(param != nullptr);
// assign output type and shape
auto oshape = ReduceShapeImpl(in_shape, param, reporter);
reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype));
return true;
}
Array<Tensor> VarianceCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
IndexExpr count = make_const(inputs[0]->dtype, 1);
const ReduceAttrs* param = attrs.as<ReduceAttrs>();
CHECK(param != nullptr);
auto axes = param->axis;
auto data = inputs[0];
auto mean = inputs[1];
for (int64_t i : GetReduceAxes(data->shape.size(),
param->axis,
param->exclude)) {
count *= data->shape[i];
}
std::vector<Integer> expand_shape;
auto sq_diff = topi::power(topi::subtract(data, mean), 2);
auto var = topi::divide(ReduceCompute(attrs, {sq_diff}, out_type, target, topi::sum)[0], count);
return {var};
}
Expr MakeVariance(Expr data,
Expr mean,
Array<Integer> axis,
bool keepdims,
bool exclude) {
auto attrs = make_node<ReduceAttrs>();
attrs->axis = std::move(axis);
attrs->keepdims = keepdims;
attrs->exclude = exclude;
static const Op& op = Op::Get("variance");
return CallNode::make(op, {data, mean}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op._make._variance")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 5>(MakeVariance, args, rv);
});
RELAY_REGISTER_OP("variance")
.describe(R"code(Computes the variance of array elements over given axes.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.ReduceAttrs")
.set_support_level(4)
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("mean", "Tensor", "The mean tensor.")
.add_type_rel("Variance", VarianceRel)
.set_attr<FTVMCompute>("FTVMCompute", VarianceCompute)
.set_attr<TOpPattern>("TOpPattern", kCommReduce);
} // namespace relay
} // namespace tvm
......@@ -33,7 +33,9 @@
#include <tvm/relay/expr.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/attrs/reduce.h>
#include <string>
#include <utility>
namespace tvm {
......@@ -373,6 +375,25 @@ inline Expr Copy(Expr data) {
}
inline Expr Mean(Expr data, Array<Integer> axis, bool keepdims, bool exclude) {
auto attrs = make_node<ReduceAttrs>();
attrs->axis = std::move(axis);
attrs->keepdims = keepdims;
attrs->exclude = exclude;
static const Op& op = Op::Get("mean");
return CallNode::make(op, {data}, Attrs(attrs), {});
}
inline Expr Variance(Expr data, Expr mean, Array<Integer> axis, bool keepdims, bool exclude) {
auto attrs = make_node<ReduceAttrs>();
attrs->axis = std::move(axis);
attrs->keepdims = keepdims;
attrs->exclude = exclude;
static const Op& op = Op::Get("variance");
return CallNode::make(op, {data, mean}, Attrs(attrs), {});
}
Expr MakeConcatenate(Expr data, int axis);
Expr MakeStridedSlice(Expr data, Array<Integer> begin, Array<Integer> end, Array<Integer> strides);
......
......@@ -25,6 +25,7 @@
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/op.h>
#include "./pattern_util.h"
namespace tvm {
......@@ -54,8 +55,8 @@ Expr BatchNormToInferUnpack(const Attrs attrs,
shift = Add(shift, beta);
}
int axis = param->axis;
auto ndim = ttype->shape.size();
int axis = (param->axis < 0) ? param->axis + ndim : param->axis;
scale = ExpandBiasToMatchAxis(scale, ndim, {axis});
shift = ExpandBiasToMatchAxis(shift, ndim, {axis});
......@@ -64,6 +65,33 @@ Expr BatchNormToInferUnpack(const Attrs attrs,
return out;
}
Expr LayerNormToInferUnpack(const Attrs attrs,
Expr data,
Expr gamma,
Expr beta,
Type tdata) {
auto ttype = tdata.as<TensorTypeNode>();
CHECK(ttype);
const auto param = attrs.as<LayerNormAttrs>();
CHECK(param);
Expr epsilon = MakeConstantScalar(Float(32), static_cast<float>(param->epsilon));
Expr mean = Mean(data, {param->axis}, true, false);
Expr var = Variance(data, mean, {param->axis}, true, false);
Expr denom = Sqrt(Add(var, epsilon));
Expr out = Divide(Subtract(data, mean), denom);
size_t ndim = ttype->shape.size();
int axis = (param->axis < 0) ? param->axis + ndim : param->axis;
if (param->scale) {
out = Multiply(out, ExpandBiasToMatchAxis(gamma, ndim, {axis}));
}
if (param->center) {
out = Add(out, ExpandBiasToMatchAxis(beta, ndim, {axis}));
}
return out;
}
class InferenceSimplifier : public ExprMutator {
public:
Expr VisitExpr_(const TupleGetItemNode* n) final {
......@@ -88,9 +116,14 @@ class InferenceSimplifier : public ExprMutator {
Expr VisitExpr_(const CallNode* n) {
static const Op& batch_norm = Op::Get("nn.batch_norm");
static const Op& layer_norm = Op::Get("nn.layer_norm");
auto new_n = ExprMutator::VisitExpr_(n);
if (n->op.same_as(batch_norm)) {
ty_map_[new_n.as<CallNode>()->args[0]] = n->args[0]->checked_type();
} else if (n->op.same_as(layer_norm)) {
const auto* call = new_n.as<CallNode>();
return LayerNormToInferUnpack(call->attrs, call->args[0], call->args[1],
call->args[2], n->args[0]->checked_type());
}
return new_n;
}
......
......@@ -728,6 +728,56 @@ def test_forward_contrib_div_sqrt_dim():
verify((3, 4))
verify((3, 4, 5))
def test_forward_batch_norm():
def verify(shape, axis=1, fix_gamma=False):
x = np.random.uniform(size=shape).astype("float32")
gamma = np.random.uniform(size=(shape[axis])).astype("float32")
beta = np.random.uniform(size=(shape[axis])).astype("float32")
moving_mean = np.random.uniform(size=(shape[axis])).astype("float32")
moving_var = np.random.uniform(size=(shape[axis])).astype("float32")
ref_res = mx.nd.BatchNorm(mx.nd.array(x), mx.nd.array(gamma), mx.nd.array(beta),
mx.nd.array(moving_mean), mx.nd.array(moving_var),
axis=axis, use_global_stats=True, fix_gamma=fix_gamma)
mx_sym = mx.sym.BatchNorm(mx.sym.var("x"), mx.sym.var("gamma"),
mx.sym.var("beta"), mx.sym.var("mean"),
mx.sym.var("var"), axis=axis, use_global_stats=True,
fix_gamma=fix_gamma)
shape_dict = {"x": x.shape, "gamma": gamma.shape, "beta": beta.shape,
"mean": moving_mean.shape, "var": moving_var.shape}
mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict)
#print(mod)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(x, gamma, beta, moving_mean, moving_var)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3)
verify((2, 3, 4, 5))
verify((2, 3, 4, 5), axis=0)
verify((2, 3, 4, 5), axis=-1)
verify((2, 3, 4, 5), fix_gamma=True)
def test_forward_layer_norm():
def verify(shape, axis=-1):
x = np.random.uniform(size=shape).astype("float32")
gamma = np.random.uniform(size=(shape[axis])).astype("float32")
beta = np.random.uniform(size=(shape[axis])).astype("float32")
ref_res = mx.nd.LayerNorm(mx.nd.array(x), mx.nd.array(gamma), mx.nd.array(beta),
axis=axis)
mx_sym = mx.sym.LayerNorm(mx.sym.var("x"), mx.sym.var("gamma"),
mx.sym.var("beta"), axis=axis)
shape_dict = {"x": x.shape, "gamma": gamma.shape, "beta": beta.shape}
mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(x, gamma, beta)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3)
verify((2, 5))
verify((2, 5), axis=0)
verify((2, 5, 6))
if __name__ == '__main__':
test_forward_mlp()
test_forward_vgg()
......@@ -773,3 +823,5 @@ if __name__ == '__main__':
test_forward_topk()
test_forward_sequence_mask()
test_forward_contrib_div_sqrt_dim()
test_forward_batch_norm()
test_forward_layer_norm()
......@@ -203,6 +203,8 @@ def test_reduce_functions():
[relay.max, np.max],
[relay.min, np.min],
[relay.mean, np.mean],
[relay.variance, np.var],
[relay.std, np.std],
[relay.prod, np.prod],
[relay.all, np.all],
[relay.argmin, _with_keepdims(np.argmin)],
......@@ -226,6 +228,43 @@ def test_reduce_functions():
verify_reduce(func, (128, 24, 128), (0, 2), True, False, (1, 24, 1))
def verify_mean_var_std(funcs, shape, axis, keepdims):
test_func = funcs[0]
ref_func = funcs[1]
dtype = "float32"
x = relay.var("x", relay.TensorType(shape, dtype))
z = test_func(x, axis, keepdims)
func = relay.Function([x], z.astuple())
x_data = np.random.uniform(size=shape).astype(dtype)
ref_mean = np.mean(x_data, axis=axis, dtype=dtype, keepdims=keepdims)
ref_res = ref_func(x_data, axis=axis, dtype=dtype, keepdims=keepdims)
for target, ctx in ctx_list():
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
intrp2 = relay.create_executor("debug", ctx=ctx, target=target)
op_res1 = intrp1.evaluate(func)(x_data)
tvm.testing.assert_allclose(op_res1[0].asnumpy(), ref_mean, rtol=1e-5)
tvm.testing.assert_allclose(op_res1[1].asnumpy(), ref_res, rtol=1e-5)
op_res2 = intrp2.evaluate(func)(x_data)
tvm.testing.assert_allclose(op_res2[0].asnumpy(), ref_mean, rtol=1e-5)
tvm.testing.assert_allclose(op_res2[1].asnumpy(), ref_res, rtol=1e-5)
def test_mean_var_std():
for func in [[relay.mean_variance, np.var],
[relay.mean_std, np.std]]:
verify_mean_var_std(func, (2, 3, 4), 1, True)
verify_mean_var_std(func, (2, 3, 4), (1,), True)
verify_mean_var_std(func, (2, 3, 4), -1, True)
verify_mean_var_std(func, (2, 3, 4), (0, 1, 2), False)
verify_mean_var_std(func, (4, 4, 3), None, False)
verify_mean_var_std(func, (4, 4, 3), (0, 2), False)
verify_mean_var_std(func, (128, 24, 128), (0, 1), False)
verify_mean_var_std(func, (128, 24, 128), (0, 2), False)
verify_mean_var_std(func, (128, 24, 128), (0, 1), True)
verify_mean_var_std(func, (128, 24, 128), (0, 2), True)
def test_strided_slice():
def verify(dshape, begin, end, strides, output, test_ref=True):
x = relay.var("x", relay.TensorType(dshape, "float32"))
......@@ -267,3 +306,4 @@ if __name__ == "__main__":
test_binary_int_broadcast()
test_where()
test_reduce_functions()
test_mean_var_std()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment