Commit 838e7181 by Jian Weng Committed by Tianqi Chen

[Hybrid Script] Inter-function call supported! (#2287)

parent 001ab525
......@@ -24,17 +24,15 @@ def script(pyfunc):
from .util import _enter_hybrid_runtime, _restore_runtime, _is_tvm_arg_types
if _is_tvm_arg_types(args):
src = _pruned_source(func)
parser = parse_python(src, args)
parser = parse_python(src, func.__globals__, args)
input_tensors = []
for i in args:
if isinstance(i, Tensor):
input_tensors.append(i)
op = _tvm_internal._HybridOp(parser.func_name, "HybridOp", None, input_tensors,
parser.outputs, parser.parsed_body)
res = [op.output(i) for i in range(len(parser.outputs))]
return res[0] if len(res) == 1 else res
intersect = _enter_hybrid_runtime(func)
......
"""Intrinsics of TVM-Python Hybrid Script for Python compilation time
semantic support."""
from .. import api as _api
from .. import expr as _expr
from .. import make as _make
from ..container import Array
from .. import ir_pass
from ..stmt import For
from .util import _internal_assert
#pylint: disable=redefined-builtin
LOOP_INTRIN = {
'range' : For.Serial,
'unroll' : For.Unrolled,
'parallel' : For.Parallel,
'vectorize': For.Vectorized,
}
def _range(annotation, args):
"""Handling TVM loop types"""
n = len(args)
if n == 1:
low, ext = _api.const(0, dtype='int32'), args[0]
else:
_internal_assert(n == 2, "A loop intrinsic should only have 1 or 2 arguments!")
low, ext = args[0], args[1]
if not ir_pass.Equal(low, _api.const(0, dtype='int32')):
ext = ext - low
for_type = LOOP_INTRIN[annotation]
iter_var = None
return iter_var, low, ext, for_type
range = unroll = vectorize = parallel = _range #pylint: disable=invalid-name
def bind(func_id, args):
"""Handling TVM thread binding"""
_internal_assert(func_id == "bind", "This function cannot be directly invoked!")
_internal_assert(len(args) == 2, "A loop bind should only have 2 arguments!")
_internal_assert(isinstance(args[0], str), \
"A loop bind's first argument should be a string!")
iter_var = _api.thread_axis(args[0])
low, ext = _api.const(0), args[1]
for_type = None
return iter_var, low, ext, for_type
def _math_intrin(func_id, args):
from .. import intrin
return getattr(intrin, func_id)(*args)
sqrt = log = exp = tanh = sigmoid = power = popcount = _math_intrin #pylint: disable=invalid-name
def _min_max(func_id, args):
_internal_assert(len(args) == 2, "Max/Min function should have 2 elements")
return getattr(_make, func_id.title())(args[0], args[1])
min = max = _min_max #pylint: disable=invalid-name
def _allocate_tensor(func_id, args):
"""Handling TVM tensor allocation.
You may refer hybrid.intrin.allocate for more details."""
n = len(args)
_internal_assert(isinstance(_api.convert(args[0]), Array), \
"allocate's first argument should be a tuple of shape!")
shape = args[0]
for i in shape:
_internal_assert(isinstance(i, _expr.Expr), "The shape should be an expression")
if n > 1:
_internal_assert(isinstance(args[1], str),
"The data type should be an str")
_internal_assert(args[1].startswith('int') or args[1].startswith('float'), \
"The data type should be either int or float!")
dtype = args[1]
else:
dtype = 'float32'
if n > 2:
_internal_assert(isinstance(args[2], str), \
"The data scope should be an string")
_internal_assert(func_id != 'output_tensor', "Output tensor cannot specify scope")
scope = args[2]
else:
scope = 'global' if func_id != 'output_tensor' else 'output'
return (shape, dtype, scope)
output_tensor = allocate = _allocate_tensor #pylint: disable=invalid-name
"""Intrinsics of TVM-Python Hybrid Script for Python runtime"""
"""Intrinsics of TVM-Python Hybrid Script for Python emulation runtime"""
import numpy
from ..stmt import For
class _range(object):
"""Base class of the loop ranges in hybrid script"""
......@@ -102,15 +101,3 @@ HYBRID_GLOBALS = {
'sigmoid' : sigmoid,
'popcount' : popcount
}
LOOP_INTRIN = {
'range' : For.Serial,
'unroll' : For.Unrolled,
'parallel' : For.Parallel,
'vectorize': For.Vectorized,
'bind' : None
}
MATH_INTRIN = ['sqrt', 'log', 'exp', 'tanh', 'sigmoid', 'power', 'popcount']
......@@ -10,6 +10,7 @@ from .._ffi.base import numeric_types
from .. import api as _api
from .. import make as _make
from .. import expr as _expr
from .. import stmt as _stmt
from ..tensor import Tensor
......@@ -86,3 +87,20 @@ def _restore_runtime(func, intersect):
_globals.pop(elem)
for k, v in intersect:
_globals[k] = v
def replace_io(body, rmap):
"""Replacing tensors usage according to the dict given"""
from .. import ir_pass
def replace(op):
if isinstance(op, _stmt.Provide) and op.func in rmap.keys():
buf = rmap[op.func]
return _make.Provide(buf.op, op.value_index, op.value, op.args)
elif isinstance(op, _expr.Call) and op.func in rmap.keys():
buf = rmap[op.func]
return _make.Call(buf.dtype, buf.name, op.args, \
_expr.Call.Halide, buf.op, buf.value_index)
return None
return ir_pass.IRTransform(body, None, replace, ['Provide', 'Call'])
......@@ -10,12 +10,13 @@ class PyVariableUsage(ast.NodeVisitor):
"""The vistor class to determine the declaration, r/w status, and last use of each variable"""
#pylint: disable=invalid-name
#pylint: disable=missing-docstring
def __init__(self, args):
def __init__(self, args, symbols):
self.status = {}
self.scope_level = []
self._args = {}
self.args = args
self.aug_assign_ = False
self.symbols = symbols
def visit_FunctionDef(self, node):
......@@ -43,7 +44,9 @@ class PyVariableUsage(ast.NodeVisitor):
#No function pointer supported so far
_internal_assert(isinstance(node.func, ast.Name), "Function call should be an id")
func_id = node.func.id
_internal_assert(func_id in list(HYBRID_GLOBALS.keys()) + ['range', 'max', 'min'], \
_internal_assert(func_id in list(HYBRID_GLOBALS.keys()) + \
['range', 'max', 'min'] + \
list(self.symbols.keys()), \
"Function call id not in intrinsics' list")
for elem in node.args:
self.visit(elem)
......@@ -75,11 +78,13 @@ class PyVariableUsage(ast.NodeVisitor):
else:
decl, loop, usage = self.status[node.id]
usage.add(type(node.ctx))
_internal_assert(loop in self.scope_level,
"%s is used out of the scope it is defined!" % node.id)
self.status[node.id] = (decl, loop, usage)
def determine_variable_usage(root, args):
def determine_variable_usage(root, args, symbols):
"""The helper function for calling the dedicated visitor."""
visitor = PyVariableUsage(args)
visitor = PyVariableUsage(args, symbols)
visitor.visit(root)
return visitor.status
......@@ -270,7 +270,7 @@ def test_bind():
return
@script
def vec_add(a, b):
c = output_tensor((1000, ), dtype='float32')
c = output_tensor((1000, ), 'float32')
for tx in bind('threadIdx.x', 1000):
c[tx] = a[tx] + b[tx]
return c
......@@ -506,7 +506,37 @@ def test_value_index():
module(tvm.ndarray.array(np_a), res)
tvm.testing.assert_allclose(res.asnumpy(), ref)
def test_func_call():
@tvm.hybrid.script
def foo(a, b):
for i in range(10):
a[i] = i + 1.0
for i in range(10):
b[i] = i + 1.0
c = outer_product(10, 10, a, b)
d = output_tensor(c.shape, c.dtype)
for i in range(10):
for j in range(10):
d[i, j] = c[i, j] + i * j
return d
a = tvm.placeholder((10, ), name='a')
b = tvm.placeholder((10, ), name='b')
run_and_check(foo, [a, b])
def test_bool():
@tvm.hybrid.script
def foo(a):
b = output_tensor(a.shape, a.dtype)
b[0] = 1.2
for i in range(1, a.shape[0] - 1):
if a[i] * a[i - 1] < a[i] or a[i] * a[i - 1] < a[i - 1] or i * a[i] == a[i]:
b[i] = a[i]
else:
b[i] = 0.0
return b
a = tvm.placeholder((10, ), name='a')
run_and_check(foo, [a])
if __name__ == "__main__":
test_outer_product()
......@@ -521,7 +551,7 @@ if __name__ == "__main__":
test_downstream()
test_const_param()
test_value_index()
test_func_call()
test_bool()
# TODO:
# test_inplace()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment