Unverified Commit 83930a3b by Samuel Committed by GitHub

[PYTORCH]Rsub, Embedded, OneHot ops support (#5434)

parent 52bf1b35
...@@ -1477,6 +1477,50 @@ def _tensor_array_stack(prelude): ...@@ -1477,6 +1477,50 @@ def _tensor_array_stack(prelude):
return _impl return _impl
def _rsub():
def _impl(inputs, input_types):
# TODO: Figure out a better way to get typing to work for tensor + scalar
type0 = input_types[0]
if isinstance(inputs[1], _expr.Expr):
type0 = input_types[1]
type1 = input_types[1]
if isinstance(inputs[0], _expr.Expr):
type1 = input_types[0]
data1 = _convert_elemwise_input(inputs[0], type0)
data0 = _convert_elemwise_input(inputs[1], type1)
alpha = _expr.const(float(inputs[2]))
return get_relay_op("subtract")(data0, alpha * data1)
return _impl
def _embedding():
def _impl(inputs, input_types):
weight = inputs[0]
indices = inputs[1]
return _op.take(weight, indices.astype('int32'), axis=0)
return _impl
def _one_hot():
def _impl(inputs, input_types):
indices = inputs[0].astype('int32')
num_classes = inputs[1]
if num_classes == -1:
msg = "Inferring the number of classes is not yet supported."
raise NotImplementedError(msg)
dtype = 'int32'
on_value = tvm.relay.const(1.0, dtype)
off_value = tvm.relay.const(0.0, dtype)
return _op.one_hot(indices, on_value, off_value, num_classes, -1, dtype)
return _impl
# Helper functions for operator implementation # Helper functions for operator implementation
def _convert_dtype_value(val): def _convert_dtype_value(val):
convert_torch_dtype_map = {7:"torch.float64", convert_torch_dtype_map = {7:"torch.float64",
...@@ -1690,6 +1734,9 @@ def _get_convert_map(prelude): ...@@ -1690,6 +1734,9 @@ def _get_convert_map(prelude):
"aten::Float" : _Float(), "aten::Float" : _Float(),
"aten::adaptive_avg_pool3d" : _adaptive_avg_pool_3d(), "aten::adaptive_avg_pool3d" : _adaptive_avg_pool_3d(),
"aten::adaptive_max_pool3d" : _adaptive_max_pool_3d(), "aten::adaptive_max_pool3d" : _adaptive_max_pool_3d(),
"aten::rsub" : _rsub(),
"aten::embedding" : _embedding(),
"aten::one_hot" : _one_hot(),
"aten::mm" : _matmul(), "aten::mm" : _matmul(),
"relay::tensor_array_stack" : _tensor_array_stack(prelude), "relay::tensor_array_stack" : _tensor_array_stack(prelude),
"aten::add" : _add(prelude), "aten::add" : _add(prelude),
......
...@@ -1463,6 +1463,56 @@ def test_forward_variance(): ...@@ -1463,6 +1463,56 @@ def test_forward_variance():
verify_model(Variance5().float().eval(), input_data=input_data) verify_model(Variance5().float().eval(), input_data=input_data)
def test_forward_rsub():
torch.set_grad_enabled(False)
class Rsub1(Module):
def forward(self, *args):
return torch.rsub(args[0], args[1])
class Rsub2(Module):
def forward(self, *args):
return torch.rsub(args[0], args[1], alpha=0.5)
d1 = torch.rand([1, 3]).float()
d2 = torch.rand([1, 3]).float()
d3 = torch.rand([1, 3]).int()
verify_model(Rsub1().float().eval(), input_data=[d1, d2])
verify_model(Rsub1().float().eval(), input_data=[d1, d3])
verify_model(Rsub2().float().eval(), input_data=[d1, d2])
verify_model(Rsub2().float().eval(), input_data=[d1, d3])
def test_forward_embedding():
torch.set_grad_enabled(False)
input_data = torch.randint(0, 10, [2, 4]).long()
verify_model(torch.nn.Embedding(10, 3).float().eval(), input_data=input_data)
input_data = torch.randint(0, 4, [2, 3, 4]).long()
verify_model(torch.nn.Embedding(4, 5, sparse=False).float().eval(), input_data=input_data)
input_data = torch.randint(0, 4, [2, 3, 4]).long()
verify_model(torch.nn.Embedding(4, 5, sparse=True).float().eval(), input_data=input_data)
def test_forward_onehot():
torch.set_grad_enabled(False)
class OneHot1(Module):
def forward(self, *args):
return torch.nn.functional.one_hot(args[0], num_classes=3)
class OneHot2(Module):
def forward(self, *args):
return torch.nn.functional.one_hot(args[0], num_classes=5)
input_data = torch.arange(0, 5) % 3
verify_model(OneHot1().float().eval(), input_data=input_data)
input_data = torch.arange(0, 5) % 4
verify_model(OneHot2().float().eval(), input_data=input_data)
def test_forward_isfinite(): def test_forward_isfinite():
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
...@@ -1984,6 +2034,9 @@ if __name__ == "__main__": ...@@ -1984,6 +2034,9 @@ if __name__ == "__main__":
test_forward_add() test_forward_add()
test_forward_subtract() test_forward_subtract()
test_forward_multiply() test_forward_multiply()
test_forward_rsub()
test_forward_onehot()
test_forward_embedding()
test_forward_reshape() test_forward_reshape()
test_forward_reciprocal() test_forward_reciprocal()
test_forward_repeat() test_forward_repeat()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment