Commit 979623e5 by Tianqi Chen Committed by GitHub

[Tutorial] External Tensor Op (#137)

parent 553657eb
......@@ -16,3 +16,9 @@ tvm.contrib.util
~~~~~~~~~~~~~~~~
.. automodule:: tvm.contrib.util
:members:
tvm.contrib.cblas
~~~~~~~~~~~~~~~~~
.. automodule:: tvm.contrib.cblas
:members:
......@@ -5,6 +5,9 @@ tvm.schedule
.. autoclass:: tvm.schedule.IterVar
:members:
.. autoclass:: tvm.schedule.Buffer
:members:
.. autofunction:: tvm.create_schedule
.. autoclass:: tvm.schedule.Schedule
......
......@@ -236,14 +236,13 @@ def scan(init, update, state_placeholder, inputs=None, name="scan"):
res = [op.output(i) for i in range(len(update))]
return res[0] if len(res) == 1 else res
def extern(shape, inputs, fcompute,
name="extern", dtype=None):
"""Compute several tensor via extern function.
Parameters
----------
shape: Shape tuple or list of shapes.
shape: tuple or list of tuples.
The shape of the outputs.
inputs: list of Tensor
......@@ -251,6 +250,17 @@ def extern(shape, inputs, fcompute,
fcompute: lambda function of inputs, outputs-> stmt
Specifies the IR statement to do the computation.
See the following note for function signature of fcompute
.. note::
**Parameters**
- **ins** (list of :any:`Buffer`) - Placeholder for each inputs
- **outs** (list of :any:`Buffer`) - Placeholder for each outputs
**Returns**
- **stmt** (:any:`Stmt`) - The statement that carries out array computation.
name: str, optional
The name hint of the tensor
......@@ -263,9 +273,23 @@ def extern(shape, inputs, fcompute,
-------
tensor: Tensor or list of Tensors
The created tensor or tuple of tensors it it contains multiple outputs.
Example
-------
In the code below, C is generated by calling external PackedFunc
`tvm.contrib.cblas.matmul`
.. code-block:: python
A = tvm.placeholder((n, l), name='A')
B = tvm.placeholder((l, m), name='B')
C = tvm.extern((n, m), [A, B],
lambda ins, outs: tvm.call_packed(
"tvm.contrib.cblas.matmul",
ins[0], ins[1], outs[0], 0, 0), name="C")
"""
if isinstance(shape[0], _expr.Expr):
shape = [shape]
shape = (shape,) if isinstance(shape, (_expr.Expr, _Integral)) else shape
shape = [shape] if isinstance(shape[0], (_expr.Expr, _Integral)) else shape
input_placeholders = []
output_placeholders = []
types = set()
......@@ -305,6 +329,8 @@ def decl_buffer(shape, dtype=None,
Normally buffer is created automatically during lower and build.
This is only needed if user want to specify their own buffer layout.
See the note below for detailed discussion on usage of buffer.
Parameters
----------
shape : tuple of Expr
......@@ -332,8 +358,21 @@ def decl_buffer(shape, dtype=None,
-------
buffer : Buffer
The created buffer
Note
----
Buffer data structure reflects the DLTensor structure in dlpack.
While DLTensor data structure is very general, it is usually helpful
to create function that only handles specific case of data structure
and make compiled function benefit from it.
If user pass strides and byte_offset is passed as None
when constructing the function, then the function will be specialized
for the DLTensor that is compact and aligned.
If user pass a fully generic symbolic array to the strides,
then the resulting function becomes fully generic.
"""
shape = (shape,) if isinstance(shape, _expr.Expr) else shape
shape = (shape,) if isinstance(shape, (_expr.Expr, _Integral)) else shape
dtype = float32 if dtype is None else dtype
strides = () if strides is None else strides
if data is None:
......
......@@ -13,12 +13,47 @@ from . import collections
from . import module
from . import codegen
def get_binds(args, binds=None):
"""Internal function to get binds and arg_list given arguments.
Parameters
----------
args : list of Buffer or Tensor or Var
The argument lists to the function.
binds : dict, optional
Dictionary that maps the binding of symbolic buffer to Tensor.
By default, a new buffer is created for each tensor in the argument.
Returns
-------
binds: dict
The bind specification
arg_list: list
The list of symbolic buffers of arguments.
"""
binds = {} if binds is None else binds.copy()
arg_list = []
for x in args:
if isinstance(x, tensor.Tensor):
buf = api.decl_buffer(x.shape, dtype=x.dtype, name=x.name)
assert x not in binds
binds[x] = buf
arg_list.append(buf)
elif isinstance(x, schedule.Buffer):
arg_list.append(x)
elif isinstance(x, expr.Var):
arg_list.append(x)
else:
raise ValueError("args must be Tensor, Buffer or Var")
return binds, arg_list
def lower(sch,
args,
name="default_function",
binds=None,
with_api_wrapper=True,
simple_mode=False,
max_auto_unroll_step=0):
"""Lowering step before build into target.
......@@ -37,8 +72,9 @@ def lower(sch,
Dictionary that maps the binding of symbolic buffer to Tensor.
By default, a new buffer is created for each tensor in the argument.
with_api_wrapper : bool, optional
Whether add API wrapper during lowering.
simple_mode : bool, optional
Whether only output simple and compact statement, this will skip
LoopPartition, api wrapper generation and Unrolling.
max_auto_unroll_step: int, optional
Maximum step to perform automatic unrolling
......@@ -49,33 +85,22 @@ def lower(sch,
The result function, if with_api_wrapper=False
Then the Stmt before make api is returned.
"""
binds = {} if binds is None else binds.copy()
arg_list = []
for x in args:
if isinstance(x, tensor.Tensor):
buf = api.decl_buffer(x.shape, dtype=x.dtype, name=x.name)
assert x not in binds
binds[x] = buf
arg_list.append(buf)
elif isinstance(x, schedule.Buffer):
arg_list.append(x)
elif isinstance(x, expr.Var):
arg_list.append(x)
else:
raise ValueError("args must be Tensor, Buffer or Var")
binds, arg_list = get_binds(args, binds)
# normalize schedule first
sch = sch.normalize()
bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds)
stmt = ir_pass.LoopPartition(stmt)
if not simple_mode:
stmt = ir_pass.LoopPartition(stmt)
stmt = ir_pass.StorageFlatten(stmt, binds)
stmt = ir_pass.CanonicalSimplify(stmt)
stmt = ir_pass.VectorizeLoop(stmt)
stmt = ir_pass.InjectVirtualThread(stmt)
stmt = ir_pass.StorageRewrite(stmt)
stmt = ir_pass.UnrollLoop(stmt, max_auto_unroll_step)
if not simple_mode:
stmt = ir_pass.UnrollLoop(stmt, max_auto_unroll_step)
stmt = ir_pass.Simplify(stmt)
if not with_api_wrapper:
if simple_mode:
return stmt
return ir_pass.MakeAPI(stmt, name, arg_list, 0)
......
......@@ -10,7 +10,18 @@ from ._ffi.function import _init_api
@register_node
class Buffer(NodeBase):
"""Represent a symbolic buffer in TVM."""
"""Symbolic data buffer in TVM.
Buffer provide a way to represent data layout
specialization of data structure in TVM.
Do not construct directly, use :any:`decl_buffer` instead.
See the documentation of :any:`decl_buffer` for more details.
See Also
--------
decl_buffer : Declare a buffer
"""
pass
@register_node
......
......@@ -5,6 +5,27 @@ def collect_visit(stmt, f):
tvm.ir_pass.PostOrderVisit(stmt, lambda x : ret.append(f(x)))
return ret
def lower(sch, args):
binds = {}
arg_list = []
for x in args:
if isinstance(x, tvm.tensor.Tensor):
buf = tvm.decl_buffer(x.shape, dtype=x.dtype, name=x.name)
assert x not in binds
binds[x] = buf
arg_list.append(buf)
else:
raise ValueError("args must be Tensor, Buffer or Var")
sch = sch.normalize()
bounds = tvm.schedule.InferBound(sch)
stmt = tvm.schedule.ScheduleOps(sch, bounds)
stmt = tvm.ir_pass.LoopPartition(stmt)
stmt = tvm.ir_pass.StorageFlatten(stmt, binds)
stmt = tvm.ir_pass.CanonicalSimplify(stmt)
stmt = tvm.ir_pass.VectorizeLoop(stmt)
stmt = tvm.ir_pass.Simplify(stmt)
return stmt
def test_basic():
n = tvm.var('n')
A = tvm.placeholder((n, ), name='A')
......@@ -92,7 +113,7 @@ def test_vectorize():
s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
s[C].vectorize(x)
stmt = tvm.lower(s, [A, B], name='ewise_add', with_api_wrapper=False)
stmt = lower(s, [A, B])
body = stmt.body.body.body.body.body
assert(x.var.name not in str(body.condition))
assert(any(collect_visit(body.then_case, lambda x: isinstance(x, tvm.expr.Ramp))))
......@@ -123,7 +144,7 @@ def test_thread_axis2():
_, x = s[C].split(x, factor=m)
s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
stmt = tvm.lower(s, [A, B], name='ewise_add', with_api_wrapper=False)
stmt = lower(s, [A, B])
for_body = stmt.body.body.body.body.body.first
assert('threadIdx' not in str(for_body.extent))
......
......@@ -59,7 +59,7 @@ def test_lstm_cell_inline():
s[forget_gate].compute_inline()
s[out_gate].compute_inline()
# verify we can lower correctly
tvm.lower(s, [X, Wi2h, Wh2h, scan_h, scan_c], with_api_wrapper=False)
tvm.lower(s, [X, Wi2h, Wh2h, scan_h, scan_c])
if __name__ == "__main__":
test_lstm_cell_inline()
......@@ -109,7 +109,7 @@ def test_scan_inline1():
[s_state1, s_state2])
s = tvm.create_schedule(res1.op)
s[s_x1].compute_inline()
stmt = tvm.lower(s, [x, res1, res2], with_api_wrapper=False)
stmt = tvm.lower(s, [x, res1, res2])
def test_scan_inline2():
m = tvm.var("m")
......@@ -131,7 +131,7 @@ def test_scan_inline2():
s[s_xx].compute_inline()
s[s_x1].compute_inline()
s[s_x2].compute_inline()
stmt = tvm.lower(s, [x, res1, res2], with_api_wrapper=False)
stmt = tvm.lower(s, [x, res1, res2])
def test_schedule_cache():
......
"""
External Tensor Functions
=========================
**Author**: `Tianqi Chen <https://tqchen.github.io>`_
While tvm support transparent code generation, sometimes
it is also helpful to incorporate manual written code into
the pipeline. For example, we might want to use cuDNN for
some of the convolution kernels and define the rest of the stages.
TVM support these black box function calls natively.
Specfically, tvm support all the tensor functions that are DLPack compatible.
Which means we can call any function with POD types(pointer, int, float)
or pointer to DLTensor as argument.
"""
from __future__ import absolute_import, print_function
import tvm
import numpy as np
from tvm.contrib import cblas
######################################################################
# Use Extern Tensor Function
# --------------------------
# In the example below, we use :any:`tvm.extern` to add an extern
# array function call. In the extern call, we declare the shape
# of output tensors. In the second argument we provide the list of inputs.
#
# User will need to provide a function describing how to compute the result.
# The compute function takes list of symbolic are placeholder for the inputs,
# list of symbolic placeholder for the outputs and returns the executing statement.
#
# In this case we simply call a registered tvm function, which invokes a CBLAS call.
# TVM do not control internal of the extern array function and treats it as blackbox.
# We can further mix schedulable TVM calls that add a bias to term to the result.
#
n = 1024
l = 128
m = 235
bias = tvm.var('bias', dtype=tvm.float32)
A = tvm.placeholder((n, l), name='A')
B = tvm.placeholder((l, m), name='B')
C = tvm.extern((n, m), [A, B],
lambda ins, outs: tvm.call_packed(
"tvm.contrib.cblas.matmul",
ins[0], ins[1], outs[0], False, False), name="C")
D = tvm.compute(C.shape, lambda i, j: C[i,j] + bias, name="D")
s = tvm.create_schedule(D.op)
######################################################################
# Verify the Result
# -----------------
# We can verify that the result matches what we expected.
#
ctx = tvm.cpu(0)
f = tvm.build(s, [A, B, D, bias], "llvm")
a = tvm.nd.array(np.random.uniform(size=(n, l)).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=(l, m)).astype(B.dtype), ctx)
d = tvm.nd.array(np.zeros((n, m), dtype=D.dtype), ctx)
bb = 10.0
f(a, b, d, bb)
np.testing.assert_allclose(
d.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()) + 10)
######################################################################
# Extern Contrib Wrappers
# -----------------------
# TVM also provide extern contrib wrappers to useful extern calls,
# the following line is equivalent to the previous example.
#
from tvm.contrib import cblas
C = cblas.matmul(A, B)
D = tvm.compute(C.shape, lambda i, j: C[i,j] + bias, name="D")
s = tvm.create_schedule(D.op)
######################################################################
# Hook Python Function as Extern
# ------------------------------
# Since we can call into any PackedFunc in TVM. We can use the extern
# function to callback into python.
#
# The following example registers a python function into tvm runtime system
# and use it to complete one stage of the computation.
# This makes TVM much more flexible. For example, we can insert front-end
# callbacks to inspect the intermediate results or mix customized code
# with TVM.
#
@tvm.register_func("tvm.contrib.my_tvm_addone")
def my_tvm_addone(x, y):
print("my_tvm_addone signatures: %s, %s" % (type(x), type(y)))
tvm.nd.array(x.asnumpy() + 1).copyto(y)
A = tvm.placeholder((n,), name='A')
B = tvm.extern(A.shape, [A], lambda ins, outs: tvm.call_packed(
"tvm.contrib.my_tvm_addone", ins[0], outs[0]), name="C")
s = tvm.create_schedule(B.op)
f = tvm.build(s, [A, B], "llvm")
a = tvm.nd.array(np.random.uniform(size=(n,)).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=(n,)).astype(B.dtype), ctx)
f(a, b)
np.testing.assert_allclose(b.asnumpy(), a.asnumpy() + 1)
######################################################################
# Summary
# -------
# - TVM call extern tensor function via :any:`tvm.extern`
# - Use contrib wrappers for short sugars of extern tensor calls.
# - We can hook front-end function as extern tensor callbacks.
#
......@@ -50,7 +50,7 @@ B = tvm.compute((n,), lambda i: tvm.sum(A[i, k], axis=k), name="B")
# Before doing anything, let us print out the IR code of default schedule.
#
s = tvm.create_schedule(B.op)
print(tvm.lower(s, [A, B], with_api_wrapper=False))
print(tvm.lower(s, [A, B], simple_mode=True))
######################################################################
# You can find that the IR code is quite like the C code.
......@@ -61,13 +61,13 @@ print(tvm.lower(s, [A, B], with_api_wrapper=False))
#
ko, ki = s[B].split(B.op.reduce_axis[0], factor=16)
xo, xi = s[B].split(B.op.axis[0], factor=32)
print(tvm.lower(s, [A, B], with_api_wrapper=False))
print(tvm.lower(s, [A, B], simple_mode=True))
######################################################################
# If we are building a GPU kernel, we can bind the rows of B to GPU threads.
s[B.op].bind(xo, tvm.thread_axis("blockIdx.x"))
s[B.op].bind(xi, tvm.thread_axis("threadIdx.x"))
print(tvm.lower(s, [A, B], with_api_wrapper=False))
print(tvm.lower(s, [A, B], simple_mode=True))
######################################################################
# Reduction Factoring and Parallelization
......@@ -84,7 +84,7 @@ print(tvm.lower(s, [A, B], with_api_wrapper=False))
s = tvm.create_schedule(B.op)
ko, ki = s[B].split(B.op.reduce_axis[0], factor=16)
BF = s.rfactor(B, ki)
print(tvm.lower(s, [A, B], with_api_wrapper=False))
print(tvm.lower(s, [A, B], simple_mode=True))
######################################################################
# The scheduled operator of B also get rewritten to be sum over
......
......@@ -39,10 +39,10 @@ C = tvm.compute((m, n), lambda i, j: A[i, j] * B[i, j], name='C')
s = tvm.create_schedule([C.op])
# lower will transform the computation from definition to the real
# callable function. With argument `with_api_wrapper=False`, it will
# callable function. With argument `simple_mode=True`, it will
# return you a readable C like statement, we use it here to print the
# schedule result.
print(tvm.lower(s, [A, B, C], with_api_wrapper=False))
print(tvm.lower(s, [A, B, C], simple_mode=True))
######################################################################
# One schedule is composed by multiple stages, and one
......@@ -59,7 +59,7 @@ B = tvm.compute((m,), lambda i: A[i]*2, name='B')
s = tvm.create_schedule(B.op)
xo, xi = s[B].split(B.op.axis[0], factor=32)
print(tvm.lower(s, [A, B], with_api_wrapper=False))
print(tvm.lower(s, [A, B], simple_mode=True))
######################################################################
# You can also split a axis by :code:`nparts`, which splits the axis
......@@ -69,7 +69,7 @@ B = tvm.compute((m,), lambda i: A[i], name='B')
s = tvm.create_schedule(B.op)
bx, tx = s[B].split(B.op.axis[0], nparts=32)
print(tvm.lower(s, [A, B], with_api_wrapper=False))
print(tvm.lower(s, [A, B], simple_mode=True))
######################################################################
# tile
......@@ -81,7 +81,7 @@ B = tvm.compute((m, n), lambda i, j: A[i, j], name='B')
s = tvm.create_schedule(B.op)
xo, yo, xi, yi = s[B].tile(B.op.axis[0], B.op.axis[1], x_factor=10, y_factor=5)
print(tvm.lower(s, [A, B], with_api_wrapper=False))
print(tvm.lower(s, [A, B], simple_mode=True))
######################################################################
# fuse
......@@ -95,7 +95,7 @@ s = tvm.create_schedule(B.op)
xo, yo, xi, yi = s[B].tile(B.op.axis[0], B.op.axis[1], x_factor=10, y_factor=5)
# then fuse (i.inner, j.inner) into one axis: (i.inner.j.inner.fused)
fused = s[B].fuse(yi, xi)
print(tvm.lower(s, [A, B], with_api_wrapper=False))
print(tvm.lower(s, [A, B], simple_mode=True))
######################################################################
# reorder
......@@ -109,7 +109,7 @@ s = tvm.create_schedule(B.op)
xo, yo, xi, yi = s[B].tile(B.op.axis[0], B.op.axis[1], x_factor=10, y_factor=5)
# then reorder the axises: (i.inner, j.outer, i.outer, j.inner)
s[B].reorder(xi, yo, xo, yi)
print(tvm.lower(s, [A, B], with_api_wrapper=False))
print(tvm.lower(s, [A, B], simple_mode=True))
######################################################################
# bind
......@@ -123,7 +123,7 @@ s = tvm.create_schedule(B.op)
bx, tx = s[B].split(B.op.axis[0], factor=64)
s[B].bind(bx, tvm.thread_axis("blockIdx.x"))
s[B].bind(tx, tvm.thread_axis("threadIdx.x"))
print(tvm.lower(s, [A, B], with_api_wrapper=False))
print(tvm.lower(s, [A, B], simple_mode=True))
######################################################################
# compute_at
......@@ -135,7 +135,7 @@ B = tvm.compute((m,), lambda i: A[i]+1, name='B')
C = tvm.compute((m,), lambda i: B[i]*2, name='C')
s = tvm.create_schedule(C.op)
print(tvm.lower(s, [A, B, C], with_api_wrapper=False))
print(tvm.lower(s, [A, B, C], simple_mode=True))
######################################################################
# :code:`compute_at` can move computation of `B` into the first axis
......@@ -146,7 +146,7 @@ C = tvm.compute((m,), lambda i: B[i]*2, name='C')
s = tvm.create_schedule(C.op)
s[B].compute_at(s[C], C.op.axis[0])
print(tvm.lower(s, [A, B, C], with_api_wrapper=False))
print(tvm.lower(s, [A, B, C], simple_mode=True))
######################################################################
# compute_inline
......@@ -160,7 +160,7 @@ C = tvm.compute((m,), lambda i: B[i]*2, name='C')
s = tvm.create_schedule(C.op)
s[B].compute_inline()
print(tvm.lower(s, [A, B, C], with_api_wrapper=False))
print(tvm.lower(s, [A, B, C], simple_mode=True))
######################################################################
# compute_root
......@@ -173,7 +173,7 @@ C = tvm.compute((m,), lambda i: B[i]*2, name='C')
s = tvm.create_schedule(C.op)
s[B].compute_at(s[C], C.op.axis[0])
s[B].compute_root()
print(tvm.lower(s, [A, B, C], with_api_wrapper=False))
print(tvm.lower(s, [A, B, C], simple_mode=True))
######################################################################
# Summary
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment