Commit 1c04f389 by Tianqi Chen Committed by GitHub

[DEV/IR] Python IRBuilder (#102)

parent d17b10f0
Subproject commit d024efd80694556c1239c4435c5b3e70853a4896
Subproject commit 398edacd956c6de82185821ffd9f482598182e51
......@@ -35,9 +35,6 @@ tvm.stmt
tvm.ir_pass
~~~~~~~~~~~
.. automodule:: tvm.ir_pass
:members:
.. autosummary::
tvm.ir_pass.Inline
......@@ -58,6 +55,13 @@ tvm.ir_pass
tvm.ir_pass.LowerThreadAllreduce
tvm.ir_pass.NarrowChannelAccess
.. automodule:: tvm.ir_pass
:members:
tvm.ir_builder
~~~~~~~~~~~~~~
.. automodule:: tvm.ir_builder
:members:
tvm.make
~~~~~~~~
......
......@@ -13,6 +13,7 @@ from . import collections
from . import schedule
from . import module
from . import node
from . import ir_builder
from . import ndarray as nd
from .ndarray import cpu, gpu, opencl, cl, vpi
......
......@@ -13,7 +13,7 @@ from .._base import c_str, py_str, string_types
from ._types import TVMValue, TypeCode, TVMType, TVMByteArray
from ._types import TVMPackedCFunc, TVMCFuncFinalizer
from ._types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func
from ._node import NodeBase, SliceBase, convert_to_node
from ._node import NodeBase, NodeGeneric, convert_to_node
from ._ndarray import NDArrayBase
FunctionHandle = ctypes.c_void_p
......@@ -114,7 +114,7 @@ def _make_tvm_args(args, temp_args):
elif isinstance(arg, string_types):
values[i].v_str = c_str(arg)
type_codes[i] = TypeCode.STR
elif isinstance(arg, (list, tuple, dict, SliceBase)):
elif isinstance(arg, (list, tuple, dict, NodeGeneric)):
arg = convert_to_node(arg)
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.NODE_HANDLE
......
......@@ -41,9 +41,12 @@ C_TO_PY_ARG_SWITCH[TypeCode.NODE_HANDLE] = _wrap_arg_func(
_return_node, TypeCode.NODE_HANDLE)
class SliceBase(object):
"""base class of slice object"""
pass
class NodeGeneric(object):
"""Base class for all classes that can be converted to node."""
def asnode(self):
"""Convert value to node"""
raise NotImplementedError()
class NodeBase(object):
"""NodeBase is the base class of all TVM language AST object."""
......@@ -176,8 +179,8 @@ def convert_to_node(value):
vlist.append(it[0])
vlist.append(convert_to_node(it[1]))
return _api_internal._Map(*vlist)
elif isinstance(value, SliceBase):
return value.tensor(*value.indices)
elif isinstance(value, NodeGeneric):
return value.asnode()
else:
raise ValueError("don't know how to convert type %s to node" % type(value))
......
"""Addon utilities to python"""
"""Addon utilities to TVM python package.
These features are useful to have not not essential to TVM.
"""
"""Information about nnvm."""
"""Verilog simulator modules."""
from __future__ import absolute_import
import subprocess
......
......@@ -50,6 +50,9 @@ class ExprOp(object):
def __rtruediv__(self, other):
return self.__rdiv__(other)
def __mod__(self, other):
return _make.Mod(self, other)
def __neg__(self):
return self.__mul__(-1)
......@@ -118,10 +121,6 @@ class Cast(Expr):
pass
@register_node
class Variable(Expr):
pass
@register_node
class Add(BinaryOpExpr):
pass
......
......@@ -57,7 +57,6 @@ def call_pure_extern(dtype, func_name, *args):
return _make.Call(
dtype, func_name, convert(args), _Call.PureExtern, None, 0)
def exp(x):
"""Take exponetial of input x.
......
"""Developer API of IR node builder make function."""
from __future__ import absolute_import as _abs
from . import api as _api
from . import stmt as _stmt
from . import expr as _expr
from . import make as _make
from . import ir_pass as _pass
from . import collections as _collections
from ._base import string_types
from ._ctypes._node import NodeGeneric
class WithScope(object):
"""Auxiliary scope with"""
def __init__(self, enter_value, exit_cb):
self._enter_value = enter_value
self._exit_cb = exit_cb
def __enter__(self):
return self._enter_value
def __exit__(self, ptype, value, trace):
self._exit_cb()
class BufferVar(NodeGeneric):
"""Buffer variable with content type, makes load store easily.
Do not create it directly, create use IRBuilder.
Examples
--------
In the follow example, x is BufferVar.
:code:`x[0] = ...` directly emit a store to the IRBuilder,
:code:`x[10]` translates to Load.
.. code-block:: python
# The following code generate IR for x[0] = x[
ib = tvm.ir_builder.create()
x = ib.pointer("float32")
x[0] = x[10] + 1
See Also
--------
IRBuilder.pointer
IRBuilder.buffer_ptr
IRBuilder.allocate
"""
def __init__(self, builder, buffer_var, content_type):
self._builder = builder
self._buffer_var = buffer_var
self._content_type = content_type
def asnode(self):
return self._buffer_var
def __getitem__(self, index):
return _make.Load(self._content_type, self._buffer_var, index)
def __setitem__(self, index, value):
value = _api.convert(value)
if value.dtype != self._content_type:
raise ValueError(
"data type does not match content type %s vs %s" % (
value.dtype, self._content_type))
self._builder.emit(_make.Store(self._buffer_var, value, index))
class IRBuilder(object):
"""Auxiliary builder to build IR for testing and dev.
Examples
--------
.. code-block:: python
ib = tvm.ir_builder.create()
n = tvm.var("n")
A = ib.allocate("float32", n, name="A")
with ib.for_range(0, n, name="i") as i:
with ib.if_scope((i % 2) == 0):
A[i] = A[i] + 1
# The result stmt.
stmt = ib.get()
"""
def __init__(self):
self._seq_stack = [[]]
def _pop_seq(self):
"""Pop sequence from stack"""
seq = self._seq_stack.pop()
if len(seq) == 0 or callable(seq[-1]):
seq.append(_make.Evaluate(0))
stmt = seq[-1]
for s in reversed(seq[:-1]):
if callable(s):
stmt = s(stmt)
else:
assert isinstance(s, _stmt.Stmt)
stmt = _make.Block(s, stmt)
return stmt
def emit(self, stmt):
"""Emit a statement to the end of current scope.
Parameters
----------
stmt : Stmt or callable.
The statement to be emitted or callable that build stmt given body.
"""
if isinstance(stmt, _expr.Call):
stmt = _make.Evaluate(stmt)
assert isinstance(stmt, _stmt.Stmt) or callable(stmt)
self._seq_stack[-1].append(stmt)
def scope_attr(self, node, attr_key, value):
"""Create an AttrStmt at current scope.
Parameters
----------
attr_key : str
The key of the attribute type.
node : Node
The attribute node to annottate on.
value : Expr
Attribute value.
Examples
--------
.. code-block:: python
ib = tvm.ir_builder.create()
i = tvm.var("i")
x = ib.pointer("float32")
ib.scope_attr(x, "storage_scope", "global")
x[i] = x[i - 1] + 1
"""
if isinstance(node, string_types):
node = _make.StringImm(node)
if isinstance(value, string_types):
value = _make.StringImm(value)
self.emit(lambda x: _make.AttrStmt(node, attr_key, value, x))
def for_range(self, begin, end, name="i", dtype="int32"):
"""Create a for iteration scope.
Parameters
----------
begin : Expr
The min iteration scope.
end : Expr
The end iteration scope
name : str, optional
The name of iteration variable
dtype : str, optional
The data type of iteration variable.
Returns
-------
loop_scope : With.Scope of Var
The for scope, when enters returns loop_var
Examples
--------
.. code-block:: python
ib = tvm.ir_builder.create()
x = ib.pointer("float32")
with ib.for_range(1, 10, name="i") as i:
x[i] = x[i - 1] + 1
"""
self._seq_stack.append([])
loop_var = _api.var(name, dtype=dtype)
extent = end if begin == 0 else _pass.Simplify(end - begin)
def _exit_cb():
self.emit(_make.For(
loop_var, begin, extent, 0, 0, self._pop_seq()))
return WithScope(loop_var, _exit_cb)
def if_scope(self, cond):
"""Create an if scope.
Parameters
----------
cond : Expr
The condition.
Returns
-------
if_scope : WithScope
The result if scope.
Examples
--------
.. code-block:: python
ib = tvm.ir_builder.create()
i = tvm.var("i")
x = ib.pointer("float32")
with ib.if_scope((i % 2) == 0):
x[i] = x[i - 1] + 1
"""
self._seq_stack.append([])
def _exit_cb():
self.emit(_make.IfThenElse(cond, self._pop_seq(), None))
return WithScope(None, _exit_cb)
def else_scope(self):
"""Create an else scope.
This can only be used right after an if scope.
Returns
-------
else_scope : WithScope
The result else scope.
Examples
--------
.. code-block:: python
ib = tvm.ir_builder.create()
i = tvm.var("i")
x = ib.pointer("float32")
with ib.if_scope((i % 2) == 0):
x[i] = x[i - 1] + 1
with ib.else_scope():
x[i] = x[i - 1] + 2
"""
if len(self._seq_stack[-1]) == 0:
raise RuntimeError("else_scope can only follow an if_scope")
prev = self._seq_stack[-1][-1]
if not isinstance(prev, _stmt.IfThenElse) or prev.else_case:
raise RuntimeError("else_scope can only follow an if_scope")
self._seq_stack[-1].pop()
self._seq_stack.append([])
def _exit_cb():
self.emit(_make.IfThenElse(prev.condition, prev.then_case, self._pop_seq()))
return WithScope(None, _exit_cb)
def allocate(self, dtype, shape, name="buf", scope=None):
"""Create a allocate statement.
Parameters
----------
dtype : str
The content data type.
shape : tuple of Expr
The shape of array to be allocated.
name : str, optional
The name of the buffer.
scope : str, optional
The scope of the buffer.
Returns
-------
buffer : BufferVar
The buffer var representing the buffer.
"""
buffer_var = _api.var(name, dtype="handle")
if not isinstance(shape, (list, tuple, _collections.Array)):
shape = [shape]
if scope:
self.scope_attr(buffer_var, "storage_scope", scope)
self.emit(lambda x: _make.Allocate(
buffer_var, dtype, shape, _api.const(1, dtype="uint1"), x))
return BufferVar(self, buffer_var, dtype)
def pointer(self, content_type, name="ptr"):
"""Create pointer variable with content type.
Parameters
----------
content_type : str
The content data type.
name : str, optional
The name of the pointer.
Returns
-------
ptr : BufferVar
The buffer var representing the buffer.
"""
buffer_var = _api.var(name, dtype="handle")
return BufferVar(self, buffer_var, content_type)
def buffer_ptr(self, buf):
"""Create pointer variable corresponds to buffer ptr.
Parameters
----------
buf : Buffer
The buffer to be extracted.
Returns
-------
ptr : BufferVar
The buffer var representing the buffer.
"""
return BufferVar(self, buf.data, buf.dtype)
def get(self):
"""Return the builded IR.
Returns
-------
stmt : Stmt
The result statement.
"""
seq = self._pop_seq()
if len(self._seq_stack) != 0:
raise RuntimeError("cannot call get inside construction scope")
return seq
def create():
"""Create a new IRBuilder
Returns
-------
builder : IRBuilder
The created IRBuilder
"""
return IRBuilder()
......@@ -51,6 +51,10 @@ class Allocate(Stmt):
pass
@register_node
class AttrStmt(Stmt):
pass
@register_node
class Free(Stmt):
pass
......
"""Tensor and Operation class for computation declaration."""
# pylint: disable=invalid-name
from __future__ import absolute_import as _abs
from ._ctypes._node import NodeBase, SliceBase, register_node, convert_to_node
from ._ctypes._node import NodeBase, NodeGeneric, register_node, convert_to_node
from . import _api_internal
from . import make as _make
from . import expr as _expr
class TensorSlice(SliceBase, _expr.ExprOp):
class TensorSlice(NodeGeneric, _expr.ExprOp):
"""Auxiliary data structure for enable slicing syntax from tensor."""
def __init__(self, tensor, indices):
if not isinstance(indices, tuple):
......@@ -19,6 +19,10 @@ class TensorSlice(SliceBase, _expr.ExprOp):
indices = (indices,)
return TensorSlice(self.tensor, self.indices + indices)
def asnode(self):
"""Convert slice to node."""
return self.tensor(*self.indices)
@property
def dtype(self):
"""Data content of the tensor."""
......
......@@ -28,20 +28,19 @@ def test_stack_vm_basic():
def tvm_stack_vm_print(*x):
print(x)
def test_stack_vm_loop():
dtype = 'int64'
n = tvm.var('n')
Ab = tvm.decl_buffer((n, ), dtype)
i = tvm.var('i')
# for i in 0 to n-1:
stmt = tvm.make.For(
i, 0, n - 1, 0, 0,
tvm.make.Block(
tvm.make.Store(Ab.data,
tvm.make.Load(dtype, Ab.data, i) + 1,
i + 1),
tvm.make.Evaluate(tvm.call_packed("tvm_stack_vm_print", i))))
ib = tvm.ir_builder.create()
A = ib.buffer_ptr(Ab)
with ib.for_range(0, n - 1, "i") as i:
A[i + 1] = A[i] + 1
ib.emit(tvm.call_packed("tvm_stack_vm_print", i))
stmt = ib.get()
fapi = tvm.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0)
a = tvm.nd.array(np.zeros(10, dtype=dtype))
def check(f):
......@@ -54,16 +53,16 @@ def test_stack_vm_cond():
dtype = 'int64'
n = tvm.var('n')
Ab = tvm.decl_buffer((n, ), dtype)
i = tvm.var('i')
# for i in 0 to n-1:
stmt = tvm.make.For(
i, 0, n - 1, 0, 0,
tvm.make.IfThenElse(
tvm.make.EQ(i, 4),
tvm.make.Store(Ab.data,
tvm.make.Load(dtype, Ab.data, i) + 1, i + 1),
tvm.make.Store(Ab.data,
tvm.make.Load(dtype, Ab.data, i) + 2, i + 1)))
ib = tvm.ir_builder.create()
A = ib.buffer_ptr(Ab)
with ib.for_range(0, n - 1, "i") as i:
with ib.if_scope(tvm.make.EQ(i, 4)):
A[i + 1] = A[i] + 1
with ib.else_scope():
A[i + 1] = A[i] + 2
stmt = ib.get()
fapi = tvm.ir_pass.MakeAPI(stmt, "test", [Ab], 0)
def check(f):
a = tvm.nd.array(np.zeros(10, dtype=dtype))
......
import tvm
def test_for():
ib = tvm.ir_builder.create()
n = tvm.var("n")
A = ib.allocate("float32", n, name="A", scope="global")
with ib.for_range(0, n, name="i") as i:
A[i] = A[i] + 1
with ib.for_range(0, 10, name="j") as j:
A[j] = A[j] + 2
body = ib.get()
print(body)
assert isinstance(body, tvm.stmt.AttrStmt)
body = body.body
assert isinstance(body, tvm.stmt.Allocate)
body = body.body
assert isinstance(body, tvm.stmt.For)
body = body.body
assert isinstance(body, tvm.stmt.Block)
assert isinstance(body.rest, tvm.stmt.For)
def test_if():
ib = tvm.ir_builder.create()
n = tvm.var("n")
A = ib.pointer("float32", name="A")
with ib.for_range(0, n, name="i") as i:
with ib.if_scope((i % 2) == 0):
A[i] = A[i] + 1
with ib.else_scope():
A[0] = A[i] + 2
body = ib.get()
assert isinstance(body, tvm.stmt.For)
body = body.body
assert isinstance(body, tvm.stmt.IfThenElse)
assert isinstance(body.then_case.index, tvm.expr.Var)
assert body.else_case.index.value == 0
if __name__ == "__main__":
test_if()
test_for()
import tvm
def collect_visit(stmt, f):
ret = []
tvm.ir_pass.PostOrderVisit(stmt, lambda x : ret.append(f(x)))
return ret
def test_basic():
n = tvm.var('n')
A = tvm.placeholder((n, ), name='A')
......@@ -16,22 +21,20 @@ def test_basic():
print(stmt)
def test_multi_loop():
i = tvm.var('i')
j = tvm.var('j')
k = tvm.var('k')
ib = tvm.ir_builder.create()
m = tvm.var('m')
n = tvm.var('n')
stmt = tvm.make.For(
i, 0, 4, 0, 0,
tvm.make.For(
j, 0, n, 0, 0,
tvm.make.For(
k, 0, m, 0, 0,
tvm.make.IfThenElse(
(i*m+j+k < n), tvm.make.Evaluate(m), tvm.make.Evaluate(n)))))
with ib.for_range(0, 4, "i") as i:
with ib.for_range(0, n, "j") as j:
with ib.for_range(0, m, "k") as k:
with ib.if_scope(i*m+j+k < n):
ib.emit(tvm.make.Evaluate(m))
with ib.else_scope():
ib.emit(tvm.make.Evaluate(n))
stmt = ib.get()
stmt = tvm.ir_pass.LoopPartition(stmt)
assert('if' not in str(stmt.body.first))
print(stmt)
assert(not any(collect_visit(stmt.body.first,
lambda x: isinstance(x, tvm.stmt.IfThenElse))))
def test_multi_if():
i = tvm.var('i')
......@@ -74,7 +77,7 @@ def test_thread_axis():
print(stmt_)
if __name__ == "__main__":
test_basic()
test_multi_loop()
test_basic()
test_multi_if()
test_thread_axis()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment