Commit 59cf5735 by Wei Chen Committed by Haichen Shen

[TF][Op] Op where (#4045)

* [TF][Op] Add TF op Where

* improve tests

* add tests for vm
parent 2d537621
......@@ -937,6 +937,8 @@ def _transpose():
def _where():
def _impl(inputs, attr, params):
if len(inputs) == 1:
return AttrCvt(op_name="argwhere")(inputs, attr)
return AttrCvt(op_name="where")(inputs, attr)
return _impl
......@@ -1354,6 +1356,7 @@ _convert_map = {
'Transpose' : _transpose(),
'TruncateMod' : _elemwise('mod'),
'Unpack' : _unpack(),
'Where' : _where(),
'ZerosLike' : AttrCvt('zeros_like'),
}
......
......@@ -46,8 +46,34 @@ def convert_to_list(x):
x = [x]
return x
def vmobj_to_list(o):
if isinstance(o, tvm.relay.backend.vmobj.TensorObject):
return [o.asnumpy().tolist()]
elif isinstance(o, tvm.relay.backend.vmobj.DatatypeObject):
result = []
for f in o:
result.extend(vmobj_to_list(f))
return result
elif isinstance(o, tvm.relay.backend.interpreter.TupleValue):
result = []
for f in o.fields:
result.append(vmobj_to_list(f))
return result
elif isinstance(o, tvm.relay.backend.interpreter.ConstructorValue):
if o.constructor.name_hint == 'cons':
tl = vmobj_to_list(o.fields[1])
hd = vmobj_to_list(o.fields[0])
hd.extend(tl)
return hd
elif o.constructor.name_hint == 'nil':
return []
elif isinstance(o, tvm.relay.backend.interpreter.TensorValue):
return [o.data.asnumpy()]
else:
raise RuntimeError("Unknown object type: %s" % type(o))
def run_tvm_graph(graph_def, input_data, input_node, num_output=1,
target='llvm', out_names=None, opt_level=3):
target='llvm', out_names=None, opt_level=3, mode='graph_runtime'):
""" Generic function to compile on relay and execute on tvm """
input_data = convert_to_list(input_data)
input_node = convert_to_list(input_node)
......@@ -63,6 +89,14 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1,
layout=layout,
shape=shape_dict,
outputs=out_names)
if mode in ['debug', 'vm']:
ex = relay.create_executor(mode, mod=mod, ctx=tvm.cpu(), target="llvm")
inputs = []
for param in mod['main'].params:
inputs.append(tvm.nd.array(params[param.name_hint]))
result = ex.evaluate()(*inputs)
return vmobj_to_list(result)
else:
with relay.build_config(opt_level=opt_level):
graph, lib, params = relay.build(mod, target, target_host, params)
......@@ -97,7 +131,7 @@ def run_tf_graph(sess, input_data, input_node, output_node):
def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False,
no_gpu=False, opt_level=3):
no_gpu=False, opt_level=3, mode='graph_runtime'):
"""Generic function to generate and compare tensorflow and TVM output"""
def name_without_num(name):
return name.split(':')[0] if ":" in name else name
......@@ -128,7 +162,7 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False,
tvm_output = run_tvm_graph(final_graph_def, in_data, in_node,
target=device, out_names=out_name,
num_output=len(out_name), opt_level=opt_level)
num_output=len(out_name), opt_level=opt_level, mode=mode)
# since the names from tensorflow and relay runs are not exactly same,
# first len(tf_output) will be compared
for i in range(len(tf_output)):
......@@ -325,6 +359,22 @@ def test_forward_biasadd():
_test_biasadd([4, 17, 17, 19], 'NHWC')
_test_biasadd([4, 3, 3, 124], 'NHWC')
def _test_forward_where(input_shape):
with tf.Graph().as_default():
dtype = tf.float32
t = tf.constant(np.random.choice([0, 1, -2, 3, -1, 0.1, -0.2],
size=input_shape).astype(dtype.name))
out = tf.where(t)
compare_tf_with_tvm([], [], out.name, mode='debug')
compare_tf_with_tvm([], [], out.name, mode='vm')
def test_forward_argwhere():
_test_forward_where((5,))
_test_forward_where((5, 5))
_test_forward_where((5, 5, 5))
_test_forward_where((5, 5, 5, 5))
_test_forward_where((5, 5, 5, 5, 5))
#######################################################################
# SpaceToBatchND
# --------------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment