stmt.py 9.19 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17 18 19 20 21 22 23 24 25 26 27
"""Statement AST Node in TVM.

User do not need to deal with AST node directly.
But they can be helpful for developer to do quick proptyping.
While not displayed in the document and python file.
Each statement node have subfields that can be visited from python side.

.. code-block:: python

    x = tvm.var("n")
    a = tvm.var("array", tvm.handle)
28 29
    st = tvm.tir.stmt.Store(a, x + 1, 1)
    assert isinstance(st, tvm.tir.stmt.Store)
30 31
    assert(st.buffer_var == a)
"""
32
import tvm._ffi
33 34

from tvm.runtime import Object
35
from . import _ffi_api
36

tqchen committed
37

38
class Stmt(Object):
39 40
    """Base class of all the statements."""

tqchen committed
41

42
@tvm._ffi.register_object
tqchen committed
43
class LetStmt(Stmt):
44 45 46 47 48 49 50
    """LetStmt node.

    Parameters
    ----------
    var : Var
        The variable in the binding.

51
    value : PrimExpr
52 53 54 55 56 57 58
        The value in to be binded.

    body : Stmt
        The body statement.
    """
    def __init__(self, var, value, body):
        self.__init_handle_by_constructor__(
59
            _ffi_api.LetStmt, var, value, body)
60

tqchen committed
61

62
@tvm._ffi.register_object
tqchen committed
63
class AssertStmt(Stmt):
64 65 66 67
    """AssertStmt node.

    Parameters
    ----------
68
    condition : PrimExpr
69 70
        The assert condition.

71
    message : PrimExpr
72 73 74 75 76 77 78
        The error message.

    body : Stmt
        The body statement.
    """
    def __init__(self, condition, message, body):
        self.__init_handle_by_constructor__(
79
            _ffi_api.AssertStmt, condition, message, body)
80

tqchen committed
81

82
@tvm._ffi.register_object
tqchen committed
83
class ProducerConsumer(Stmt):
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
    """ProducerConsumer node.

    Parameters
    ----------
    func : Operation
        The Operation.

    is_producer : bool
        Whether if the node is producer.

    body : Stmt
        The body statement.
    """
    def __init__(self, func, is_producer, body):
        self.__init_handle_by_constructor__(
99
            _ffi_api.ProducerConsumer, func, is_producer, body)
100

tqchen committed
101

102
@tvm._ffi.register_object
tqchen committed
103
class For(Stmt):
104 105 106 107 108 109 110
    """For node.

    Parameters
    ----------
    loop_var : Var
        The loop variable.

111
    min_val : PrimExpr
112 113
        The begining value.

114
    extent : PrimExpr
115 116 117 118 119 120 121 122 123 124 125
        The length of the loop.

    for_type : int
        The for type.

    device_api : int
        The device api type.

    body : Stmt
        The body statement.
    """
tqchen committed
126 127 128 129
    Serial = 0
    Parallel = 1
    Vectorized = 2
    Unrolled = 3
130 131 132 133 134 135 136 137
    def __init__(self,
                 loop_var,
                 min_val,
                 extent,
                 for_type,
                 device_api,
                 body):
        self.__init_handle_by_constructor__(
138
            _ffi_api.For, loop_var, min_val, extent,
139 140
            for_type, device_api, body)

tqchen committed
141

142
@tvm._ffi.register_object
tqchen committed
143
class Store(Stmt):
144 145 146 147 148 149 150
    """Store node.

    Parameters
    ----------
    buffer_var : Var
        The buffer Variable.

151
    value : PrimExpr
152 153
        The value we want to store.

154
    index : PrimExpr
155 156
        The index in the store expression.

157
    predicate : PrimExpr
158 159
        The store predicate.
    """
160 161
    def __init__(self, buffer_var, value, index, predicate=None):
        args = [] if predicate is None else [predicate]
162
        self.__init_handle_by_constructor__(
163
            _ffi_api.Store, buffer_var, value, index, *args)
164

tqchen committed
165

166
@tvm._ffi.register_object
tqchen committed
167
class Provide(Stmt):
168 169 170 171 172 173 174 175 176 177
    """Provide node.

    Parameters
    ----------
    func : Operation
        The operation to create the function.

    value_index : int
        The output value index

178
    value : PrimExpr
179 180 181 182 183 184 185
        The value to be stored.

    args : list of Expr
        The index arguments of the Provide.
    """
    def __init__(self, func, value_index, value, args):
        self.__init_handle_by_constructor__(
186
            _ffi_api.Provide, func, value_index, value, args)
187

tqchen committed
188

189
@tvm._ffi.register_object
tqchen committed
190
class Allocate(Stmt):
191 192 193 194 195 196 197 198 199 200 201 202 203
    """Allocate node.

    Parameters
    ----------
    buffer_var : Var
        The buffer variable.

    dtype : str
        The data type of the buffer.

    extents : list of Expr
        The extents of the allocate

204
    condition : PrimExpr
205 206 207 208 209 210 211 212 213 214 215 216
        The condition.

    body : Stmt
        The body statement.
    """
    def __init__(self,
                 buffer_var,
                 dtype,
                 extents,
                 condition,
                 body):
        self.__init_handle_by_constructor__(
217
            _ffi_api.Allocate, buffer_var, dtype,
218 219
            extents, condition, body)

tqchen committed
220

221
@tvm._ffi.register_object
222
class AttrStmt(Stmt):
223 224 225 226 227 228 229 230 231 232
    """AttrStmt node.

    Parameters
    ----------
    node : Node
        The node to annotate the attribute

    attr_key : str
        Attribute type key.

233
    value : PrimExpr
234 235 236 237 238 239 240
        The value of the attribute

    body : Stmt
        The body statement.
    """
    def __init__(self, node, attr_key, value, body):
        self.__init_handle_by_constructor__(
241
            _ffi_api.AttrStmt, node, attr_key, value, body)
242

243

244
@tvm._ffi.register_object
tqchen committed
245
class Free(Stmt):
246 247 248 249 250 251 252 253 254
    """Free node.

    Parameters
    ----------
    buffer_var : Var
        The buffer variable.
    """
    def __init__(self, buffer_var):
        self.__init_handle_by_constructor__(
255
            _ffi_api.Free, buffer_var)
256

tqchen committed
257

258
@tvm._ffi.register_object
tqchen committed
259
class Realize(Stmt):
260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275
    """Realize node.

    Parameters
    ----------
    func : Operation
        The operation to create the function.

    value_index : int
        The output value index

    dtype : str
        The data type of the operation.

    bounds : list of range
        The bound of realize

276
    condition : PrimExpr
277 278 279 280 281 282 283 284 285 286 287 288 289
        The realize condition.

    body : Stmt
        The realize body
    """
    def __init__(self,
                 func,
                 value_index,
                 dtype,
                 bounds,
                 condition,
                 body):
        self.__init_handle_by_constructor__(
290
            _ffi_api.Realize, func, value_index, dtype,
291 292
            bounds, condition, body)

tqchen committed
293

294
@tvm._ffi.register_object
295 296
class SeqStmt(Stmt):
    """Sequence of statements.
297 298 299

    Parameters
    ----------
300 301
    seq : List[Stmt]
        The statements
302
    """
303
    def __init__(self, seq):
304
        self.__init_handle_by_constructor__(
305
            _ffi_api.SeqStmt, seq)
306 307 308 309 310 311

    def __getitem__(self, i):
        return self.seq[i]

    def __len__(self):
        return len(self.seq)
312

tqchen committed
313

314
@tvm._ffi.register_object
tqchen committed
315
class IfThenElse(Stmt):
316 317 318 319
    """IfThenElse node.

    Parameters
    ----------
320
    condition : PrimExpr
321 322 323 324 325 326 327 328 329 330
        The expression

    then_case : Stmt
        The statement to execute if condition is true.

    else_case : Stmt
        The statement to execute if condition is false.
    """
    def __init__(self, condition, then_case, else_case):
        self.__init_handle_by_constructor__(
331
            _ffi_api.IfThenElse, condition, then_case, else_case)
332

tqchen committed
333

334
@tvm._ffi.register_object
tqchen committed
335
class Evaluate(Stmt):
336 337 338 339
    """Evaluate node.

    Parameters
    ----------
340
    value : PrimExpr
341 342 343 344
        The expression to be evalued.
    """
    def __init__(self, value):
        self.__init_handle_by_constructor__(
345
            _ffi_api.Evaluate, value)
346

347

348
@tvm._ffi.register_object
349
class Prefetch(Stmt):
350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367
    """Prefetch node.

    Parameters
    ----------
    func : Operation
        The operation to create the function.

    value_index : int
        The output value index

    dtype : str
        The data type to be prefetched.

    bounds : list of Range
        The bounds to be prefetched.
    """
    def __init__(self, func, value_index, dtype, bounds):
        self.__init_handle_by_constructor__(
368
            _ffi_api.Prefetch, func, value_index, dtype, bounds)
369 370


371 372 373 374 375 376 377 378
@tvm._ffi.register_object
class LoweredFunc(Object):
    """Represent a LoweredFunc in TVM."""
    MixedFunc = 0
    HostFunc = 1
    DeviceFunc = 2


379 380 381 382 383 384 385 386 387 388 389 390 391
def stmt_seq(*args):
    """Make sequence of statements

    Parameters
    ----------
    args : list of Expr or Var
        List of statements to be combined as sequence.

    Returns
    -------
    stmt : Stmt
        The combined statement.
    """
392
    ret = []
393 394 395
    for value in args:
        if not isinstance(value, Stmt):
            value = Evaluate(value)
396 397 398 399
        ret.append(value)
    if len(ret) == 1:
        return ret[0]
    return SeqStmt(ret)
400 401 402 403 404 405 406 407 408 409 410 411 412 413


def stmt_list(stmt):
    """Make list of stmt from blocks.

    Parameters
    ----------
    stmt : A block statement

    Returns
    -------
    stmt_list : list of Stmt
         The unpacked list of statements
    """
414 415 416 417 418
    if isinstance(stmt, SeqStmt):
        res = []
        for x in stmt:
            res += stmt_list(x)
        return res
419
    if isinstance(stmt, ProducerConsumer):
420 421
        return stmt_list(stmt.body)
    return [stmt]