Commit 90db723d by Jian Weng Committed by Tianqi Chen

[FRONTEND] A Python hybrid frontend (#1251)

parent a55bc290
tvm.hybrid
----------
.. automodule:: tvm.hybrid
.. autosummary::
tvm.hybrid.parse
tvm.hybrid.script
tvm.hybrid.popcount
tvm.hybrid.sigmoid
.. autofunction:: tvm.hybrid.parse
.. autofunction:: tvm.hybrid.script
.. autofunction:: tvm.hybrid.popcount
.. autofunction:: tvm.hybrid.sigmoid
......@@ -21,3 +21,4 @@ Python API
dev
topi
nnvm/index
hybrid
Hybrid Frontend Developer Guide
===============================
If you are a developer:
1. who is trying writing some preliminary patterns that have not been supported by TVM yet,
maybe :ref:`hybrid-langref-label` is a better place for you.
2. who wants to know the implementing details of this module, you are right here!
Features
--------
Software emulation
~~~~~~~~~~~~~~~~~~
In software emulation, the most intresting thing is the decorator ``tvm.hybrid.script``.
This decorator helps 2 things:
1. Importing runtime variables
2. Overload the function according to the arguments passed
Correct me if I am wrong: I believe that how 1. is implemented is dangerous, but I have no
choice. What I did is add those names into python dict ``func.__global__`` and after
the call to ``func`` is done, those names will be cleaned up.
Overload is simple: the decorator checks the arguments' types and determines which function
should be actually called.
Backend Compilation
~~~~~~~~~~~~~~~~~~~
Compilation is a large module, you can see ``python/tvm/hybrid/var_decl.py`` and
``python/tvm/hybrid/parser.py`` for more details. The first stage determines the
usage, or more accurately the declaration of each variable and the second stage does
the actual IR generation.
Attributes
~~~~~~~~~~
So far, ONLY tensors' `shape` attribute is supported. You can see ``visit_Subscript``
in ``python/tvm/hybrid/parser.py`` for more details. This is a hacky solution, I just
check the attributes when subscript.
Loops
~~~~~
In HalideIR, loops have in total 4 types: ``serial``, ``unrolled``, ``parallel``, and ``vectorized``.
.. note::
Unlike what that is in HalideIR, in ``loop_type(a, b)``, ``a`` is the starting point and ``b``
is the trip count of iterations. Here ``loop_type(a, b)`` indicates ``[a, b)``. Thus, when lowering it
to HalideIR, we need to do ``start, extent = a, b - a``
.. note::
In HalideIR those are enums, they are in passive form.
Here we use active form to annotate loops, because they are ready to run.
Variables
~~~~~~~~~
Because there is no variables in ``HalideIR``, all the mutatable variables will be lowered to an array with size 1.
It takes the first store of a variable as its declaration.
Math intrinsics
~~~~~~~~~~~~~~~
So far, these math intrinsics, ``log``, ``exp``, ``sigmoid``, ``tanh``, ``power``, and ``popcount``, are supported.
Math intrinsics will be imported by the decorator. Most of the intrinsics are borrowed by library implementation
except ``popcount`` and ``sigmoid``. I implemented them manually.
......@@ -10,3 +10,4 @@ In this part of documentation, we share the rationale for the specific choices m
runtime
nnvm_json_spec
nnvm_overview
hybrid_script
.. _hybrid-langref-label:
Hybrid Frontend Language Reference
==================================
Overview
--------
This hybrid frontend allows users to write preliminary versions of some idioms that yet have
been supported by TVM officially.
Features
--------
Software Emulation
~~~~~~~~~~~~~~~~~~
Both software emulation and compilation are supported. To define a function,
you need to use ``tvm.hybrid.script`` decorator to indicate this is a hybrid function:
.. code-block:: python
@tvm.hybrid.script
def outer_product(a, b, c):
for i in range(a.shape[0]):
for j in range(b.shape[0]):
c[i, j] = a[i] * b[j]
a = numpy.random.rand(100)
b = numpy.random.rand(99)
c = numpy.zeros((100, 99))
outer_product(a, b, c)
This decorator will import `Keywords`_ required spontaneously when software emulation.
After software emulation is done, the imported keywords will be cleaned up. Users do not need
worry about keyword conflict and pollution.
Every element passed for software emulation in the argument list is either a python variable
or ``numpy`` numeric type.
Backend Compilation
~~~~~~~~~~~~~~~~~~~
The current parse interface looks like:
.. code-block:: python
a = tvm.placeholder((100, ), name='a')
b = tvm.placeholder((99, ), name='b')
c = tvm.placeholder((100, 99), name='c')
tvm.hybrid.parse(outer_product, [a, b, c]) # return an ir root of this function
If we pass these tvm tensors to this function, it returns a op node:
**Under construction, we are still deciding what kind of node should be returned.**
.. code-block:: python
a = tvm.placeholder((100, ), name='a')
b = tvm.placeholder((99, ), name='b')
c = tvm.placeholder((100, 99), name='c')
op = outer_product(a, b, c) # return the corresponding op node
Tuning
~~~~~~
**Under construction, not truly supported yet.**
Follow up the example above, you can use some tvm like interfaces to tune the code:
.. code-block:: python
sch = tvm.create_schedule(op)
jo, ji = sch.split(j, 4)
sch.vectorize(ji)
``split``, ``reorder``, and loop_annotation will be supported!
Loops
~~~~~
In HalideIR, loops have in total 4 types: ``serial``, ``unrolled``, ``parallel``, and ``vectorized``.
Here we use ``range`` aka ``serial``, ``unroll``, ``parallel``, and ``vectorize``,
these **4** keywords to annotate the corresponding types of for loops.
The the usage is roughly the same as Python standard ``range``.
Variables
~~~~~~~~~
All the mutatable variables will be lowered to an array with size 1.
It regards the first store of a variable as its declaration.
.. note::
Unlike conventional Python, in hybrid script, the declared variable
can only be used in the scope level it is declared.
.. note::
Currently, you can ONLY use basic-typed variables, i.e. the type of the
variable should be either ``float32``, or ``int32``.
.. code-block:: python
for i in range(5):
s = 0 # declaration, this s will be a 1-array in lowered IR
for j in range(5):
s += a[i, j] # do something with sum
b[i] = sum # you can still use sum in this level
a[0] = s # you CANNOT use s here, even though it is allowed in conventional Python
b = (1, 2) # this has NOT been supported yet!
Attributes
~~~~~~~~~~
So far, ONLY tensors' ``shape`` attribute is supported! The ``shape`` atrribute is essentailly a
tuple, so you MUST access it as an array. Also, currently, only constant-indexed access is supported.
.. code-block:: python
x = a.shape[2] # OK!
for i in range(3):
for j in a.shape[i]: # BAD! i is not a constant!
# do something
Conditional Statement and Expression
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. code-block:: python
if condition:
# do something
a = b if condition else c
However, NO ``True`` and ``False`` keyword supported yet.
Math Intrinsics
~~~~~~~~~~~~~~~
So far, these math intrinsics, ``log``, ``exp``, ``sigmoid``,
``tanh``, ``power``, and ``popcount``, are supported.
No import is required, just as it is mentioned in `Software Emulation`_, just use it!
Array Allocation
~~~~~~~~~~~~~~~~
**Under construction, this function will be supported later!**
Use a function call ``allocation(shape, type, share/local)`` to declare an array buffer.
The basic usage is roughly the same as a normal array.
Thread Bind
~~~~~~~~~~~
You can also do loop-thread bind by writing code like this:
.. code-block:: python
for tx in bind("threadIdx.x", 100):
a[tx] = b[tx]
Keywords
~~~~~~~~
- For keywords: ``serial``, ``range``, ``unroll``, ``parallel``, ``vectorize``, ``bind``
- Math keywords: ``log``, ``exp``, ``sigmoid``, ``tanh``, ``power``, ``popcount``
......@@ -2,3 +2,8 @@ Language Reference
==================
This document provide references to
embedded languages in TVM stack.
.. toctree::
:maxdepth: 2
hybrid_script
......@@ -332,12 +332,20 @@ def lower(sch,
lower_phase1 = [x[1] for x in add_lower_pass if x[0] == 1]
lower_phase2 = [x[1] for x in add_lower_pass if x[0] == 2]
lower_phase3 = [x[1] for x in add_lower_pass if x[0] > 2]
# normalize schedule first
sch = sch.normalize()
# Phase 0
bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds)
stmt = ir_pass.InjectPrefetch(stmt)
if isinstance(sch, schedule.Schedule):
# normalize schedule first
sch = sch.normalize()
bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds)
stmt = ir_pass.InjectPrefetch(stmt)
else:
#So far there is no op for hybrid script, so a plain ir body is given
if not isinstance(sch, _stmt.Stmt):
raise ValueError("sch should be either a Schedule or a Stmt")
stmt = sch
for f in lower_phase0:
stmt = f(stmt)
# Phase 1
......
"""Hybrid Programming APIs of TVM Python Package.
This package maps a subset of python to HalideIR so that:
1. Users can write some preliminary versions of the computation patterns
have not been supported yet and verify it across the real execution and
python semantic emulation.
2. Developers can build HalideIR by writing Python code.
"""
from .api import script, parse
"""APIs of lowering the Python subset to HalideIR"""
from __future__ import absolute_import as _abs
import types
import decorator
from .parser import parse_python
@decorator.decorator
def script(func, *args):
"""If the arguments are tvm types, compile it to HalideIR.
O.W. return the python emulated result"""
from .util import _enter_hybrid_runtime, _restore_runtime, _is_tvm_arg_types
if _is_tvm_arg_types(args):
return parse(func, args)
else:
intersect = _enter_hybrid_runtime(func)
func(*args)
_restore_runtime(func, intersect)
return func
def parse(func, args):
"""Parse a subset of Python to HalideIR
Parameters
----------
func : str or types.FunctionType
If it is a string, parse the source code
If it is a function, parse the function
args : list of Buffer or Tensor or Var
The argument lists to the function.
Leave it None if no buffer is related to the function to be parsed
Returns
-------
root : Stmt
The result Halide IR and the parser class instance.
"""
from .util import _pruned_source
if isinstance(func, str):
src = func
else:
assert isinstance(func, types.FunctionType)
src = _pruned_source(func)
return parse_python(src, args)
"""Intrinsics of TVM-Python Hybrid Script for Python runtime"""
import numpy
from ..stmt import For
class _range(object):
"""Base class of the loop ranges in hybrid script"""
def __init__(self, a, b=None):
if b is None:
self.low = 0
self.ext = a
else:
self.low = a
self.ext = b
def __iter__(self):
i = 0
while i < self.ext:
yield i + self.low
i += 1
class bind(_range): #pylint: disable=invalid-name
def __init__(self, tag, ext):
super(bind, self).__init__(ext)
self.tag = tag
unroll = vectorize = parallel = _range #pylint: disable=invalid-name
def allocate(shape, dtype='float32'):
"""Allocate a buffer with given shape
Parameters
----------
shape: Tuple
The shape of the tensor to be allocated
dtype: string
The data type of the tensor
Returns
-------
tensor: numpy.array
The tensor allocated
"""
return numpy.zeros(shape).astype(dtype)
def popcount(x):
"""
Count ones in the binary representation of number x
Parameters
----------
x: Integer
The number to be counted
Returns
-------
cnt: Integer
The number of ones in the binary representation of number x
"""
cnt = 0
while x:
x -= x & -x
cnt += 1
return cnt
def sigmoid(x):
"""
Sigmoid function of x, aka 1/(1+exp(-x)).
Parameters
----------
x: a real number
Returns
-------
res: a real number
The result of sigmoid function
"""
return 1 / (1 + numpy.exp(-x))
HYBRID_GLOBALS = {
'unroll' : unroll,
'vectorize' : vectorize,
'parallel' : parallel,
'allocate' : allocate,
'bind' : bind,
'sqrt' : numpy.sqrt,
'log' : numpy.log,
'tanh' : numpy.tanh,
'power' : numpy.power,
'exp' : numpy.exp,
'sigmoid' : sigmoid,
'popcount' : popcount
}
LOOP_INTRIN = {
'range' : For.Serial,
'unroll' : For.Unrolled,
'parallel' : For.Parallel,
'vectorize': For.Vectorized,
'bind' : None
}
MATH_INTRIN = ['sqrt', 'log', 'exp', 'tanh', 'sigmoid', 'power', 'popcount']
"""Hybrid Script Parser"""
import ast
import operator
import sys
from .util import make_nop, make_const_true, make_range_one, halide_imm_types
from .intrin import LOOP_INTRIN, MATH_INTRIN
from .var_decl import determine_variable_usage
from ..api import thread_axis
from .. import expr as _expr
from .. import make as _make
from .. import intrin
from .. import api as _api
from .. import ir_pass as _ir_pass
def list_to_block(visit, lst):
"""Convert a list of Python IR nodes to HalideIR Block"""
lst = list(map(visit, lst))
lst = [stmt for stmt in lst if not _ir_pass.Equal(stmt, make_nop())]
if not lst:
return make_nop()
if len(lst) == 1:
return lst[0]
body = lst[0]
for i in lst[1:]:
body = _make.Block(body, i)
return body
class HybridParser(ast.NodeVisitor):
"""Python AST visitor pass which finally lowers it to HalideIR"""
_binop_maker = {
ast.Add : operator.add,
ast.Sub : operator.sub,
ast.Mult : operator.mul,
ast.Div : _make.Div,
ast.Mod : operator.mod,
ast.BitOr : operator.or_,
ast.BitAnd: operator.and_,
ast.BitXor: operator.xor,
ast.Gt : operator.gt,
ast.GtE : operator.ge,
ast.Lt : operator.lt,
ast.LtE : operator.le,
ast.Eq : operator.eq,
ast.NotEq : operator.ne,
}
_unaryop_maker = {
ast.USub : operator.neg,
ast.Invert : operator.invert,
ast.Not : operator.not_
}
def __init__(self, args, usage, func_name=None):
"""
Parameters
----------
args: A list of tvm.placeholder or tvm.var
Provided by the user, the argument list of the function to be lowered.
usage: A dict of variables used in last in this function
Provided by last lower pass, which collects this information
Returns
-------
func_name: str
The name of the function to be lowered; if not provided,
the compiler will use the name in the AST
"""
self.args = args[:]
self.usage = usage.copy()
self._args = {} # Dict maps arg name to actual arg instance (either a var or a buffer)
self.buffers = {}
self.loops_above = {} # State variable that indicates loop levels above the current node
self.var_consts = {} # Variables that are determined as readonly in previous stage
self.func_name = func_name # The name of the function to be lowered
self.iter_axis = []
def wrap_up_realize(self, node, body):
"""Wrap up all the variables which will no longer be used"""
for key, val in self.usage.items():
if key in self.var_consts.keys():
continue
_, scope, _ = val
if scope == node:
_buf = self.buffers[key]
_dtype = _buf.dtype
_one = make_range_one()
_true = make_const_true()
body = _make.Realize(_buf.op, 0, _dtype, [_one], _true, body)
return body
def _check_id_a_buffer(self, s):
if s not in self._args.keys():
raise ValueError("This %s is expected to be in argument list or allocated buffer!" % s)
#pylint: disable=invalid-name, missing-docstring
def visit_Module(self, node):
if len(node.body) != 1:
raise ValueError("Only one-function source code can be fed to this parser!")
return self.visit(node.body[0])
def visit_FunctionDef(self, node):
if len(node.args.args) != len(self.args):
raise ValueError("The number of arguments passed to the function\
should be the same as it is defined!")
for idx, arg in enumerate(node.args.args):
_attr = 'id' if sys.version_info[0] < 3 else 'arg' # To make py2 and 3 compatible
self._args[getattr(arg, _attr)] = self.args[idx]
res = list_to_block(self.visit, node.body)
res = self.wrap_up_realize(node, res)
if self.func_name is None:
self.func_name = node.name
return res
def visit_Expr(self, node):
return self.visit(node.value)
def visit_Name(self, node):
_id = node.id
if _id in self._args.keys() and isinstance(self._args[_id], _expr.Var):
return self._args[_id]
elif _id in self.loops_above.keys():
return self.loops_above[_id]
if _id in self._args.keys():
raise ValueError("This id %s should be handled in visit_Subscript!" % _id)
if _id not in self.usage.keys():
raise ValueError("This id %s is expected to be a defined variable!" % _id)
# Buffer
if _id in self.buffers.keys():
_buf = self.buffers[_id]
return _make.Call(_buf.dtype, _id, [_api.const(0)], _expr.Call.Halide, _buf.op, 0)
# Compilation time constant
if _id not in self.var_consts.keys():
raise ValueError("This id %s is expected to a compilation time constant!" % _id)
return self.var_consts[_id]
def visit_Num(self, node):
return _api.const(node.n)
def visit_Assign(self, node):
if len(node.targets) != 1:
raise ValueError("So far only one-valued assignment is supported!")
lhs = node.targets[0]
rhs = _ir_pass.Simplify(self.visit(node.value))
if isinstance(lhs, ast.Name):
#TODO: support defined intermediate buffer later
lhs_ = lhs
lhs = lhs.id
if lhs in self.loops_above.keys():
raise ValueError("You CAN NEVER overwrite a loop variable!")
decl, _, rw = self.usage[lhs]
if decl == lhs_:
if lhs in self.var_consts.keys():
raise ValueError("BUG: A constant cannot be overwritten!")
if lhs in self.buffers.keys():
raise ValueError("BUG: This value should not be defined before this point!")
if isinstance(rhs, halide_imm_types) and ast.Store not in rw:
self.var_consts[lhs] = rhs
else:
self.buffers[lhs] = _api.placeholder((1, ), dtype=rhs.dtype, name=lhs)
if lhs in self.var_consts.keys():
return make_nop()
else:
if lhs not in self.buffers.keys():
raise ValueError("BUG: This value should be defined before!")
return _make.Provide(self.buffers[lhs].op, 0, rhs, [_api.const(0, dtype=rhs.dtype)])
else:
lhs = self.visit(lhs)
if not isinstance(lhs, _expr.Call):
raise ValueError("An array access's LHS is expected to be a expr.Call!")
#TODO: support slice later
self._check_id_a_buffer(lhs.name)
return _make.Provide(self._args[lhs.name].op, 0, rhs, lhs.args)
def visit_Index(self, node):
if isinstance(node.value, ast.Tuple):
return [self.visit(i) for i in node.value.elts]
return [self.visit(node.value)]
def visit_Subscript(self, node):
args = self.visit(node.slice)
if isinstance(node.value, ast.Name):
array = node.value.id
self._check_id_a_buffer(array)
_buf = self._args[array]
return _make.Call(_buf.dtype, array, args, _expr.Call.Halide, _buf.op, 0)
elif isinstance(node.value, ast.Attribute):
if not isinstance(node.value.value, ast.Name):
raise ValueError("The root of array access is expect to be a id!")
if node.value.attr != "shape":
raise ValueError("Attribute access so far only 'shape' is supported!")
if len(args) != 1:
raise ValueError("For 'shape' access the argument should be only one!")
args = args[0]
#TODO: maybe support non-constant value later?
if not isinstance(args, (_expr.IntImm, _expr.UIntImm)):
raise ValueError("So far only constant shape access supported!")
self._check_id_a_buffer(node.value.value.id)
return self._args[node.value.value.id].shape[args.value]
else:
raise ValueError("Not supported yet!")
def visit_With(self, node):
if sys.version_info[0] < 3:
context = node.context_expr
option = node.optional_vars
else:
if len(node.items) != 1:
raise ValueError("Only one with element is supported so far!")
context = node.items[0].context_expr
option = node.items[0].optional_vars
if not isinstance(context, ast.Call):
raise ValueError("The object must be a Python function call!")
if not isinstance(option, ast.Name):
raise ValueError("The object after 'as' must be an id!")
self.annotation[option.id] = context.func.id
return list_to_block(self.visit, node.body)
def visit_If(self, node):
cond = self.visit(node.test)
if_body = list_to_block(self.visit, node.body)
if node.orelse:
else_body = list_to_block(self.visit, node.orelse)
else:
else_body = make_nop()
return _make.IfThenElse(cond, if_body, else_body)
def visit_IfExp(self, node):
cond = self.visit(node.test)
if_body = self.visit(node.body)
else_body = self.visit(node.orelse)
return _make.Select(cond, if_body, else_body)
def visit_Compare(self, node):
lhs = self.visit(node.left)
if len(node.ops) != 1:
raise ValueError("Only one compare op is supported!")
if len(node.comparators) != 1:
raise ValueError("Only one comparator is supported!")
rhs = self.visit(node.comparators[0])
return HybridParser._binop_maker[type(node.ops[0])](lhs, rhs)
def visit_UnaryOp(self, node):
operand = self.visit(node.operand)
return HybridParser._unaryop_maker[type(node.op)](operand)
def visit_BinOp(self, node):
lhs = self.visit(node.left)
rhs = self.visit(node.right)
return HybridParser._binop_maker[type(node.op)](lhs, rhs)
def visit_Call(self, node):
# Yet, no function pointer supported
if not isinstance(node.func, ast.Name):
raise ValueError("Only id-function function call is supported so far!")
func_id = node.func.id
n = len(node.args)
if func_id in LOOP_INTRIN.keys() and func_id != 'bind':
if n == 1:
low, ext = _api.const(0, dtype='int32'), self.visit(node.args[0])
else:
if n != 2:
raise ValueError("A loop intrinsic should only have 1 or 2 arguments!")
low, ext = self.visit(node.args[0]), self.visit(node.args[1])
if not _ir_pass.Equal(low, _api.const(0, dtype='int32')):
ext = ext - low
for_type = LOOP_INTRIN[func_id]
iter_var = None
return iter_var, low, ext, for_type
elif func_id == 'bind':
if n != 2:
raise ValueError("A loop bind should only have 2 arguments!")
if not isinstance(node.args[0], ast.Str):
raise ValueError("A loop bind's first argument should be a string!")
_vn = node.args[0].s
iter_var = thread_axis(node.args[0].s)
low, ext = _api.const(0, dtype='int32'), self.visit(node.args[1])
for_type = None
return iter_var, low, ext, for_type
elif func_id in MATH_INTRIN:
return getattr(intrin, func_id)(*[self.visit(arg) for arg in node.args])
elif func_id == 'allocate':
#TODO: Support it later!
return make_nop()
else:
raise ValueError("Function call not supported yet!")
def visit_For(self, node):
iter_var, low, ext, for_type = self.visit(node.iter)
if not isinstance(node.target, ast.Name):
raise ValueError("The loop iterator should be a variable!")
_name = node.target.id
if iter_var is None:
if for_type is None:
raise ValueError("The loop bind function parse error!")
iter_var = _api.var(_name)
self.loops_above[_name] = iter_var
else:
if for_type is not None:
raise ValueError("The loop iterating function parse error!")
self.loops_above[_name] = iter_var.var
_body = list_to_block(self.visit, node.body)
_body = self.wrap_up_realize(node, _body)
if for_type is None:
res = _make.AttrStmt(iter_var, 'thread_extent', ext, _body)
else:
res = _make.For(iter_var, low, ext, for_type, 0, _body)
self.loops_above.pop(_name)
return res
def parse_python(src, args):
"""The helper function of calling the AST visitor"""
root = ast.parse(src)
var_usage = determine_variable_usage(root, args)
parser = HybridParser(args, var_usage)
halide_ir = parser.visit(root)
return halide_ir
"""Internal utilities for parsing Python subset to HalideIR"""
import inspect
import numpy
from .intrin import HYBRID_GLOBALS
from .._ffi.base import numeric_types
from .. import api as _api
from .. import make as _make
from .. import expr as _expr
from ..tensor import Tensor
#pylint: disable=invalid-name
np_arg_types = tuple(list(numeric_types) + [numpy.ndarray])
tvm_arg_types = (Tensor, _expr.Var)
halide_imm_types = (_expr.IntImm, _expr.FloatImm, _expr.UIntImm)
# Useful constants. In avoid of runtime dependences, we use function calls to return them.
def make_nop():
"""Returns a 'no operation' node in HalideIR."""
return _make.Evaluate(_api.const(0, dtype='int32'))
def make_range_one():
"""Returns a [0, 1] range node in HalideIR."""
return _make.range_by_min_extent(0, 1)
def make_const_true():
"""Returns a constant True node in HalideIR."""
return _api.convert(True)
def _pruned_source(func):
"""Prune source code's extra leading spaces"""
lines = inspect.getsource(func).split('\n')
leading_space = len(lines[0]) - len(lines[0].lstrip(' '))
lines = [line[leading_space:] for line in lines]
return '\n'.join(lines)
def _is_tvm_arg_types(args):
"""Determine a list of element is either a list of tvm arguments of a list of numpy arguments.
If neither is true, raise a value error."""
if isinstance(args[0], tvm_arg_types):
for elem in args[1:]:
if not isinstance(elem, tvm_arg_types):
raise ValueError("Expect a Var or Tensor instance but % get!" % str(type(elem)))
return True
if not isinstance(args[0], np_arg_types):
raise ValueError("Expect a numpy type but % get!" % str(type(args[0])))
for elem in args[1:]:
if not isinstance(elem, np_arg_types):
raise ValueError("Expect a numpy type but % get!" % str(type(elem)))
return False
def _enter_hybrid_runtime(func):
"""Put hybrid runtime variables into the global scope"""
_globals = func.__globals__
intersect = []
for elem in list(HYBRID_GLOBALS.keys()):
if elem in _globals.keys():
intersect.append((elem, _globals[elem]))
_globals[elem] = HYBRID_GLOBALS[elem]
return intersect
def _restore_runtime(func, intersect):
"""Rollback the modification caused by hybrid runtime"""
_globals = func.__globals__
for elem in list(HYBRID_GLOBALS.keys()):
_globals.pop(elem)
for k, v in intersect:
_globals[k] = v
"""Determines the declaration, r/w status, and last use of each variable"""
import ast
import sys
from .intrin import HYBRID_GLOBALS
class PyVariableUsage(ast.NodeVisitor):
"""The vistor class to determine the declaration, r/w status, and last use of each variable"""
#pylint: disable=invalid-name
#pylint: disable=missing-docstring
def __init__(self, args):
self.status = {}
self.scope_level = []
self._args = {}
self.args = args
def visit_FunctionDef(self, node):
self.scope_level.append(node)
if len(node.args.args) != len(self.args):
raise ValueError('#arguments passed should be the same as #arguments defined')
for idx, arg in enumerate(node.args.args):
_attr = 'id' if sys.version_info[0] < 3 else 'arg' # To make py2 and 3 compatible
self._args[getattr(arg, _attr)] = self.args[idx]
for i in node.body:
self.visit(i)
def visit_For(self, node):
if not isinstance(node.target, ast.Name):
raise ValueError("For's iterator should be an id")
self.visit(node.iter)
self.scope_level.append(node)
for i in node.body:
self.visit(i)
self.scope_level.pop()
def visit_Call(self, node):
#No function pointer supported so far
if not isinstance(node.func, ast.Name):
raise ValueError("Function call should be an id")
if (node.func.id not in HYBRID_GLOBALS.keys()) and node.func.id != 'range':
raise ValueError("Function call id not in intrinsics' list")
for elem in node.args:
self.visit(elem)
def visit_Name(self, node):
# If it is from the argument list or loop variable, we do not worry about it!
if node.id in self._args.keys():
return
fors = [loop.target.id for loop in self.scope_level if isinstance(loop, ast.For)]
if node.id in fors:
return
# The loop variable cannot be overwritten when iteration
if isinstance(node.ctx, ast.Store) and node.id in fors:
raise ValueError("Iter var cannot be overwritten")
if node.id not in self.status.keys():
if not isinstance(node.ctx, ast.Store):
raise ValueError('In Python, "first store" indicates "declaration"')
self.status[node.id] = (node, self.scope_level[-1], set())
else:
decl, loop, usage = self.status[node.id]
loop = self.scope_level[-1]
usage.add(type(node.ctx))
self.status[node.id] = (decl, loop, usage)
def determine_variable_usage(root, args):
"""The helper function for calling the dedicated visitor."""
visitor = PyVariableUsage(args)
visitor.visit(root)
return visitor.status
import tvm, inspect, sys, traceback, numpy
from tvm.hybrid import script
from tvm.hybrid.intrin import HYBRID_GLOBALS
@script
def outer_product(n, m, a, b, c):
for i in range(n):
for j in range(m):
c[i, j] = a[i] * b[j]
#Test global function
#Test bridge between frontend and backend
def test_outer_product():
n = tvm.var('n')
m = tvm.var('m')
a = tvm.placeholder((n, ), name='a')
b = tvm.placeholder((m, ), name='b')
c = tvm.placeholder((n, m), name='c')
ir = outer_product(n, m, a, b, c)
#Check for i in (0, n)
assert isinstance(ir, tvm.stmt.For)
assert ir.loop_var.name == 'i'
assert ir.min.value == 0
assert ir.extent.name == 'n'
ibody = ir.body
assert isinstance(ibody, tvm.stmt.For)
#Check for j in (0, m)
assert ibody.loop_var.name == 'j'
assert ibody.min.value == 0
assert ibody.extent.name == 'm'
#Check loop body
jbody = ibody.body
assert isinstance(jbody, tvm.stmt.Provide)
assert jbody.func.name == 'c'
assert len(jbody.args) == 2
assert jbody.args[0].name == 'i'
assert jbody.args[1].name == 'j'
assert isinstance(jbody.value, tvm.expr.Mul)
mul = jbody.value
assert isinstance(mul.a, tvm.expr.Call)
assert mul.a.name == 'a'
assert mul.b.name == 'b'
func = tvm.lower(ir, [n, m, a, b, c])
func = tvm.build(func)
_n = 999
_m = 1001
_a = numpy.random.rand(_n).astype('float32')
_b = numpy.random.rand(_m).astype('float32')
c_python = numpy.zeros((_n, _m), dtype='float32')
outer_product(_n, _m, _a, _b, c_python)
tvm_a = tvm.ndarray.array(_a)
tvm_b = tvm.ndarray.array(_b)
tvm_c = tvm.ndarray.array(numpy.zeros((_n, _m), dtype='float32'))
func(_n, _m, tvm_a, tvm_b, tvm_c)
numpy.testing.assert_allclose(tvm_c.asnumpy(), c_python, rtol=1e-5)
for key, _ in HYBRID_GLOBALS.items():
assert key not in globals().keys()
assert key not in outer_product.__globals__.keys()
#Test local function
#Test allocation of local variable
def test_fanout():
@script
def fanout(n, a, b):
three = 3.0
for i in range(a.shape[0] - 3):
sigma = 0.0
for j in range(3):
sigma = sigma + a[i + j]
sigma = sigma / three
b[i] = sigma
n = tvm.var('n')
a = tvm.placeholder((n, ), name='a')
b = tvm.placeholder((n-3, ), name='b')
ir = fanout(n, a, b)
#Check for i in (0, n-3)
assert isinstance(ir, tvm.stmt.For)
assert ir.loop_var.name == 'i'
assert ir.min.value == 0
assert tvm.ir_pass.Equal(ir.extent, n - 3)
#Check loopbody
ibody = ir.body
assert isinstance(ibody, tvm.stmt.Realize)
assert ibody.bounds[0].min.value == 0
assert ibody.bounds[0].extent.value == 1
assert ibody.func.name == 'sigma'
#Check i loop body
rbody = ibody.body
assert isinstance(rbody.first, tvm.stmt.Provide)
assert rbody.first.func.name == 'sigma'
assert len(rbody.first.args) == 1
assert rbody.first.args[0].value == 0
#Check fanout loop
jloop = rbody.rest.first
assert jloop.loop_var.name == 'j'
assert jloop.min.value == 0
assert jloop.extent.value == 3
jbody = jloop.body
assert isinstance(jbody, tvm.stmt.Provide)
assert len(jbody.args) == 1
assert jbody.args[0].value == 0
assert jbody.func.name == 'sigma'
assert isinstance(jbody.value, tvm.expr.Add)
value = jbody.value
assert isinstance(value.a, tvm.expr.Call)
assert value.a.name == 'sigma'
assert len(value.a.args) == 1
assert value.a.args[0].value == 0
assert value.b.name == 'a'
assert len(value.b.args) == 1
assert tvm.ir_pass.Equal(value.b.args[0], ir.loop_var + jloop.loop_var)
divide= rbody.rest.rest.first
assert isinstance(divide, tvm.stmt.Provide)
assert len(divide.args) == 1
assert divide.args[0].value == 0
value = divide.value
assert isinstance(value, tvm.expr.Mul)
assert value.a.name == 'sigma'
assert len(value.a.args) == 1
assert value.a.args[0].value == 0
assert abs(value.b.value - (1 / 3.0)) < 1e-5
write = rbody.rest.rest.rest
assert isinstance(write, tvm.stmt.Provide)
assert write.func.name == 'b'
assert write.value.name == 'sigma'
assert len(write.value.args) == 1
assert write.value.args[0].value == 0
@script
def failure():
for i in range(1, 100):
i = 0
def test_failure():
try:
tvm.hybrid.parse(failure, [])
except IOError as err:
assert sys.version_info[0] == 2
print('[Warning] Python2 cannot do the failure case because "%s"' % str(err))
except Exception as err:
assert str(err) == 'You CAN NEVER overwrite a loop variable!'
def test_looptype():
@script
def looptype(a):
for i in parallel(6):
a[i] = i
for j in vectorize(6):
a[j] = j
for k in unroll(6):
a[k] = k
a = tvm.placeholder((6, ), name='a')
ir = looptype(a)
iloop = ir.first
jloop = ir.rest.first
kloop = ir.rest.rest
assert iloop.for_type == tvm.stmt.For.Parallel
assert jloop.for_type == tvm.stmt.For.Vectorized
assert kloop.for_type == tvm.stmt.For.Unrolled
def test_if():
@script
def if_then_else(a, b):
for i in range(10):
if i % 2 == 0:
a[i] = -1
else:
a[i] = 1
for i in unroll(10):
b[i] = -1 if i % 2 == 0 else 1
a = tvm.placeholder((10, ), dtype='int32', name='a')
b = tvm.placeholder((10, ), dtype='int32', name='b')
ir = if_then_else(a, b)
func = tvm.lower(ir, [a, b])
func = tvm.build(func)
assert func
_a = numpy.zeros((10, ), dtype = 'int32')
_b = numpy.zeros((10, ), dtype = 'int32')
if_then_else(_a, _b)
tvm_a = tvm.ndarray.array(numpy.zeros((10, ), dtype='int32'))
tvm_b = tvm.ndarray.array(numpy.zeros((10, ), dtype='int32'))
func(tvm_a, tvm_b)
numpy.testing.assert_allclose(tvm_a.asnumpy(), _a, rtol=1e-5)
numpy.testing.assert_allclose(tvm_b.asnumpy(), _b, rtol=1e-5)
numpy.testing.assert_allclose(tvm_a.asnumpy(), tvm_b.asnumpy(), rtol=1e-5)
def test_bind():
if not tvm.gpu(0).exist:
print('No GPU found! Skip this test!')
return
@script
def vec_add(a, b, c):
for tx in bind('threadIdx.x', 1000):
c[tx] = b[tx] + c[tx]
a = tvm.placeholder((1000, ), dtype='float32', name='a')
b = tvm.placeholder((1000, ), dtype='float32', name='b')
c = tvm.placeholder((1000, ), dtype='float32', name='c')
ir = vec_add(a, b, c)
func = tvm.lower(ir, [a, b, c])
func = tvm.build(func, target = 'cuda')
_a = numpy.random.rand(1000).astype('float32')
_b = numpy.random.rand(1000).astype('float32')
_c = numpy.zeros((1000, ), dtype = 'float32')
tvm_a = tvm.ndarray.array(_a, tvm.gpu(0))
tvm_b = tvm.ndarray.array(_b, tvm.gpu(0))
tvm_c = tvm.ndarray.array(_c, tvm.gpu(0))
func(tvm_a, tvm_b, tvm_c)
vec_add(_a, _b, _c)
numpy.testing.assert_allclose(_c, tvm_c.asnumpy(), rtol=1e-5)
def test_math_intrin():
@script
def intrin_real(a):
a[0] = sqrt(a[0])
a[1] = log(a[1])
a[2] = exp(a[2])
a[3] = sigmoid(a[3])
a[4] = power(a[4], a[5])
a[5] = tanh(a[5])
a6 = tvm.placeholder((6, ), dtype='float32', name='a')
ir = intrin_real(a6)
func = tvm.build(tvm.lower(ir, [a6]))
assert func
a = numpy.arange(2, 8).astype('float32')
tvm_a = tvm.ndarray.array(a)
func(tvm_a)
intrin_real(a)
numpy.testing.assert_allclose(a, tvm_a.asnumpy(), rtol=1e-5)
@script
def intrin_int(a):
a[0] = popcount(a[0])
a1 = tvm.placeholder((1, ), dtype='int32')
ir = intrin_int(a1)
func = tvm.build(tvm.lower(ir, [a1]))
assert func
a = numpy.array([1234567890]).astype('int32')
tvm_a = tvm.ndarray.array(a)
intrin_int(a)
func(tvm_a)
assert tvm_a.asnumpy()[0] == a[0]
def test_allocate_buffer():
def blur(a):
for i in serail(32):
h_blur = allocate((4, 36))
for j in serail(4):
for k in serail(36):
s = allocate((1, ), 'float32')
for dj in serail(4):
s[0] = s[0] + a[i, j + dj]
h_blur[j, k] = s[0] / 4.
for j in serail(32):
s = 0.
for di in serail(4):
s = s + h_blur[di, j]
h_blur[i, j] = s / 4.
if __name__ == "__main__":
test_outer_product()
test_fanout()
test_failure()
test_looptype()
test_if()
test_bind()
test_math_intrin()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment