Commit 2fbc82e6 by Jian Weng Committed by Tianqi Chen

[Hybrid Script] allow const_range allocation; allow const_range lazy compilation (#2423)

parent a61d3b41
...@@ -334,14 +334,21 @@ class HybridParser(ast.NodeVisitor): ...@@ -334,14 +334,21 @@ class HybridParser(ast.NodeVisitor):
def visit_If(self, node): def visit_If(self, node):
cond = self.visit(node.test) cond = self.visit(node.test)
# Return no IfThenElse if proven
if isinstance(cond, _expr.UIntImm):
if cond.value:
return visit_list_to_block(self.visit, node.body)
elif node.orelse:
return visit_list_to_block(self.visit, node.orelse)
return util.make_nop()
if_body = visit_list_to_block(self.visit, node.body) if_body = visit_list_to_block(self.visit, node.body)
if node.orelse: if node.orelse:
else_body = visit_list_to_block(self.visit, node.orelse) else_body = visit_list_to_block(self.visit, node.orelse)
else: else:
else_body = util.make_nop() else_body = util.make_nop()
# Return no IfThenElse if proven
if isinstance(cond, _expr.UIntImm):
return if_body if cond.value else else_body
return _make.IfThenElse(cond, if_body, else_body) return _make.IfThenElse(cond, if_body, else_body)
...@@ -431,8 +438,10 @@ class HybridParser(ast.NodeVisitor): ...@@ -431,8 +438,10 @@ class HybridParser(ast.NodeVisitor):
bodies = [] bodies = []
for i in range(low, low + ext): for i in range(low, low + ext):
self.symbols[_name] = Symbol.ConstLoopVar, i self.symbols[_name] = Symbol.ConstLoopVar, i
bodies.append(visit_list_to_block(self.visit, node.body)) body = visit_list_to_block(self.visit, node.body)
_body = pack_list_to_block(bodies) body = self.wrap_up_realize(node, body)
bodies.append(body)
return pack_list_to_block(bodies)
elif iter_var is None: elif iter_var is None:
_internal_assert(for_type is not None, "The loop bind function parse error!") _internal_assert(for_type is not None, "The loop bind function parse error!")
...@@ -450,10 +459,10 @@ class HybridParser(ast.NodeVisitor): ...@@ -450,10 +459,10 @@ class HybridParser(ast.NodeVisitor):
if for_type is None: if for_type is None:
res = _make.AttrStmt(iter_var, 'thread_extent', ext, _body) res = _make.AttrStmt(iter_var, 'thread_extent', ext, _body)
elif not isinstance(for_type, tuple):
res = _make.For(iter_var, _api.const(0, 'int32'), ext, for_type, 0, _body)
else: else:
res = _body _internal_assert(not isinstance(for_type, tuple), \
"Micro expansion should be handled before!")
res = _make.For(iter_var, _api.const(0, 'int32'), ext, for_type, 0, _body)
self.symbols.pop(_name) self.symbols.pop(_name)
return res return res
......
...@@ -563,6 +563,37 @@ def test_const_range(): ...@@ -563,6 +563,37 @@ def test_const_range():
b = [[1, 2, 3, 4, 5], [5, 4, 3, 2, 1]] b = [[1, 2, 3, 4, 5], [5, 4, 3, 2, 1]]
run_and_check(foo, [a, b]) run_and_check(foo, [a, b])
@tvm.hybrid.script
def goo(a, b):
c = output_tensor(a.shape, a.dtype)
len_b = len(b)
for i in const_range(len_b * 2):
if i < len_b:
c[i] = a[i] + b[i]
else:
c[i - len_b] = a[i - len_b] + b[i - len_b]
return c
a = tvm.placeholder((5, ), name='a', dtype='int32')
b = [1, 2, 3, 4, 5]
c = goo(a, tvm.convert(b))
sch = tvm.create_schedule(c.op)
run_and_check(goo, [a, b])
@tvm.hybrid.script
def hoo(a, b):
c = output_tensor(a.shape, a.dtype)
len_b = len(b)
for i in range(a.shape[0]):
for j in const_range(len(b)):
d = a[i] * b[j]
d += a[i] + b[j]
c[i] = d
return c
a = tvm.placeholder((5, ), name='a', dtype='int32')
b = [1, 2, 3, 4, 5]
run_and_check(hoo, [a, b])
if __name__ == "__main__": if __name__ == "__main__":
test_outer_product() test_outer_product()
test_fanout() test_fanout()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment