Commit 76812dea by Jian Weng Committed by Tianqi Chen

fix lint (#2649)

parent dee8cf9b
...@@ -45,8 +45,8 @@ def bind(func_id, args): ...@@ -45,8 +45,8 @@ def bind(func_id, args):
_internal_assert(args.__len__() == 2, "A loop bind should only have 2 arguments!") _internal_assert(args.__len__() == 2, "A loop bind should only have 2 arguments!")
_internal_assert(isinstance(args[0], str), \ _internal_assert(isinstance(args[0], str), \
"A loop bind's first argument should be a string!") "A loop bind's first argument should be a string!")
iter_var = _api.thread_axis(args[0])
low, ext = _api.const(0, "int32"), args[1] low, ext = _api.const(0, "int32"), args[1]
iter_var = _api.thread_axis((low, ext), args[0])
for_type = None for_type = None
return iter_var, low, ext, for_type return iter_var, low, ext, for_type
......
...@@ -12,7 +12,7 @@ from enum import Enum ...@@ -12,7 +12,7 @@ from enum import Enum
from .util import _internal_assert from .util import _internal_assert
from . import calls from . import calls
from . import util from . import util
from .var_decl import determine_variable_usage from .preprocessor import determine_variable_usage
from ..api import all as _all from ..api import all as _all
from ..api import any as _any from ..api import any as _any
from ..container import Array from ..container import Array
...@@ -61,6 +61,7 @@ class Symbol(Enum): ...@@ -61,6 +61,7 @@ class Symbol(Enum):
BufferVar = 7 BufferVar = 7
LoopVar = 8 LoopVar = 8
ConstLoopVar = 9 ConstLoopVar = 9
ThreadBind = 10
class HybridParser(ast.NodeVisitor): class HybridParser(ast.NodeVisitor):
...@@ -117,7 +118,10 @@ class HybridParser(ast.NodeVisitor): ...@@ -117,7 +118,10 @@ class HybridParser(ast.NodeVisitor):
self.symbols = {} # Symbol table self.symbols = {} # Symbol table
for k, v in symbols.items(): for k, v in symbols.items():
if isinstance(v, types.FunctionType): if isinstance(v, types.FunctionType):
self.symbols[k] = Symbol.Callable, v self.add_symbol(k, Symbol.Callable, v)
self.binds = {} # Thread binds
self.device = 0 # Is it generating device
self.func_name = func_name # The name of the function to be lowered self.func_name = func_name # The name of the function to be lowered
self.outputs = [] # Output tensors' name self.outputs = [] # Output tensors' name
...@@ -126,6 +130,25 @@ class HybridParser(ast.NodeVisitor): ...@@ -126,6 +130,25 @@ class HybridParser(ast.NodeVisitor):
self.returned = False # If this function has a valid return self.returned = False # If this function has a valid return
def add_symbol(self, key, ty, val): #pylint: disable=invalid-name
"""Add value to the symbol table context"""
if key in self.symbols.keys():
old = str(self.symbols[key])
new = str((ty, val))
_internal_assert(False,
"Name conflict in symbol table! [%s] %s -> %s" % (key, old, new))
self.symbols[key] = ty, val
if ty == Symbol.ThreadBind:
if val.var.name not in self.binds.keys():
self.binds[val.var.name] = val
return
val_ = self.binds[val.var.name]
_internal_assert(_ir_pass.Equal(val_.dom.extent, val.dom.extent),
"Thread extents should be uniform!")
self.symbols[key] = ty, val_
def wrap_up_realize(self, node, body): def wrap_up_realize(self, node, body):
"""Wrap up all the variables which will no longer be used""" """Wrap up all the variables which will no longer be used"""
...@@ -141,11 +164,14 @@ class HybridParser(ast.NodeVisitor): ...@@ -141,11 +164,14 @@ class HybridParser(ast.NodeVisitor):
continue continue
elif 'Buffer' in ty.name: elif 'Buffer' in ty.name:
_buf = entry _buf = entry
_scope = ty.name[:-6].lower() if ty is not Symbol.BufferVar else 'global' _scope = 'global' if ty is Symbol.BufferVar else ty.name[:-6].lower()
to_pop.append(key) to_pop.append(key)
else: else:
continue continue
if _scope == 'global':
body = self.wrap_up_binds(body)
_domain = [_make.range_by_min_extent(0, i) for i in _buf.shape] _domain = [_make.range_by_min_extent(0, i) for i in _buf.shape]
_dtype = _buf.dtype _dtype = _buf.dtype
_true = _api.convert(True) _true = _api.convert(True)
...@@ -158,6 +184,14 @@ class HybridParser(ast.NodeVisitor): ...@@ -158,6 +184,14 @@ class HybridParser(ast.NodeVisitor):
return body return body
def wrap_up_binds(self, body):
for _, iter_var in self.binds.items():
ext = iter_var.dom.extent
body = _make.AttrStmt(iter_var, 'thread_extent', ext, body)
self.binds = {}
return body
#pylint: disable=invalid-name, missing-docstring #pylint: disable=invalid-name, missing-docstring
def visit_Module(self, node): def visit_Module(self, node):
_internal_assert(len(node.body) == 1, \ _internal_assert(len(node.body) == 1, \
...@@ -173,10 +207,10 @@ class HybridParser(ast.NodeVisitor): ...@@ -173,10 +207,10 @@ class HybridParser(ast.NodeVisitor):
self.func_name = node.name self.func_name = node.name
for idx, arg in enumerate(node.args.args): for idx, arg in enumerate(node.args.args):
_attr = 'id' if sys.version_info[0] < 3 else 'arg' # To make py2 and 3 compatible _attr = 'id' if sys.version_info[0] < 3 else 'arg' # To make py2 and 3 compatible
self.symbols[getattr(arg, _attr)] = (Symbol.Input, self.args[idx]) self.add_symbol(getattr(arg, _attr), Symbol.Input, self.args[idx])
res = visit_list_to_block(self.visit, node.body) res = visit_list_to_block(self.visit, node.body)
res = self.wrap_up_realize(node, res) res = self.wrap_up_realize(node, res)
return res return self.wrap_up_binds(res)
def visit_Expr(self, node): def visit_Expr(self, node):
...@@ -189,6 +223,8 @@ class HybridParser(ast.NodeVisitor): ...@@ -189,6 +223,8 @@ class HybridParser(ast.NodeVisitor):
_internal_assert(name in self.symbols, "Unknown symbol %s!" % name) _internal_assert(name in self.symbols, "Unknown symbol %s!" % name)
if ty in [Symbol.LoopVar, Symbol.Input, Symbol.ConstLoopVar]: if ty in [Symbol.LoopVar, Symbol.Input, Symbol.ConstLoopVar]:
return entry return entry
if ty is Symbol.ThreadBind:
return entry.var
if ty is Symbol.ConstVar: if ty is Symbol.ConstVar:
return entry if isinstance(node.ctx, ast.Load) else None return entry if isinstance(node.ctx, ast.Load) else None
if ty is Symbol.BufferVar: if ty is Symbol.BufferVar:
...@@ -237,7 +273,7 @@ class HybridParser(ast.NodeVisitor): ...@@ -237,7 +273,7 @@ class HybridParser(ast.NodeVisitor):
for i in range(rhs.num_outputs): for i in range(rhs.num_outputs):
_internal_assert(isinstance(node.targets[i], ast.Name), _internal_assert(isinstance(node.targets[i], ast.Name),
"You should bind a pure name to the tensors") "You should bind a pure name to the tensors")
self.symbols[node.targets[i].id] = Symbol.GlobalBuffer, rhs.output(i) self.add_symbol(node.targets[i].id, Symbol.GlobalBuffer, rhs.output(i))
rmap[rhs.outputs[i].op] = rhs.output(i) rmap[rhs.outputs[i].op] = rhs.output(i)
return util.replace_io(rhs.body, rmap) return util.replace_io(rhs.body, rmap)
...@@ -260,15 +296,19 @@ class HybridParser(ast.NodeVisitor): ...@@ -260,15 +296,19 @@ class HybridParser(ast.NodeVisitor):
if isinstance(rhs, tuple): if isinstance(rhs, tuple):
shape, dtype, scope = rhs shape, dtype, scope = rhs
ph = _api.placeholder(shape, dtype=dtype, name=lhs) ph = _api.placeholder(shape, dtype=dtype, name=lhs)
self.symbols[lhs] = getattr(Symbol, scope.title() + "Buffer"), ph self.add_symbol(lhs, getattr(Symbol, scope.title() + "Buffer"), ph)
if scope == 'output': if scope == 'output':
self.outputs.append(lhs) self.outputs.append(lhs)
return util.make_nop() return util.make_nop()
if isinstance(rhs, util.halide_imm_types) and ast.Store not in rw: if isinstance(rhs, util.halide_imm_types) and ast.Store not in rw:
self.symbols[lhs] = Symbol.ConstVar, rhs self.add_symbol(lhs, Symbol.ConstVar, rhs)
else: else:
_internal_assert(self.device == 0,
"Single variable not supported in devices' side!\n" + \
"If you are using GPU, please allocate a 'local' spad " + \
"outside the bind body")
ph = _api.placeholder((1, ), dtype=rhs.dtype, name=lhs) ph = _api.placeholder((1, ), dtype=rhs.dtype, name=lhs)
self.symbols[lhs] = Symbol.BufferVar, ph self.add_symbol(lhs, Symbol.BufferVar, ph)
lhs = self.visit(lhs_) lhs = self.visit(lhs_)
if lhs is not None: if lhs is not None:
buf, args = lhs buf, args = lhs
...@@ -356,7 +396,7 @@ class HybridParser(ast.NodeVisitor): ...@@ -356,7 +396,7 @@ class HybridParser(ast.NodeVisitor):
if node.orelse: if node.orelse:
else_body = visit_list_to_block(self.visit, node.orelse) else_body = visit_list_to_block(self.visit, node.orelse)
else: else:
else_body = util.make_nop() else_body = None
return _make.IfThenElse(cond, if_body, else_body) return _make.IfThenElse(cond, if_body, else_body)
...@@ -445,28 +485,31 @@ class HybridParser(ast.NodeVisitor): ...@@ -445,28 +485,31 @@ class HybridParser(ast.NodeVisitor):
bodies = [] bodies = []
for i in range(low, low + ext): for i in range(low, low + ext):
self.symbols[_name] = Symbol.ConstLoopVar, i self.add_symbol(_name, Symbol.ConstLoopVar, i)
body = visit_list_to_block(self.visit, node.body) body = visit_list_to_block(self.visit, node.body)
body = self.wrap_up_realize(node, body) body = self.wrap_up_realize(node, body)
bodies.append(body) bodies.append(body)
self.symbols.pop(_name)
return concat_list_to_block(bodies) return concat_list_to_block(bodies)
if iter_var is None: if iter_var is None:
_internal_assert(for_type is not None, "The loop bind function parse error!") _internal_assert(for_type is not None, "The loop iterating function parse error!")
offset = iter_var = _api.var(_name) offset = iter_var = _api.var(_name)
if not _ir_pass.Equal(low, _api.const(0, 'int32')): if not _ir_pass.Equal(low, _api.const(0, 'int32')):
offset = iter_var + low offset = iter_var + low
self.symbols[_name] = Symbol.LoopVar, offset self.add_symbol(_name, Symbol.LoopVar, offset)
_body = visit_list_to_block(self.visit, node.body) _body = visit_list_to_block(self.visit, node.body)
else: else:
_internal_assert(for_type is None, "The loop iterating function parse error!") _internal_assert(for_type is None, "The loop bind function parse error!")
self.symbols[_name] = Symbol.LoopVar, iter_var.var self.add_symbol(_name, Symbol.ThreadBind, iter_var)
self.device += 1
_body = visit_list_to_block(self.visit, node.body) _body = visit_list_to_block(self.visit, node.body)
self.device -= 1
_body = self.wrap_up_realize(node, _body) _body = self.wrap_up_realize(node, _body)
if for_type is None: if for_type is None:
res = _make.AttrStmt(iter_var, 'thread_extent', ext, _body) res = _body
else: else:
_internal_assert(not isinstance(for_type, tuple), \ _internal_assert(not isinstance(for_type, tuple), \
"Micro expansion should be handled before!") "Micro expansion should be handled before!")
......
...@@ -300,6 +300,7 @@ def test_bind(): ...@@ -300,6 +300,7 @@ def test_bind():
if not tvm.gpu(0).exist: if not tvm.gpu(0).exist:
print('[Warning] No GPU found! Skip bind test!') print('[Warning] No GPU found! Skip bind test!')
return return
@script @script
def vec_add(a, b): def vec_add(a, b):
c = output_tensor((1000, ), 'float32') c = output_tensor((1000, ), 'float32')
...@@ -326,23 +327,29 @@ def test_bind(): ...@@ -326,23 +327,29 @@ def test_bind():
func, ins, outs = run_and_check(raw, [a, b], sch=sch, outs=[c], target='cuda') func, ins, outs = run_and_check(raw, [a, b], sch=sch, outs=[c], target='cuda')
run_and_check(func, ins, outs=outs, target='cuda') run_and_check(func, ins, outs=outs, target='cuda')
# Test loop binds
@tvm.hybrid.script @tvm.hybrid.script
def goo(a, b): def foo(a):
c = output_tensor(a.shape, a.dtype) c = output_tensor((a.shape[0],), a.dtype)
len_b = len(b) total = allocate((1,), a.dtype, 'local')
for i in const_range(len_b * 2): len_i = a.shape[0]
if i < len_b: len_j = a.shape[1]
c[i] = a[i] + b[i] for i in bind('threadIdx.x', len_i):
else: total[0] = 0.
c[i - len_b] = a[i - len_b] + b[i - len_b] for k in const_range(len_j):
total[0] += a[i, k]
c[i] = total[0]
return c return c
a = tvm.placeholder((5, ), name='a', dtype='int32')
b = [1, 2, 3, 4, 5] a = tvm.placeholder((8, 4), 'float32')
c = goo(a, tvm.convert(b)) c = foo(a)
sch = tvm.create_schedule(c.op) s = tvm.create_schedule(c.op)
func, ins, outs = run_and_check(goo, [a, b], sch=sch, outs=[c]) ir = tvm.lower(s, [a, c], simple_mode=True)
run_and_check(func, ins, outs=outs) assert not isinstance(ir, tvm.stmt.AttrStmt)
func, ins, outs = run_and_check(foo, [a], target='cuda')
run_and_check(func, ins, outs=outs, target='cuda')
def test_math_intrin(): def test_math_intrin():
@script @script
...@@ -455,6 +462,7 @@ def test_allocate(): ...@@ -455,6 +462,7 @@ def test_allocate():
a = tvm.placeholder((256, ), dtype='float32', name='a') a = tvm.placeholder((256, ), dtype='float32', name='a')
b = tvm.placeholder((256, ), dtype='float32', name='b') b = tvm.placeholder((256, ), dtype='float32', name='b')
c = share_vec_add(a, b)
func, ins, outs = run_and_check(share_vec_add, [a, b], target='cuda') func, ins, outs = run_and_check(share_vec_add, [a, b], target='cuda')
run_and_check(func, ins, outs=outs, target='cuda') run_and_check(func, ins, outs=outs, target='cuda')
else: else:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment