Commit 98e761f8 by Jian Weng Committed by Tianqi Chen

allow const_range allocation; preprove if-then-else (#2419)

parent 7e5966a0
...@@ -339,6 +339,9 @@ class HybridParser(ast.NodeVisitor): ...@@ -339,6 +339,9 @@ class HybridParser(ast.NodeVisitor):
else_body = visit_list_to_block(self.visit, node.orelse) else_body = visit_list_to_block(self.visit, node.orelse)
else: else:
else_body = util.make_nop() else_body = util.make_nop()
# Return no IfThenElse if proven
if isinstance(cond, _expr.UIntImm):
return if_body if cond.value else else_body
return _make.IfThenElse(cond, if_body, else_body) return _make.IfThenElse(cond, if_body, else_body)
...@@ -429,7 +432,7 @@ class HybridParser(ast.NodeVisitor): ...@@ -429,7 +432,7 @@ class HybridParser(ast.NodeVisitor):
for i in range(low, low + ext): for i in range(low, low + ext):
self.symbols[_name] = Symbol.ConstLoopVar, i self.symbols[_name] = Symbol.ConstLoopVar, i
bodies.append(visit_list_to_block(self.visit, node.body)) bodies.append(visit_list_to_block(self.visit, node.body))
return pack_list_to_block(bodies) _body = pack_list_to_block(bodies)
elif iter_var is None: elif iter_var is None:
_internal_assert(for_type is not None, "The loop bind function parse error!") _internal_assert(for_type is not None, "The loop bind function parse error!")
...@@ -449,6 +452,9 @@ class HybridParser(ast.NodeVisitor): ...@@ -449,6 +452,9 @@ class HybridParser(ast.NodeVisitor):
res = _make.AttrStmt(iter_var, 'thread_extent', ext, _body) res = _make.AttrStmt(iter_var, 'thread_extent', ext, _body)
elif not isinstance(for_type, tuple): elif not isinstance(for_type, tuple):
res = _make.For(iter_var, _api.const(0, 'int32'), ext, for_type, 0, _body) res = _make.For(iter_var, _api.const(0, 'int32'), ext, for_type, 0, _body)
else:
res = _body
self.symbols.pop(_name) self.symbols.pop(_name)
return res return res
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment