Unverified Commit b236565e by Samuel Committed by GitHub

[PYTORCH]Repeat, Reciprocal & Reshape Op support (#5280)

parent 0d1babce
...@@ -154,6 +154,34 @@ def _select(): ...@@ -154,6 +154,34 @@ def _select():
return _op.transform.take(data, index, axis=dim) return _op.transform.take(data, index, axis=dim)
return _impl return _impl
def _reciprocal():
def _impl(inputs, input_types):
data = inputs[0]
return _expr.const(1.0) / data
return _impl
def _repeat():
def _impl(inputs, input_types):
data = inputs[0]
reps = _get_dims(inputs[1])
return _op.transform.tile(data, reps=reps)
return _impl
def _repeat_interleave():
def _impl(inputs, input_types):
data = inputs[0]
if isinstance(inputs[1], int):
repeats = inputs[1]
axis = inputs[2]
else:
msg = "Only repeat with one value as repeat is currently supported."
raise AssertionError(msg)
if axis is None: # Flatten the data if no axis is given from torch
data = _op.transform.reshape(data, [-1])
axis = 0
return _op.transform.repeat(data, repeats=repeats, axis=axis)
return _impl
def _ones(): def _ones():
def _impl(inputs, input_types): def _impl(inputs, input_types):
data = inputs[0] data = inputs[0]
...@@ -675,6 +703,16 @@ def _view(): ...@@ -675,6 +703,16 @@ def _view():
return _op.transform.reshape(data, new_shape) return _op.transform.reshape(data, new_shape)
return _impl return _impl
def _reshape():
def _impl(inputs, input_types):
data = inputs[0]
if isinstance(inputs[1], list):
new_shape = inputs[1]
else:
new_shape = _infer_shape(inputs[1])
return _op.transform.reshape(data, new_shape)
return _impl
def _clone(): def _clone():
def _impl(inputs, input_types): def _impl(inputs, input_types):
data = inputs[0] data = inputs[0]
...@@ -1082,6 +1120,9 @@ _convert_map = { ...@@ -1082,6 +1120,9 @@ _convert_map = {
"aten::div_" : _elemwise("divide"), "aten::div_" : _elemwise("divide"),
"aten::ones" : _ones(), "aten::ones" : _ones(),
"aten::zeros" : _zeros(), "aten::zeros" : _zeros(),
"aten::reciprocal" : _reciprocal(),
"aten::repeat" : _repeat(),
"aten::repeat_interleave" : _repeat_interleave(),
"aten::to" : _to(), "aten::to" : _to(),
"aten::squeeze" : _squeeze(), "aten::squeeze" : _squeeze(),
"aten::unsqueeze" : _unsqueeze(), "aten::unsqueeze" : _unsqueeze(),
...@@ -1122,6 +1163,7 @@ _convert_map = { ...@@ -1122,6 +1163,7 @@ _convert_map = {
"aten::addmm" : _dense(), "aten::addmm" : _dense(),
"aten::size" : _size(), "aten::size" : _size(),
"aten::view" : _view(), "aten::view" : _view(),
"aten::reshape" : _reshape(),
"aten::clone" : _clone(), "aten::clone" : _clone(),
"aten::log_softmax" : _log_softmax(), "aten::log_softmax" : _log_softmax(),
"aten::sigmoid" : _sigmoid(), "aten::sigmoid" : _sigmoid(),
......
...@@ -293,6 +293,61 @@ def test_forward_multiply(): ...@@ -293,6 +293,61 @@ def test_forward_multiply():
verify_model(Multiply3().float().eval(), input_data=input_data) verify_model(Multiply3().float().eval(), input_data=input_data)
verify_model(Multiply4().float().eval(), input_data=input_data) verify_model(Multiply4().float().eval(), input_data=input_data)
def test_forward_reciprocal():
torch.set_grad_enabled(False)
input_shape = [2, 1, 10, 1, 10]
class Reciprocal1(Module):
def forward(self, *args):
return args[0].reciprocal()
input_data = torch.rand(input_shape).float()
verify_model(Reciprocal1().float().eval(), input_data=input_data)
def test_forward_repeat():
torch.set_grad_enabled(False)
input_shape = [1, 3]
class Repeat1(Module):
def forward(self, *args):
return args[0].repeat(1, 1)
class Repeat2(Module):
def forward(self, *args):
return args[0].repeat(4, 2)
class Repeat3(Module):
def forward(self, *args):
return args[0].repeat(4, 2, 1)
input_data = torch.rand(input_shape).float()
verify_model(Repeat1().float().eval(), input_data=input_data)
verify_model(Repeat2().float().eval(), input_data=input_data)
verify_model(Repeat3().float().eval(), input_data=input_data)
def test_forward_repeat_interleave():
torch.set_grad_enabled(False)
input_shape = [2, 2, 3]
class RepeatInterleave1(Module):
def forward(self, *args):
return args[0].repeat_interleave(2)
class RepeatInterleave2(Module):
def forward(self, *args):
return args[0].repeat_interleave(3, dim=0)
class RepeatInterleave3(Module):
def forward(self, *args):
return args[0].repeat_interleave(2, dim=1)
class RepeatInterleave4(Module):
def forward(self, *args):
return args[0].repeat_interleave(4, dim=2)
input_data = torch.rand(input_shape).float()
verify_model(RepeatInterleave1().float().eval(), input_data=input_data)
verify_model(RepeatInterleave2().float().eval(), input_data=input_data)
verify_model(RepeatInterleave3().float().eval(), input_data=input_data)
verify_model(RepeatInterleave4().float().eval(), input_data=input_data)
def test_forward_unsqueeze(): def test_forward_unsqueeze():
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
input_shape = [10, 10] input_shape = [10, 10]
...@@ -600,6 +655,22 @@ def test_forward_layernorm(): ...@@ -600,6 +655,22 @@ def test_forward_layernorm():
init_weight(ln.eval()) init_weight(ln.eval())
verify_model(ln.eval(), input_data=inp) verify_model(ln.eval(), input_data=inp)
def test_forward_reshape():
torch.set_grad_enabled(False)
input_shape = [2, 1, 10, 1, 10]
new_shape = [2, 1, 10, 10]
class Reshape1(Module):
def forward(self, *args):
return args[0].reshape(new_shape)
class Reshape2(Module):
def forward(self, *args):
return args[0].reshape([-1])
input_data = torch.rand(input_shape).float()
verify_model(Reshape1().float().eval(), input_data=input_data)
verify_model(Reshape2().float().eval(), input_data=input_data)
def test_forward_transpose(): def test_forward_transpose():
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10] input_shape = [1, 3, 10, 10]
...@@ -1151,6 +1222,10 @@ if __name__ == "__main__": ...@@ -1151,6 +1222,10 @@ if __name__ == "__main__":
test_forward_add() test_forward_add()
test_forward_subtract() test_forward_subtract()
test_forward_multiply() test_forward_multiply()
test_forward_reshape()
test_forward_reciprocal()
test_forward_repeat()
test_forward_repeat_interleave()
test_forward_squeeze() test_forward_squeeze()
test_forward_unsqueeze() test_forward_unsqueeze()
test_forward_concatenate() test_forward_concatenate()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment