Commit 19164063 by Wei Chen Committed by Zhi

[Relay][Prelude] Add more dtypes to tensor_t (#4233)

parent aa49e851
...@@ -591,6 +591,14 @@ class Prelude: ...@@ -591,6 +591,14 @@ class Prelude:
for global_def in GLOBAL_DEFS: for global_def in GLOBAL_DEFS:
setattr(self, global_def, self.mod.get_global_var(global_def)) setattr(self, global_def, self.mod.get_global_var(global_def))
for dtype in ['float32', 'int32']: for dtype in ['float32',
'float16',
'float64',
'int32',
'uint8',
'int8',
'int16',
'uint16',
'int64']:
tensor_array_ops = TensorArrayOps(self, dtype) tensor_array_ops = TensorArrayOps(self, dtype)
tensor_array_ops.register() tensor_array_ops.register()
...@@ -48,6 +48,17 @@ def convert_to_list(x): ...@@ -48,6 +48,17 @@ def convert_to_list(x):
x = [x] x = [x]
return x return x
tf_dtypes = {
'float32': tf.float32,
'float16': tf.float16,
'float64': tf.float64,
'int32': tf.int32,
'uint8' : tf.uint8,
'int8': tf.int8,
'int16': tf.int16,
'uint16': tf.uint16,
'int64': tf.int64,
}
def vmobj_to_list(o): def vmobj_to_list(o):
if isinstance(o, tvm.relay.backend.vmobj.Tensor): if isinstance(o, tvm.relay.backend.vmobj.Tensor):
...@@ -626,34 +637,24 @@ def test_forward_squeeze(): ...@@ -626,34 +637,24 @@ def test_forward_squeeze():
def test_tensor_array_constructor(): def test_tensor_array_constructor():
def run(dtype_str): def run(dtype_str):
with tf.Graph().as_default(): with tf.Graph().as_default():
dtype = { dtype = tf_dtypes[dtype_str]
'float32': tf.float32, t = tf.constant(np.array([[1.0, 2.0], [3.0, 4.0]]).astype(dtype_str), dtype=dtype)
'int32': tf.int32 t2 = tf.constant(np.array([[1.0, 2.0], [3.0, 4.0]]).astype(dtype_str), dtype=dtype)
}[dtype_str] ta1 = tf.TensorArray(dtype=dtype, size=2, infer_shape=False, dynamic_size=False)
t = tf.constant(np.array([[1.0, 2.0], [3.0, 4.0]]).astype(
dtype_str), dtype=dtype)
t2 = tf.constant(np.array([[1.0, 2.0], [3.0, 4.0]]).astype(
dtype_str), dtype=dtype)
ta1 = tf.TensorArray(dtype=dtype, size=2,
infer_shape=False, dynamic_size=False)
ta2 = ta1.write(0, t) ta2 = ta1.write(0, t)
ta3 = ta2.write(1, t2) ta3 = ta2.write(1, t2)
out = ta3.read(0) out = ta3.read(0)
g = tf.get_default_graph() g = tf.get_default_graph()
compare_tf_with_tvm([], [], 'TensorArrayReadV3:0', mode='debug') compare_tf_with_tvm([], [], 'TensorArrayReadV3:0', mode='debug')
run('float32') for dtype in tf_dtypes.keys():
run('int32') run(dtype)
def test_tensor_array_scatter(): def test_tensor_array_scatter():
def run(dtype_str): def run(dtype_str):
with tf.Graph().as_default(): with tf.Graph().as_default():
dtype = { dtype = tf_dtypes[dtype_str]
'float32': tf.float32, t = tf.constant(np.array([[1.0], [2.0], [3.0]]).astype(dtype_str), dtype=dtype)
'int32': tf.int32
}[dtype_str]
t = tf.constant(np.array([[1.0], [2.0], [3.0]]).astype(
dtype_str), dtype=dtype)
indices = tf.constant([2, 1, 0]) indices = tf.constant([2, 1, 0])
ta1 = tf.TensorArray(dtype=dtype, size=3, ta1 = tf.TensorArray(dtype=dtype, size=3,
infer_shape=False, dynamic_size=False) infer_shape=False, dynamic_size=False)
...@@ -663,12 +664,10 @@ def test_tensor_array_scatter(): ...@@ -663,12 +664,10 @@ def test_tensor_array_scatter():
out2 = ta2.read(2) out2 = ta2.read(2)
g = tf.get_default_graph() g = tf.get_default_graph()
compare_tf_with_tvm([], [], ['TensorArrayReadV3:0'], mode='debug') compare_tf_with_tvm([], [], ['TensorArrayReadV3:0'], mode='debug')
compare_tf_with_tvm( compare_tf_with_tvm([], [], ['TensorArrayReadV3_1:0'], mode='debug')
[], [], ['TensorArrayReadV3_1:0'], mode='debug') compare_tf_with_tvm([], [], ['TensorArrayReadV3_2:0'], mode='debug')
compare_tf_with_tvm( for dtype in tf_dtypes.keys():
[], [], ['TensorArrayReadV3_2:0'], mode='debug') run(dtype)
run('float32')
run('int32')
# TODO(wweic): Fix gather issue with PartialEvaluate # TODO(wweic): Fix gather issue with PartialEvaluate
# def test_tensor_array_gather(): # def test_tensor_array_gather():
...@@ -687,12 +686,8 @@ def test_tensor_array_scatter(): ...@@ -687,12 +686,8 @@ def test_tensor_array_scatter():
def test_tensor_array_split(): def test_tensor_array_split():
def run(dtype_str): def run(dtype_str):
with tf.Graph().as_default(): with tf.Graph().as_default():
dtype = { dtype = tf_dtypes[dtype_str]
'float32': tf.float32, t = tf.constant(np.array([[1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0], [8.0]]).astype(dtype_str), dtype=dtype)
'int32': tf.int32
}[dtype_str]
t = tf.constant(np.array([[1.0], [2.0], [3.0], [4.0], [5.0], [
6.0], [7.0], [8.0]]).astype(dtype_str), dtype=dtype)
split_length = tf.constant([2, 2, 2, 2], dtype=tf.int32) split_length = tf.constant([2, 2, 2, 2], dtype=tf.int32)
ta1 = tf.TensorArray(dtype=dtype, size=4, ta1 = tf.TensorArray(dtype=dtype, size=4,
infer_shape=False, dynamic_size=False) infer_shape=False, dynamic_size=False)
...@@ -703,50 +698,38 @@ def test_tensor_array_split(): ...@@ -703,50 +698,38 @@ def test_tensor_array_split():
out3 = ta2.read(3) out3 = ta2.read(3)
g = tf.get_default_graph() g = tf.get_default_graph()
compare_tf_with_tvm([], [], ['TensorArrayReadV3:0'], mode='debug') compare_tf_with_tvm([], [], ['TensorArrayReadV3:0'], mode='debug')
compare_tf_with_tvm( compare_tf_with_tvm([], [], ['TensorArrayReadV3_1:0'], mode='debug')
[], [], ['TensorArrayReadV3_1:0'], mode='debug') compare_tf_with_tvm([], [], ['TensorArrayReadV3_2:0'], mode='debug')
compare_tf_with_tvm( compare_tf_with_tvm([], [], ['TensorArrayReadV3_3:0'], mode='debug')
[], [], ['TensorArrayReadV3_2:0'], mode='debug') for dtype in tf_dtypes.keys():
compare_tf_with_tvm( run(dtype)
[], [], ['TensorArrayReadV3_3:0'], mode='debug')
run('float32')
run('int32')
def test_tensor_array_concat(): def test_tensor_array_concat():
def run(dtype_str): def run(dtype_str):
with tf.Graph().as_default(): with tf.Graph().as_default():
dtype = { dtype = tf_dtypes[dtype_str]
'float32': tf.float32, t = tf.constant(np.array([[1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0], [8.0]]).astype(dtype_str), dtype=dtype)
'int32': tf.int32
}[dtype_str]
t = tf.constant(np.array([[1.0], [2.0], [3.0], [4.0], [5.0], [
6.0], [7.0], [8.0]]).astype(dtype_str), dtype=dtype)
split_length = tf.constant([2, 2, 2, 2], dtype=tf.int32) split_length = tf.constant([2, 2, 2, 2], dtype=tf.int32)
ta1 = tf.TensorArray(dtype=dtype, size=4, ta1 = tf.TensorArray(dtype=dtype, size=4,
infer_shape=False, dynamic_size=False) infer_shape=False, dynamic_size=False)
ta2 = ta1.split(t, split_length) ta2 = ta1.split(t, split_length)
t = ta2.concat() t = ta2.concat()
compare_tf_with_tvm( compare_tf_with_tvm([], [], ['TensorArrayConcatV3:0'], mode='debug')
[], [], ['TensorArrayConcatV3:0'], mode='debug') for dtype in tf_dtypes.keys():
run('float32') run(dtype)
run('int32')
def test_tensor_array_size(): def test_tensor_array_size():
def run(dtype_str): def run(dtype_str):
with tf.Graph().as_default(): with tf.Graph().as_default():
dtype = { dtype = tf_dtypes[dtype_str]
'float32': tf.float32, ta1 = tf.TensorArray(dtype=dtype, size=2, infer_shape=False, dynamic_size=False)
'int32': tf.int32
}[dtype_str]
ta1 = tf.TensorArray(dtype=dtype, size=2,
infer_shape=False, dynamic_size=False)
out = ta1.size() out = ta1.size()
g = tf.get_default_graph() g = tf.get_default_graph()
compare_tf_with_tvm([], [], 'TensorArraySizeV3:0', mode='debug') compare_tf_with_tvm([], [], 'TensorArraySizeV3:0', mode='debug')
run('float32') for dtype in tf_dtypes.keys():
run('int32') run(dtype)
####################################################################### #######################################################################
# ConcatV2 # ConcatV2
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment