Unverified Commit 302e8ee2 by Samuel Committed by GitHub

[PYTORCH]Dropouts And InstanceNorm support added (#5203)

* [PYTORCH]Dropouts And InstanceNorm support added

* Review comments fixed
parent afb8bf06
......@@ -442,6 +442,36 @@ def _batch_norm():
scale=scale)[0]
return _impl
def _instance_norm():
def _impl(inputs, input_types):
data = inputs[0]
data_type = input_types[0]
channels = _infer_shape(data)
if isinstance(inputs[1], _expr.Expr) and isinstance(inputs[2], _expr.Expr):
scale = center = True
weight = inputs[1]
beta = inputs[2]
gamma = weight
else:
scale = center = False
if not scale:
gamma = _create_typed_const(np.ones([int(channels[1])]), data_type)
if not center:
beta = _create_typed_const(np.zeros([int(channels[1])]), data_type)
epsilon = float(inputs[7])
return _op.nn.instance_norm(data,
gamma,
beta,
axis=1,
epsilon=epsilon,
center=center,
scale=scale)
return _impl
def _transpose():
def _impl(inputs, input_types):
data = inputs[0]
......@@ -965,6 +995,7 @@ _convert_map = {
"aten::threshold_" : _threshold(),
"aten::contiguous" : _contiguous(),
"aten::batch_norm" : _batch_norm(),
"aten::instance_norm" : _instance_norm(),
"aten::transpose" : _transpose(),
"aten::transpose_" : _transpose(),
"aten::t" : _transpose(),
......@@ -978,6 +1009,8 @@ _convert_map = {
"aten::avg_pool2d" : _avg_pool2d(),
"aten::dropout" : _dropout(),
"aten::dropout_" : _dropout(),
"aten::feature_dropout" : _dropout(),
"aten::alpha_dropout" : _dropout(),
"aten::mean" : _mean(),
"aten::chunk" : _chunk(),
"aten::matmul" : _matmul(),
......
......@@ -512,6 +512,15 @@ def test_forward_batchnorm():
verify_model(bn.eval(), input_data=inp)
def test_forward_instancenorm():
inp_2d = torch.rand((1, 16, 10, 10))
inp_3d = torch.rand((1, 16, 10, 10, 10))
for ins_norm, inp in [(torch.nn.InstanceNorm2d(16), inp_2d),
(torch.nn.InstanceNorm3d(16), inp_3d)]:
verify_model(ins_norm.eval(), input_data=inp)
def test_forward_transpose():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]
......@@ -619,13 +628,11 @@ def test_forward_dense():
def test_forward_dropout():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]
class Dropout1(Module):
def forward(self, *args):
return torch.nn.functional.dropout(args[0][0, 0], 0.5, False)
input_data = torch.rand(input_shape).float()
verify_model(Dropout1().float().eval(), input_data=input_data)
verify_model(torch.nn.Dropout(p=0.5).eval(), input_data=input_data[0, 0])
verify_model(torch.nn.Dropout2d(p=0.5).eval(), input_data=input_data[0])
verify_model(torch.nn.Dropout3d(p=0.5).eval(), input_data=input_data)
verify_model(torch.nn.AlphaDropout(p=0.5).eval(), input_data=input_data[0, 0])
def test_forward_slice():
torch.set_grad_enabled(False)
......@@ -1080,6 +1087,7 @@ if __name__ == "__main__":
test_forward_threshold()
test_forward_contiguous()
test_forward_batchnorm()
test_forward_instancenorm()
test_forward_transpose()
test_forward_size()
test_forward_view()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment