ir_builder.py 11.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17
"""Developer API of IR node builder make function."""
18
from tvm._ffi.base import string_types
19
from tvm.runtime import ObjectGeneric, DataType, convert, const
20
from tvm.ir import container as _container
21 22 23 24

from . import stmt as _stmt
from . import expr as _expr
from . import ir_pass as _pass
25

26 27 28 29 30 31 32 33 34 35 36 37 38 39

class WithScope(object):
    """Auxiliary scope  with"""
    def __init__(self, enter_value, exit_cb):
        self._enter_value = enter_value
        self._exit_cb = exit_cb

    def __enter__(self):
        return self._enter_value

    def __exit__(self, ptype, value, trace):
        self._exit_cb()


40
class BufferVar(ObjectGeneric):
41 42 43 44 45 46 47 48 49 50 51 52 53
    """Buffer variable with content type, makes load store easily.

    Do not create it directly, create use IRBuilder.

    Examples
    --------
    In the follow example, x is BufferVar.
    :code:`x[0] = ...` directly emit a store to the IRBuilder,
    :code:`x[10]` translates to Load.

    .. code-block:: python

        # The following code generate IR for x[0] = x[
54
        ib = tvm.tir.ir_builder.create()
55 56 57 58 59 60 61 62 63 64 65 66 67 68
        x = ib.pointer("float32")
        x[0] = x[10] + 1

    See Also
    --------
    IRBuilder.pointer
    IRBuilder.buffer_ptr
    IRBuilder.allocate
    """
    def __init__(self, builder, buffer_var, content_type):
        self._builder = builder
        self._buffer_var = buffer_var
        self._content_type = content_type

69
    def asobject(self):
70 71
        return self._buffer_var

72 73 74 75
    @property
    def dtype(self):
        return self._content_type

76
    def __getitem__(self, index):
77
        t = DataType(self._content_type)
78
        if t.lanes > 1:
79 80
            index = _expr.Ramp(index * t.lanes, 1, t.lanes)
        return _expr.Load(self._content_type, self._buffer_var, index)
81 82

    def __setitem__(self, index, value):
83
        value = convert(value)
84 85 86 87
        if value.dtype != self._content_type:
            raise ValueError(
                "data type does not match content type %s vs %s" % (
                    value.dtype, self._content_type))
88
        t = DataType(self._content_type)
89
        if t.lanes > 1:
90 91
            index = _expr.Ramp(index * t.lanes, 1, t.lanes)
        self._builder.emit(_stmt.Store(self._buffer_var, value, index))
92 93 94 95 96 97 98 99 100


class IRBuilder(object):
    """Auxiliary builder to build IR for testing and dev.

    Examples
    --------
    .. code-block:: python

101 102
        ib = tvm.tir.ir_builder.create()
        n = te.var("n")
103 104
        A = ib.allocate("float32", n, name="A")
        with ib.for_range(0, n, name="i") as i:
105
            with ib.if_scope((i % 2) == 0):
106 107 108 109 110 111
                A[i] = A[i] + 1
        # The result stmt.
        stmt = ib.get()
    """
    def __init__(self):
        self._seq_stack = [[]]
112
        self.nidx = 0
113 114 115 116

    def _pop_seq(self):
        """Pop sequence from stack"""
        seq = self._seq_stack.pop()
117
        if not seq or callable(seq[-1]):
118
            seq.append(_stmt.Evaluate(0))
119 120 121
        seqwrap = lambda x: x[0] if len(x) == 1 else _stmt.SeqStmt(list(reversed(x)))
        ret_seq = [seq[-1]]

122 123
        for s in reversed(seq[:-1]):
            if callable(s):
124
                ret_seq = [s(seqwrap(ret_seq))]
125 126
            else:
                assert isinstance(s, _stmt.Stmt)
127 128
                ret_seq.append(s)
        return seqwrap(ret_seq)
129 130 131 132 133 134 135 136 137 138

    def emit(self, stmt):
        """Emit a statement to the end of current scope.

        Parameters
        ----------
        stmt : Stmt or callable.
           The statement to be emitted or callable that build stmt given body.
        """
        if isinstance(stmt, _expr.Call):
139
            stmt = _stmt.Evaluate(stmt)
140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160
        assert isinstance(stmt, _stmt.Stmt) or callable(stmt)
        self._seq_stack[-1].append(stmt)

    def scope_attr(self, node, attr_key, value):
        """Create an AttrStmt at current scope.

        Parameters
        ----------
        attr_key : str
            The key of the attribute type.

        node : Node
            The attribute node to annottate on.

        value : Expr
            Attribute value.

        Examples
        --------
        .. code-block:: python

161 162
            ib = tvm.tir.ir_builder.create()
            i = te.var("i")
163 164 165 166 167
            x = ib.pointer("float32")
            ib.scope_attr(x, "storage_scope", "global")
            x[i] = x[i - 1] + 1
        """
        if isinstance(node, string_types):
168
            node = _expr.StringImm(node)
169
        if isinstance(value, string_types):
170 171
            value = _expr.StringImm(value)
        self.emit(lambda x: _stmt.AttrStmt(node, attr_key, value, x))
172

173
    def for_range(self, begin, end, name="i", dtype="int32", for_type="serial"):
174 175 176 177 178 179 180 181 182 183 184
        """Create a for iteration scope.

        Parameters
        ----------
        begin : Expr
            The min iteration scope.

        end : Expr
            The end iteration scope

        name : str, optional
185 186
            The name of iteration variable, if no input names,
            using typical index names i, j, k, then i_nidx
187 188 189 190

        dtype : str, optional
            The data type of iteration variable.

191 192 193
        for_type : str, optional
            The special tag on the for loop.

194 195 196 197 198 199 200 201 202
        Returns
        -------
        loop_scope : With.Scope of Var
            The for scope, when enters returns loop_var

        Examples
        --------
        .. code-block:: python

203
            ib = tvm.tir.ir_builder.create()
204 205 206 207
            x = ib.pointer("float32")
            with ib.for_range(1, 10, name="i") as i:
                x[i] = x[i - 1] + 1
        """
208 209 210
        if name == 'i':
            name = chr(ord(name) + self.nidx) if self.nidx < 3 else name + "_" + str(self.nidx - 3)
            self.nidx += 1
211
        self._seq_stack.append([])
212
        loop_var = _expr.Var(name, dtype=dtype)
213 214
        extent = end if begin == 0 else _pass.Simplify(end - begin)
        def _exit_cb():
215 216 217 218 219 220 221 222 223 224
            if for_type == "serial":
                for_type_id = 0
            elif for_type == "parallel":
                for_type_id = 1
            elif for_type == "vectorize":
                for_type_id = 2
            elif for_type == "unroll":
                for_type_id = 3
            else:
                raise ValueError("Unknown for_type")
225
            self.emit(_stmt.For(
226
                loop_var, begin, extent, for_type_id, 0, self._pop_seq()))
227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245
        return WithScope(loop_var, _exit_cb)

    def if_scope(self, cond):
        """Create an if scope.

        Parameters
        ----------
        cond : Expr
            The condition.

        Returns
        -------
        if_scope : WithScope
           The result if scope.

        Examples
        --------
        .. code-block:: python

246 247
            ib = tvm.tir.ir_builder.create()
            i = te.var("i")
248 249 250 251 252 253
            x = ib.pointer("float32")
            with ib.if_scope((i % 2) == 0):
                x[i] = x[i - 1] + 1
        """
        self._seq_stack.append([])
        def _exit_cb():
254
            self.emit(_stmt.IfThenElse(cond, self._pop_seq(), None))
255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270
        return WithScope(None, _exit_cb)

    def else_scope(self):
        """Create an else scope.

        This can only be used right after an if scope.

        Returns
        -------
        else_scope : WithScope
           The result else scope.

        Examples
        --------
        .. code-block:: python

271 272
            ib = tvm.tir.ir_builder.create()
            i = te.var("i")
273 274 275 276 277 278
            x = ib.pointer("float32")
            with ib.if_scope((i % 2) == 0):
                x[i] = x[i - 1] + 1
            with ib.else_scope():
                x[i] = x[i - 1] + 2
        """
279
        if not self._seq_stack[-1]:
280 281 282 283 284 285 286
            raise RuntimeError("else_scope can only follow an if_scope")
        prev = self._seq_stack[-1][-1]
        if not isinstance(prev, _stmt.IfThenElse) or prev.else_case:
            raise RuntimeError("else_scope can only follow an if_scope")
        self._seq_stack[-1].pop()
        self._seq_stack.append([])
        def _exit_cb():
287
            self.emit(_stmt.IfThenElse(prev.condition, prev.then_case, self._pop_seq()))
288 289
        return WithScope(None, _exit_cb)

290 291 292 293 294 295 296 297 298 299 300 301 302 303 304
    def new_scope(self):
        """Create new scope,

        this is useful to set boundary of attr and allocate.

        Returns
        -------
        new_scope : WithScope
           The result new scope.
        """
        self._seq_stack.append([])
        def _exit_cb():
            self.emit(self._pop_seq())
        return WithScope(None, _exit_cb)

305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326
    def allocate(self, dtype, shape, name="buf", scope=None):
        """Create a allocate statement.

        Parameters
        ----------
        dtype : str
            The content data type.

        shape : tuple of Expr
            The shape of array to be allocated.

        name : str, optional
            The name of the buffer.

        scope : str, optional
            The scope of the buffer.

        Returns
        -------
        buffer : BufferVar
            The buffer var representing the buffer.
        """
327
        buffer_var = _expr.Var(name, dtype="handle")
328
        if not isinstance(shape, (list, tuple, _container.Array)):
329 330 331
            shape = [shape]
        if scope:
            self.scope_attr(buffer_var, "storage_scope", scope)
332 333
        self.emit(lambda x: _stmt.Allocate(
            buffer_var, dtype, shape, const(1, dtype="uint1"), x))
334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351
        return BufferVar(self, buffer_var, dtype)

    def pointer(self, content_type, name="ptr"):
        """Create pointer variable with content type.

        Parameters
        ----------
        content_type : str
            The content data type.

        name : str, optional
            The name of the pointer.

        Returns
        -------
        ptr : BufferVar
            The buffer var representing the buffer.
        """
352
        buffer_var = _expr.Var(name, dtype="handle")
353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369
        return BufferVar(self, buffer_var, content_type)

    def buffer_ptr(self, buf):
        """Create pointer variable corresponds to buffer ptr.

        Parameters
        ----------
        buf : Buffer
            The buffer to be extracted.

        Returns
        -------
        ptr : BufferVar
            The buffer var representing the buffer.
        """
        return BufferVar(self, buf.data, buf.dtype)

370 371 372 373 374 375 376 377 378 379 380
    def likely(self, expr):
        """Add likely tag for expression.
        Parameters
        ----------
        expr : Expr
            The expression. Usually a condition expression.
        Returns
        -------
        expr : Expr
            The expression will likely tag.
        """
381 382
        return _expr.Call(expr.dtype, "likely", [expr],
                          _expr.Call.PureIntrinsic, None, 0)
383

384 385 386 387 388 389 390 391 392
    def get(self):
        """Return the builded IR.

        Returns
        -------
        stmt : Stmt
           The result statement.
        """
        seq = self._pop_seq()
393
        if self._seq_stack:
394 395 396 397 398 399 400 401 402 403 404 405 406
            raise RuntimeError("cannot call get inside construction scope")
        return seq


def create():
    """Create a new IRBuilder

    Returns
    -------
    builder : IRBuilder
        The created IRBuilder
    """
    return IRBuilder()