Commit 838e7181 by Jian Weng Committed by Tianqi Chen

[Hybrid Script] Inter-function call supported! (#2287)

parent 001ab525
...@@ -24,17 +24,15 @@ def script(pyfunc): ...@@ -24,17 +24,15 @@ def script(pyfunc):
from .util import _enter_hybrid_runtime, _restore_runtime, _is_tvm_arg_types from .util import _enter_hybrid_runtime, _restore_runtime, _is_tvm_arg_types
if _is_tvm_arg_types(args): if _is_tvm_arg_types(args):
src = _pruned_source(func) src = _pruned_source(func)
parser = parse_python(src, args) parser = parse_python(src, func.__globals__, args)
input_tensors = [] input_tensors = []
for i in args: for i in args:
if isinstance(i, Tensor): if isinstance(i, Tensor):
input_tensors.append(i) input_tensors.append(i)
op = _tvm_internal._HybridOp(parser.func_name, "HybridOp", None, input_tensors, op = _tvm_internal._HybridOp(parser.func_name, "HybridOp", None, input_tensors,
parser.outputs, parser.parsed_body) parser.outputs, parser.parsed_body)
res = [op.output(i) for i in range(len(parser.outputs))] res = [op.output(i) for i in range(len(parser.outputs))]
return res[0] if len(res) == 1 else res return res[0] if len(res) == 1 else res
intersect = _enter_hybrid_runtime(func) intersect = _enter_hybrid_runtime(func)
......
"""Intrinsics of TVM-Python Hybrid Script for Python compilation time
semantic support."""
from .. import api as _api
from .. import expr as _expr
from .. import make as _make
from ..container import Array
from .. import ir_pass
from ..stmt import For
from .util import _internal_assert
#pylint: disable=redefined-builtin
LOOP_INTRIN = {
'range' : For.Serial,
'unroll' : For.Unrolled,
'parallel' : For.Parallel,
'vectorize': For.Vectorized,
}
def _range(annotation, args):
"""Handling TVM loop types"""
n = len(args)
if n == 1:
low, ext = _api.const(0, dtype='int32'), args[0]
else:
_internal_assert(n == 2, "A loop intrinsic should only have 1 or 2 arguments!")
low, ext = args[0], args[1]
if not ir_pass.Equal(low, _api.const(0, dtype='int32')):
ext = ext - low
for_type = LOOP_INTRIN[annotation]
iter_var = None
return iter_var, low, ext, for_type
range = unroll = vectorize = parallel = _range #pylint: disable=invalid-name
def bind(func_id, args):
"""Handling TVM thread binding"""
_internal_assert(func_id == "bind", "This function cannot be directly invoked!")
_internal_assert(len(args) == 2, "A loop bind should only have 2 arguments!")
_internal_assert(isinstance(args[0], str), \
"A loop bind's first argument should be a string!")
iter_var = _api.thread_axis(args[0])
low, ext = _api.const(0), args[1]
for_type = None
return iter_var, low, ext, for_type
def _math_intrin(func_id, args):
from .. import intrin
return getattr(intrin, func_id)(*args)
sqrt = log = exp = tanh = sigmoid = power = popcount = _math_intrin #pylint: disable=invalid-name
def _min_max(func_id, args):
_internal_assert(len(args) == 2, "Max/Min function should have 2 elements")
return getattr(_make, func_id.title())(args[0], args[1])
min = max = _min_max #pylint: disable=invalid-name
def _allocate_tensor(func_id, args):
"""Handling TVM tensor allocation.
You may refer hybrid.intrin.allocate for more details."""
n = len(args)
_internal_assert(isinstance(_api.convert(args[0]), Array), \
"allocate's first argument should be a tuple of shape!")
shape = args[0]
for i in shape:
_internal_assert(isinstance(i, _expr.Expr), "The shape should be an expression")
if n > 1:
_internal_assert(isinstance(args[1], str),
"The data type should be an str")
_internal_assert(args[1].startswith('int') or args[1].startswith('float'), \
"The data type should be either int or float!")
dtype = args[1]
else:
dtype = 'float32'
if n > 2:
_internal_assert(isinstance(args[2], str), \
"The data scope should be an string")
_internal_assert(func_id != 'output_tensor', "Output tensor cannot specify scope")
scope = args[2]
else:
scope = 'global' if func_id != 'output_tensor' else 'output'
return (shape, dtype, scope)
output_tensor = allocate = _allocate_tensor #pylint: disable=invalid-name
"""Intrinsics of TVM-Python Hybrid Script for Python runtime""" """Intrinsics of TVM-Python Hybrid Script for Python emulation runtime"""
import numpy import numpy
from ..stmt import For
class _range(object): class _range(object):
"""Base class of the loop ranges in hybrid script""" """Base class of the loop ranges in hybrid script"""
...@@ -102,15 +101,3 @@ HYBRID_GLOBALS = { ...@@ -102,15 +101,3 @@ HYBRID_GLOBALS = {
'sigmoid' : sigmoid, 'sigmoid' : sigmoid,
'popcount' : popcount 'popcount' : popcount
} }
LOOP_INTRIN = {
'range' : For.Serial,
'unroll' : For.Unrolled,
'parallel' : For.Parallel,
'vectorize': For.Vectorized,
'bind' : None
}
MATH_INTRIN = ['sqrt', 'log', 'exp', 'tanh', 'sigmoid', 'power', 'popcount']
...@@ -4,24 +4,24 @@ import ast ...@@ -4,24 +4,24 @@ import ast
import operator import operator
import logging import logging
import sys import sys
from .util import make_nop, halide_imm_types, is_docstring, _internal_assert from .util import _internal_assert
from .intrin import LOOP_INTRIN, MATH_INTRIN from . import calls
from . import util
from .var_decl import determine_variable_usage from .var_decl import determine_variable_usage
from ..api import thread_axis
from ..api import all as _all from ..api import all as _all
from ..api import any as _any from ..api import any as _any
from ..tensor import Tensor, Operation
from .. import expr as _expr from .. import expr as _expr
from .. import make as _make from .. import make as _make
from .. import intrin
from .. import api as _api from .. import api as _api
from .. import ir_pass as _ir_pass from .. import ir_pass as _ir_pass
def list_to_block(visit, lst): def list_to_block(visit, lst):
"""Convert a list of Python IR nodes to HalideIR Block""" """Convert a list of Python IR nodes to HalideIR Block"""
lst = [visit(stmt) for stmt in lst if not is_docstring(stmt)] lst = [visit(stmt) for stmt in lst if not util.is_docstring(stmt)]
lst = [stmt for stmt in lst if not _ir_pass.Equal(stmt, make_nop())] lst = [stmt for stmt in lst if not _ir_pass.Equal(stmt, util.make_nop())]
if not lst: if not lst:
return make_nop() return util.make_nop()
if len(lst) == 1: if len(lst) == 1:
return lst[0] return lst[0]
body = lst[0] body = lst[0]
...@@ -62,7 +62,7 @@ class HybridParser(ast.NodeVisitor): ...@@ -62,7 +62,7 @@ class HybridParser(ast.NodeVisitor):
} }
def __init__(self, args, usage, func_name=None): def __init__(self, args, usage, symbols, func_name=None):
""" """
Parameters Parameters
---------- ----------
...@@ -81,32 +81,49 @@ class HybridParser(ast.NodeVisitor): ...@@ -81,32 +81,49 @@ class HybridParser(ast.NodeVisitor):
self.args = list(args) self.args = list(args)
self.usage = usage.copy() self.usage = usage.copy()
self._args = {} # Dict maps arg name to actual arg instance (either a var or a buffer) self._args = {} # Dict maps arg name to actual arg instance (either a var or a buffer)
self.alloc_buffers = {} # Buffers formed by allocate instructions self.alloc_buffers = {} # Buffers formed by explicit allocate instructions
self.loops_above = {} # State variable that indicates loop levels above the current node self.loops_above = {} # State variable that indicates loop levels above the current node
self.var_consts = {} # Variables that are determined as readonly in previous stage self.variables = {} # The status of defined variables
self.func_name = func_name # The name of the function to be lowered self.func_name = func_name # The name of the function to be lowered
self.outputs = [] # Output tensors' name self.outputs = [] # Output tensors' name
self.side_effect = set() # Tensors with side effects self.side_effect = set() # Tensors with side effects
self.parsed_body = None # The parsed HalideIR body self.parsed_body = None # The parsed HalideIR body
self.returned = False self.returned = False # If this function has a valid return
self.symbols = symbols # The global context
def wrap_up_realize(self, node, body): def wrap_up_realize(self, node, body):
"""Wrap up all the variables which will no longer be used""" """Wrap up all the variables which will no longer be used"""
pop_buf = []
pop_var = []
for key, val in self.usage.items(): for key, val in self.usage.items():
if key in self.var_consts.keys():
continue
_, level, _ = val _, level, _ = val
if level == node: if level != node:
continue
if key in self._args.keys(): if key in self._args.keys():
continue continue
else: if key in self.alloc_buffers.keys():
_buf, _scope = self.alloc_buffers[key] _buf, _scope = self.alloc_buffers[key]
if _scope == 'output':
continue
pop_buf.append(key)
else:
_internal_assert(key in self.variables.keys(),
"Key should be either in one of args, buffers, and vars")
if not isinstance(self.variables[key], tuple):
continue
_buf, _scope = self.variables[key]
pop_var.append(key)
_domain = [_make.range_by_min_extent(0, i) for i in _buf.shape] _domain = [_make.range_by_min_extent(0, i) for i in _buf.shape]
_dtype = _buf.dtype _dtype = _buf.dtype
_true = _api.convert(True) _true = _api.convert(True)
body = _make.Realize(_buf.op, 0, _dtype, _domain, _true, body) body = _make.Realize(_buf.op, 0, _dtype, _domain, _true, body)
body = _make.AttrStmt(_buf.op, 'realize_scope', _api.convert(_scope), body) body = _make.AttrStmt(_buf.op, 'realize_scope', _api.convert(_scope), body)
for elem in pop_buf:
self.alloc_buffers.pop(elem)
for elem in pop_var:
self.variables.pop(elem)
return body return body
...@@ -121,7 +138,6 @@ class HybridParser(ast.NodeVisitor): ...@@ -121,7 +138,6 @@ class HybridParser(ast.NodeVisitor):
return self.alloc_buffers[s][0] return self.alloc_buffers[s][0]
#pylint: disable=invalid-name, missing-docstring #pylint: disable=invalid-name, missing-docstring
def visit_Module(self, node): def visit_Module(self, node):
_internal_assert(len(node.body) == 1, \ _internal_assert(len(node.body) == 1, \
...@@ -133,13 +149,13 @@ class HybridParser(ast.NodeVisitor): ...@@ -133,13 +149,13 @@ class HybridParser(ast.NodeVisitor):
_internal_assert(len(node.args.args) == len(self.args), \ _internal_assert(len(node.args.args) == len(self.args), \
"The number of arguments passed to the \ "The number of arguments passed to the \
function should be the same as it is defined!") function should be the same as it is defined!")
if self.func_name is None:
self.func_name = node.name
for idx, arg in enumerate(node.args.args): for idx, arg in enumerate(node.args.args):
_attr = 'id' if sys.version_info[0] < 3 else 'arg' # To make py2 and 3 compatible _attr = 'id' if sys.version_info[0] < 3 else 'arg' # To make py2 and 3 compatible
self._args[getattr(arg, _attr)] = self.args[idx] self._args[getattr(arg, _attr)] = self.args[idx]
res = list_to_block(self.visit, node.body) res = list_to_block(self.visit, node.body)
res = self.wrap_up_realize(node, res) res = self.wrap_up_realize(node, res)
if self.func_name is None:
self.func_name = node.name
return res return res
...@@ -148,23 +164,22 @@ class HybridParser(ast.NodeVisitor): ...@@ -148,23 +164,22 @@ class HybridParser(ast.NodeVisitor):
def visit_Name(self, node): def visit_Name(self, node):
_id = node.id name = node.id
if _id in self._args.keys() and isinstance(self._args[_id], (_expr.Var, _expr.ConstExpr)): if name in self.loops_above.keys():
return self._args[_id] return self.loops_above[name]
elif _id in self.loops_above.keys(): elif name in self.variables.keys():
return self.loops_above[_id] res = self.variables[name]
_internal_assert(_id not in self._args.keys(), \ if isinstance(res, tuple):
"This id %s should be handled in visit_Subscript!" % _id) buf = res[0]
_internal_assert(_id in self.usage.keys(), \ if isinstance(node.ctx, ast.Load):
"This id %s is expected to be a defined variable!" % _id) return _make.Call(buf.dtype, buf.name, [_api.const(0)], \
# Buffer _expr.Call.Halide, buf.op, buf.value_index)
if _id in self.alloc_buffers.keys(): return buf, [_api.const(0)]
_buf, _ = self.alloc_buffers[_id] if isinstance(node.ctx, ast.Load):
return _make.Call(_buf.dtype, _id, [_api.const(0)], _expr.Call.Halide, _buf.op, 0) return res
# Compilation time constant return None
_internal_assert(_id in self.var_consts.keys(), buf = self._get_buffer_from_id(name)
"This id %s is expected to a compilation time constant!" % _id) return buf
return self.var_consts[_id]
def visit_Num(self, node): def visit_Num(self, node):
...@@ -172,18 +187,36 @@ class HybridParser(ast.NodeVisitor): ...@@ -172,18 +187,36 @@ class HybridParser(ast.NodeVisitor):
def visit_AugAssign(self, node): def visit_AugAssign(self, node):
lhs = self.visit(node.target) buf = self.visit(node.target)
rhs = self.visit(node.value) rhs = self.visit(node.value)
rhs = HybridParser._binop_maker[type(node.op)](lhs, rhs) if isinstance(buf, tuple):
_internal_assert(isinstance(lhs, _expr.Call), \ _internal_assert(len(buf) == 2, "LHS is supposed to be (buf, args)!")
"The LHS of an AugAssign is supposed to be a call!") buf, args = buf
return _make.Provide(lhs.func, 0, rhs, lhs.args) else:
args = [_api.const(0)]
_internal_assert(isinstance(buf, Tensor), "LHS is supposed to be Tensor!")
read = _make.Call(buf.dtype, buf.name, args, _expr.Call.Halide, buf.op, buf.value_index)
value = HybridParser._binop_maker[type(node.op)](read, rhs)
return _make.Provide(buf.op, 0, value, args)
def visit_Assign(self, node): def visit_Assign(self, node):
rhs = self.visit(node.value)
if isinstance(rhs, Operation):
rmap = {}
_internal_assert(len(node.targets) == rhs.num_outputs, \
"Unable to detuple the outs to targets")
for i in range(rhs.num_outputs):
_internal_assert(isinstance(node.targets[i], ast.Name),
"You should bind a pure name to the tensors")
self.alloc_buffers[node.targets[i].id] = (rhs.output(i), 'global')
rmap[rhs.outputs[i].op] = rhs.output(i)
return util.replace_io(rhs.body, rmap)
_internal_assert(len(node.targets) == 1, "So far only one-valued assignment is supported!") _internal_assert(len(node.targets) == 1, "So far only one-valued assignment is supported!")
lhs = node.targets[0] lhs = node.targets[0]
rhs = self.visit(node.value)
if isinstance(rhs, _expr.Expr): if isinstance(rhs, _expr.Expr):
rhs = _ir_pass.Simplify(rhs) rhs = _ir_pass.Simplify(rhs)
if isinstance(lhs, ast.Name): if isinstance(lhs, ast.Name):
...@@ -194,65 +227,63 @@ class HybridParser(ast.NodeVisitor): ...@@ -194,65 +227,63 @@ class HybridParser(ast.NodeVisitor):
"Loop variable cannot be overwritten!") "Loop variable cannot be overwritten!")
decl, _, rw = self.usage[lhs] decl, _, rw = self.usage[lhs]
if decl == lhs_: if decl == lhs_:
_internal_assert(lhs not in self.var_consts.keys(), \ _internal_assert(lhs not in self.variables.keys() and
"A constant cannot be overwritten!") lhs not in self.alloc_buffers.keys(), \
_internal_assert(lhs not in self.alloc_buffers.keys(), \
"This value should not be defined before this point!") "This value should not be defined before this point!")
if isinstance(rhs, tuple): if isinstance(rhs, tuple):
shape, dtype, scope = rhs shape, dtype, scope = rhs
ph = _api.placeholder(shape, dtype=dtype, name=lhs) ph = _api.placeholder(shape, dtype=dtype, name=lhs)
if scope != 'output':
self.alloc_buffers[lhs] = (ph, scope) self.alloc_buffers[lhs] = (ph, scope)
else: if scope == 'output':
self._args[lhs] = ph
self.outputs.append(lhs) self.outputs.append(lhs)
return make_nop() return util.make_nop()
if isinstance(rhs, halide_imm_types) and ast.Store not in rw: if isinstance(rhs, util.halide_imm_types) and ast.Store not in rw:
self.var_consts[lhs] = rhs self.variables[lhs] = rhs
else: else:
ph = _api.placeholder((1, ), dtype=rhs.dtype, name=lhs) ph = _api.placeholder((1, ), dtype=rhs.dtype, name=lhs)
self.alloc_buffers[lhs] = (ph, 'global') self.variables[lhs] = (ph, 'global')
if lhs in self.var_consts.keys(): lhs = self.visit(lhs_)
return make_nop() if lhs is not None:
_internal_assert(lhs in self.alloc_buffers.keys(), \ buf, args = lhs
"This variable should be defined before!") return _make.Provide(buf.op, 0, rhs, args)
tgt, _ = self.alloc_buffers[lhs] return util.make_nop()
return _make.Provide(tgt.op, 0, rhs, [_api.const(0, dtype=rhs.dtype)])
else: else:
lhs = self.visit(lhs) lhs, args = self.visit(lhs)
_internal_assert(isinstance(lhs, _expr.Call), \ _internal_assert(isinstance(lhs, Tensor), \
"An array access's LHS is expected to be a expr.Call!") "An array access's LHS is expected to be a expr.Call!")
#TODO: support slice later res = _make.Provide(lhs.op, lhs.value_index, rhs, args)
buf = self._get_buffer_from_id(lhs.name, for_provide=True) return res
return _make.Provide(buf.op, 0, rhs, lhs.args)
def visit_Index(self, node): def visit_Index(self, node):
if isinstance(node.value, ast.Tuple): if isinstance(node.value, ast.Tuple):
return [self.visit(i) for i in node.value.elts] return self.visit(node.value)
return [self.visit(node.value)] return [self.visit(node.value)]
def visit_Attribute(self, node):
_internal_assert(isinstance(node.value, ast.Name), \
"For atrribute access, only both names are supported so far!")
buf = self._get_buffer_from_id(node.value.id)
return getattr(buf, node.attr)
def visit_Subscript(self, node): def visit_Subscript(self, node):
args = self.visit(node.slice) args = self.visit(node.slice)
if isinstance(node.value, ast.Name): if isinstance(node.value, ast.Name):
array = node.value.id buf = self.visit(node.value)
_buf = self._get_buffer_from_id(array) if isinstance(node.ctx, ast.Load):
return _make.Call(_buf.dtype, array, args, _expr.Call.Halide, _buf.op, _buf.value_index) return _make.Call(buf.dtype, buf.name, args, \
_expr.Call.Halide, buf.op, buf.value_index)
_internal_assert(isinstance(node.value, ast.Attribute), \ return buf, args
"Only variable and attribute's subscript supported so far")
_internal_assert(isinstance(node.value.value, ast.Name), \ shape = self.visit(node.value)
"The root of array access is expect to be a id!")
_internal_assert(node.value.attr == "shape", \
"Attribute access so far only 'shape' is supported!")
_internal_assert(len(args) == 1, "For 'shape' access the argument should be only one!") _internal_assert(len(args) == 1, "For 'shape' access the argument should be only one!")
args = args[0] args = args[0]
#TODO: maybe support non-constant value later? #TODO: maybe support non-constant value later?
_internal_assert(isinstance(args, (_expr.IntImm, _expr.UIntImm)), \ _internal_assert(isinstance(args, (_expr.IntImm, _expr.UIntImm)), \
"So far only constant shape access supported!") "So far only constant shape access supported!")
buf = self._get_buffer_from_id(node.value.value.id) return shape[args.value]
return buf.shape[args.value]
def visit_With(self, node): def visit_With(self, node):
...@@ -275,7 +306,7 @@ class HybridParser(ast.NodeVisitor): ...@@ -275,7 +306,7 @@ class HybridParser(ast.NodeVisitor):
if node.orelse: if node.orelse:
else_body = list_to_block(self.visit, node.orelse) else_body = list_to_block(self.visit, node.orelse)
else: else:
else_body = make_nop() else_body = util.make_nop()
return _make.IfThenElse(cond, if_body, else_body) return _make.IfThenElse(cond, if_body, else_body)
...@@ -305,13 +336,10 @@ class HybridParser(ast.NodeVisitor): ...@@ -305,13 +336,10 @@ class HybridParser(ast.NodeVisitor):
_internal_assert(isinstance(node.op, ast.Not), \ _internal_assert(isinstance(node.op, ast.Not), \
"Unary is supposed to be not!") "Unary is supposed to be not!")
return operator.not_(self.visit(node.values[0])) return operator.not_(self.visit(node.values[0]))
elif n == 2:
_internal_assert(isinstance(node.op, (ast.And, ast.Or)), \ _internal_assert(isinstance(node.op, (ast.And, ast.Or)), \
"Binary is supposed to be and/or!") "Binary is supposed to be and/or!")
values = [self.visit(i) for i in node.values] values = [self.visit(i) for i in node.values]
return HybridParser._binop_maker[type(node.op)](*values) return HybridParser._binop_maker[type(node.op)](*values)
else:
raise ValueError("This Bool Op is not supported yet!")
def visit_UnaryOp(self, node): def visit_UnaryOp(self, node):
...@@ -329,67 +357,17 @@ class HybridParser(ast.NodeVisitor): ...@@ -329,67 +357,17 @@ class HybridParser(ast.NodeVisitor):
# Yet, no function pointer supported # Yet, no function pointer supported
_internal_assert(isinstance(node.func, ast.Name), \ _internal_assert(isinstance(node.func, ast.Name), \
"Only id-function function call is supported so far!") "Only id-function function call is supported so far!")
func_id = node.func.id func_id = node.func.id
n = len(node.args) args = [self.visit(i) for i in node.args]
if func_id in LOOP_INTRIN.keys() and func_id != 'bind': try:
if n == 1: return getattr(calls, func_id)(func_id, args)
low, ext = _api.const(0, dtype='int32'), self.visit(node.args[0]) except AttributeError:
else: _internal_assert(func_id in self.symbols.keys(), \
_internal_assert(n == 2, "A loop intrinsic should only have 1 or 2 arguments!") "The function called is not in the context either!")
low, ext = self.visit(node.args[0]), self.visit(node.args[1]) outs = self.symbols[func_id](*args)
if not _ir_pass.Equal(low, _api.const(0, dtype='int32')): op = outs.op if isinstance(outs, Tensor) else outs[0].op
ext = ext - low return op
for_type = LOOP_INTRIN[func_id]
iter_var = None
return iter_var, low, ext, for_type
elif func_id == 'bind':
_internal_assert(n == 2, "A loop bind should only have 2 arguments!")
_internal_assert(isinstance(node.args[0], ast.Str), \
"A loop bind's first argument should be a string!")
_vn = node.args[0].s
iter_var = thread_axis(node.args[0].s)
low, ext = _api.const(0, dtype='int32'), self.visit(node.args[1])
for_type = None
return iter_var, low, ext, for_type
elif func_id in MATH_INTRIN:
return getattr(intrin, func_id)(*[self.visit(arg) for arg in node.args])
elif func_id in ['allocate', 'output_tensor']:
_internal_assert(isinstance(node.args[0], ast.Tuple), \
"allocate's first argument should be a tuple of shape!")
shape = tuple(self.visit(i) for i in node.args[0].elts)
if func_id == 'output_tensor':
_internal_assert(not self.loops_above, \
"Are you sure to allocate a output buffer multiple times?")
for i in shape:
_internal_assert(isinstance(i, _expr.Expr), "The shape should be an expression")
if n > 1:
if isinstance(node.args[1], ast.Str):
dtype = node.args[1].s
else:
_internal_assert(isinstance(node.args[1], ast.Attribute), \
"Unable to evaluate to get data type")
to_eval = node.args[1]
_internal_assert(isinstance(to_eval.value, ast.Name), \
"Unable to evaluate the attribute to get data type")
_internal_assert(to_eval.attr == 'dtype', \
"Only dtype attribute is supported so far")
dtype = self._get_buffer_from_id(to_eval.value.id).dtype
else:
dtype = 'float32'
if n > 2:
_internal_assert(isinstance(node.args[2], ast.Str), \
"The data scope should be an string")
_internal_assert(func_id != 'output_tensor', "Output tensor cannot specify scope")
scope = node.args[2].s
else:
scope = 'global' if func_id != 'output_tensor' else 'output'
return (shape, dtype, scope)
elif func_id == 'max' or func_id == 'min':
_internal_assert(n == 2, "Max/Min function should have 2 elements")
a, b = self.visit(node.args[0]), self.visit(node.args[1])
return getattr(_make, func_id.title())(a, b)
else:
raise ValueError("Function call not supported yet!")
def visit_For(self, node): def visit_For(self, node):
...@@ -400,7 +378,7 @@ class HybridParser(ast.NodeVisitor): ...@@ -400,7 +378,7 @@ class HybridParser(ast.NodeVisitor):
if iter_var is None: if iter_var is None:
_internal_assert(for_type is not None, "The loop bind function parse error!") _internal_assert(for_type is not None, "The loop bind function parse error!")
offset = iter_var = _api.var(_name) offset = iter_var = _api.var(_name)
if not _ir_pass.Equal(low, _api.const(0, dtype='int32')): if not _ir_pass.Equal(low, _api.const(0)):
offset = iter_var + low offset = iter_var + low
self.loops_above[_name] = offset self.loops_above[_name] = offset
else: else:
...@@ -411,7 +389,7 @@ class HybridParser(ast.NodeVisitor): ...@@ -411,7 +389,7 @@ class HybridParser(ast.NodeVisitor):
if for_type is None: if for_type is None:
res = _make.AttrStmt(iter_var, 'thread_extent', ext, _body) res = _make.AttrStmt(iter_var, 'thread_extent', ext, _body)
else: else:
res = _make.For(iter_var, _api.const(0, dtype='int32'), ext, for_type, 0, _body) res = _make.For(iter_var, _api.const(0), ext, for_type, 0, _body)
self.loops_above.pop(_name) self.loops_above.pop(_name)
return res return res
...@@ -428,14 +406,22 @@ class HybridParser(ast.NodeVisitor): ...@@ -428,14 +406,22 @@ class HybridParser(ast.NodeVisitor):
_internal_assert(isinstance(i, ast.Name), "What do you return?") _internal_assert(isinstance(i, ast.Name), "What do you return?")
ids.append(i.id) ids.append(i.id)
_internal_assert(len(set(ids)) == len(ids), "Duplicated tensors in the return tuples") _internal_assert(len(set(ids)) == len(ids), "Duplicated tensors in the return tuples")
if len(ids) != len(self.outputs): if len(ids) < len(self.outputs):
logging.log(logging.CRITICAL, '[Warning] Not all the output buffers returned!') logging.log(logging.CRITICAL, '[Warning] Not all the output buffers returned!')
self.outputs = [self._args[i] for i in ids] self.outputs = [self.alloc_buffers[i][0] for i in ids]
self.returned = True self.returned = True
return make_nop() return util.make_nop()
def visit_Tuple(self, node):
return tuple(self.visit(i) for i in node.elts)
def parse_python(src, args): def visit_Str(self, node):
return node.s
def parse_python(src, symbols, args):
"""The helper function of calling the AST visitor """The helper function of calling the AST visitor
Parameters Parameters
...@@ -443,6 +429,9 @@ def parse_python(src, args): ...@@ -443,6 +429,9 @@ def parse_python(src, args):
src : str src : str
The source code of the function to be parsed. The source code of the function to be parsed.
src : str
The symbol list of the global context of the function.
args : list of Tensors or Vars args : list of Tensors or Vars
The argument lists to the function. The argument lists to the function.
It is NOT encouraged to write a function without arguments. It is NOT encouraged to write a function without arguments.
...@@ -454,8 +443,8 @@ def parse_python(src, args): ...@@ -454,8 +443,8 @@ def parse_python(src, args):
The result Halide IR and the parser class instance. The result Halide IR and the parser class instance.
""" """
root = ast.parse(src) root = ast.parse(src)
var_usage = determine_variable_usage(root, args) var_usage = determine_variable_usage(root, args, symbols)
parser = HybridParser(args, var_usage) parser = HybridParser(args, var_usage, symbols)
parser.parsed_body = parser.visit(root) parser.parsed_body = parser.visit(root)
_internal_assert(parser.returned, 'No valid return found in the function body!') _internal_assert(parser.returned, 'No valid return found in the function body!')
return parser return parser
...@@ -10,6 +10,7 @@ from .._ffi.base import numeric_types ...@@ -10,6 +10,7 @@ from .._ffi.base import numeric_types
from .. import api as _api from .. import api as _api
from .. import make as _make from .. import make as _make
from .. import expr as _expr from .. import expr as _expr
from .. import stmt as _stmt
from ..tensor import Tensor from ..tensor import Tensor
...@@ -86,3 +87,20 @@ def _restore_runtime(func, intersect): ...@@ -86,3 +87,20 @@ def _restore_runtime(func, intersect):
_globals.pop(elem) _globals.pop(elem)
for k, v in intersect: for k, v in intersect:
_globals[k] = v _globals[k] = v
def replace_io(body, rmap):
"""Replacing tensors usage according to the dict given"""
from .. import ir_pass
def replace(op):
if isinstance(op, _stmt.Provide) and op.func in rmap.keys():
buf = rmap[op.func]
return _make.Provide(buf.op, op.value_index, op.value, op.args)
elif isinstance(op, _expr.Call) and op.func in rmap.keys():
buf = rmap[op.func]
return _make.Call(buf.dtype, buf.name, op.args, \
_expr.Call.Halide, buf.op, buf.value_index)
return None
return ir_pass.IRTransform(body, None, replace, ['Provide', 'Call'])
...@@ -10,12 +10,13 @@ class PyVariableUsage(ast.NodeVisitor): ...@@ -10,12 +10,13 @@ class PyVariableUsage(ast.NodeVisitor):
"""The vistor class to determine the declaration, r/w status, and last use of each variable""" """The vistor class to determine the declaration, r/w status, and last use of each variable"""
#pylint: disable=invalid-name #pylint: disable=invalid-name
#pylint: disable=missing-docstring #pylint: disable=missing-docstring
def __init__(self, args): def __init__(self, args, symbols):
self.status = {} self.status = {}
self.scope_level = [] self.scope_level = []
self._args = {} self._args = {}
self.args = args self.args = args
self.aug_assign_ = False self.aug_assign_ = False
self.symbols = symbols
def visit_FunctionDef(self, node): def visit_FunctionDef(self, node):
...@@ -43,7 +44,9 @@ class PyVariableUsage(ast.NodeVisitor): ...@@ -43,7 +44,9 @@ class PyVariableUsage(ast.NodeVisitor):
#No function pointer supported so far #No function pointer supported so far
_internal_assert(isinstance(node.func, ast.Name), "Function call should be an id") _internal_assert(isinstance(node.func, ast.Name), "Function call should be an id")
func_id = node.func.id func_id = node.func.id
_internal_assert(func_id in list(HYBRID_GLOBALS.keys()) + ['range', 'max', 'min'], \ _internal_assert(func_id in list(HYBRID_GLOBALS.keys()) + \
['range', 'max', 'min'] + \
list(self.symbols.keys()), \
"Function call id not in intrinsics' list") "Function call id not in intrinsics' list")
for elem in node.args: for elem in node.args:
self.visit(elem) self.visit(elem)
...@@ -75,11 +78,13 @@ class PyVariableUsage(ast.NodeVisitor): ...@@ -75,11 +78,13 @@ class PyVariableUsage(ast.NodeVisitor):
else: else:
decl, loop, usage = self.status[node.id] decl, loop, usage = self.status[node.id]
usage.add(type(node.ctx)) usage.add(type(node.ctx))
_internal_assert(loop in self.scope_level,
"%s is used out of the scope it is defined!" % node.id)
self.status[node.id] = (decl, loop, usage) self.status[node.id] = (decl, loop, usage)
def determine_variable_usage(root, args): def determine_variable_usage(root, args, symbols):
"""The helper function for calling the dedicated visitor.""" """The helper function for calling the dedicated visitor."""
visitor = PyVariableUsage(args) visitor = PyVariableUsage(args, symbols)
visitor.visit(root) visitor.visit(root)
return visitor.status return visitor.status
...@@ -270,7 +270,7 @@ def test_bind(): ...@@ -270,7 +270,7 @@ def test_bind():
return return
@script @script
def vec_add(a, b): def vec_add(a, b):
c = output_tensor((1000, ), dtype='float32') c = output_tensor((1000, ), 'float32')
for tx in bind('threadIdx.x', 1000): for tx in bind('threadIdx.x', 1000):
c[tx] = a[tx] + b[tx] c[tx] = a[tx] + b[tx]
return c return c
...@@ -506,7 +506,37 @@ def test_value_index(): ...@@ -506,7 +506,37 @@ def test_value_index():
module(tvm.ndarray.array(np_a), res) module(tvm.ndarray.array(np_a), res)
tvm.testing.assert_allclose(res.asnumpy(), ref) tvm.testing.assert_allclose(res.asnumpy(), ref)
def test_func_call():
@tvm.hybrid.script
def foo(a, b):
for i in range(10):
a[i] = i + 1.0
for i in range(10):
b[i] = i + 1.0
c = outer_product(10, 10, a, b)
d = output_tensor(c.shape, c.dtype)
for i in range(10):
for j in range(10):
d[i, j] = c[i, j] + i * j
return d
a = tvm.placeholder((10, ), name='a')
b = tvm.placeholder((10, ), name='b')
run_and_check(foo, [a, b])
def test_bool():
@tvm.hybrid.script
def foo(a):
b = output_tensor(a.shape, a.dtype)
b[0] = 1.2
for i in range(1, a.shape[0] - 1):
if a[i] * a[i - 1] < a[i] or a[i] * a[i - 1] < a[i - 1] or i * a[i] == a[i]:
b[i] = a[i]
else:
b[i] = 0.0
return b
a = tvm.placeholder((10, ), name='a')
run_and_check(foo, [a])
if __name__ == "__main__": if __name__ == "__main__":
test_outer_product() test_outer_product()
...@@ -521,7 +551,7 @@ if __name__ == "__main__": ...@@ -521,7 +551,7 @@ if __name__ == "__main__":
test_downstream() test_downstream()
test_const_param() test_const_param()
test_value_index() test_value_index()
test_func_call()
test_bool()
# TODO: # TODO:
# test_inplace() # test_inplace()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment