ir_pass.h 12.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

tqchen committed
20
/*!
21
 * \file tvm/tir/ir_pass.h
tqchen committed
22 23
 * \brief Collection of IR pass functions
 *
24 25
 *  When the pass functions in this file are for Stmt,
 *  we can use PassFunction(Evaluate(expr)) to apply it to Expr
tqchen committed
26
 */
27 28
#ifndef TVM_TIR_IR_PASS_H_
#define TVM_TIR_IR_PASS_H_
tqchen committed
29

30
#include <tvm/te/schedule.h>
31 32
#include <tvm/tir/expr.h>
#include <tvm/tir/buffer.h>
33
#include <tvm/tir/function.h>
34

tqchen committed
35
#include <unordered_map>
36
#include <unordered_set>
tqchen committed
37
#include <vector>
38
#include <string>
39

tqchen committed
40 41

namespace tvm {
42
namespace tir {
tqchen committed
43

44 45 46 47 48 49
/*!
 * \brief Simplify the expression.
 * \param expr The expression to be simplifed.
 * \param vrange The range information about the variable.
 * \return Canonicalized statement.
 */
50
TVM_DLL PrimExpr Simplify(PrimExpr expr, Map<Var, Range> vrange = Map<Var, Range>());
51

52 53 54 55 56 57 58
/*!
 * \brief Simplify the statement.
 * \param stmt The statement to be simplifed.
 * \param vrange The range information about the variable.
 * \return Canonicalized statement.
 */
Stmt Simplify(Stmt stmt, Map<Var, Range> vrange = Map<Var, Range>());
59 60

/*!
61 62
 * \brief Simplify by applying canonical form.
 * \param stmt The statement to be canonically simplifed.
63
 * \param vrange The range information about the variable.
64 65
 * \return Canonicalized statement.
 */
66 67
Stmt CanonicalSimplify(Stmt stmt,
                       Map<Var, Range> vrange = Map<Var, Range>());
68 69 70 71

/*!
 * \brief Simplify by applying canonical form.
 * \param expr The statement to be canonically simplifed.
72
 * \param vrange The range information about the variable.
73 74
 * \return Canonicalized expression.
 */
75
TVM_DLL PrimExpr CanonicalSimplify(PrimExpr expr,
76
                                   Map<Var, Range> vrange = Map<Var, Range>());
77 78

/*!
tqchen committed
79 80 81 82 83 84
 * \brief verifies whether the IR stmt or Expr is in SSA form.
 *  That is: each VarExpr is defined and assigned once(in Let/For)
 *
 * \param ir The root of the IR DAG.
 * \return Whether IR is in SSA form.
 * \note All the passes in this file uses SSA form and outputs SSA form.
tqchen committed
85
 */
86
TVM_DLL bool VerifySSA(const Stmt& ir);
tqchen committed
87 88

/*!
89 90 91
 * \brief Whether the expression have side effect.
 * \return whether expression have side effect
 */
92
TVM_DLL bool HasSideEffect(const PrimExpr& e);
93 94

/*!
95 96 97 98 99
 * \brief Whether e expression used var.
 * \param e The expression to be checked.
 * \param v The variable.
 * \return Whether e uses v.
 */
100
bool ExprUseVar(const PrimExpr& e, const Var& v);
101 102

/*!
103 104 105 106 107
 * \brief Whether e expression used any var in variable set..
 * \param e The expression to be checked.
 * \param vset The variable set.
 * \return Whether e uses vset.
 */
108
bool ExprUseVar(const PrimExpr& e, const std::unordered_set<const VarNode*>& vset);
109 110

/*!
tqchen committed
111 112 113 114
 * \brief Convert a IR node to be SSA form.
 * \param stmt The source statement to be converted.
 * \return The converted form.
 */
115
TVM_DLL Stmt ConvertSSA(Stmt stmt);
tqchen committed
116 117

/*!
118 119 120 121 122
 * \brief Substitute the var specified in key->var to be value.
 * \param stmt The source statement to be substituted
 * \param value_map The map of new values.
 * \return The converted form.
 */
123
Stmt Substitute(Stmt stmt,
124
                const std::unordered_map<const VarNode*, PrimExpr>& value_map);
125 126 127 128 129 130 131

/*!
 * \brief Substitute the var specified in key->var to be value.
 * \param expr The source expression to be substituted
 * \param value_map The map of new values.
 * \return The converted expression.
 */
132 133
PrimExpr Substitute(PrimExpr expr,
                const std::unordered_map<const VarNode*, PrimExpr>& value_map);
134 135 136 137 138 139 140

/*!
 * \brief Substitute the var specified in key->var to be value.
 * \param stmt The source statement to be substituted
 * \param value_map The map of new values.
 * \return The converted form.
 */
141
Stmt Substitute(Stmt stmt, const Map<Var, PrimExpr>& value_map);
142 143 144 145 146 147 148

/*!
 * \brief Substitute the var specified in key->var to be value.
 * \param expr The source expression to be substituted
 * \param value_map The map of new values.
 * \return The converted expression.
 */
149
PrimExpr Substitute(PrimExpr expr, const Map<Var, PrimExpr>& value_map);
150 151

/*!
tqchen committed
152 153
 * \brief inline all calls of f in stmt.
 *
154
 * \param stmt The statement to apply inline optimization.
tqchen committed
155 156
 * \param f The function reference to be inlined
 * \param args The arguments variable of the function.
157
 * \param body The definition body of the function.
tqchen committed
158 159 160 161
 * \return The result stmt
 *
 * \note All the passes in this file uses SSA form and outputs SSA form.
 */
162 163
Stmt Inline(Stmt stmt,
            FunctionRef f,
tqchen committed
164
            Array<Var> args,
165
            PrimExpr body);
166 167 168 169 170 171 172 173

/*!
 * \brief Flatten the multi-dimensional read/write
 *  to single dimensional Load/Store
 *
 * \param stmt The stmt to be trasnformed.
 * \param extern_buffer Map specifies external
 *    buffer assignment of input and outputs.
174
 * \param cache_line_size The size of CPU cache line.
175
 * \param create_bound_attribute Whether to create bound attributes.
176
 * \return Transformed stmt.
177 178
 */
Stmt StorageFlatten(Stmt stmt,
179
                    Map<te::Tensor, Buffer> extern_buffer,
180 181
                    int cache_line_size,
                    bool create_bound_attribute = false);
182 183 184 185 186 187 188 189 190 191 192

/*!
 * \brief Try to modify the AST to support TensorCore
 *
 * \param stmt The stmt to be trasnformed.
 * \param schedule The original schedule.
 * \param extern_buffer Map specifies external
 *    buffer assignment of input and outputs.
 * \return Transformed stmt.
 */
Stmt RewriteForTensorCore(Stmt stmt,
193 194
                          te::Schedule schedule,
                          Map<te::Tensor, Buffer> extern_buffer);
195

196 197 198 199 200 201 202 203
/*!
 * \brief Verify if there is any argument bound to compact buffer.
 *
 * \param stmt The stmt to be verified.
 * \return true if there is any buffer_bind_scope attribute found,
 *        otherwise, false.
 */
bool VerifyCompactBuffer(Stmt stmt);
tqchen committed
204

205
/*!
Tianqi Chen committed
206 207 208 209 210 211 212
 * \brief Remove No Op from the Stmt.
 * \param stmt The stmt to be trasnformed
 * \return Transformed stmt.
 */
Stmt RemoveNoOp(Stmt stmt);

/*!
213 214 215
 * \brief unroll the constant loop marked by unroll.
 * This pass also automatically attach pragma unroll tag to loops which meets the standard.
 *
216
 * \param stmt The statment to be unrolled.
217
 * \param auto_max_step The maximum step before stop attach automatic unroll
218
 * \param auto_max_depth The maximum depth before stop attach automatic unroll
219
 * \param auto_max_extent The maximum extent of the loop we can unroll,
Siju committed
220
 *                     this is an legacy option that do not take the loop total steps into account.
221
 * \param explicit_unroll Whether explicitly unroll the loop, or leave unroll annotation to codegen.
222
 * \return Transformed stmt.
223
 */
224 225
Stmt UnrollLoop(Stmt stmt,
                int auto_max_step,
226
                int auto_max_depth,
227 228
                int auto_max_extent,
                bool explicit_unroll);
229 230

/*!
231
 * \brief vectorize the constant loops
232
 * \param stmt The statement to be vectorized.
233
 * \return Transformed stmt.
234 235 236 237
 */
Stmt VectorizeLoop(Stmt stmt);

/*!
238 239 240 241 242 243 244
 * \brief convert vectorized loops into serialized loops
 * \param stmt The statement to skip vectorization on.
 * \return Transformed stmt.
 */
Stmt SkipVectorize(Stmt stmt);

/*!
245
* \brief instruments bound checkers.
246 247
* \param stmt The statement to be instrumented.
* \return Instrumented stmt.
248 249 250 251
*/
Stmt InstrumentBoundCheckers(Stmt stmt);

/*!
252
 * \brief Inject virtual thread loops into stmt.
253
 * \param stmt The statement to be transformed.
254 255 256 257 258
 * \return Transformed stmt.
 */
Stmt InjectVirtualThread(Stmt stmt);

/*!
259
 * \brief Inject prefetch instructions into stmt.
260
 * \param stmt The statement to be transformed.
261 262 263 264 265
 * \return Transformed stmt.
 */
Stmt InjectPrefetch(Stmt stmt);

/*!
266
 * \brief Inject double buffer into stmt.
267
 * \param stmt The statement to be transformed.
268
 * \param split_loop Loop splitting factor.
269 270
 * \return Transformed stmt.
 */
271
Stmt InjectDoubleBuffer(Stmt stmt, int split_loop);
272 273

/*!
274 275
 * \brief Inject copy intrinsics with optional pad.
 *
276
 * \param stmt The statement to be transformed.
277 278 279 280 281 282 283 284 285 286 287 288 289 290 291
 * \param pragma_key The pragma key for hint of copy.
 * \param fintrin The function with signature
 *
 *   Stmt fintrin(Buffer src,
 *                Buffer dst,
 *                Array<Expr> pad_before,
 *                Array<Expr> pad_after,
 *                Expr pad_value)
 * \return Transformed stmt.
 */
Stmt InjectCopyIntrin(Stmt stmt,
                      const std::string& pragma_key,
                      const runtime::PackedFunc& fintrin);

/*!
292 293 294 295
 * \brief Rewrite storage allocation pattern.
 *  Moves the allocation to outer most possible scope.
 *  Trying to share space between allocations to make
 *  a static allocation plan when possible.
296
 *
297
 * \param stmt The stmt to be transformed
298 299
 * \return Transformed stmt.
 */
300
Stmt StorageRewrite(Stmt stmt);
301 302

/*!
303 304
 * \brief partition loops in the stmt
 * \param stmt The stmt to do loop partition
305
 * \param split_const_loop flag to enable partition for const loop
306 307
 * \return Transformed stmt.
 */
308
Stmt LoopPartition(Stmt stmt, bool split_const_loop);
309 310

/*!
311 312
 * \brief Detect and insert sync points to co-processor.
 *
313
 * \param stmt The stmt to be transformed
314 315 316 317 318
 * \return Transformed stmt.
 */
Stmt CoProcSync(Stmt stmt);

/*!
319 320
 * \brief Lift common attrs with attr_key to outer scope.
 *
321
 * \param stmt The stmt to be transformed
322 323 324 325 326 327
 * \param attr_key The attribute key to be checked.
 * \return Transformed stmt.
 */
Stmt LiftAttrScope(Stmt stmt, std::string attr_key);

/*!
328
 * \brief Detect and rewrite unsafe select that contains memory access.
329
 * \param stmt The statement to be rewritten.
330 331 332 333 334
 * \return Transformed stmt.
 */
Stmt RewriteUnsafeSelect(Stmt stmt);

/*!
335 336 337
 * \brief Lower attached storage access information.
 * Do this pass after all storage access analysis finish.
 *
338
 * \param stmt The stmt to be transformed
339 340 341 342 343
 * \return Transformed stmt.
 */
Stmt LowerStorageAccessInfo(Stmt stmt);

/*!
344
 * \brief Decorate the stmt with a device scope, this is helpful for
345 346
 * hardware accelerator without thread blocks.
 *
347
 * \param stmt The stmt to be transformed
348 349 350 351 352
 * \return Transformed stmt.
 */
Stmt DecorateDeviceScope(Stmt stmt);

/*!
353 354 355 356 357 358 359
 * \brief Loop invariant code motion which locates and hoists if statements.
 * \param stmt The stmt to do if statement hoisting.
 * \return Transformed stmt.
 */
Stmt HoistIfThenElse(Stmt stmt);

/*!
360 361 362 363 364 365 366 367 368
 * \brief Narrow down PrimExpr datatype in stmt to target_bits.
 * \note  Run this pass after StorageFlatten.
 * \param stmt The stmt to do datatype rewrite
 * \param target_bits the bit of target datatype
 * \return Transformed stmt.
 */
Stmt NarrowDataType(Stmt stmt, int target_bits);

/*!
369 370 371 372 373 374 375 376 377 378 379
 * \brief Rewrite the pointer content type of arguments,
 *  as well as Alloc internal to the function to use
 *  the most frequently accessed type for load/store
 *  to avoid pointer casting in backend when possible.
 *
 * \note implemeneted in storage_rewrite.cc
 * \param f The function to be trasnformed
 * \return Transformed function.
 */
PrimFunc PointerValueTypeRewrite(PrimFunc f);

380
/*!
381 382 383 384 385 386 387 388 389
 * \brief Verify the correctness of a GPU code
 *        It will check the whether the amount of memory usage or the number of threads
 *        in a block exceeds the limit
 * \param stmt The statement to be checked
 * \param constraints The dict to specify constraints to check.
 *        Possible keys are
 *
 *        "max_local_memory_per_block": Total amount of local memory per block (in bytes).
 *        "max_shared_memory_per_block": Total amount of shared memory per block (in bytes).
390
 *        "max_threads_per_block": Maximum number of threads per block.
391 392 393 394 395 396 397 398 399
 *        "max_thread_x": Maximum length of threadIdx.x.
 *        "max_thread_y": Maximum length of threadIdx.y.
 *        "max_thread_z": Maximum length of threadIdx.z.
 *
 *        If one key is missing in this argument, the pass won't check for that item.
 * \return valid Whether it is a valid GPU code
 *
 */
bool VerifyGPUCode(Stmt stmt,
400
                   Map<std::string, PrimExpr> constraints);
401

402
}  // namespace tir
tqchen committed
403
}  // namespace tvm
404
#endif  // TVM_TIR_IR_PASS_H_