/*! * Copyright (c) 2016 by Contributors * \file tvm/ir_pass.h * \brief Collection of IR pass functions * * When the pass functions in this file are for Stmt, * we can use PassFunction(Evaluate(expr)) to apply it to Expr */ #ifndef TVM_IR_PASS_H_ #define TVM_IR_PASS_H_ #include <arithmetic/Simplify.h> #include <unordered_map> #include <vector> #include <string> #include "expr.h" #include "buffer.h" #include "schedule.h" #include "lowered_func.h" namespace tvm { namespace ir { /*! * \brief Simplify the expression. * \param expr The expression to be simplifed. * \param vrange The range information about the variable. * \return Canonicalized statement. */ EXPORT Expr Simplify(Expr expr, Map<Var, Range> vrange = Map<Var, Range>()); /*! * \brief Simplify the statement. * \param stmt The statement to be simplifed. * \param vrange The range information about the variable. * \return Canonicalized statement. */ Stmt Simplify(Stmt stmt, Map<Var, Range> vrange = Map<Var, Range>()); /*! * \brief Simplify by applying canonical form. * \param stmt The statement to be canonically simplifed. * \param vrange The range information about the variable. * \return Canonicalized statement. */ Stmt CanonicalSimplify(Stmt stmt, Map<Var, Range> vrange = Map<Var, Range>()); /*! * \brief Simplify by applying canonical form. * \param expr The statement to be canonically simplifed. * \param vrange The range information about the variable. * \return Canonicalized expression. */ EXPORT Expr CanonicalSimplify(Expr expr, Map<Var, Range> vrange = Map<Var, Range>()); /*! * \brief Deep compare lhs and rhs * \param lhs The left operand * \param rhs The right operand * \return The comparison result. */ EXPORT bool Equal(const Expr& lhs, const Expr& rhs); /*! * \brief Deep compare lhs and rhs * \param lhs The left operand * \param rhs The right operand * \return The comparison result. */ bool Equal(const Stmt& lhs, const Stmt& rhs); /*! * \brief Deep compare lhs and rhs. * * If you only want equality comparison, use Equal * which will also tie definitions. The compare mode * will give order of expression in total order. * * \param lhs The left operand * \param rhs The right operand * \return The comparison result. */ int Compare(const Expr& lhs, const Expr& rhs); /*! * \brief verifies whether the IR stmt or Expr is in SSA form. * That is: each VarExpr is defined and assigned once(in Let/For) * * \param ir The root of the IR DAG. * \return Whether IR is in SSA form. * \note All the passes in this file uses SSA form and outputs SSA form. */ bool VerifySSA(const Stmt& ir); /*! * \brief Whether the expression have side effect. * \return whether expression have side effect */ bool HasSideEffect(const Expr& e); /*! * \brief Whether e expression used var. * \param e The expression to be checked. * \param v The variable. * \return Whether e uses v. */ bool ExprUseVar(const Expr& e, const Var& v); /*! * \brief Whether e expression used any var in variable set.. * \param e The expression to be checked. * \param vset The variable set. * \return Whether e uses vset. */ bool ExprUseVar(const Expr& e, const std::unordered_set<const Variable*>& vset); /*! * \brief Convert a IR node to be SSA form. * \param stmt The source statement to be converted. * \return The converted form. */ Stmt ConvertSSA(Stmt stmt); /*! * \brief Substitute the var specified in key->var to be value. * \param stmt The source statement to be substituted * \param value_map The map of new values. * \return The converted form. */ Stmt Substitute(Stmt stmt, const std::unordered_map<const Variable*, Expr>& value_map); /*! * \brief Substitute the var specified in key->var to be value. * \param expr The source expression to be substituted * \param value_map The map of new values. * \return The converted expression. */ Expr Substitute(Expr expr, const std::unordered_map<const Variable*, Expr>& value_map); /*! * \brief Substitute the var specified in key->var to be value. * \param stmt The source statement to be substituted * \param value_map The map of new values. * \return The converted form. */ Stmt Substitute(Stmt stmt, const Map<Var, Expr>& value_map); /*! * \brief Substitute the var specified in key->var to be value. * \param expr The source expression to be substituted * \param value_map The map of new values. * \return The converted expression. */ Expr Substitute(Expr expr, const Map<Var, Expr>& value_map); /*! * \brief inline all calls of f in stmt. * * \param stmt The statement to apply inline optimization. * \param f The function reference to be inlined * \param args The arguments variable of the function. * \param body The definition body of the function. * \return The result stmt * * \note All the passes in this file uses SSA form and outputs SSA form. */ Stmt Inline(Stmt stmt, FunctionRef f, Array<Var> args, Expr body); /*! * \brief Flatten the multi-dimensional read/write * to single dimensional Load/Store * * \param stmt The stmt to be trasnformed. * \param extern_buffer Map specifies external * buffer assignment of input and outputs. * \param cache_line_size The size of CPU cache line. * \return Transformed stmt. */ Stmt StorageFlatten(Stmt stmt, Map<Tensor, Buffer> extern_buffer, int cache_line_size); /*! * \brief Remove No Op from the Stmt. * \param stmt The stmt to be trasnformed * \return Transformed stmt. */ Stmt RemoveNoOp(Stmt stmt); /*! * \brief Split statement into pipeine stages. * \param stmt The stmt to be splitted * \param split_load Whether split load into its own stage. * \return Transformed stmt. */ Stmt SplitPipeline(Stmt stmt, bool split_load); /*! * \brief Narrow channel access to smaller range. * \param stmt The stmt to do access rewriting. * \return Transformed stmt. */ Stmt NarrowChannelAccess(Stmt stmt); /*! * \brief unroll the constant loop marked by unroll. * This pass also automatically attach pragma unroll tag to loops which meets the standard. * * \param stmt The statment to be unrolled. * \param auto_max_step The maximum step before stop attach automatic unroll * \param auto_max_depth The maximum depth before stop attach automatic unroll * \param auto_max_extent The maximum extent of the loop we can unroll, * this is an legacy option that do not take the loop total steps into account. * \param explicit_unroll Whether explicitly unroll the loop, or leave unroll annotation to codegen. * \return Transformed stmt. */ Stmt UnrollLoop(Stmt stmt, int auto_max_step, int auto_max_depth, int auto_max_extent, bool explicit_unroll); /*! * \brief vectorize the constant loops * \param stmt The statment to be vectorized. * \return Transformed stmt. */ Stmt VectorizeLoop(Stmt stmt); /*! * \brief Inject virtual thread loops into stmt. * \param stmt The statment to be transformed. * \return Transformed stmt. */ Stmt InjectVirtualThread(Stmt stmt); /*! * \brief Inject prefetch instructions into stmt. * \param stmt The statment to be transformed. * \return Transformed stmt. */ Stmt InjectPrefetch(Stmt stmt); /*! * \brief Inject double buffer into stmt. * \param stmt The statment to be transformed. * \param split_loop Loop splitting factor. * \return Transformed stmt. */ Stmt InjectDoubleBuffer(Stmt stmt, int split_loop); /*! * \brief Inject copy intrinsics with optional pad. * * \param stmt The statment to be transformed. * \param pragma_key The pragma key for hint of copy. * \param fintrin The function with signature * * Stmt fintrin(Buffer src, * Buffer dst, * Array<Expr> pad_before, * Array<Expr> pad_after, * Expr pad_value) * \return Transformed stmt. */ Stmt InjectCopyIntrin(Stmt stmt, const std::string& pragma_key, const runtime::PackedFunc& fintrin); /*! * \brief Rewrite storage allocation pattern. * Moves the allocation to outer most possible scope. * Trying to share space between allocations to make * a static allocation plan when possible. * * \param stmt The stmt to be trasnformed * \return Transformed stmt. */ Stmt StorageRewrite(Stmt stmt); /*! * \brief partition loops in the stmt * \param stmt The stmt to do loop partition * \param split_const_loop flag to enable partition for const loop * \return Transformed stmt. */ Stmt LoopPartition(Stmt stmt, bool split_const_loop); /*! * \brief Detect and insert sync points to co-processor. * * \param stmt The stmt to be trasnformed * \return Transformed stmt. */ Stmt CoProcSync(Stmt stmt); /*! * \brief Lift common attrs with attr_key to outer scope. * * \param stmt The stmt to be trasnformed * \param attr_key The attribute key to be checked. * \return Transformed stmt. */ Stmt LiftAttrScope(Stmt stmt, std::string attr_key); /*! * \brief Detect and rewrite unsafe select that contains memory access. * \param stmt The statment to be rewritten. * \return Transformed stmt. */ Stmt RewriteUnsafeSelect(Stmt stmt); /*! * \brief Lower attached storage access information. * Do this pass after all storage access analysis finish. * * \param stmt The stmt to be trasnformed * \return Transformed stmt. */ Stmt LowerStorageAccessInfo(Stmt stmt); /*! * \brief Make an user callable API LoweredFunc. * * The main task of this function is to create code to : * - Map the values in the api_args to of Var that is required by body. * - Insert assertions to check type/value of the passed arguments. * * \param body The body of the function. * \param name The name of the function. * \param api_args Arguments to the function, can be either Var, or Buffer * \param num_unpacked_args Number of arguments that * are processed in plain form instead of packed form. * \param is_restricted Whether the caller can guarantee that each buffer argument do not overlap. * It is recommended to set to true for optimized code if such invariant holds. * * \return a LoweredFunc with the specified signiture. * * \note * The function signiture have two cases * * let num_packed_args = len(api_args) - num_unpacked_args; * * if num_packed_args is zero: * f(api_arg_0, api_arg_1, .., api_arg_n) where n == len(api_args) * * if num_packed_args is not zero: * f(TVMArg* packed_args, int* packed_arg_type_ids, int num_packed_args, * api_arg_k, api_arg_k+1, ... api_arg_n) * * where n == len(api_args), k == num_packed_args * * There is no thread_axis in generated function. */ LoweredFunc MakeAPI(Stmt body, std::string name, Array<NodeRef> api_args, int num_unpacked_args, bool is_restricted); /*! * \brief Bind the device type of host function to be device_type. * \param func The function to be binded. * \param device_type The device type to be binded. * \return The binded function. */ LoweredFunc BindDeviceType(LoweredFunc func, int device_type); /*! * \brief Find undefined vars in the statment. * \param stmt The function to be checked. * \param defs The vars that is defined. * \return Array of undefined vars. */ Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& defs); /*! * \brief Split the function into a host function and device functions. * \param func The function to be splitted. * * \return Array of functions, the first one is host function, * the others are device functions. */ Array<LoweredFunc> SplitHostDevice(LoweredFunc func); /*! * \brief Insert sync between parallel read/write of shared buffers. * * \param stmt The stmt to be trasnformed. * \param storage_scope The storage scope considered. */ LoweredFunc ThreadSync(LoweredFunc stmt, std::string storage_scope); /*! * \brief Lower cross thread alleduce in the stmt. * \param f The device function to be lowered. * \param warp_size the size of warp where no sync is needed. * \return Transformed function. */ LoweredFunc LowerThreadAllreduce(LoweredFunc f, int warp_size); /*! * \brief Lower warp memory in stmt. * \param f The device function to be lowered. * \param warp_size the size of warp where no sync is needed. * this function will only take in effect if warp_size is bigger than one. * \return Transformed function. */ LoweredFunc LowerWarpMemory(LoweredFunc f, int warp_size); /*! * \brief Remap the thread axis * * This can be used to get equivalent program which uses * threadIdx.y in place of threadIdx.x by passing * {"threadIdx.x": thread_axis("threadIdx.y")} * * * \param f The device function to be lowered. * \param axis_map The map from StringImm -> ItrVar * \return Transformed function. */ LoweredFunc RemapThreadAxis(LoweredFunc f, Map<Expr, IterVar> axis_map); /*! * \brief Lower packed function call. * \param f The function to be lowered. * \return Transformed function. */ LoweredFunc LowerTVMBuiltin(LoweredFunc f); /*! * \brief Combine context function calls. * \param f The host function to be lowered. * \return Transformed function. */ LoweredFunc CombineContextCall(LoweredFunc f); /*! * \brief Rewrite the pointer content type of arguments, * as well as Alloc internal to the function to use * the most frequently accessed type for load/store * to avoid pointer casting in backend when possible. * * \note implemeneted in storage_rewrite.cc * \param f The function to be trasnformed * \return Transformed function. */ LoweredFunc PointerValueTypeRewrite(LoweredFunc f); /*! * \brief Lower intrinsic function calls. * \param f The device function to be lowered. * \param target The target device. * \return Transformed function. */ LoweredFunc LowerIntrin(LoweredFunc f, const std::string& target); /*! * \brief Verify if memory accesses are legal for a specific target device type. * * In the case that tgt is cuda, if not all workload is bound with * threads, CPU code is generated that tries to access GPU memory, * which is illegal. This pass performs verification for this case. * * \param func The function to be verified. * \param device_type The target device type. * \return Success of memory verification. */ bool VerifyMemory(LoweredFunc func, int device_type); /*! * \brief Verify the correctness of a GPU code * It will check the whether the amount of memory usage or the number of threads * in a block exceeds the limit * \param stmt The statement to be checked * \param constraints The dict to specify constraints to check. * Possible keys are * * "max_local_memory_per_block": Total amount of local memory per block (in bytes). * "max_shared_memory_per_block": Total amount of shared memory per block (in bytes). * "max_threads_per_block": Maximum number of threads per block. * "max_thread_x": Maximum length of threadIdx.x. * "max_thread_y": Maximum length of threadIdx.y. * "max_thread_z": Maximum length of threadIdx.z. * * If one key is missing in this argument, the pass won't check for that item. * \return valid Whether it is a valid GPU code * */ bool VerifyGPUCode(Stmt stmt, Map<std::string, Expr> constraints); } // namespace ir } // namespace tvm #endif // TVM_IR_PASS_H_