Unverified Commit 9bbee96f by Samuel Committed by GitHub

[PYTORCH]Tensor creation ops support (#5347)

parent 0075d8ce
...@@ -348,12 +348,25 @@ def _ones(): ...@@ -348,12 +348,25 @@ def _ones():
msg = "Data type %s could not be parsed in ones op" % (type(data)) msg = "Data type %s could not be parsed in ones op" % (type(data))
raise AssertionError(msg) raise AssertionError(msg)
dtype_map = {6: "float32", 3: "int32"} dtype = _convert_data_type(_convert_dtype_value(inputs[1]))
dtype_id = inputs[1]
assert dtype_id in dtype_map, "Unsupported dtype %d" % dtype_id return _op.full(_expr.const(1), shape, dtype=dtype)
return _op.full(_expr.const(1), shape, dtype=dtype_map[dtype_id]) return _impl
def _ones_like():
def _impl(inputs, input_types):
data = inputs[0]
out = _op.ones_like(data)
# If the input and the output datatype is different, do a cast
dtype = _convert_data_type(_convert_dtype_value(inputs[1]))
if input_types[0] not in dtype:
out = _op.cast(out, dtype)
return out
return _impl return _impl
def _zeros(): def _zeros():
def _impl(inputs, input_types): def _impl(inputs, input_types):
data = inputs[0] data = inputs[0]
...@@ -369,12 +382,88 @@ def _zeros(): ...@@ -369,12 +382,88 @@ def _zeros():
msg = "Data type %s could not be parsed in zeros op" % (type(data)) msg = "Data type %s could not be parsed in zeros op" % (type(data))
raise AssertionError(msg) raise AssertionError(msg)
dtype_map = {6: "float32", 3: "int32"} dtype = _convert_data_type(_convert_dtype_value(inputs[1]))
dtype_id = inputs[1]
assert dtype_id in dtype_map, "Unsupported dtype %d" % dtype_id return _op.full(_expr.const(0), shape, dtype=dtype)
return _op.full(_expr.const(0), shape, dtype=dtype_map[dtype_id]) return _impl
def _zeros_like():
def _impl(inputs, input_types):
data = inputs[0]
out = _op.zeros_like(data)
# If the input and the output datatype is different, do a cast
dtype = _convert_data_type(_convert_dtype_value(inputs[1]))
if input_types[0] not in dtype:
out = _op.cast(out, dtype)
return out
return _impl
def _full():
def _impl(inputs, input_types):
data = inputs[0]
fill_value = inputs[1]
import torch
if isinstance(data, _expr.Expr):
shape = _infer_shape(data)
elif isinstance(data, list):
shape = data
elif isinstance(data, (torch.Tensor, np.ndarray)):
shape = data.shape
else:
msg = "Data type %s could not be parsed in zeros op" % (type(data))
raise AssertionError(msg)
dtype = _convert_data_type(_convert_dtype_value(inputs[2]))
return _op.full(_expr.const(fill_value), shape, dtype=dtype)
return _impl
def _full_like():
def _impl(inputs, input_types):
data = inputs[0]
fill_value = inputs[1]
out = _op.full_like(data, _expr.const(fill_value))
# If the input and the output datatype is different, do a cast
dtype = _convert_data_type(_convert_dtype_value(inputs[2]))
if input_types[0] not in dtype:
out = _op.cast(out, dtype)
return out
return _impl return _impl
def _linspace():
def _impl(inputs, input_types):
start = inputs[0]
stop = inputs[1]
step = inputs[2]
# Find the spacing between values as step
if step != 1:
step = (stop - start) / (step - 1)
stop = stop + step
else:
stop = start + step
dtype = "float" if "float" in input_types[0:3] else _convert_dtype_value(inputs[3])
start = _create_typed_const(start, dtype)
stop = _create_typed_const(stop, dtype)
step = _create_typed_const(step, dtype)
return _op.transform.arange(start=start,
stop=stop,
step=step,
dtype=_convert_data_type(dtype))
return _impl
def _relu(): def _relu():
def _impl(inputs, input_types): def _impl(inputs, input_types):
data = inputs[0] data = inputs[0]
...@@ -1497,7 +1586,12 @@ def _get_convert_map(prelude): ...@@ -1497,7 +1586,12 @@ def _get_convert_map(prelude):
"aten::div" : _elemwise("divide"), "aten::div" : _elemwise("divide"),
"aten::div_" : _elemwise("divide"), "aten::div_" : _elemwise("divide"),
"aten::ones" : _ones(), "aten::ones" : _ones(),
"aten::ones_like" : _ones_like(),
"aten::zeros" : _zeros(), "aten::zeros" : _zeros(),
"aten::zeros_like" : _zeros_like(),
"aten::full" : _full(),
"aten::full_like" : _full_like(),
"aten::linspace" : _linspace(),
"aten::reciprocal" : _reciprocal(), "aten::reciprocal" : _reciprocal(),
"aten::repeat" : _repeat(), "aten::repeat" : _repeat(),
"aten::repeat_interleave" : _repeat_interleave(), "aten::repeat_interleave" : _repeat_interleave(),
......
...@@ -1545,6 +1545,144 @@ def test_forward_round(): ...@@ -1545,6 +1545,144 @@ def test_forward_round():
verify_model(Round1().float().eval(), input_data=input_data) verify_model(Round1().float().eval(), input_data=input_data)
def test_forward_ones():
torch.set_grad_enabled(False)
class Ones1(Module):
def forward(self, *args):
return torch.ones(2,3)
verify_model(Ones1().float().eval(), input_data=[])
def test_forward_ones_like():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]
class OnesLike1(Module):
def forward(self, *args):
return torch.ones_like(args[0])
class OnesLike2(Module):
def forward(self, *args):
return torch.ones_like(args[0], dtype=torch.int8)
class OnesLike3(Module):
def forward(self, *args):
return torch.ones_like(args[0], dtype=torch.float)
input_data = torch.rand(input_shape).float()
verify_model(OnesLike1().float().eval(), input_data=input_data)
verify_model(OnesLike2().float().eval(), input_data=input_data)
verify_model(OnesLike3().float().eval(), input_data=input_data)
def test_forward_zeros():
torch.set_grad_enabled(False)
class Zeros1(Module):
def forward(self, *args):
return torch.zeros(2,3)
verify_model(Zeros1().float().eval(), input_data=[])
def test_forward_zeros_like():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]
class ZerosLike1(Module):
def forward(self, *args):
return torch.zeros_like(args[0])
class ZerosLike2(Module):
def forward(self, *args):
return torch.zeros_like(args[0], dtype=torch.int32)
class ZerosLike3(Module):
def forward(self, *args):
return torch.zeros_like(args[0], dtype=torch.float)
input_data = torch.rand(input_shape).float()
verify_model(ZerosLike1().float().eval(), input_data=input_data)
verify_model(ZerosLike2().float().eval(), input_data=input_data)
verify_model(ZerosLike3().float().eval(), input_data=input_data)
def test_forward_full():
torch.set_grad_enabled(False)
class Full1(Module):
def forward(self, *args):
return torch.full((2,3), 3.14)
class Full2(Module):
def forward(self, *args):
return torch.full((1, 2,3), 1.0, dtype=torch.int32)
verify_model(Full1().float().eval(), input_data=[])
verify_model(Full2().float().eval(), input_data=[])
def test_forward_full_like():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]
class FullLike1(Module):
def forward(self, *args):
return torch.full_like(args[0], 3.14)
class FullLike2(Module):
def forward(self, *args):
return torch.full_like(args[0], 22.22, dtype=torch.int32)
class FullLike3(Module):
def forward(self, *args):
return torch.full_like(args[0], 1.4, dtype=torch.float)
input_data = torch.rand(input_shape).float()
verify_model(FullLike1().float().eval(), input_data=input_data)
verify_model(FullLike2().float().eval(), input_data=input_data)
verify_model(FullLike3().float().eval(), input_data=input_data)
def test_forward_linspace():
torch.set_grad_enabled(False)
class Linspace1(Module):
def forward(self, *args):
return torch.linspace(5, 10)
class Linspace2(Module):
def forward(self, *args):
return torch.linspace(-10, 10, steps=5)
class Linspace3(Module):
def forward(self, *args):
return torch.linspace(start=-10, end=10, steps=5)
class Linspace4(Module):
def forward(self, *args):
return torch.linspace(start=-10, end=10, steps=1)
class Linspace5(Module):
def forward(self, *args):
return torch.linspace(1, 2, 1, dtype=torch.int32)
class Linspace6(Module):
def forward(self, *args):
return torch.linspace(start=1, end=6, steps=2)
class Linspace7(Module):
def forward(self, *args):
return torch.linspace(1, 4, dtype=torch.float32)
class Linspace8(Module):
def forward(self, *args):
return torch.linspace(1, 2, 1, dtype=torch.int16)
verify_model(Linspace1().float().eval())
verify_model(Linspace2().float().eval())
verify_model(Linspace3().float().eval())
verify_model(Linspace4().float().eval())
verify_model(Linspace5().float().eval())
verify_model(Linspace6().float().eval())
verify_model(Linspace7().float().eval())
verify_model(Linspace8().float().eval())
def test_forward_take(): def test_forward_take():
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
class Take1(Module): class Take1(Module):
...@@ -1759,6 +1897,13 @@ if __name__ == "__main__": ...@@ -1759,6 +1897,13 @@ if __name__ == "__main__":
test_forward_isfinite() test_forward_isfinite()
test_forward_isnan() test_forward_isnan()
test_forward_isinf() test_forward_isinf()
test_forward_ones()
test_forward_ones_like()
test_forward_zeros()
test_forward_zeros_like()
test_forward_full()
test_forward_full_like()
test_forward_linspace()
test_forward_arange() test_forward_arange()
test_forward_chunk() test_forward_chunk()
test_forward_split() test_forward_split()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment