Commit 770ac84e by Alexey Romanov Committed by Tianqi Chen

[Relay][Frontend] Simplify parameter handling in Tensorflow frontend (#2993)

parent 5999f7a6
...@@ -56,31 +56,23 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1, ...@@ -56,31 +56,23 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1,
layout = None layout = None
if target == "cuda": if target == "cuda":
layout = "NCHW" layout = "NCHW"
target_host = 'llvm' target_host = None
if isinstance(input_data, list): shape_dict = {e: i.shape for e, i in zip(input_node, input_data)}
shape_dict = {}
dtype_dict = {}
for i, e in enumerate(input_node):
shape_dict[e] = input_data[i].shape
dtype_dict[e] = input_data[i].dtype
else:
shape_dict = {input_node: input_data.shape}
dtype_dict = {input_node: input_data.dtype}
sym, params = relay.frontend.from_tensorflow(graph_def, sym, params = relay.frontend.from_tensorflow(graph_def,
layout=layout, layout=layout,
shape=shape_dict, shape=shape_dict,
outputs=out_names) outputs=out_names)
with relay.build_config(opt_level=opt_level): with relay.build_config(opt_level=opt_level):
graph, lib, params = relay.build(sym, target, params=params) graph, lib, params = relay.build(sym, target, target_host, params)
ctx = tvm.context(target, 0) ctx = tvm.context(target, 0)
from tvm.contrib import graph_runtime from tvm.contrib import graph_runtime
m = graph_runtime.create(graph, lib, ctx) m = graph_runtime.create(graph, lib, ctx)
# set inputs # set inputs
for i, e in enumerate(input_node): for e, i in zip(input_node, input_data):
m.set_input(e, tvm.nd.array(input_data[i].astype(input_data[i].dtype))) m.set_input(e, tvm.nd.array(i))
m.set_input(**params) m.set_input(**params)
# execute # execute
...@@ -88,10 +80,7 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1, ...@@ -88,10 +80,7 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1,
# get outputs # get outputs
assert out_names is None or num_output == len(out_names), ( assert out_names is None or num_output == len(out_names), (
"out_names: {} num_output: {}".format(out_names, num_output)) "out_names: {} num_output: {}".format(out_names, num_output))
tvm_output_list = [] tvm_output_list = [m.get_output(i).asnumpy() for i in range(num_output)]
for i in range(0, num_output):
tvm_output = m.get_output(i)
tvm_output_list.append(tvm_output.asnumpy())
return tvm_output_list return tvm_output_list
def run_tf_graph(sess, input_data, input_node, output_node): def run_tf_graph(sess, input_data, input_node, output_node):
...@@ -100,13 +89,9 @@ def run_tf_graph(sess, input_data, input_node, output_node): ...@@ -100,13 +89,9 @@ def run_tf_graph(sess, input_data, input_node, output_node):
input_node = convert_to_list(input_node) input_node = convert_to_list(input_node)
output_node = convert_to_list(output_node) output_node = convert_to_list(output_node)
tensor = [0] * len(output_node) tensor = [sess.graph.get_tensor_by_name(output_name) for output_name in output_node]
for i in range(len(output_node)):
tensor[i] = sess.graph.get_tensor_by_name(output_node[i])
input_dict = {} input_dict = {e: input_data[i] for i, e in enumerate(input_node)}
for i, e in enumerate(input_node):
input_dict[e] = input_data[i]
output_data = sess.run(tensor, input_dict) output_data = sess.run(tensor, input_dict)
return output_data return output_data
...@@ -115,17 +100,15 @@ def run_tf_graph(sess, input_data, input_node, output_node): ...@@ -115,17 +100,15 @@ def run_tf_graph(sess, input_data, input_node, output_node):
def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False, def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False,
no_gpu=False, opt_level=3): no_gpu=False, opt_level=3):
"""Generic function to generate and compare tensorflow and TVM output""" """Generic function to generate and compare tensorflow and TVM output"""
def name_without_num(name):
return name.split(':')[0] if ":" in name else name
out_name = convert_to_list(out_name) out_name = convert_to_list(out_name)
out_node = [0]*len(out_name) out_node = [name_without_num(name) for name in out_name]
for i in range(len(out_name)):
out_node[i] = out_name[i].split(':')[0] if ":" in out_name[i] else out_name[i]
in_data = convert_to_list(in_data) in_data = convert_to_list(in_data)
in_name = convert_to_list(in_name) in_name = convert_to_list(in_name)
in_node = [0]*len(in_name) in_node = [name_without_num(name) for name in in_name]
for i in range(len(in_name)):
in_node[i] = in_name[i].split(':')[0] if ":" in in_name[i] else in_name[i]
with tf.Session() as sess: with tf.Session() as sess:
if init_global_variables: if init_global_variables:
sess.run(variables.global_variables_initializer()) sess.run(variables.global_variables_initializer())
...@@ -578,6 +561,38 @@ def test_forward_variable(): ...@@ -578,6 +561,38 @@ def test_forward_variable():
####################################################################### #######################################################################
# MatMul
# ------
def _test_matmul(i, j, k, dtype, outer=None):
""" One iteration of matmul """
A_shape_init = [i, j]
B_shape_init = [j, k]
for transpose_a in [False, True]:
for transpose_b in [False, True]:
outer = outer or []
A_shape = outer + (A_shape_init[::-1] if transpose_a else A_shape_init)
B_shape = outer + (B_shape_init[::-1] if transpose_b else B_shape_init)
with tf.Graph().as_default():
A = tf.placeholder(shape=A_shape, dtype=dtype, name='A')
B = tf.placeholder(shape=B_shape, dtype=dtype, name='B')
result = tf.matmul(A, B, transpose_a=transpose_a, transpose_b=transpose_b)
A_np = np.random.uniform(high=5.0, size=A_shape).astype(dtype)
B_np = np.random.uniform(high=5.0, size=B_shape).astype(dtype)
compare_tf_with_tvm([A_np, B_np], [A.name, B.name], result.name)
def test_forward_matmul():
""" Matmul op test"""
_test_matmul(1, 3, 6, 'int32')
_test_matmul(5, 3, 1, 'float64')
# TODO non-empty outer requires BatchMatMul (BatchMatMulV2 for some cases?) support
#######################################################################
# StridedSlice # StridedSlice
# ------------ # ------------
...@@ -1785,3 +1800,6 @@ if __name__ == '__main__': ...@@ -1785,3 +1800,6 @@ if __name__ == '__main__':
test_forward_rel_ops() test_forward_rel_ops()
test_forward_logical() test_forward_logical()
test_where() test_where()
test_forward_matmul()
# TODO missing tests: rank, range
\ No newline at end of file
...@@ -151,11 +151,7 @@ def get_const_tuple(in_tuple): ...@@ -151,11 +151,7 @@ def get_const_tuple(in_tuple):
out_tuple : tuple of int out_tuple : tuple of int
The output. The output.
""" """
out_tuple = () return tuple(get_const_int(elem) for elem in in_tuple)
for elem in in_tuple:
value = get_const_int(elem)
out_tuple = out_tuple + (value, )
return out_tuple
def get_float_tuple(in_tuple): def get_float_tuple(in_tuple):
...@@ -171,11 +167,7 @@ def get_float_tuple(in_tuple): ...@@ -171,11 +167,7 @@ def get_float_tuple(in_tuple):
out_tuple : tuple of float out_tuple : tuple of float
The output. The output.
""" """
out_tuple = () return tuple(get_const_float(elem) for elem in in_tuple)
for elem in in_tuple:
value = get_const_float(elem)
out_tuple = out_tuple + (value, )
return out_tuple
def simplify(expr): def simplify(expr):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment