codegen_llvm.h 8.97 KB
Newer Older
1 2 3 4 5 6 7 8 9 10
/*!
 *  Copyright (c) 2017 by Contributors
 * \file codegen_llvm.h
 * \brief Common base class for generating into LLVM IR
 */
#ifndef TVM_CODEGEN_LLVM_CODEGEN_LLVM_H_
#define TVM_CODEGEN_LLVM_CODEGEN_LLVM_H_
#ifdef TVM_LLVM_VERSION

#include <tvm/ir.h>
11
#include <tvm/ir_functor_ext.h>
12
#include <tvm/codegen.h>
13
#include <tvm/arithmetic.h>
14 15 16 17 18 19 20 21 22 23 24 25 26
#include <memory>
#include <vector>
#include <string>
#include "./llvm_common.h"

namespace tvm {
namespace codegen {

using namespace ir;

/*!
 * \brief A base class to generate a LLVM.
 */
27 28 29
class CodeGenLLVM :
      public ExprFunctor<llvm::Value* (const Expr&)>,
      public StmtFunctor<void(const Stmt&)> {
30 31 32 33
 public:
  /*!
   * \brief Initialize the code generator with given context
   * \param module_name The name of the module.
34
   * \param target_triple The target triple, can be empty.
35 36
   * \param ctx The context.
   */
37 38 39
  void Init(const std::string& module_name,
            const std::string& target_triple,
            llvm::LLVMContext* ctx);
40 41 42 43 44 45
  /*!
   * \brief Compile and add function f to the current module.
   * \param f The function to be added.
   */
  void AddFunction(const LoweredFunc& f);
  /*!
46 47 48 49 50
   * \brief Add main function as the entry name
   * \param entry_func_name The name of entry function to be added.
   */
  void AddMainFunction(const std::string& entry_func_name);
  /*!
51 52 53 54 55 56 57 58 59 60
   * \brief Finish current pass of codegen, get the module.
   * \return the created module.
   */
  std::unique_ptr<llvm::Module> Finish();
  /*!
   * \brief Create Value for expression e
   * \param e The expression to be created value for.
   * \return created value.
   */
  llvm::Value* MakeValue(const Expr& e) {
61
    return VisitExpr(e);
62 63 64 65 66 67
  }
  // Short hande code to get a constant int 32
  llvm::Constant* ConstInt32(unsigned value) const {
    return llvm::ConstantInt::get(t_int32_, value);
  }
  // override codegen
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
  llvm::Value* VisitExpr_(const Variable* op) override;
  llvm::Value* VisitExpr_(const Cast* op) override;
  llvm::Value* VisitExpr_(const IntImm* op) override;
  llvm::Value* VisitExpr_(const UIntImm* op) override;
  llvm::Value* VisitExpr_(const FloatImm* op) override;
  llvm::Value* VisitExpr_(const StringImm* op) override;
  llvm::Value* VisitExpr_(const Add* op) override;
  llvm::Value* VisitExpr_(const Sub* op) override;
  llvm::Value* VisitExpr_(const Mul* op) override;
  llvm::Value* VisitExpr_(const Div* op) override;
  llvm::Value* VisitExpr_(const Mod* op) override;
  llvm::Value* VisitExpr_(const Min* op) override;
  llvm::Value* VisitExpr_(const Max* op) override;
  llvm::Value* VisitExpr_(const LT* op) override;
  llvm::Value* VisitExpr_(const LE* op) override;
  llvm::Value* VisitExpr_(const GT* op) override;
  llvm::Value* VisitExpr_(const GE* op) override;
  llvm::Value* VisitExpr_(const EQ* op) override;
  llvm::Value* VisitExpr_(const NE* op) override;
  llvm::Value* VisitExpr_(const And* op) override;
  llvm::Value* VisitExpr_(const Or* op) override;
  llvm::Value* VisitExpr_(const Not* op) override;
  llvm::Value* VisitExpr_(const Select* op) override;
  llvm::Value* VisitExpr_(const Let* op) override;
  llvm::Value* VisitExpr_(const Load* op) override;
  llvm::Value* VisitExpr_(const Call* op) override;
  llvm::Value* VisitExpr_(const Ramp* op) override;
  llvm::Value* VisitExpr_(const Broadcast* op) override;
96
  // stmt
97 98 99 100 101 102 103 104 105 106
  void VisitStmt_(const Store* op) override;
  void VisitStmt_(const For* op) override;
  void VisitStmt_(const IfThenElse* op) override;
  void VisitStmt_(const Allocate* op) override;
  void VisitStmt_(const AttrStmt* op) override;
  void VisitStmt_(const AssertStmt* op) override;
  void VisitStmt_(const LetStmt* op) override;
  void VisitStmt_(const Block* op) override;
  void VisitStmt_(const Evaluate* op) override;
  void VisitStmt_(const ProducerConsumer* op) override;
107 108 109 110 111 112
  // create intrinstic given call
  virtual llvm::Value* CreateIntrinstic(const Call* op);
  // create extern function call
  virtual llvm::Value* CreateCallExtern(const Call* op);
  // create call into tvm packed function.
  virtual llvm::Value* CreateCallPacked(const Call* op);
113 114 115 116
  // Scalarize e by iterating elements of e.
  // f is a callback that takes index and v.
  virtual void Scalarize(const Expr& e,
                         std::function<void(int i, llvm::Value* v)> f);
117 118 119 120 121 122
 protected:
  /*!
   * \param t The original type.
   * \return LLVM type of t
   */
  llvm::Type* LLVMType(const Type& t) const;
123 124 125 126 127 128
  // initialize the function state.
  void InitFuncState();
  // Get alignment given index.
  void GetAlignment(
      Type t, const Variable* buf_var, const Expr& index,
      int* p_alignment, int* p_native_bits);
129 130 131 132 133
  // do a scalarize call with f
  llvm::Value* CreateScalarizedCall(
      const Call* op, llvm::Function* f, const std::vector<llvm::Value*>& args);
  // apply optimization on the module.
  virtual void Optimize();
134 135
  // Get the maximim storage align bits of buffer pointer given storage scope.
  virtual int NativeVectorBits(const std::string& storage_scope) const;
136 137 138 139 140 141 142 143
  // The IRBuilder.
  using IRBuilder = llvm::IRBuilder<llvm::ConstantFolder, llvm::IRBuilderDefaultInserter>;
  // The current function
  llvm::Function* function_;
  // Internal builder
  std::unique_ptr<IRBuilder> builder_;
  // The module to be returned;
  std::unique_ptr<llvm::Module> module_;
144
  std::unique_ptr<llvm::DataLayout> data_layout_;
145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162
  // Internal metabuilder
  std::unique_ptr<llvm::MDBuilder> md_builder_;
  // llvm context
  llvm::LLVMContext* ctx_{nullptr};
  // helpful data types
  llvm::Type* t_void_{nullptr};
  llvm::Type* t_void_p_{nullptr};
  llvm::Type* t_int_{nullptr};
  llvm::Type* t_char_{nullptr};
  llvm::Type* t_int8_{nullptr};
  llvm::Type* t_int16_{nullptr};
  llvm::Type* t_int32_{nullptr};
  llvm::Type* t_int64_{nullptr};
  llvm::Type* t_float64_{nullptr};
  // branch
  llvm::MDNode* md_very_likely_branch_{nullptr};
  llvm::MDNode* md_tbaa_root_{nullptr};
  // TVM related data types
163
  llvm::Type* t_tvm_shape_index_{nullptr};
164 165 166 167 168
  llvm::Type* t_tvm_func_handle_{nullptr};
  llvm::StructType* t_tvm_context_{nullptr};
  llvm::StructType* t_tvm_type_{nullptr};
  llvm::StructType* t_tvm_array_{nullptr};
  llvm::StructType* t_tvm_value_{nullptr};
169
  llvm::FunctionType* t_f_tvm_par_for_lambda_{nullptr};
170 171
  // tvm api functions
  llvm::Function* f_tvm_func_call_{nullptr};
172
  llvm::Function* f_tvm_get_func_from_env_{nullptr};
173
  llvm::Function* f_tvm_api_set_last_error_{nullptr};
174
  llvm::Function* f_tvm_parallel_for_{nullptr};
175 176
  // The acting body
  llvm::BasicBlock* block_{nullptr};
177 178
  /*! \brief the storage scope of allocation */
  std::unordered_map<const Variable*, std::string> alloc_storage_scope_;
179 180 181 182 183 184 185 186 187 188 189 190 191 192

 private:
  // comparison op
  llvm::Value* GetVarValue(const Variable* v) const;
  llvm::Value* CreateLT(Type t, llvm::Value* a, llvm::Value* b);
  llvm::Value* CreateLE(Type t, llvm::Value* a, llvm::Value* b);
  llvm::Value* CreateGT(Type t, llvm::Value* a, llvm::Value* b);
  llvm::Value* CreateGE(Type t, llvm::Value* a, llvm::Value* b);
  llvm::Value* CreateAdd(Type t, llvm::Value* a, llvm::Value* b);
  llvm::Value* CreateSub(Type t, llvm::Value* a, llvm::Value* b);
  llvm::Value* CreateMul(Type t, llvm::Value* a, llvm::Value* b);
  llvm::Value* CreateBroadcast(llvm::Value* value, int lanes);
  llvm::Value* GetConstString(const std::string& str);
  llvm::Value* CreateBufferPtr(Type t, llvm::Value* buffer, llvm::Value* index);
193
  llvm::Value* CreateStructRefPtr(Type t, llvm::Value* buffer, llvm::Value* index, int kind);
194
  llvm::Value* CreateCast(Type from, Type to, llvm::Value* value);
195
  llvm::Value* GetPackedFuncHandle(const std::string& str);
196 197 198 199 200
  // Vector concatenation.
  llvm::Value* CreateVecSlice(llvm::Value* vec, int begin, int extent);
  llvm::Value* CreateVecFlip(llvm::Value* vec);
  llvm::Value* CreateVecConcat(std::vector<llvm::Value*> vecs);
  llvm::Value* CreateVecPad(llvm::Value* vec, int target_lanes);
201 202 203 204 205
  // Create parallel for.
  void CreateParallelFor(const For* op);
  // Create serial for
  void CreateSerialFor(llvm::Value* begin, llvm::Value* end,
                       const VarExpr& loop_var, const Stmt& body);
206 207 208
  // Check if the call to packed function is successful
  // if not directly finalize function and pass on return code.
  // return the end block after the check
209
  llvm::BasicBlock* CheckCallSuccess(llvm::Value* retcode);
210 211 212 213
  // Initialize target
  void InitTarget(const std::string& target);
  // Add a function to set global module context
  void InitGlobalContext();
214
  // add alias information.
215
  void AddAliasInfo(llvm::Instruction* load, const Variable* buffer, Expr index, Type type);
216 217 218 219
  // The definition of local variable.
  std::unordered_map<const Variable*, llvm::Value*> var_map_;
  // global strings
  std::unordered_map<std::string, llvm::Constant*> str_map_;
220 221
  // The alignment information
  std::unordered_map<const Variable*, arith::ModularEntry> align_map_;
222 223
  // The local module_context
  llvm::GlobalVariable* gv_mod_ctx_{nullptr};
224 225 226 227 228 229 230
  // global to packed function handle
  std::unordered_map<std::string, llvm::GlobalVariable*> func_handle_map_;
};
}  // namespace codegen
}  // namespace tvm
#endif  // LLVM_VERSION
#endif  // TVM_CODEGEN_LLVM_CODEGEN_LLVM_H_