Commit 123a4077 by Haichen Shen Committed by Yao Wang

[Hybrid][Fix] Fix hybrid script to support array of tensors (#4494)

* [Fix][Hybrid] Fix hybrid script to support array of tensors

* add test case

* clean up

* trigger ci
parent fb12f356
...@@ -647,9 +647,15 @@ def source_to_op(src, args, symbols, closure_vars): ...@@ -647,9 +647,15 @@ def source_to_op(src, args, symbols, closure_vars):
parser = parse_python(src, args, symbols, closure_vars) parser = parse_python(src, args, symbols, closure_vars)
input_tensors = [] input_tensors = []
def get_input_tensors(arg):
if isinstance(arg, Tensor):
input_tensors.append(arg)
elif isinstance(arg, Array):
for i in arg:
get_input_tensors(i)
for i in args: for i in args:
if isinstance(i, Tensor): get_input_tensors(i)
input_tensors.append(i)
op = _tvm_internal._HybridOp(parser.func_name, "HybridOp", None, input_tensors, op = _tvm_internal._HybridOp(parser.func_name, "HybridOp", None, input_tensors,
parser.outputs, parser.parsed_body) parser.outputs, parser.parsed_body)
res = [op.output(i) for i in range(len(parser.outputs))] res = [op.output(i) for i in range(len(parser.outputs))]
......
...@@ -789,6 +789,37 @@ def test_capture(): ...@@ -789,6 +789,37 @@ def test_capture():
func, ins, outs = run_and_check(add_something, [a]) func, ins, outs = run_and_check(add_something, [a])
run_and_check(func, ins, outs=outs) run_and_check(func, ins, outs=outs)
def test_array_inputs():
@script
def sum_array(inputs):
out = output_tensor((10,), inputs[0].dtype)
n = len(inputs)
for i in range(10):
for j in const_range(n):
out[i] += inputs[j][i]
return out
n = 5
inputs = []
for i in range(n):
inputs.append(tvm.placeholder((10,), name='t%s' % i, dtype='float32'))
out = sum_array(tvm.convert(inputs))
assert len(out.op.inputs) == n
sch = tvm.create_schedule(out.op)
mod = tvm.build(sch, inputs + [out], target='llvm')
assert mod
input_nd = []
out_ref = numpy.zeros((10,))
for _ in range(n):
arr = numpy.random.uniform(size=(10,)).astype('float32')
input_nd.append(tvm.nd.array(arr))
out_ref += arr
out_nd = tvm.nd.array(numpy.zeros((10,), 'float32'))
mod(*input_nd, out_nd)
tvm.testing.assert_allclose(out_nd.asnumpy(), out_ref)
if __name__ == "__main__": if __name__ == "__main__":
test_outer_product() test_outer_product()
test_fanout() test_fanout()
...@@ -807,5 +838,6 @@ if __name__ == "__main__": ...@@ -807,5 +838,6 @@ if __name__ == "__main__":
test_const_range() test_const_range()
test_schedule() test_schedule()
test_capture() test_capture()
test_array_inputs()
# TODO: # TODO:
# test_inplace() # test_inplace()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment