Unverified Commit 0cc26614 by Samuel Committed by GitHub

[PYTORCH]LayerNorm support added (#5249)

parent 5e50f476
......@@ -503,6 +503,34 @@ def _instance_norm():
scale=scale)
return _impl
def _get_dims(data):
import torch
if isinstance(data, _expr.Expr):
dims = _infer_shape(data)
elif isinstance(data, list):
dims = data
elif isinstance(data, (torch.Tensor, np.ndarray)):
dims = data.shape
else:
msg = "Data type %s could not be parsed" % type(data)
raise AssertionError(msg)
return dims
def _layer_norm():
def _impl(inputs, input_types):
data = inputs[0]
ndims = len(_get_dims(inputs[1]))
assert ndims == 1, "Support only normalization over last one dimension."
return _op.nn.layer_norm(data,
gamma=inputs[1],
beta=inputs[2],
axis=-1,
epsilon=float(inputs[4]),
center=False,
scale=False)
return _impl
def _transpose():
def _impl(inputs, input_types):
data = inputs[0]
......@@ -1050,6 +1078,7 @@ _convert_map = {
"aten::contiguous" : _contiguous(),
"aten::batch_norm" : _batch_norm(),
"aten::instance_norm" : _instance_norm(),
"aten::layer_norm" : _layer_norm(),
"aten::transpose" : _transpose(),
"aten::transpose_" : _transpose(),
"aten::t" : _transpose(),
......
......@@ -561,6 +561,9 @@ def test_forward_instancenorm():
(torch.nn.InstanceNorm3d(16), inp_3d)]:
verify_model(ins_norm.eval(), input_data=inp)
def test_forward_layernorm():
inp = torch.rand((20, 5, 10, 10))
verify_model(torch.nn.LayerNorm(10).eval(), input_data=inp)
def test_forward_transpose():
torch.set_grad_enabled(False)
......@@ -1132,6 +1135,7 @@ if __name__ == "__main__":
test_forward_contiguous()
test_forward_batchnorm()
test_forward_instancenorm()
test_forward_layernorm()
test_forward_transpose()
test_forward_size()
test_forward_view()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment