Commit 90db723d by Jian Weng Committed by Tianqi Chen

[FRONTEND] A Python hybrid frontend (#1251)

parent a55bc290
tvm.hybrid
----------
.. automodule:: tvm.hybrid
.. autosummary::
tvm.hybrid.parse
tvm.hybrid.script
tvm.hybrid.popcount
tvm.hybrid.sigmoid
.. autofunction:: tvm.hybrid.parse
.. autofunction:: tvm.hybrid.script
.. autofunction:: tvm.hybrid.popcount
.. autofunction:: tvm.hybrid.sigmoid
...@@ -21,3 +21,4 @@ Python API ...@@ -21,3 +21,4 @@ Python API
dev dev
topi topi
nnvm/index nnvm/index
hybrid
Hybrid Frontend Developer Guide
===============================
If you are a developer:
1. who is trying writing some preliminary patterns that have not been supported by TVM yet,
maybe :ref:`hybrid-langref-label` is a better place for you.
2. who wants to know the implementing details of this module, you are right here!
Features
--------
Software emulation
~~~~~~~~~~~~~~~~~~
In software emulation, the most intresting thing is the decorator ``tvm.hybrid.script``.
This decorator helps 2 things:
1. Importing runtime variables
2. Overload the function according to the arguments passed
Correct me if I am wrong: I believe that how 1. is implemented is dangerous, but I have no
choice. What I did is add those names into python dict ``func.__global__`` and after
the call to ``func`` is done, those names will be cleaned up.
Overload is simple: the decorator checks the arguments' types and determines which function
should be actually called.
Backend Compilation
~~~~~~~~~~~~~~~~~~~
Compilation is a large module, you can see ``python/tvm/hybrid/var_decl.py`` and
``python/tvm/hybrid/parser.py`` for more details. The first stage determines the
usage, or more accurately the declaration of each variable and the second stage does
the actual IR generation.
Attributes
~~~~~~~~~~
So far, ONLY tensors' `shape` attribute is supported. You can see ``visit_Subscript``
in ``python/tvm/hybrid/parser.py`` for more details. This is a hacky solution, I just
check the attributes when subscript.
Loops
~~~~~
In HalideIR, loops have in total 4 types: ``serial``, ``unrolled``, ``parallel``, and ``vectorized``.
.. note::
Unlike what that is in HalideIR, in ``loop_type(a, b)``, ``a`` is the starting point and ``b``
is the trip count of iterations. Here ``loop_type(a, b)`` indicates ``[a, b)``. Thus, when lowering it
to HalideIR, we need to do ``start, extent = a, b - a``
.. note::
In HalideIR those are enums, they are in passive form.
Here we use active form to annotate loops, because they are ready to run.
Variables
~~~~~~~~~
Because there is no variables in ``HalideIR``, all the mutatable variables will be lowered to an array with size 1.
It takes the first store of a variable as its declaration.
Math intrinsics
~~~~~~~~~~~~~~~
So far, these math intrinsics, ``log``, ``exp``, ``sigmoid``, ``tanh``, ``power``, and ``popcount``, are supported.
Math intrinsics will be imported by the decorator. Most of the intrinsics are borrowed by library implementation
except ``popcount`` and ``sigmoid``. I implemented them manually.
...@@ -10,3 +10,4 @@ In this part of documentation, we share the rationale for the specific choices m ...@@ -10,3 +10,4 @@ In this part of documentation, we share the rationale for the specific choices m
runtime runtime
nnvm_json_spec nnvm_json_spec
nnvm_overview nnvm_overview
hybrid_script
.. _hybrid-langref-label:
Hybrid Frontend Language Reference
==================================
Overview
--------
This hybrid frontend allows users to write preliminary versions of some idioms that yet have
been supported by TVM officially.
Features
--------
Software Emulation
~~~~~~~~~~~~~~~~~~
Both software emulation and compilation are supported. To define a function,
you need to use ``tvm.hybrid.script`` decorator to indicate this is a hybrid function:
.. code-block:: python
@tvm.hybrid.script
def outer_product(a, b, c):
for i in range(a.shape[0]):
for j in range(b.shape[0]):
c[i, j] = a[i] * b[j]
a = numpy.random.rand(100)
b = numpy.random.rand(99)
c = numpy.zeros((100, 99))
outer_product(a, b, c)
This decorator will import `Keywords`_ required spontaneously when software emulation.
After software emulation is done, the imported keywords will be cleaned up. Users do not need
worry about keyword conflict and pollution.
Every element passed for software emulation in the argument list is either a python variable
or ``numpy`` numeric type.
Backend Compilation
~~~~~~~~~~~~~~~~~~~
The current parse interface looks like:
.. code-block:: python
a = tvm.placeholder((100, ), name='a')
b = tvm.placeholder((99, ), name='b')
c = tvm.placeholder((100, 99), name='c')
tvm.hybrid.parse(outer_product, [a, b, c]) # return an ir root of this function
If we pass these tvm tensors to this function, it returns a op node:
**Under construction, we are still deciding what kind of node should be returned.**
.. code-block:: python
a = tvm.placeholder((100, ), name='a')
b = tvm.placeholder((99, ), name='b')
c = tvm.placeholder((100, 99), name='c')
op = outer_product(a, b, c) # return the corresponding op node
Tuning
~~~~~~
**Under construction, not truly supported yet.**
Follow up the example above, you can use some tvm like interfaces to tune the code:
.. code-block:: python
sch = tvm.create_schedule(op)
jo, ji = sch.split(j, 4)
sch.vectorize(ji)
``split``, ``reorder``, and loop_annotation will be supported!
Loops
~~~~~
In HalideIR, loops have in total 4 types: ``serial``, ``unrolled``, ``parallel``, and ``vectorized``.
Here we use ``range`` aka ``serial``, ``unroll``, ``parallel``, and ``vectorize``,
these **4** keywords to annotate the corresponding types of for loops.
The the usage is roughly the same as Python standard ``range``.
Variables
~~~~~~~~~
All the mutatable variables will be lowered to an array with size 1.
It regards the first store of a variable as its declaration.
.. note::
Unlike conventional Python, in hybrid script, the declared variable
can only be used in the scope level it is declared.
.. note::
Currently, you can ONLY use basic-typed variables, i.e. the type of the
variable should be either ``float32``, or ``int32``.
.. code-block:: python
for i in range(5):
s = 0 # declaration, this s will be a 1-array in lowered IR
for j in range(5):
s += a[i, j] # do something with sum
b[i] = sum # you can still use sum in this level
a[0] = s # you CANNOT use s here, even though it is allowed in conventional Python
b = (1, 2) # this has NOT been supported yet!
Attributes
~~~~~~~~~~
So far, ONLY tensors' ``shape`` attribute is supported! The ``shape`` atrribute is essentailly a
tuple, so you MUST access it as an array. Also, currently, only constant-indexed access is supported.
.. code-block:: python
x = a.shape[2] # OK!
for i in range(3):
for j in a.shape[i]: # BAD! i is not a constant!
# do something
Conditional Statement and Expression
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. code-block:: python
if condition:
# do something
a = b if condition else c
However, NO ``True`` and ``False`` keyword supported yet.
Math Intrinsics
~~~~~~~~~~~~~~~
So far, these math intrinsics, ``log``, ``exp``, ``sigmoid``,
``tanh``, ``power``, and ``popcount``, are supported.
No import is required, just as it is mentioned in `Software Emulation`_, just use it!
Array Allocation
~~~~~~~~~~~~~~~~
**Under construction, this function will be supported later!**
Use a function call ``allocation(shape, type, share/local)`` to declare an array buffer.
The basic usage is roughly the same as a normal array.
Thread Bind
~~~~~~~~~~~
You can also do loop-thread bind by writing code like this:
.. code-block:: python
for tx in bind("threadIdx.x", 100):
a[tx] = b[tx]
Keywords
~~~~~~~~
- For keywords: ``serial``, ``range``, ``unroll``, ``parallel``, ``vectorize``, ``bind``
- Math keywords: ``log``, ``exp``, ``sigmoid``, ``tanh``, ``power``, ``popcount``
...@@ -2,3 +2,8 @@ Language Reference ...@@ -2,3 +2,8 @@ Language Reference
================== ==================
This document provide references to This document provide references to
embedded languages in TVM stack. embedded languages in TVM stack.
.. toctree::
:maxdepth: 2
hybrid_script
...@@ -332,12 +332,20 @@ def lower(sch, ...@@ -332,12 +332,20 @@ def lower(sch,
lower_phase1 = [x[1] for x in add_lower_pass if x[0] == 1] lower_phase1 = [x[1] for x in add_lower_pass if x[0] == 1]
lower_phase2 = [x[1] for x in add_lower_pass if x[0] == 2] lower_phase2 = [x[1] for x in add_lower_pass if x[0] == 2]
lower_phase3 = [x[1] for x in add_lower_pass if x[0] > 2] lower_phase3 = [x[1] for x in add_lower_pass if x[0] > 2]
# Phase 0
if isinstance(sch, schedule.Schedule):
# normalize schedule first # normalize schedule first
sch = sch.normalize() sch = sch.normalize()
# Phase 0
bounds = schedule.InferBound(sch) bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds) stmt = schedule.ScheduleOps(sch, bounds)
stmt = ir_pass.InjectPrefetch(stmt) stmt = ir_pass.InjectPrefetch(stmt)
else:
#So far there is no op for hybrid script, so a plain ir body is given
if not isinstance(sch, _stmt.Stmt):
raise ValueError("sch should be either a Schedule or a Stmt")
stmt = sch
for f in lower_phase0: for f in lower_phase0:
stmt = f(stmt) stmt = f(stmt)
# Phase 1 # Phase 1
......
"""Hybrid Programming APIs of TVM Python Package.
This package maps a subset of python to HalideIR so that:
1. Users can write some preliminary versions of the computation patterns
have not been supported yet and verify it across the real execution and
python semantic emulation.
2. Developers can build HalideIR by writing Python code.
"""
from .api import script, parse
"""APIs of lowering the Python subset to HalideIR"""
from __future__ import absolute_import as _abs
import types
import decorator
from .parser import parse_python
@decorator.decorator
def script(func, *args):
"""If the arguments are tvm types, compile it to HalideIR.
O.W. return the python emulated result"""
from .util import _enter_hybrid_runtime, _restore_runtime, _is_tvm_arg_types
if _is_tvm_arg_types(args):
return parse(func, args)
else:
intersect = _enter_hybrid_runtime(func)
func(*args)
_restore_runtime(func, intersect)
return func
def parse(func, args):
"""Parse a subset of Python to HalideIR
Parameters
----------
func : str or types.FunctionType
If it is a string, parse the source code
If it is a function, parse the function
args : list of Buffer or Tensor or Var
The argument lists to the function.
Leave it None if no buffer is related to the function to be parsed
Returns
-------
root : Stmt
The result Halide IR and the parser class instance.
"""
from .util import _pruned_source
if isinstance(func, str):
src = func
else:
assert isinstance(func, types.FunctionType)
src = _pruned_source(func)
return parse_python(src, args)
"""Intrinsics of TVM-Python Hybrid Script for Python runtime"""
import numpy
from ..stmt import For
class _range(object):
"""Base class of the loop ranges in hybrid script"""
def __init__(self, a, b=None):
if b is None:
self.low = 0
self.ext = a
else:
self.low = a
self.ext = b
def __iter__(self):
i = 0
while i < self.ext:
yield i + self.low
i += 1
class bind(_range): #pylint: disable=invalid-name
def __init__(self, tag, ext):
super(bind, self).__init__(ext)
self.tag = tag
unroll = vectorize = parallel = _range #pylint: disable=invalid-name
def allocate(shape, dtype='float32'):
"""Allocate a buffer with given shape
Parameters
----------
shape: Tuple
The shape of the tensor to be allocated
dtype: string
The data type of the tensor
Returns
-------
tensor: numpy.array
The tensor allocated
"""
return numpy.zeros(shape).astype(dtype)
def popcount(x):
"""
Count ones in the binary representation of number x
Parameters
----------
x: Integer
The number to be counted
Returns
-------
cnt: Integer
The number of ones in the binary representation of number x
"""
cnt = 0
while x:
x -= x & -x
cnt += 1
return cnt
def sigmoid(x):
"""
Sigmoid function of x, aka 1/(1+exp(-x)).
Parameters
----------
x: a real number
Returns
-------
res: a real number
The result of sigmoid function
"""
return 1 / (1 + numpy.exp(-x))
HYBRID_GLOBALS = {
'unroll' : unroll,
'vectorize' : vectorize,
'parallel' : parallel,
'allocate' : allocate,
'bind' : bind,
'sqrt' : numpy.sqrt,
'log' : numpy.log,
'tanh' : numpy.tanh,
'power' : numpy.power,
'exp' : numpy.exp,
'sigmoid' : sigmoid,
'popcount' : popcount
}
LOOP_INTRIN = {
'range' : For.Serial,
'unroll' : For.Unrolled,
'parallel' : For.Parallel,
'vectorize': For.Vectorized,
'bind' : None
}
MATH_INTRIN = ['sqrt', 'log', 'exp', 'tanh', 'sigmoid', 'power', 'popcount']
"""Internal utilities for parsing Python subset to HalideIR"""
import inspect
import numpy
from .intrin import HYBRID_GLOBALS
from .._ffi.base import numeric_types
from .. import api as _api
from .. import make as _make
from .. import expr as _expr
from ..tensor import Tensor
#pylint: disable=invalid-name
np_arg_types = tuple(list(numeric_types) + [numpy.ndarray])
tvm_arg_types = (Tensor, _expr.Var)
halide_imm_types = (_expr.IntImm, _expr.FloatImm, _expr.UIntImm)
# Useful constants. In avoid of runtime dependences, we use function calls to return them.
def make_nop():
"""Returns a 'no operation' node in HalideIR."""
return _make.Evaluate(_api.const(0, dtype='int32'))
def make_range_one():
"""Returns a [0, 1] range node in HalideIR."""
return _make.range_by_min_extent(0, 1)
def make_const_true():
"""Returns a constant True node in HalideIR."""
return _api.convert(True)
def _pruned_source(func):
"""Prune source code's extra leading spaces"""
lines = inspect.getsource(func).split('\n')
leading_space = len(lines[0]) - len(lines[0].lstrip(' '))
lines = [line[leading_space:] for line in lines]
return '\n'.join(lines)
def _is_tvm_arg_types(args):
"""Determine a list of element is either a list of tvm arguments of a list of numpy arguments.
If neither is true, raise a value error."""
if isinstance(args[0], tvm_arg_types):
for elem in args[1:]:
if not isinstance(elem, tvm_arg_types):
raise ValueError("Expect a Var or Tensor instance but % get!" % str(type(elem)))
return True
if not isinstance(args[0], np_arg_types):
raise ValueError("Expect a numpy type but % get!" % str(type(args[0])))
for elem in args[1:]:
if not isinstance(elem, np_arg_types):
raise ValueError("Expect a numpy type but % get!" % str(type(elem)))
return False
def _enter_hybrid_runtime(func):
"""Put hybrid runtime variables into the global scope"""
_globals = func.__globals__
intersect = []
for elem in list(HYBRID_GLOBALS.keys()):
if elem in _globals.keys():
intersect.append((elem, _globals[elem]))
_globals[elem] = HYBRID_GLOBALS[elem]
return intersect
def _restore_runtime(func, intersect):
"""Rollback the modification caused by hybrid runtime"""
_globals = func.__globals__
for elem in list(HYBRID_GLOBALS.keys()):
_globals.pop(elem)
for k, v in intersect:
_globals[k] = v
"""Determines the declaration, r/w status, and last use of each variable"""
import ast
import sys
from .intrin import HYBRID_GLOBALS
class PyVariableUsage(ast.NodeVisitor):
"""The vistor class to determine the declaration, r/w status, and last use of each variable"""
#pylint: disable=invalid-name
#pylint: disable=missing-docstring
def __init__(self, args):
self.status = {}
self.scope_level = []
self._args = {}
self.args = args
def visit_FunctionDef(self, node):
self.scope_level.append(node)
if len(node.args.args) != len(self.args):
raise ValueError('#arguments passed should be the same as #arguments defined')
for idx, arg in enumerate(node.args.args):
_attr = 'id' if sys.version_info[0] < 3 else 'arg' # To make py2 and 3 compatible
self._args[getattr(arg, _attr)] = self.args[idx]
for i in node.body:
self.visit(i)
def visit_For(self, node):
if not isinstance(node.target, ast.Name):
raise ValueError("For's iterator should be an id")
self.visit(node.iter)
self.scope_level.append(node)
for i in node.body:
self.visit(i)
self.scope_level.pop()
def visit_Call(self, node):
#No function pointer supported so far
if not isinstance(node.func, ast.Name):
raise ValueError("Function call should be an id")
if (node.func.id not in HYBRID_GLOBALS.keys()) and node.func.id != 'range':
raise ValueError("Function call id not in intrinsics' list")
for elem in node.args:
self.visit(elem)
def visit_Name(self, node):
# If it is from the argument list or loop variable, we do not worry about it!
if node.id in self._args.keys():
return
fors = [loop.target.id for loop in self.scope_level if isinstance(loop, ast.For)]
if node.id in fors:
return
# The loop variable cannot be overwritten when iteration
if isinstance(node.ctx, ast.Store) and node.id in fors:
raise ValueError("Iter var cannot be overwritten")
if node.id not in self.status.keys():
if not isinstance(node.ctx, ast.Store):
raise ValueError('In Python, "first store" indicates "declaration"')
self.status[node.id] = (node, self.scope_level[-1], set())
else:
decl, loop, usage = self.status[node.id]
loop = self.scope_level[-1]
usage.add(type(node.ctx))
self.status[node.id] = (decl, loop, usage)
def determine_variable_usage(root, args):
"""The helper function for calling the dedicated visitor."""
visitor = PyVariableUsage(args)
visitor.visit(root)
return visitor.status
import tvm, inspect, sys, traceback, numpy
from tvm.hybrid import script
from tvm.hybrid.intrin import HYBRID_GLOBALS
@script
def outer_product(n, m, a, b, c):
for i in range(n):
for j in range(m):
c[i, j] = a[i] * b[j]
#Test global function
#Test bridge between frontend and backend
def test_outer_product():
n = tvm.var('n')
m = tvm.var('m')
a = tvm.placeholder((n, ), name='a')
b = tvm.placeholder((m, ), name='b')
c = tvm.placeholder((n, m), name='c')
ir = outer_product(n, m, a, b, c)
#Check for i in (0, n)
assert isinstance(ir, tvm.stmt.For)
assert ir.loop_var.name == 'i'
assert ir.min.value == 0
assert ir.extent.name == 'n'
ibody = ir.body
assert isinstance(ibody, tvm.stmt.For)
#Check for j in (0, m)
assert ibody.loop_var.name == 'j'
assert ibody.min.value == 0
assert ibody.extent.name == 'm'
#Check loop body
jbody = ibody.body
assert isinstance(jbody, tvm.stmt.Provide)
assert jbody.func.name == 'c'
assert len(jbody.args) == 2
assert jbody.args[0].name == 'i'
assert jbody.args[1].name == 'j'
assert isinstance(jbody.value, tvm.expr.Mul)
mul = jbody.value
assert isinstance(mul.a, tvm.expr.Call)
assert mul.a.name == 'a'
assert mul.b.name == 'b'
func = tvm.lower(ir, [n, m, a, b, c])
func = tvm.build(func)
_n = 999
_m = 1001
_a = numpy.random.rand(_n).astype('float32')
_b = numpy.random.rand(_m).astype('float32')
c_python = numpy.zeros((_n, _m), dtype='float32')
outer_product(_n, _m, _a, _b, c_python)
tvm_a = tvm.ndarray.array(_a)
tvm_b = tvm.ndarray.array(_b)
tvm_c = tvm.ndarray.array(numpy.zeros((_n, _m), dtype='float32'))
func(_n, _m, tvm_a, tvm_b, tvm_c)
numpy.testing.assert_allclose(tvm_c.asnumpy(), c_python, rtol=1e-5)
for key, _ in HYBRID_GLOBALS.items():
assert key not in globals().keys()
assert key not in outer_product.__globals__.keys()
#Test local function
#Test allocation of local variable
def test_fanout():
@script
def fanout(n, a, b):
three = 3.0
for i in range(a.shape[0] - 3):
sigma = 0.0
for j in range(3):
sigma = sigma + a[i + j]
sigma = sigma / three
b[i] = sigma
n = tvm.var('n')
a = tvm.placeholder((n, ), name='a')
b = tvm.placeholder((n-3, ), name='b')
ir = fanout(n, a, b)
#Check for i in (0, n-3)
assert isinstance(ir, tvm.stmt.For)
assert ir.loop_var.name == 'i'
assert ir.min.value == 0
assert tvm.ir_pass.Equal(ir.extent, n - 3)
#Check loopbody
ibody = ir.body
assert isinstance(ibody, tvm.stmt.Realize)
assert ibody.bounds[0].min.value == 0
assert ibody.bounds[0].extent.value == 1
assert ibody.func.name == 'sigma'
#Check i loop body
rbody = ibody.body
assert isinstance(rbody.first, tvm.stmt.Provide)
assert rbody.first.func.name == 'sigma'
assert len(rbody.first.args) == 1
assert rbody.first.args[0].value == 0
#Check fanout loop
jloop = rbody.rest.first
assert jloop.loop_var.name == 'j'
assert jloop.min.value == 0
assert jloop.extent.value == 3
jbody = jloop.body
assert isinstance(jbody, tvm.stmt.Provide)
assert len(jbody.args) == 1
assert jbody.args[0].value == 0
assert jbody.func.name == 'sigma'
assert isinstance(jbody.value, tvm.expr.Add)
value = jbody.value
assert isinstance(value.a, tvm.expr.Call)
assert value.a.name == 'sigma'
assert len(value.a.args) == 1
assert value.a.args[0].value == 0
assert value.b.name == 'a'
assert len(value.b.args) == 1
assert tvm.ir_pass.Equal(value.b.args[0], ir.loop_var + jloop.loop_var)
divide= rbody.rest.rest.first
assert isinstance(divide, tvm.stmt.Provide)
assert len(divide.args) == 1
assert divide.args[0].value == 0
value = divide.value
assert isinstance(value, tvm.expr.Mul)
assert value.a.name == 'sigma'
assert len(value.a.args) == 1
assert value.a.args[0].value == 0
assert abs(value.b.value - (1 / 3.0)) < 1e-5
write = rbody.rest.rest.rest
assert isinstance(write, tvm.stmt.Provide)
assert write.func.name == 'b'
assert write.value.name == 'sigma'
assert len(write.value.args) == 1
assert write.value.args[0].value == 0
@script
def failure():
for i in range(1, 100):
i = 0
def test_failure():
try:
tvm.hybrid.parse(failure, [])
except IOError as err:
assert sys.version_info[0] == 2
print('[Warning] Python2 cannot do the failure case because "%s"' % str(err))
except Exception as err:
assert str(err) == 'You CAN NEVER overwrite a loop variable!'
def test_looptype():
@script
def looptype(a):
for i in parallel(6):
a[i] = i
for j in vectorize(6):
a[j] = j
for k in unroll(6):
a[k] = k
a = tvm.placeholder((6, ), name='a')
ir = looptype(a)
iloop = ir.first
jloop = ir.rest.first
kloop = ir.rest.rest
assert iloop.for_type == tvm.stmt.For.Parallel
assert jloop.for_type == tvm.stmt.For.Vectorized
assert kloop.for_type == tvm.stmt.For.Unrolled
def test_if():
@script
def if_then_else(a, b):
for i in range(10):
if i % 2 == 0:
a[i] = -1
else:
a[i] = 1
for i in unroll(10):
b[i] = -1 if i % 2 == 0 else 1
a = tvm.placeholder((10, ), dtype='int32', name='a')
b = tvm.placeholder((10, ), dtype='int32', name='b')
ir = if_then_else(a, b)
func = tvm.lower(ir, [a, b])
func = tvm.build(func)
assert func
_a = numpy.zeros((10, ), dtype = 'int32')
_b = numpy.zeros((10, ), dtype = 'int32')
if_then_else(_a, _b)
tvm_a = tvm.ndarray.array(numpy.zeros((10, ), dtype='int32'))
tvm_b = tvm.ndarray.array(numpy.zeros((10, ), dtype='int32'))
func(tvm_a, tvm_b)
numpy.testing.assert_allclose(tvm_a.asnumpy(), _a, rtol=1e-5)
numpy.testing.assert_allclose(tvm_b.asnumpy(), _b, rtol=1e-5)
numpy.testing.assert_allclose(tvm_a.asnumpy(), tvm_b.asnumpy(), rtol=1e-5)
def test_bind():
if not tvm.gpu(0).exist:
print('No GPU found! Skip this test!')
return
@script
def vec_add(a, b, c):
for tx in bind('threadIdx.x', 1000):
c[tx] = b[tx] + c[tx]
a = tvm.placeholder((1000, ), dtype='float32', name='a')
b = tvm.placeholder((1000, ), dtype='float32', name='b')
c = tvm.placeholder((1000, ), dtype='float32', name='c')
ir = vec_add(a, b, c)
func = tvm.lower(ir, [a, b, c])
func = tvm.build(func, target = 'cuda')
_a = numpy.random.rand(1000).astype('float32')
_b = numpy.random.rand(1000).astype('float32')
_c = numpy.zeros((1000, ), dtype = 'float32')
tvm_a = tvm.ndarray.array(_a, tvm.gpu(0))
tvm_b = tvm.ndarray.array(_b, tvm.gpu(0))
tvm_c = tvm.ndarray.array(_c, tvm.gpu(0))
func(tvm_a, tvm_b, tvm_c)
vec_add(_a, _b, _c)
numpy.testing.assert_allclose(_c, tvm_c.asnumpy(), rtol=1e-5)
def test_math_intrin():
@script
def intrin_real(a):
a[0] = sqrt(a[0])
a[1] = log(a[1])
a[2] = exp(a[2])
a[3] = sigmoid(a[3])
a[4] = power(a[4], a[5])
a[5] = tanh(a[5])
a6 = tvm.placeholder((6, ), dtype='float32', name='a')
ir = intrin_real(a6)
func = tvm.build(tvm.lower(ir, [a6]))
assert func
a = numpy.arange(2, 8).astype('float32')
tvm_a = tvm.ndarray.array(a)
func(tvm_a)
intrin_real(a)
numpy.testing.assert_allclose(a, tvm_a.asnumpy(), rtol=1e-5)
@script
def intrin_int(a):
a[0] = popcount(a[0])
a1 = tvm.placeholder((1, ), dtype='int32')
ir = intrin_int(a1)
func = tvm.build(tvm.lower(ir, [a1]))
assert func
a = numpy.array([1234567890]).astype('int32')
tvm_a = tvm.ndarray.array(a)
intrin_int(a)
func(tvm_a)
assert tvm_a.asnumpy()[0] == a[0]
def test_allocate_buffer():
def blur(a):
for i in serail(32):
h_blur = allocate((4, 36))
for j in serail(4):
for k in serail(36):
s = allocate((1, ), 'float32')
for dj in serail(4):
s[0] = s[0] + a[i, j + dj]
h_blur[j, k] = s[0] / 4.
for j in serail(32):
s = 0.
for di in serail(4):
s = s + h_blur[di, j]
h_blur[i, j] = s / 4.
if __name__ == "__main__":
test_outer_product()
test_fanout()
test_failure()
test_looptype()
test_if()
test_bind()
test_math_intrin()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment