Commit 4bbf96e4 by Jian Weng Committed by Tianqi Chen

[BUGFIX] [Hybrid Script] fix in-correct value index in hybrid script (#2268)

parent 6b405824
...@@ -39,10 +39,11 @@ class HybridParser(ast.NodeVisitor): ...@@ -39,10 +39,11 @@ class HybridParser(ast.NodeVisitor):
ast.Sub : operator.sub, ast.Sub : operator.sub,
ast.Mult : operator.mul, ast.Mult : operator.mul,
ast.Div : operator.div if sys.version_info[0] == 2 else operator.truediv, ast.Div : operator.div if sys.version_info[0] == 2 else operator.truediv,
ast.FloorDiv: operator.div if sys.version_info[0] == 2 else operator.truediv,
ast.Mod : operator.mod, ast.Mod : operator.mod,
ast.BitOr : operator.or_, ast.BitOr : operator.or_,
ast.BitAnd: operator.and_, ast.BitAnd : operator.and_,
ast.BitXor: operator.xor, ast.BitXor : operator.xor,
ast.Gt : operator.gt, ast.Gt : operator.gt,
ast.GtE : operator.ge, ast.GtE : operator.ge,
ast.Lt : operator.lt, ast.Lt : operator.lt,
...@@ -237,7 +238,7 @@ class HybridParser(ast.NodeVisitor): ...@@ -237,7 +238,7 @@ class HybridParser(ast.NodeVisitor):
if isinstance(node.value, ast.Name): if isinstance(node.value, ast.Name):
array = node.value.id array = node.value.id
_buf = self._get_buffer_from_id(array) _buf = self._get_buffer_from_id(array)
return _make.Call(_buf.dtype, array, args, _expr.Call.Halide, _buf.op, 0) return _make.Call(_buf.dtype, array, args, _expr.Call.Halide, _buf.op, _buf.value_index)
_internal_assert(isinstance(node.value, ast.Attribute), \ _internal_assert(isinstance(node.value, ast.Attribute), \
"Only variable and attribute's subscript supported so far") "Only variable and attribute's subscript supported so far")
......
import tvm, inspect, sys, traceback, numpy, nose import tvm, inspect, sys, traceback, numpy, nose, types
from tvm.hybrid import script from tvm.hybrid import script
from tvm.hybrid.intrin import HYBRID_GLOBALS from tvm.hybrid.intrin import HYBRID_GLOBALS
...@@ -11,6 +11,10 @@ def run_and_check(func, args, var_dict={}, target='llvm'): ...@@ -11,6 +11,10 @@ def run_and_check(func, args, var_dict={}, target='llvm'):
return val.value return val.value
ctx = tvm.context(target, 0) ctx = tvm.context(target, 0)
op = None
outs = func(*args)
op = outs[0].op if isinstance(outs, list) else outs.op
emu_args = [] emu_args = []
nd_args = [] nd_args = []
...@@ -24,8 +28,6 @@ def run_and_check(func, args, var_dict={}, target='llvm'): ...@@ -24,8 +28,6 @@ def run_and_check(func, args, var_dict={}, target='llvm'):
emu_args.append(tvm_val_2_py_val(i)) emu_args.append(tvm_val_2_py_val(i))
nd_args.append(emu_args[-1]) nd_args.append(emu_args[-1])
outs = func(*args)
op = outs[0].op if isinstance(outs, list) else outs.op
sch = tvm.create_schedule(op) sch = tvm.create_schedule(op)
module = tvm.build(sch, args + (outs if isinstance(outs, list) else [outs]), target=target) module = tvm.build(sch, args + (outs if isinstance(outs, list) else [outs]), target=target)
assert module assert module
...@@ -426,9 +428,11 @@ def test_downstream(): ...@@ -426,9 +428,11 @@ def test_downstream():
b[i] = a[i] * i b[i] = a[i] * i
return b return b
a = tvm.placeholder((20, ), 'float32') a = tvm.placeholder((20, ), 'float32')
b = downstream(a) b = downstream(a)
c = tvm.compute((20, ), lambda x: b[x] + 1.0) c = tvm.compute((20, ), lambda x: b[x] + 1.0)
sch = tvm.create_schedule(c.op) sch = tvm.create_schedule(c.op)
module = tvm.build(sch, [a, c]) module = tvm.build(sch, [a, c])
assert module assert module
...@@ -469,6 +473,40 @@ def test_const_param(): ...@@ -469,6 +473,40 @@ def test_const_param():
tvm.testing.assert_allclose(nd_c.asnumpy(), ref, 1e-5, 1e-5) tvm.testing.assert_allclose(nd_c.asnumpy(), ref, 1e-5, 1e-5)
def test_value_index():
@tvm.hybrid.script
def kernel_a(a):
b = output_tensor((16, ), 'int32')
c = output_tensor((4, 4), 'int32')
for i in range(16):
b[i] = a[i] + 2
c[i // 4, i % 4] = a[i] + 1
return b, c
@tvm.hybrid.script
def kernel_b(b, a):
c = output_tensor((4, 4), 'int32')
for i in range(4):
for j in range(4):
c[i, j] = a[i * 4 + j] * b[i, j]
return c
a = tvm.placeholder((16, ), 'int32')
b, c = kernel_a(a)
d = kernel_b(c, b)
sch = tvm.create_schedule(d.op)
module = tvm.build(sch, [a, d])
assert module
np_a = numpy.arange(16).astype('int32')
np_b, np_c = kernel_a(np_a)
ref = kernel_b(np_c, np_b)
res = tvm.ndarray.array(numpy.zeros((4, 4)).astype('int32'))
module(tvm.ndarray.array(np_a), res)
tvm.testing.assert_allclose(res.asnumpy(), ref)
if __name__ == "__main__": if __name__ == "__main__":
test_outer_product() test_outer_product()
...@@ -479,9 +517,11 @@ if __name__ == "__main__": ...@@ -479,9 +517,11 @@ if __name__ == "__main__":
test_math_intrin() test_math_intrin()
test_non_zero() test_non_zero()
test_allocate() test_allocate()
#test_inplace()
test_upstream() test_upstream()
test_downstream() test_downstream()
test_const_param() test_const_param()
test_value_index()
# TODO:
# test_inplace()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment