Commit a42d1e3c by Jian Weng Committed by Tianqi Chen

[Hybrid Script] Unify the symbol tables to one; support `tvm.container.Array` (#2366)

parent 151f550b
...@@ -52,7 +52,8 @@ The current parse interface looks like: ...@@ -52,7 +52,8 @@ The current parse interface looks like:
parser = tvm.hybrid.parse(outer_product, [a, b]) # return the parser of this function parser = tvm.hybrid.parse(outer_product, [a, b]) # return the parser of this function
If we pass these tvm tensors to this function, it returns a op node: If we pass these tvm data structures, like ``Tensor``, ``Var``, ``Expr.*Imm``,
or ``tvm.container.Array``, to this function, it returns a op node:
.. code-block:: python .. code-block:: python
...@@ -60,12 +61,14 @@ If we pass these tvm tensors to this function, it returns a op node: ...@@ -60,12 +61,14 @@ If we pass these tvm tensors to this function, it returns a op node:
b = tvm.placeholder((99, ), name='b') b = tvm.placeholder((99, ), name='b')
c = outer_product(a, b, c) # return the output tensor(s) of the operator c = outer_product(a, b, c) # return the output tensor(s) of the operator
**Under construction, we are still deciding what kind of node should be returned.** You can use any methods that can be applied on a TVM ``OpNode``, like create_schedule, although
so far, the functionality of schedule is as limited as ``ExternOpNode``. At least, it can be built
to LLVM module.
Tuning Tuning
~~~~~~ ~~~~~~
**Under construction, not truly supported yet.** **Under construction, not supported yet.**
Follow up the example above, you can use some tvm like interfaces to tune the code: Follow up the example above, you can use some tvm like interfaces to tune the code:
...@@ -86,6 +89,21 @@ Here we use ``range`` aka ``serial``, ``unroll``, ``parallel``, and ``vectorize` ...@@ -86,6 +89,21 @@ Here we use ``range`` aka ``serial``, ``unroll``, ``parallel``, and ``vectorize`
these **4** keywords to annotate the corresponding types of for loops. these **4** keywords to annotate the corresponding types of for loops.
The the usage is roughly the same as Python standard ``range``. The the usage is roughly the same as Python standard ``range``.
Besides all the loop types supported in Halide, ``const_range`` is supported for some specific conditions.
Sometimes, ``tvm.container.Array`` is desired to pass as an argument, but in TVM-HalideIR, there is no
such support that converts ``tvm.container.Array`` to an ``Expr``. Thus, a limited feature is supported.
Users can access containers by either constants or constants loops annotated.
.. code-block:: python
@tvm.hybrid.script
def foo(a, b): # b is a tvm.container.Array
c = output_tensor(a.shape, a.dtype)
for i in const_range(len(a)): # because you have b access, i should be explicitly annotated as const_range
c[i] = a[i] + b[i]
return c
Variables Variables
~~~~~~~~~ ~~~~~~~~~
...@@ -111,14 +129,14 @@ It regards the first store of a variable as its declaration. ...@@ -111,14 +129,14 @@ It regards the first store of a variable as its declaration.
s += a[i, j] # do something with sum s += a[i, j] # do something with sum
b[i] = sum # you can still use sum in this level b[i] = sum # you can still use sum in this level
a[0] = s # you CANNOT use s here, even though it is allowed in conventional Python a[0] = s # you CANNOT use s here, even though it is allowed in conventional Python
b = (1, 2) # this has NOT been supported yet!
Attributes Attributes
~~~~~~~~~~ ~~~~~~~~~~
So far, ONLY tensors' ``shape`` attribute is supported! The ``shape`` atrribute is essentailly a So far, ONLY tensors' ``shape`` and ``dtype`` attribute are supported!
tuple, so you MUST access it as an array. Also, currently, only constant-indexed access is supported. The ``shape`` atrribute is essentailly a tuple, so you MUST access it as an array.
Currently, only constant-indexed access is supported.
.. code-block:: python .. code-block:: python
...@@ -133,8 +151,11 @@ Conditional Statement and Expression ...@@ -133,8 +151,11 @@ Conditional Statement and Expression
.. code-block:: python .. code-block:: python
if condition: if condition1 and condition2 and condition3:
# do something # do something
else:
# do something else
# Select
a = b if condition else c a = b if condition else c
However, NO ``True`` and ``False`` keyword supported yet. However, NO ``True`` and ``False`` keyword supported yet.
...@@ -153,7 +174,9 @@ Array Allocation ...@@ -153,7 +174,9 @@ Array Allocation
**Under construction, this function will be supported later!** **Under construction, this function will be supported later!**
Use a function call ``allocation(shape, type, share/local)`` to declare an array buffer. Use a function call ``allocation(shape, type, share/local)`` to declare an array buffer.
The basic usage is roughly the same as a normal array. The basic usage is roughly the same as a normal ``numpy.array``, and you should access
high-dim array in ``a[i, j, k]`` fashion instead of ``a[i][j][k]``,
even for ``tvm.container.Array`` for compilation.
Thread Bind Thread Bind
...@@ -170,5 +193,5 @@ You can also do loop-thread bind by writing code like this: ...@@ -170,5 +193,5 @@ You can also do loop-thread bind by writing code like this:
Keywords Keywords
~~~~~~~~ ~~~~~~~~
- For keywords: ``serial``, ``range``, ``unroll``, ``parallel``, ``vectorize``, ``bind`` - For keywords: ``serial``, ``range``, ``unroll``, ``parallel``, ``vectorize``, ``bind``, ``const_expr``
- Math keywords: ``log``, ``exp``, ``sigmoid``, ``tanh``, ``power``, ``popcount`` - Math keywords: ``log``, ``exp``, ``sigmoid``, ``tanh``, ``power``, ``popcount``
...@@ -15,12 +15,14 @@ LOOP_INTRIN = { ...@@ -15,12 +15,14 @@ LOOP_INTRIN = {
'range' : For.Serial, 'range' : For.Serial,
'unroll' : For.Unrolled, 'unroll' : For.Unrolled,
'parallel' : For.Parallel, 'parallel' : For.Parallel,
'vectorize': For.Vectorized, 'vectorize' : For.Vectorized,
'const_range' : (For.Unrolled, ),
} }
def _range(annotation, args): def _range(annotation, args):
"""Handling TVM loop types""" """Handling TVM loop types"""
n = len(args) n = args.__len__()
if n == 1: if n == 1:
low, ext = _api.const(0, dtype='int32'), args[0] low, ext = _api.const(0, dtype='int32'), args[0]
else: else:
...@@ -33,13 +35,13 @@ def _range(annotation, args): ...@@ -33,13 +35,13 @@ def _range(annotation, args):
return iter_var, low, ext, for_type return iter_var, low, ext, for_type
range = unroll = vectorize = parallel = _range #pylint: disable=invalid-name range = unroll = vectorize = parallel = const_range = _range #pylint: disable=invalid-name
def bind(func_id, args): def bind(func_id, args):
"""Handling TVM thread binding""" """Handling TVM thread binding"""
_internal_assert(func_id == "bind", "This function cannot be directly invoked!") _internal_assert(func_id == "bind", "This function cannot be directly invoked!")
_internal_assert(len(args) == 2, "A loop bind should only have 2 arguments!") _internal_assert(args.__len__() == 2, "A loop bind should only have 2 arguments!")
_internal_assert(isinstance(args[0], str), \ _internal_assert(isinstance(args[0], str), \
"A loop bind's first argument should be a string!") "A loop bind's first argument should be a string!")
iter_var = _api.thread_axis(args[0]) iter_var = _api.thread_axis(args[0])
...@@ -56,7 +58,7 @@ sqrt = log = exp = tanh = sigmoid = power = popcount = _math_intrin #pylint: dis ...@@ -56,7 +58,7 @@ sqrt = log = exp = tanh = sigmoid = power = popcount = _math_intrin #pylint: dis
def _min_max(func_id, args): def _min_max(func_id, args):
_internal_assert(len(args) == 2, "Max/Min function should have 2 elements") _internal_assert(args.__len__() == 2, "Max/Min function should have 2 elements")
return getattr(_make, func_id.title())(args[0], args[1]) return getattr(_make, func_id.title())(args[0], args[1])
...@@ -66,7 +68,7 @@ min = max = _min_max #pylint: disable=invalid-name ...@@ -66,7 +68,7 @@ min = max = _min_max #pylint: disable=invalid-name
def _allocate_tensor(func_id, args): def _allocate_tensor(func_id, args):
"""Handling TVM tensor allocation. """Handling TVM tensor allocation.
You may refer hybrid.intrin.allocate for more details.""" You may refer hybrid.intrin.allocate for more details."""
n = len(args) n = args.__len__()
_internal_assert(isinstance(_api.convert(args[0]), Array), \ _internal_assert(isinstance(_api.convert(args[0]), Array), \
"allocate's first argument should be a tuple of shape!") "allocate's first argument should be a tuple of shape!")
shape = args[0] shape = args[0]
...@@ -89,4 +91,16 @@ def _allocate_tensor(func_id, args): ...@@ -89,4 +91,16 @@ def _allocate_tensor(func_id, args):
scope = 'global' if func_id != 'output_tensor' else 'output' scope = 'global' if func_id != 'output_tensor' else 'output'
return (shape, dtype, scope) return (shape, dtype, scope)
output_tensor = allocate = _allocate_tensor #pylint: disable=invalid-name output_tensor = allocate = _allocate_tensor #pylint: disable=invalid-name
def len(func_id, args):
"""Iterpret the len function"""
_internal_assert(args.__len__() == 1, "Only 1 argument is expected!")
_internal_assert(func_id == "len", "This function cannot be directly invoked!")
try:
return _api.convert(args[0].__len__())
except: #pylint: disable=bare-except
_internal_assert(args[0].shape.__len__() == 1, "Only one-dimension array can get len")
return _api.convert(args[0].shape[0])
...@@ -2,32 +2,19 @@ ...@@ -2,32 +2,19 @@
import numpy import numpy
class _range(object):
"""Base class of the loop ranges in hybrid script""" class bind(object): #pylint: disable=invalid-name
def __init__(self, a, b=None): """GPU bind software emulataion runtime."""
if b is None: def __init__(self, _, ext):
self.low = 0 self.ext = ext
self.ext = a
else:
self.low = a
self.ext = b
def __iter__(self): def __iter__(self):
i = 0 i = 0
while i < self.ext: while i < self.ext:
yield i + self.low yield i
i += 1 i += 1
class bind(_range): #pylint: disable=invalid-name
def __init__(self, tag, ext):
super(bind, self).__init__(ext)
self.tag = tag
unroll = vectorize = parallel = _range #pylint: disable=invalid-name
def allocate(shape, dtype='float32', scope='global'): #pylint: disable=unused-argument def allocate(shape, dtype='float32', scope='global'): #pylint: disable=unused-argument
"""Allocate a buffer with given shape """Allocate a buffer with given shape
...@@ -47,7 +34,6 @@ def allocate(shape, dtype='float32', scope='global'): #pylint: disable=unused-ar ...@@ -47,7 +34,6 @@ def allocate(shape, dtype='float32', scope='global'): #pylint: disable=unused-ar
""" """
return numpy.zeros(shape).astype(dtype) return numpy.zeros(shape).astype(dtype)
output_tensor = allocate #pylint: disable=invalid-name
def popcount(x): def popcount(x):
""" """
...@@ -87,17 +73,19 @@ def sigmoid(x): ...@@ -87,17 +73,19 @@ def sigmoid(x):
HYBRID_GLOBALS = { HYBRID_GLOBALS = {
'unroll' : unroll, 'len' : len,
'vectorize' : vectorize, 'unroll' : range,
'parallel' : parallel, 'vectorize' : range,
'allocate' : allocate, 'parallel' : range,
'output_tensor': output_tensor, 'const_range' : range,
'bind' : bind, 'bind' : bind,
'allocate' : allocate,
'output_tensor': allocate,
'sqrt' : numpy.sqrt, 'sqrt' : numpy.sqrt,
'log' : numpy.log, 'log' : numpy.log,
'tanh' : numpy.tanh, 'tanh' : numpy.tanh,
'power' : numpy.power, 'power' : numpy.power,
'exp' : numpy.exp, 'exp' : numpy.exp,
'sigmoid' : sigmoid, 'sigmoid' : sigmoid,
'popcount' : popcount 'popcount' : popcount,
} }
...@@ -4,7 +4,10 @@ import ast ...@@ -4,7 +4,10 @@ import ast
import operator import operator
import logging import logging
import sys import sys
from numbers import Integral import types
import numbers
from enum import Enum
from .util import _internal_assert from .util import _internal_assert
from . import calls from . import calls
...@@ -12,18 +15,15 @@ from . import util ...@@ -12,18 +15,15 @@ from . import util
from .var_decl import determine_variable_usage from .var_decl import determine_variable_usage
from ..api import all as _all from ..api import all as _all
from ..api import any as _any from ..api import any as _any
from ..container import Array
from ..tensor import Tensor, Operation from ..tensor import Tensor, Operation
from .. import expr as _expr from .. import expr as _expr
from .. import make as _make from .. import make as _make
from .. import api as _api from .. import api as _api
from .. import ir_pass as _ir_pass from .. import ir_pass as _ir_pass
def list_to_block(visit, lst):
"""Convert a list of Python IR nodes to HalideIR Block""" def pack_list_to_block(lst):
lst = [visit(stmt) for stmt in lst if not util.is_docstring(stmt)]
lst = [stmt for stmt in lst if not _ir_pass.Equal(stmt, util.make_nop())]
if not lst:
return util.make_nop()
if len(lst) == 1: if len(lst) == 1:
return lst[0] return lst[0]
body = lst[0] body = lst[0]
...@@ -32,6 +32,29 @@ def list_to_block(visit, lst): ...@@ -32,6 +32,29 @@ def list_to_block(visit, lst):
return body return body
def visit_list_to_block(visit, lst):
"""Convert a list of Python IR nodes to HalideIR Block"""
lst = [visit(stmt) for stmt in lst if not util.is_docstring(stmt)]
lst = [stmt for stmt in lst if not _ir_pass.Equal(stmt, util.make_nop())]
if not lst:
return util.make_nop()
return pack_list_to_block(lst)
class Symbol(Enum):
"""Enumerates types in the symbol table"""
Callable = 0
Input = 1
OutputBuffer = 2
GlobalBuffer = 3
LocalBuffer = 4
SharedBuffer = 5
ConstVar = 6
BufferVar = 7
LoopVar = 8
ConstLoopVar = 9
class HybridParser(ast.NodeVisitor): class HybridParser(ast.NodeVisitor):
"""Python AST visitor pass which finally lowers it to HalideIR""" """Python AST visitor pass which finally lowers it to HalideIR"""
...@@ -82,77 +105,55 @@ class HybridParser(ast.NodeVisitor): ...@@ -82,77 +105,55 @@ class HybridParser(ast.NodeVisitor):
""" """
self.args = list(args) self.args = list(args)
self.usage = usage.copy() self.usage = usage.copy()
self._args = {} # Dict maps arg name to actual arg instance (either a var or a buffer)
self.alloc_buffers = {} # Buffers formed by explicit allocate instructions self.symbols = {} # Symbol table
self.loops_above = {} # State variable that indicates loop levels above the current node for k, v in symbols.items():
self.variables = {} # The status of defined variables if isinstance(v, types.FunctionType):
self.symbols[k] = Symbol.Callable, v
self.func_name = func_name # The name of the function to be lowered self.func_name = func_name # The name of the function to be lowered
self.outputs = [] # Output tensors' name self.outputs = [] # Output tensors' name
self.side_effect = set() # Tensors with side effects self.side_effect = set() # Tensors with side effects
self.parsed_body = None # The parsed HalideIR body self.parsed_body = None # The parsed HalideIR body
self.returned = False # If this function has a valid return self.returned = False # If this function has a valid return
self.symbols = symbols # The global context
def wrap_up_realize(self, node, body): def wrap_up_realize(self, node, body):
"""Wrap up all the variables which will no longer be used""" """Wrap up all the variables which will no longer be used"""
pop_buf = [] to_pop = []
pop_var = []
for key, val in self.usage.items(): for key, val in self.usage.items():
_, level, _ = val _, level, _ = val
if level != node: if level != node:
continue continue
if key in self._args.keys(): _internal_assert(key in self.symbols.keys(), "Unknown symbol %s!" % key)
continue
if key in self.alloc_buffers.keys(): ty, entry = self.symbols[key] #pylint: disable=invalid-name
_buf, _scope = self.alloc_buffers[key] if ty in [Symbol.Input, Symbol.OutputBuffer]:
if _scope == 'output':
continue continue
pop_buf.append(key) elif 'Buffer' in ty.name:
_buf = entry
_scope = ty.name[:-6].lower() if ty is not Symbol.BufferVar else 'global'
to_pop.append(key)
else: else:
_internal_assert(key in self.variables.keys(),
"Key should be either in one of args, buffers, and vars")
if not isinstance(self.variables[key], tuple):
continue continue
_buf, _scope = self.variables[key]
pop_var.append(key)
_domain = [_make.range_by_min_extent(0, i) for i in _buf.shape] _domain = [_make.range_by_min_extent(0, i) for i in _buf.shape]
_dtype = _buf.dtype _dtype = _buf.dtype
_true = _api.convert(True) _true = _api.convert(True)
body = _make.Realize(_buf.op, 0, _dtype, _domain, _true, body) body = _make.Realize(_buf.op, 0, _dtype, _domain, _true, body)
body = _make.AttrStmt(_buf.op, 'realize_scope', _api.convert(_scope), body) body = _make.AttrStmt(_buf.op, 'realize_scope', _api.convert(_scope), body)
for elem in pop_buf: for elem in to_pop:
self.alloc_buffers.pop(elem) self.symbols.pop(elem)
for elem in pop_var:
self.variables.pop(elem)
return body
return body
def _get_buffer_from_id(self, s, for_provide=False):
_internal_assert((s in self._args.keys()) + (s in self.alloc_buffers.keys()) == 1,
"This %s is expected to be in either \
argument list or allocated buffer!" % s)
if s in self._args.keys():
if for_provide:
self.side_effect.add(self._args[s])
return self._args[s]
return self.alloc_buffers[s][0]
def _const(self, value, dtype=None):
if dtype is None:
if isinstance(value, bool):
dtype = "bool"
elif isinstance(value, Integral):
dtype = "int32"
else:
dtype = "float32"
return _api.const(value, dtype)
#pylint: disable=invalid-name, missing-docstring #pylint: disable=invalid-name, missing-docstring
def visit_Module(self, node): def visit_Module(self, node):
_internal_assert(len(node.body) == 1, \ _internal_assert(len(node.body) == 1, \
"Only one-function source code can be fed to this parser!") "Only one-function source code will be fed to this parser!")
return self.visit(node.body[0]) return self.visit(node.body[0])
...@@ -164,8 +165,8 @@ class HybridParser(ast.NodeVisitor): ...@@ -164,8 +165,8 @@ class HybridParser(ast.NodeVisitor):
self.func_name = node.name self.func_name = node.name
for idx, arg in enumerate(node.args.args): for idx, arg in enumerate(node.args.args):
_attr = 'id' if sys.version_info[0] < 3 else 'arg' # To make py2 and 3 compatible _attr = 'id' if sys.version_info[0] < 3 else 'arg' # To make py2 and 3 compatible
self._args[getattr(arg, _attr)] = self.args[idx] self.symbols[getattr(arg, _attr)] = (Symbol.Input, self.args[idx])
res = list_to_block(self.visit, node.body) res = visit_list_to_block(self.visit, node.body)
res = self.wrap_up_realize(node, res) res = self.wrap_up_realize(node, res)
return res return res
...@@ -176,25 +177,31 @@ class HybridParser(ast.NodeVisitor): ...@@ -176,25 +177,31 @@ class HybridParser(ast.NodeVisitor):
def visit_Name(self, node): def visit_Name(self, node):
name = node.id name = node.id
if name in self.loops_above.keys(): ty, entry = self.symbols[name]
return self.loops_above[name] _internal_assert(name in self.symbols, "Unknown symbol %s!" % name)
elif name in self.variables.keys(): if ty in [Symbol.LoopVar, Symbol.Input, Symbol.ConstLoopVar]:
res = self.variables[name] return entry
if isinstance(res, tuple): elif ty is Symbol.ConstVar:
buf = res[0] return entry if isinstance(node.ctx, ast.Load) else None
if isinstance(node.ctx, ast.Load): elif ty is Symbol.BufferVar:
return _make.Call(buf.dtype, buf.name, [self._const(0)], \
_expr.Call.Halide, buf.op, buf.value_index)
return buf, [self._const(0)]
if isinstance(node.ctx, ast.Load): if isinstance(node.ctx, ast.Load):
return res return _make.Call(entry.dtype, entry.name, [_api.const(0, 'int32')], \
return None _expr.Call.Halide, entry.op, entry.value_index)
buf = self._get_buffer_from_id(name) return entry, [_api.const(0, 'int32')]
return buf # Do I need any assertion here?
return entry
def visit_Num(self, node): def visit_Num(self, node):
return self._const(node.n) if isinstance(node.n, numbers.Integral):
dtype = "int32"
elif isinstance(node.n, float):
dtype = "float32"
else:
_internal_assert(isinstance(node.n, bool),
"The data type should be one of (int, float, bool)")
dtype = "bool"
return _api.const(node.n, dtype)
def visit_AugAssign(self, node): def visit_AugAssign(self, node):
...@@ -204,7 +211,7 @@ class HybridParser(ast.NodeVisitor): ...@@ -204,7 +211,7 @@ class HybridParser(ast.NodeVisitor):
_internal_assert(len(buf) == 2, "LHS is supposed to be (buf, args)!") _internal_assert(len(buf) == 2, "LHS is supposed to be (buf, args)!")
buf, args = buf buf, args = buf
else: else:
args = [self._const(0)] args = [_api.const(0, 'int32')]
_internal_assert(isinstance(buf, Tensor), "LHS is supposed to be Tensor!") _internal_assert(isinstance(buf, Tensor), "LHS is supposed to be Tensor!")
read = _make.Call(buf.dtype, buf.name, args, _expr.Call.Halide, buf.op, buf.value_index) read = _make.Call(buf.dtype, buf.name, args, _expr.Call.Halide, buf.op, buf.value_index)
...@@ -222,7 +229,7 @@ class HybridParser(ast.NodeVisitor): ...@@ -222,7 +229,7 @@ class HybridParser(ast.NodeVisitor):
for i in range(rhs.num_outputs): for i in range(rhs.num_outputs):
_internal_assert(isinstance(node.targets[i], ast.Name), _internal_assert(isinstance(node.targets[i], ast.Name),
"You should bind a pure name to the tensors") "You should bind a pure name to the tensors")
self.alloc_buffers[node.targets[i].id] = (rhs.output(i), 'global') self.symbols[node.targets[i].id] = Symbol.GlobalBuffer, rhs.output(i)
rmap[rhs.outputs[i].op] = rhs.output(i) rmap[rhs.outputs[i].op] = rhs.output(i)
return util.replace_io(rhs.body, rmap) return util.replace_io(rhs.body, rmap)
...@@ -234,25 +241,26 @@ class HybridParser(ast.NodeVisitor): ...@@ -234,25 +241,26 @@ class HybridParser(ast.NodeVisitor):
#TODO: support defined intermediate buffer later #TODO: support defined intermediate buffer later
lhs_ = lhs lhs_ = lhs
lhs = lhs.id lhs = lhs.id
_internal_assert(lhs not in self.loops_above.keys(), \ if lhs in self.symbols.keys():
ty, _ = self.symbols[lhs]
_internal_assert(ty != Symbol.LoopVar, \
"Loop variable cannot be overwritten!") "Loop variable cannot be overwritten!")
decl, _, rw = self.usage[lhs] decl, _, rw = self.usage[lhs]
if decl == lhs_: if decl == lhs_:
_internal_assert(lhs not in self.variables.keys() and _internal_assert(lhs not in self.symbols.keys(),
lhs not in self.alloc_buffers.keys(), \
"This value should not be defined before this point!") "This value should not be defined before this point!")
if isinstance(rhs, tuple): if isinstance(rhs, tuple):
shape, dtype, scope = rhs shape, dtype, scope = rhs
ph = _api.placeholder(shape, dtype=dtype, name=lhs) ph = _api.placeholder(shape, dtype=dtype, name=lhs)
self.alloc_buffers[lhs] = (ph, scope) self.symbols[lhs] = getattr(Symbol, scope.title() + "Buffer"), ph
if scope == 'output': if scope == 'output':
self.outputs.append(lhs) self.outputs.append(lhs)
return util.make_nop() return util.make_nop()
if isinstance(rhs, util.halide_imm_types) and ast.Store not in rw: if isinstance(rhs, util.halide_imm_types) and ast.Store not in rw:
self.variables[lhs] = rhs self.symbols[lhs] = Symbol.ConstVar, rhs
else: else:
ph = _api.placeholder((1, ), dtype=rhs.dtype, name=lhs) ph = _api.placeholder((1, ), dtype=rhs.dtype, name=lhs)
self.variables[lhs] = (ph, 'global') self.symbols[lhs] = Symbol.BufferVar, ph
lhs = self.visit(lhs_) lhs = self.visit(lhs_)
if lhs is not None: if lhs is not None:
buf, args = lhs buf, args = lhs
...@@ -275,17 +283,30 @@ class HybridParser(ast.NodeVisitor): ...@@ -275,17 +283,30 @@ class HybridParser(ast.NodeVisitor):
def visit_Attribute(self, node): def visit_Attribute(self, node):
_internal_assert(isinstance(node.value, ast.Name), \ _internal_assert(isinstance(node.value, ast.Name), \
"For atrribute access, only both names are supported so far!") "For atrribute access, only both names are supported so far!")
buf = self._get_buffer_from_id(node.value.id) buf = self.visit(node.value)
return getattr(buf, node.attr) return getattr(buf, node.attr)
def visit_Subscript(self, node): def visit_Subscript(self, node):
args = self.visit(node.slice) args = self.visit(node.slice)
if isinstance(node.value, ast.Name): if isinstance(node.value, ast.Name):
buf = self.visit(node.value) buf = self.visit(node.value)
if isinstance(buf, Array):
for i in args:
if isinstance(i, numbers.Integral):
buf = buf[i]
else:
_internal_assert(isinstance(i, (_expr.IntImm, _expr.UIntImm)), \
"All indices are supposed to be constants")
buf = buf[i.value]
return buf
if isinstance(node.ctx, ast.Load): if isinstance(node.ctx, ast.Load):
return _make.Call(buf.dtype, buf.name, args, \ return _make.Call(buf.dtype, buf.name, args, \
_expr.Call.Halide, buf.op, buf.value_index) _expr.Call.Halide, buf.op, buf.value_index)
return buf, args return buf, args
shape = self.visit(node.value) shape = self.visit(node.value)
...@@ -308,14 +329,14 @@ class HybridParser(ast.NodeVisitor): ...@@ -308,14 +329,14 @@ class HybridParser(ast.NodeVisitor):
_internal_assert(isinstance(context, ast.Call), "The object must be a Python func call!") _internal_assert(isinstance(context, ast.Call), "The object must be a Python func call!")
_internal_assert(isinstance(option, ast.Name), "The object after 'as' must be an id!") _internal_assert(isinstance(option, ast.Name), "The object after 'as' must be an id!")
self.annotation[option.id] = context.func.id self.annotation[option.id] = context.func.id
return list_to_block(self.visit, node.body) return visit_list_to_block(self.visit, node.body)
def visit_If(self, node): def visit_If(self, node):
cond = self.visit(node.test) cond = self.visit(node.test)
if_body = list_to_block(self.visit, node.body) if_body = visit_list_to_block(self.visit, node.body)
if node.orelse: if node.orelse:
else_body = list_to_block(self.visit, node.orelse) else_body = visit_list_to_block(self.visit, node.orelse)
else: else:
else_body = util.make_nop() else_body = util.make_nop()
return _make.IfThenElse(cond, if_body, else_body) return _make.IfThenElse(cond, if_body, else_body)
...@@ -376,7 +397,10 @@ class HybridParser(ast.NodeVisitor): ...@@ -376,7 +397,10 @@ class HybridParser(ast.NodeVisitor):
except AttributeError: except AttributeError:
_internal_assert(func_id in self.symbols.keys(), \ _internal_assert(func_id in self.symbols.keys(), \
"The function called is not in the context either!") "The function called is not in the context either!")
outs = self.symbols[func_id](*args) ty, entry = self.symbols[func_id]
_internal_assert(ty is Symbol.Callable, \
"Are you sure what you call is a function?!")
outs = entry(*args)
op = outs.op if isinstance(outs, Tensor) else outs[0].op op = outs.op if isinstance(outs, Tensor) else outs[0].op
return op return op
...@@ -385,41 +409,66 @@ class HybridParser(ast.NodeVisitor): ...@@ -385,41 +409,66 @@ class HybridParser(ast.NodeVisitor):
iter_var, low, ext, for_type = self.visit(node.iter) iter_var, low, ext, for_type = self.visit(node.iter)
_internal_assert(isinstance(node.target, ast.Name), \ _internal_assert(isinstance(node.target, ast.Name), \
"The loop iterator should be a variable!") "The loop iterator should be a variable!")
_name = node.target.id _name = node.target.id
if iter_var is None:
if isinstance(for_type, tuple):
low = _ir_pass.Simplify(low)
ext = _ir_pass.Simplify(ext)
_internal_assert(isinstance(low, _expr.ConstExpr) and
isinstance(ext, _expr.ConstExpr), \
"Const range should start from a const" + \
"and iterate const times")
low, ext = low.value, ext.value
if ext > 114514:
logging.log(logging.CRITICAL, \
'[Warning] Are you sure to unroll a large loop in Python?')
bodies = []
for i in range(low, low + ext):
self.symbols[_name] = Symbol.ConstLoopVar, i
bodies.append(visit_list_to_block(self.visit, node.body))
return pack_list_to_block(bodies)
elif iter_var is None:
_internal_assert(for_type is not None, "The loop bind function parse error!") _internal_assert(for_type is not None, "The loop bind function parse error!")
offset = iter_var = _api.var(_name) offset = iter_var = _api.var(_name)
if not _ir_pass.Equal(low, self._const(0)): if not _ir_pass.Equal(low, _api.const(0, 'int32')):
offset = iter_var + low offset = iter_var + low
self.loops_above[_name] = offset self.symbols[_name] = Symbol.LoopVar, offset
_body = visit_list_to_block(self.visit, node.body)
else: else:
_internal_assert(for_type is None, "The loop iterating function parse error!") _internal_assert(for_type is None, "The loop iterating function parse error!")
self.loops_above[_name] = iter_var.var self.symbols[_name] = Symbol.LoopVar, iter_var.var
_body = list_to_block(self.visit, node.body) _body = visit_list_to_block(self.visit, node.body)
_body = self.wrap_up_realize(node, _body) _body = self.wrap_up_realize(node, _body)
if for_type is None: if for_type is None:
res = _make.AttrStmt(iter_var, 'thread_extent', ext, _body) res = _make.AttrStmt(iter_var, 'thread_extent', ext, _body)
else: elif not isinstance(for_type, tuple):
res = _make.For(iter_var, self._const(0), ext, for_type, 0, _body) res = _make.For(iter_var, _api.const(0, 'int32'), ext, for_type, 0, _body)
self.loops_above.pop(_name) self.symbols.pop(_name)
return res return res
def visit_Return(self, node): def visit_Return(self, node):
_internal_assert(not self.loops_above, "Return should not be in a loop body!") _internal_assert(all(ty != Symbol.LoopVar for ty, _ in self.symbols.values()), \
"Return should not be in a loop body!")
ids = [] ids = []
if isinstance(node.value, ast.Name): if isinstance(node.value, ast.Name):
ids.append(node.value.id) ids = [node.value.id]
else: else:
_internal_assert(isinstance(node.value, ast.Tuple), \ _internal_assert(isinstance(node.value, ast.Tuple), \
"You should return either a single tensor or a tuple") "You should return either a single tensor or a tuple")
for i in node.value.elts: _internal_assert(all(isinstance(i, ast.Name) for i in node.value.elts), \
_internal_assert(isinstance(i, ast.Name), "What do you return?") "What do you return?")
ids.append(i.id) ids = [i.id for i in node.value.elts]
_internal_assert(len(set(ids)) == len(ids), "Duplicated tensors in the return tuples") _internal_assert(len(set(ids)) == len(ids), "Duplicated tensors in the return tuples")
if len(ids) < len(self.outputs): if len(ids) < len(self.outputs):
logging.log(logging.CRITICAL, '[Warning] Not all the output buffers returned!') logging.log(logging.CRITICAL, '[Warning] Not all the output buffers returned!')
self.outputs = [self.alloc_buffers[i][0] for i in ids] self.outputs = [self.symbols[i][1] for i in ids]
self.returned = True self.returned = True
return util.make_nop() return util.make_nop()
......
...@@ -11,12 +11,13 @@ from .. import api as _api ...@@ -11,12 +11,13 @@ from .. import api as _api
from .. import make as _make from .. import make as _make
from .. import expr as _expr from .. import expr as _expr
from .. import stmt as _stmt from .. import stmt as _stmt
from ..container import Array
from ..tensor import Tensor from ..tensor import Tensor
#pylint: disable=invalid-name #pylint: disable=invalid-name
np_arg_types = tuple(list(numeric_types) + [numpy.ndarray]) np_arg_types = tuple(list(numeric_types) + [numpy.ndarray])
tvm_arg_types = (Tensor, _expr.Var, _expr.ConstExpr) tvm_arg_types = (Tensor, Array, _expr.Var, _expr.ConstExpr)
halide_imm_types = (_expr.IntImm, _expr.FloatImm, _expr.UIntImm) halide_imm_types = (_expr.IntImm, _expr.FloatImm, _expr.UIntImm)
def _internal_assert(cond, err): def _internal_assert(cond, err):
......
...@@ -13,7 +13,7 @@ def run_and_check(func, args, var_dict={}, target='llvm'): ...@@ -13,7 +13,7 @@ def run_and_check(func, args, var_dict={}, target='llvm'):
ctx = tvm.context(target, 0) ctx = tvm.context(target, 0)
op = None op = None
outs = func(*args) outs = func(*tuple(tvm.convert(i) if isinstance(i, list) else i for i in args))
op = outs[0].op if isinstance(outs, list) else outs.op op = outs[0].op if isinstance(outs, list) else outs.op
emu_args = [] emu_args = []
...@@ -23,13 +23,18 @@ def run_and_check(func, args, var_dict={}, target='llvm'): ...@@ -23,13 +23,18 @@ def run_and_check(func, args, var_dict={}, target='llvm'):
shape = [tvm_val_2_py_val(j) for j in i.shape] shape = [tvm_val_2_py_val(j) for j in i.shape]
emu_args.append(numpy.random.randn(*shape).astype(i.dtype)) emu_args.append(numpy.random.randn(*shape).astype(i.dtype))
nd_args.append(tvm.nd.array(emu_args[-1], ctx)) nd_args.append(tvm.nd.array(emu_args[-1], ctx))
else: elif isinstance(i, tvm.expr.Var):
assert isinstance(i, tvm.expr.Var)
emu_args.append(tvm_val_2_py_val(i)) emu_args.append(tvm_val_2_py_val(i))
nd_args.append(emu_args[-1]) nd_args.append(emu_args[-1])
else:
assert isinstance(i, list)
emu_args.append(numpy.array(i))
sch = tvm.create_schedule(op) sch = tvm.create_schedule(op)
module = tvm.build(sch, args + (outs if isinstance(outs, list) else [outs]), target=target) module = tvm.build(sch,
[i for i in args if isinstance(i, (tvm.tensor.Tensor, tvm.expr.Var))] + \
(outs if isinstance(outs, list) else [outs]),
target=target)
assert module assert module
out_tensors = [] out_tensors = []
...@@ -192,20 +197,20 @@ def test_fanout(): ...@@ -192,20 +197,20 @@ def test_fanout():
def test_looptype(): def test_looptype():
@script @script
def looptype(a, b, c): def looptype(a, b, c):
d = output_tensor((8, ), 'int32') d = output_tensor((16, ), 'int32')
e = output_tensor((8, ), 'int32') e = output_tensor((16, ), 'int32')
f = output_tensor((8, ), 'int32') f = output_tensor((16, ), 'int32')
for i in parallel(8): for i in parallel(16):
d[i] = a[i] d[i] = a[i]
for j in vectorize(8): for j in vectorize(16):
e[j] = b[j] e[j] = b[j]
for k in unroll(8): for k in unroll(16):
f[k] = c[k] f[k] = c[k]
return d, e, f return d, e, f
a = tvm.placeholder((8, ), name='a', dtype='int32') a = tvm.placeholder((16, ), name='a', dtype='int32')
b = tvm.placeholder((8, ), name='b', dtype='int32') b = tvm.placeholder((16, ), name='b', dtype='int32')
c = tvm.placeholder((8, ), name='c', dtype='int32') c = tvm.placeholder((16, ), name='c', dtype='int32')
try: try:
d, e, f = looptype(a, b, c) d, e, f = looptype(a, b, c)
ir = d.op.body ir = d.op.body
...@@ -509,9 +514,9 @@ def test_value_index(): ...@@ -509,9 +514,9 @@ def test_value_index():
def test_func_call(): def test_func_call():
@tvm.hybrid.script @tvm.hybrid.script
def foo(a, b): def foo(a, b):
for i in range(10): for i in range(len(a)):
a[i] = i + 1.0 a[i] = i + 1.0
for i in range(10): for i in range(len(a)):
b[i] = i + 1.0 b[i] = i + 1.0
c = outer_product(10, 10, a, b) c = outer_product(10, 10, a, b)
d = output_tensor(c.shape, c.dtype) d = output_tensor(c.shape, c.dtype)
...@@ -538,6 +543,26 @@ def test_bool(): ...@@ -538,6 +543,26 @@ def test_bool():
a = tvm.placeholder((10, ), name='a') a = tvm.placeholder((10, ), name='a')
run_and_check(foo, [a]) run_and_check(foo, [a])
def test_const_range():
@tvm.hybrid.script
def foo(a, b):
c = output_tensor(a.shape, a.dtype)
d = output_tensor(a.shape, a.dtype)
for i in const_range(2):
for j in const_range(5):
c[i, j] = a[i, j] + b[i, j]
for i in const_range(len(b)):
for j in const_range(len(b[0])):
d[i, j] = a[i, j] + b[i, j]
return c, d
a = tvm.placeholder((2, 5), name='a', dtype='int32')
b = [[1, 2, 3, 4, 5], [5, 4, 3, 2, 1]]
run_and_check(foo, [a, b])
if __name__ == "__main__": if __name__ == "__main__":
test_outer_product() test_outer_product()
test_fanout() test_fanout()
...@@ -553,5 +578,6 @@ if __name__ == "__main__": ...@@ -553,5 +578,6 @@ if __name__ == "__main__":
test_value_index() test_value_index()
test_func_call() test_func_call()
test_bool() test_bool()
test_const_range()
# TODO: # TODO:
# test_inplace() # test_inplace()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment