codegen_llvm.h 10.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10
/*!
 *  Copyright (c) 2017 by Contributors
 * \file codegen_llvm.h
 * \brief Common base class for generating into LLVM IR
 */
#ifndef TVM_CODEGEN_LLVM_CODEGEN_LLVM_H_
#define TVM_CODEGEN_LLVM_CODEGEN_LLVM_H_
#ifdef TVM_LLVM_VERSION

#include <tvm/ir.h>
11
#include <tvm/ir_functor_ext.h>
12
#include <tvm/codegen.h>
13
#include <tvm/arithmetic.h>
14
#include <memory>
15
#include <utility>
16 17
#include <vector>
#include <string>
18
#include "llvm_common.h"
19
#include "../../runtime/thread_storage_scope.h"
20 21 22 23 24 25

namespace tvm {
namespace codegen {

using namespace ir;

26

27 28 29
/*!
 * \brief A base class to generate a LLVM.
 */
30 31 32
class CodeGenLLVM :
      public ExprFunctor<llvm::Value* (const Expr&)>,
      public StmtFunctor<void(const Stmt&)> {
33 34
 public:
  /*!
35 36 37 38 39 40
   * \brief Create new code generator based on target machine.
   * \param tm The target machine
   * \return The created llvm generator.
   */
  static std::unique_ptr<CodeGenLLVM> Create(llvm::TargetMachine* tm);
  /*!
41 42
   * \brief Initialize the code generator with given context
   * \param module_name The name of the module.
43
   * \param tm Target machine model
44
   * \param ctx The context.
45
   * \param system_lib Whether to insert system library registration.
46 47
   * \param dynamic_lookup Whether dynamically lookup runtime function
   *                       or use the runtime function table passed by caller.
48
   */
49 50 51 52 53
  virtual void Init(const std::string& module_name,
                    llvm::TargetMachine* tm,
                    llvm::LLVMContext* ctx,
                    bool system_lib,
                    bool dynamic_lookup);
54 55 56 57
  /*!
   * \brief Compile and add function f to the current module.
   * \param f The function to be added.
   */
58
  virtual void AddFunction(const LoweredFunc& f);
59
  /*!
60 61 62
   * \brief Add main function as the entry name
   * \param entry_func_name The name of entry function to be added.
   */
63
  virtual void AddMainFunction(const std::string& entry_func_name);
64
  /*!
65 66 67
   * \brief Finish current pass of codegen, get the module.
   * \return the created module.
   */
68
  virtual std::unique_ptr<llvm::Module> Finish();
69
  /*!
70 71 72 73 74
   * \brief Add mod to be linked with the generated module
   * \param mod The module to be linked.
   */
  void AddLinkModule(std::unique_ptr<llvm::Module>&& mod);
  /*!
75 76 77 78 79
   * \brief Create Value for expression e
   * \param e The expression to be created value for.
   * \return created value.
   */
  llvm::Value* MakeValue(const Expr& e) {
80
    return VisitExpr(e);
81 82
  }
  // Short hande code to get a constant int 32
83 84
  llvm::Constant* ConstInt32(int64_t value) const {
    return llvm::ConstantInt::getSigned(t_int32_, value);
85 86
  }
  // override codegen
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
  llvm::Value* VisitExpr_(const Variable* op) override;
  llvm::Value* VisitExpr_(const Cast* op) override;
  llvm::Value* VisitExpr_(const IntImm* op) override;
  llvm::Value* VisitExpr_(const UIntImm* op) override;
  llvm::Value* VisitExpr_(const FloatImm* op) override;
  llvm::Value* VisitExpr_(const StringImm* op) override;
  llvm::Value* VisitExpr_(const Add* op) override;
  llvm::Value* VisitExpr_(const Sub* op) override;
  llvm::Value* VisitExpr_(const Mul* op) override;
  llvm::Value* VisitExpr_(const Div* op) override;
  llvm::Value* VisitExpr_(const Mod* op) override;
  llvm::Value* VisitExpr_(const Min* op) override;
  llvm::Value* VisitExpr_(const Max* op) override;
  llvm::Value* VisitExpr_(const LT* op) override;
  llvm::Value* VisitExpr_(const LE* op) override;
  llvm::Value* VisitExpr_(const GT* op) override;
  llvm::Value* VisitExpr_(const GE* op) override;
  llvm::Value* VisitExpr_(const EQ* op) override;
  llvm::Value* VisitExpr_(const NE* op) override;
  llvm::Value* VisitExpr_(const And* op) override;
  llvm::Value* VisitExpr_(const Or* op) override;
  llvm::Value* VisitExpr_(const Not* op) override;
  llvm::Value* VisitExpr_(const Select* op) override;
  llvm::Value* VisitExpr_(const Let* op) override;
  llvm::Value* VisitExpr_(const Load* op) override;
  llvm::Value* VisitExpr_(const Call* op) override;
  llvm::Value* VisitExpr_(const Ramp* op) override;
  llvm::Value* VisitExpr_(const Broadcast* op) override;
115
  // stmt
116 117 118 119 120 121 122 123 124 125
  void VisitStmt_(const Store* op) override;
  void VisitStmt_(const For* op) override;
  void VisitStmt_(const IfThenElse* op) override;
  void VisitStmt_(const Allocate* op) override;
  void VisitStmt_(const AttrStmt* op) override;
  void VisitStmt_(const AssertStmt* op) override;
  void VisitStmt_(const LetStmt* op) override;
  void VisitStmt_(const Block* op) override;
  void VisitStmt_(const Evaluate* op) override;
  void VisitStmt_(const ProducerConsumer* op) override;
126

127
 protected:
128 129 130
  /*! \brief The storage information */
  struct StorageInfo {
    /*! \brief The storage scope */
131
    runtime::StorageScope scope;
132 133 134
    /*! \brief The alignment of allocation */
    int alignment{0};
  };
135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154
  /*!
   * \brief Execute falloca at the beginning of the
   *  currrent function and obtain its return value.
   *
   *  This is a helper function to make sure that
   *  alloca always happen in the beginning of the function.
   *
   * \param falloca The allocation function to be executed.
   * \tparam F The function to be executed.
   * \return The result.
   */
  template<typename F>
  inline llvm::AllocaInst* WithFunctionEntry(F falloca) {
    llvm::BasicBlock* current = builder_->GetInsertBlock();
    llvm::BasicBlock* entry = &(function_->getEntryBlock());
    builder_->SetInsertPoint(entry, entry->begin());
    llvm::AllocaInst* res = falloca();
    builder_->SetInsertPoint(current);
    return res;
  }
155 156 157 158 159 160 161 162 163 164 165 166 167 168
  // create intrinstic given call
  virtual llvm::Value* CreateIntrinsic(const Call* op);
  // create extern function call
  virtual llvm::Value* CreateCallExtern(const Call* op);
  // Get the corresponding thread index
  virtual llvm::Value* GetThreadIndex(const IterVar& iv);
  // Get the corresponding thread index
  virtual llvm::Value* CreateStorageSync(const Call* op);
  // apply optimization on the module.
  virtual void InitPassManagerBuilder(llvm::PassManagerBuilder* builder);
  // Scalarize by iterating elements of e.
  // f is a callback that takes index and v.
  virtual void Scalarize(const Expr& e,
                         std::function<void(int i, llvm::Value* v)> f);
169 170 171 172 173 174 175
  // Initialize target
  virtual void InitTarget(llvm::TargetMachine* tm);
  // Add module startup function if needed.
  virtual void AddStartupFunction() {}
  // apply optimization on the module.
  virtual void Optimize();
  // Get the maximim storage align bits of buffer pointer given storage scope.
176
  virtual int NativeVectorBits(const runtime::StorageScope& storage_scope) const;
177 178 179
  // Get correct address space depending on the backend
  virtual unsigned GetGlobalAddressSpace();

180 181 182 183 184
  void AddFunctionInternal(const LoweredFunc& f, bool ret_void);
  // Create extern call
  llvm::CallInst* CreateCallExtern(llvm::Type* ret,
                                   const std::string& name,
                                   const std::vector<llvm::Value*>& value);
185 186 187 188 189
  /*!
   * \param t The original type.
   * \return LLVM type of t
   */
  llvm::Type* LLVMType(const Type& t) const;
190 191 192 193 194 195
  // initialize the function state.
  void InitFuncState();
  // Get alignment given index.
  void GetAlignment(
      Type t, const Variable* buf_var, const Expr& index,
      int* p_alignment, int* p_native_bits);
196 197
  // Get constant string
  llvm::Value* GetConstString(const std::string& str);
198 199 200
  // do a scalarize call with f
  llvm::Value* CreateScalarizedCall(
      const Call* op, llvm::Function* f, const std::vector<llvm::Value*>& args);
201 202
  // handle module import
  void HandleImport(const std::string& code);
203 204 205 206 207 208 209 210 211 212 213 214 215
  // cast operatpr
  llvm::Value* CreateCast(Type from, Type to, llvm::Value* value);
  // comparison op
  llvm::Value* GetVarValue(const Variable* v) const;
  llvm::Value* CreateLT(Type t, llvm::Value* a, llvm::Value* b);
  llvm::Value* CreateLE(Type t, llvm::Value* a, llvm::Value* b);
  llvm::Value* CreateGT(Type t, llvm::Value* a, llvm::Value* b);
  llvm::Value* CreateGE(Type t, llvm::Value* a, llvm::Value* b);
  llvm::Value* CreateAdd(Type t, llvm::Value* a, llvm::Value* b);
  llvm::Value* CreateSub(Type t, llvm::Value* a, llvm::Value* b);
  llvm::Value* CreateMul(Type t, llvm::Value* a, llvm::Value* b);
  llvm::Value* CreateBroadcast(llvm::Value* value, int lanes);
  llvm::Value* CreateBufferPtr(Type t, llvm::Value* buffer, llvm::Value* index);
216
  llvm::Value* CreateBufferVecPtr(Type t, llvm::Value* buffer, llvm::Value* index);
217 218 219 220 221 222 223 224 225 226 227 228
  // Vector concatenation.
  llvm::Value* CreateVecSlice(llvm::Value* vec, int begin, int extent);
  llvm::Value* CreateVecFlip(llvm::Value* vec);
  llvm::Value* CreateVecConcat(std::vector<llvm::Value*> vecs);
  llvm::Value* CreateVecPad(llvm::Value* vec, int target_lanes);
  // Create serial for
  void CreateSerialFor(llvm::Value* begin,
                       llvm::Value* end,
                       llvm::Value* stride,
                       const VarExpr& loop_var, const Stmt& body);
  // add alias information.
  void AddAliasInfo(llvm::Instruction* load, const Variable* buffer, Expr index, Type type);
229 230 231 232 233 234 235 236
  // The IRBuilder.
  using IRBuilder = llvm::IRBuilder<llvm::ConstantFolder, llvm::IRBuilderDefaultInserter>;
  // The current function
  llvm::Function* function_;
  // Internal builder
  std::unique_ptr<IRBuilder> builder_;
  // The module to be returned;
  std::unique_ptr<llvm::Module> module_;
237
  std::unique_ptr<llvm::DataLayout> data_layout_;
238 239
  // Internal metabuilder
  std::unique_ptr<llvm::MDBuilder> md_builder_;
240 241
  // llvm target machine
  llvm::TargetMachine* target_machine_{nullptr};
242 243 244 245
  // llvm context
  llvm::LLVMContext* ctx_{nullptr};
  // helpful data types
  llvm::Type* t_void_{nullptr};
246
  llvm::PointerType* t_void_p_{nullptr};
247 248 249 250 251 252 253
  llvm::Type* t_int_{nullptr};
  llvm::Type* t_char_{nullptr};
  llvm::Type* t_int8_{nullptr};
  llvm::Type* t_int16_{nullptr};
  llvm::Type* t_int32_{nullptr};
  llvm::Type* t_int64_{nullptr};
  llvm::Type* t_float64_{nullptr};
254
  // meta data
255 256
  llvm::MDNode* md_very_likely_branch_{nullptr};
  llvm::MDNode* md_tbaa_root_{nullptr};
257
  llvm::MDNode* md_tbaa_alias_set_{nullptr};
258 259
  // modules to be linked.
  std::vector<std::unique_ptr<llvm::Module> > link_modules_;
260 261
  /*! \brief native vector bits of current targetx*/
  int native_vector_bits_{0};
262
  /*! \brief the storage scope of allocation */
263
  std::unordered_map<const Variable*, StorageInfo> alloc_storage_info_;
264 265 266 267
  // The definition of local variable.
  std::unordered_map<const Variable*, llvm::Value*> var_map_;
  // global strings
  std::unordered_map<std::string, llvm::Constant*> str_map_;
268 269
  // Whether current function is restricted
  bool is_restricted_{true};
270 271
  // The alignment information
  std::unordered_map<const Variable*, arith::ModularEntry> align_map_;
272 273
  // set of var that are not restricted(can alias)
  std::unordered_set<const Variable*> alias_var_set_;
274 275
  // set of volatile buffer.
  std::unordered_set<const Variable*> volatile_buf_;
276 277 278 279 280
};
}  // namespace codegen
}  // namespace tvm
#endif  // LLVM_VERSION
#endif  // TVM_CODEGEN_LLVM_CODEGEN_LLVM_H_