Unverified Commit b1364ebb by Samuel Committed by GitHub

[PYTORCH]Take, Topk op support (#5332)

* [PYTORCH]take, topk op support

* Ci Failure fix
parent afcf9397
...@@ -272,6 +272,39 @@ def _select(): ...@@ -272,6 +272,39 @@ def _select():
return _op.transform.take(data, index, axis=dim) return _op.transform.take(data, index, axis=dim)
return _impl return _impl
def _take():
def _impl(inputs, input_types):
data = inputs[0]
import torch
if isinstance(inputs[1], _expr.Var):
indices = _op.cast(inputs[1], "int32")
elif isinstance(inputs[1], torch.Tensor):
indices = _wrap_const(inputs[1].numpy())
else:
msg = "Data type %s could not be parsed in take operator." % (type(inputs[1]))
raise AssertionError(msg)
return _op.transform.take(data, indices=indices)
return _impl
def _topk():
def _impl(inputs, input_types):
data = inputs[0]
k = int(inputs[1])
axis = int(inputs[2])
is_ascend = not bool(inputs[3])
sort = bool(inputs[4])
if not sort:
msg = "Currently supports only sorted output for topk operator."
raise AssertionError(msg)
outs = _op.topk(data, k=k, axis=axis, is_ascend=is_ascend, ret_type="both")
return outs[0], outs[1]
return _impl
def _reciprocal(): def _reciprocal():
def _impl(inputs, input_types): def _impl(inputs, input_types):
data = inputs[0] data = inputs[0]
...@@ -1416,6 +1449,8 @@ def _get_convert_map(prelude): ...@@ -1416,6 +1449,8 @@ def _get_convert_map(prelude):
"aten::split" : _split(), "aten::split" : _split(),
"aten::split_with_sizes" : _split_with_sizes(), "aten::split_with_sizes" : _split_with_sizes(),
"aten::select" : _select(), "aten::select" : _select(),
"aten::take" : _take(),
"aten::topk" : _topk(),
"aten::relu" : _relu(), "aten::relu" : _relu(),
"aten::relu_" : _relu(), "aten::relu_" : _relu(),
"aten::prelu" : _prelu(), "aten::prelu" : _prelu(),
......
...@@ -1545,6 +1545,61 @@ def test_forward_round(): ...@@ -1545,6 +1545,61 @@ def test_forward_round():
verify_model(Round1().float().eval(), input_data=input_data) verify_model(Round1().float().eval(), input_data=input_data)
def test_forward_take():
torch.set_grad_enabled(False)
class Take1(Module):
def forward(self, *args):
indices = torch.tensor([[0,0],[1,0]])
if torch.cuda.is_available():
indices = indices.cuda()
return torch.take(args[0], indices)
class Take2(Module):
def forward(self, *args):
return torch.take(args[0], args[1])
input_data = torch.tensor([[1,2],[3,4]])
verify_model(Take1().float().eval(), input_data=input_data)
indices = torch.tensor([[0,0],[1,0]])
verify_model(Take2().float().eval(), input_data=[input_data, indices])
def test_forward_topk():
torch.set_grad_enabled(False)
class Topk1(Module):
def forward(self, *args):
return torch.topk(args[0], k=3)
class Topk2(Module):
def forward(self, *args):
return torch.topk(args[0], k=3, dim=-2)
class Topk3(Module):
def forward(self, *args):
return torch.topk(args[0], k=3, dim=3)
class Topk4(Module):
def forward(self, *args):
return torch.topk(args[0], k=3, largest=True)
class Topk5(Module):
def forward(self, *args):
return torch.topk(args[0], k=3, largest=False)
class Topk6(Module):
def forward(self, *args):
return torch.topk(args[0], k=3, sorted=True)
input_shape = [1, 3, 10, 10]
input_data = torch.rand(input_shape).float()
verify_model(Topk1().float().eval(), input_data=input_data)
verify_model(Topk2().float().eval(), input_data=input_data)
verify_model(Topk3().float().eval(), input_data=input_data)
verify_model(Topk4().float().eval(), input_data=input_data)
verify_model(Topk5().float().eval(), input_data=input_data)
verify_model(Topk6().float().eval(), input_data=input_data)
if __name__ == "__main__": if __name__ == "__main__":
# Single operator tests # Single operator tests
test_forward_add() test_forward_add()
...@@ -1587,6 +1642,8 @@ if __name__ == "__main__": ...@@ -1587,6 +1642,8 @@ if __name__ == "__main__":
test_forward_size() test_forward_size()
test_forward_view() test_forward_view()
test_forward_select() test_forward_select()
test_forward_take()
test_forward_topk()
test_forward_clone() test_forward_clone()
test_forward_softplus() test_forward_softplus()
test_forward_softsign() test_forward_softsign()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment