codegen_llvm.h 10.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10
/*!
 *  Copyright (c) 2017 by Contributors
 * \file codegen_llvm.h
 * \brief Common base class for generating into LLVM IR
 */
#ifndef TVM_CODEGEN_LLVM_CODEGEN_LLVM_H_
#define TVM_CODEGEN_LLVM_CODEGEN_LLVM_H_
#ifdef TVM_LLVM_VERSION

#include <tvm/ir.h>
11
#include <tvm/ir_functor_ext.h>
12
#include <tvm/codegen.h>
13
#include <tvm/arithmetic.h>
14
#include <memory>
15
#include <utility>
16 17
#include <vector>
#include <string>
18
#include "llvm_common.h"
19
#include "../../runtime/thread_storage_scope.h"
20 21 22 23 24 25

namespace tvm {
namespace codegen {

using namespace ir;

26

27 28 29
/*!
 * \brief A base class to generate a LLVM.
 */
30 31 32
class CodeGenLLVM :
      public ExprFunctor<llvm::Value* (const Expr&)>,
      public StmtFunctor<void(const Stmt&)> {
33 34
 public:
  /*!
35 36 37 38 39 40
   * \brief Create new code generator based on target machine.
   * \param tm The target machine
   * \return The created llvm generator.
   */
  static std::unique_ptr<CodeGenLLVM> Create(llvm::TargetMachine* tm);
  /*!
41 42
   * \brief Initialize the code generator with given context
   * \param module_name The name of the module.
43
   * \param tm Target machine model
44
   * \param ctx The context.
45
   * \param system_lib Whether to insert system library registration.
46 47
   * \param dynamic_lookup Whether dynamically lookup runtime function
   *                       or use the runtime function table passed by caller.
48
   */
49 50 51 52 53
  virtual void Init(const std::string& module_name,
                    llvm::TargetMachine* tm,
                    llvm::LLVMContext* ctx,
                    bool system_lib,
                    bool dynamic_lookup);
54 55 56 57
  /*!
   * \brief Compile and add function f to the current module.
   * \param f The function to be added.
   */
58
  virtual void AddFunction(const LoweredFunc& f);
59
  /*!
60 61 62
   * \brief Add main function as the entry name
   * \param entry_func_name The name of entry function to be added.
   */
63
  virtual void AddMainFunction(const std::string& entry_func_name);
64
  /*!
65 66 67
   * \brief Finish current pass of codegen, get the module.
   * \return the created module.
   */
68
  virtual std::unique_ptr<llvm::Module> Finish();
69
  /*!
70 71 72 73 74
   * \brief Add mod to be linked with the generated module
   * \param mod The module to be linked.
   */
  void AddLinkModule(std::unique_ptr<llvm::Module>&& mod);
  /*!
75 76 77 78 79
   * \brief Create Value for expression e
   * \param e The expression to be created value for.
   * \return created value.
   */
  llvm::Value* MakeValue(const Expr& e) {
80
    return VisitExpr(e);
81 82
  }
  // Short hande code to get a constant int 32
83 84
  llvm::Constant* ConstInt32(int64_t value) const {
    return llvm::ConstantInt::getSigned(t_int32_, value);
85 86
  }
  // override codegen
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
  llvm::Value* VisitExpr_(const Variable* op) override;
  llvm::Value* VisitExpr_(const Cast* op) override;
  llvm::Value* VisitExpr_(const IntImm* op) override;
  llvm::Value* VisitExpr_(const UIntImm* op) override;
  llvm::Value* VisitExpr_(const FloatImm* op) override;
  llvm::Value* VisitExpr_(const StringImm* op) override;
  llvm::Value* VisitExpr_(const Add* op) override;
  llvm::Value* VisitExpr_(const Sub* op) override;
  llvm::Value* VisitExpr_(const Mul* op) override;
  llvm::Value* VisitExpr_(const Div* op) override;
  llvm::Value* VisitExpr_(const Mod* op) override;
  llvm::Value* VisitExpr_(const Min* op) override;
  llvm::Value* VisitExpr_(const Max* op) override;
  llvm::Value* VisitExpr_(const LT* op) override;
  llvm::Value* VisitExpr_(const LE* op) override;
  llvm::Value* VisitExpr_(const GT* op) override;
  llvm::Value* VisitExpr_(const GE* op) override;
  llvm::Value* VisitExpr_(const EQ* op) override;
  llvm::Value* VisitExpr_(const NE* op) override;
  llvm::Value* VisitExpr_(const And* op) override;
  llvm::Value* VisitExpr_(const Or* op) override;
  llvm::Value* VisitExpr_(const Not* op) override;
  llvm::Value* VisitExpr_(const Select* op) override;
  llvm::Value* VisitExpr_(const Let* op) override;
  llvm::Value* VisitExpr_(const Load* op) override;
  llvm::Value* VisitExpr_(const Call* op) override;
  llvm::Value* VisitExpr_(const Ramp* op) override;
  llvm::Value* VisitExpr_(const Broadcast* op) override;
115
  // stmt
116 117 118 119 120 121 122 123 124 125
  void VisitStmt_(const Store* op) override;
  void VisitStmt_(const For* op) override;
  void VisitStmt_(const IfThenElse* op) override;
  void VisitStmt_(const Allocate* op) override;
  void VisitStmt_(const AttrStmt* op) override;
  void VisitStmt_(const AssertStmt* op) override;
  void VisitStmt_(const LetStmt* op) override;
  void VisitStmt_(const Block* op) override;
  void VisitStmt_(const Evaluate* op) override;
  void VisitStmt_(const ProducerConsumer* op) override;
126

127
 protected:
128 129 130
  /*! \brief The storage information */
  struct StorageInfo {
    /*! \brief The storage scope */
131
    runtime::StorageScope scope;
132 133 134
    /*! \brief The alignment of allocation */
    int alignment{0};
  };
135 136 137 138 139 140 141 142 143 144 145 146 147 148
  // create intrinstic given call
  virtual llvm::Value* CreateIntrinsic(const Call* op);
  // create extern function call
  virtual llvm::Value* CreateCallExtern(const Call* op);
  // Get the corresponding thread index
  virtual llvm::Value* GetThreadIndex(const IterVar& iv);
  // Get the corresponding thread index
  virtual llvm::Value* CreateStorageSync(const Call* op);
  // apply optimization on the module.
  virtual void InitPassManagerBuilder(llvm::PassManagerBuilder* builder);
  // Scalarize by iterating elements of e.
  // f is a callback that takes index and v.
  virtual void Scalarize(const Expr& e,
                         std::function<void(int i, llvm::Value* v)> f);
149 150 151 152 153 154 155
  // Initialize target
  virtual void InitTarget(llvm::TargetMachine* tm);
  // Add module startup function if needed.
  virtual void AddStartupFunction() {}
  // apply optimization on the module.
  virtual void Optimize();
  // Get the maximim storage align bits of buffer pointer given storage scope.
156
  virtual int NativeVectorBits(const runtime::StorageScope& storage_scope) const;
157 158 159
  // Get correct address space depending on the backend
  virtual unsigned GetGlobalAddressSpace();

160 161 162 163 164
  void AddFunctionInternal(const LoweredFunc& f, bool ret_void);
  // Create extern call
  llvm::CallInst* CreateCallExtern(llvm::Type* ret,
                                   const std::string& name,
                                   const std::vector<llvm::Value*>& value);
165 166 167 168 169
  /*!
   * \param t The original type.
   * \return LLVM type of t
   */
  llvm::Type* LLVMType(const Type& t) const;
170 171 172 173 174 175
  // initialize the function state.
  void InitFuncState();
  // Get alignment given index.
  void GetAlignment(
      Type t, const Variable* buf_var, const Expr& index,
      int* p_alignment, int* p_native_bits);
176 177
  // Get constant string
  llvm::Value* GetConstString(const std::string& str);
178 179 180
  // do a scalarize call with f
  llvm::Value* CreateScalarizedCall(
      const Call* op, llvm::Function* f, const std::vector<llvm::Value*>& args);
181 182
  // handle module import
  void HandleImport(const std::string& code);
183 184 185 186 187 188 189 190 191 192 193 194 195
  // cast operatpr
  llvm::Value* CreateCast(Type from, Type to, llvm::Value* value);
  // comparison op
  llvm::Value* GetVarValue(const Variable* v) const;
  llvm::Value* CreateLT(Type t, llvm::Value* a, llvm::Value* b);
  llvm::Value* CreateLE(Type t, llvm::Value* a, llvm::Value* b);
  llvm::Value* CreateGT(Type t, llvm::Value* a, llvm::Value* b);
  llvm::Value* CreateGE(Type t, llvm::Value* a, llvm::Value* b);
  llvm::Value* CreateAdd(Type t, llvm::Value* a, llvm::Value* b);
  llvm::Value* CreateSub(Type t, llvm::Value* a, llvm::Value* b);
  llvm::Value* CreateMul(Type t, llvm::Value* a, llvm::Value* b);
  llvm::Value* CreateBroadcast(llvm::Value* value, int lanes);
  llvm::Value* CreateBufferPtr(Type t, llvm::Value* buffer, llvm::Value* index);
196
  llvm::Value* CreateBufferVecPtr(Type t, llvm::Value* buffer, llvm::Value* index);
197 198 199 200 201 202 203 204 205 206 207 208
  // Vector concatenation.
  llvm::Value* CreateVecSlice(llvm::Value* vec, int begin, int extent);
  llvm::Value* CreateVecFlip(llvm::Value* vec);
  llvm::Value* CreateVecConcat(std::vector<llvm::Value*> vecs);
  llvm::Value* CreateVecPad(llvm::Value* vec, int target_lanes);
  // Create serial for
  void CreateSerialFor(llvm::Value* begin,
                       llvm::Value* end,
                       llvm::Value* stride,
                       const VarExpr& loop_var, const Stmt& body);
  // add alias information.
  void AddAliasInfo(llvm::Instruction* load, const Variable* buffer, Expr index, Type type);
209 210 211 212 213 214 215 216
  // The IRBuilder.
  using IRBuilder = llvm::IRBuilder<llvm::ConstantFolder, llvm::IRBuilderDefaultInserter>;
  // The current function
  llvm::Function* function_;
  // Internal builder
  std::unique_ptr<IRBuilder> builder_;
  // The module to be returned;
  std::unique_ptr<llvm::Module> module_;
217
  std::unique_ptr<llvm::DataLayout> data_layout_;
218 219
  // Internal metabuilder
  std::unique_ptr<llvm::MDBuilder> md_builder_;
220 221
  // llvm target machine
  llvm::TargetMachine* target_machine_{nullptr};
222 223 224 225
  // llvm context
  llvm::LLVMContext* ctx_{nullptr};
  // helpful data types
  llvm::Type* t_void_{nullptr};
226
  llvm::PointerType* t_void_p_{nullptr};
227 228 229 230 231 232 233
  llvm::Type* t_int_{nullptr};
  llvm::Type* t_char_{nullptr};
  llvm::Type* t_int8_{nullptr};
  llvm::Type* t_int16_{nullptr};
  llvm::Type* t_int32_{nullptr};
  llvm::Type* t_int64_{nullptr};
  llvm::Type* t_float64_{nullptr};
234
  // meta data
235 236
  llvm::MDNode* md_very_likely_branch_{nullptr};
  llvm::MDNode* md_tbaa_root_{nullptr};
237
  llvm::MDNode* md_tbaa_alias_set_{nullptr};
238 239
  // modules to be linked.
  std::vector<std::unique_ptr<llvm::Module> > link_modules_;
240 241
  /*! \brief native vector bits of current targetx*/
  int native_vector_bits_{0};
242
  /*! \brief the storage scope of allocation */
243
  std::unordered_map<const Variable*, StorageInfo> alloc_storage_info_;
244 245 246 247
  // The definition of local variable.
  std::unordered_map<const Variable*, llvm::Value*> var_map_;
  // global strings
  std::unordered_map<std::string, llvm::Constant*> str_map_;
248 249
  // Whether current function is restricted
  bool is_restricted_{true};
250 251
  // The alignment information
  std::unordered_map<const Variable*, arith::ModularEntry> align_map_;
252 253
  // set of var that are not restricted(can alias)
  std::unordered_set<const Variable*> alias_var_set_;
254 255
  // set of volatile buffer.
  std::unordered_set<const Variable*> volatile_buf_;
256 257 258 259 260
};
}  // namespace codegen
}  // namespace tvm
#endif  // LLVM_VERSION
#endif  // TVM_CODEGEN_LLVM_CODEGEN_LLVM_H_