codegen_llvm.h 12.1 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25 26 27 28 29
/*!
 *  Copyright (c) 2017 by Contributors
 * \file codegen_llvm.h
 * \brief Common base class for generating into LLVM IR
 */
#ifndef TVM_CODEGEN_LLVM_CODEGEN_LLVM_H_
#define TVM_CODEGEN_LLVM_CODEGEN_LLVM_H_
#ifdef TVM_LLVM_VERSION

#include <tvm/ir.h>
30
#include <tvm/ir_functor_ext.h>
31
#include <tvm/codegen.h>
32
#include <tvm/arithmetic.h>
33
#include <memory>
34
#include <utility>
35 36
#include <vector>
#include <string>
37 38
#include <unordered_map>
#include <unordered_set>
39
#include "llvm_common.h"
40
#include "../../runtime/thread_storage_scope.h"
41 42 43 44 45 46 47 48 49

namespace tvm {
namespace codegen {

using namespace ir;

/*!
 * \brief A base class to generate a LLVM.
 */
50 51 52
class CodeGenLLVM :
      public ExprFunctor<llvm::Value* (const Expr&)>,
      public StmtFunctor<void(const Stmt&)> {
53 54
 public:
  /*!
55 56 57 58 59 60
   * \brief Create new code generator based on target machine.
   * \param tm The target machine
   * \return The created llvm generator.
   */
  static std::unique_ptr<CodeGenLLVM> Create(llvm::TargetMachine* tm);
  /*!
61 62
   * \brief Initialize the code generator with given context
   * \param module_name The name of the module.
63
   * \param tm Target machine model
64
   * \param ctx The context.
65
   * \param system_lib Whether to insert system library registration.
66 67
   * \param dynamic_lookup Whether dynamically lookup runtime function
   *                       or use the runtime function table passed by caller.
68
   */
69 70 71 72 73
  virtual void Init(const std::string& module_name,
                    llvm::TargetMachine* tm,
                    llvm::LLVMContext* ctx,
                    bool system_lib,
                    bool dynamic_lookup);
74 75 76 77
  /*!
   * \brief Compile and add function f to the current module.
   * \param f The function to be added.
   */
78
  virtual void AddFunction(const LoweredFunc& f);
79
  /*!
80 81 82
   * \brief Add main function as the entry name
   * \param entry_func_name The name of entry function to be added.
   */
83
  virtual void AddMainFunction(const std::string& entry_func_name);
84
  /*!
85 86 87
   * \brief Finish current pass of codegen, get the module.
   * \return the created module.
   */
88
  virtual std::unique_ptr<llvm::Module> Finish();
89
  /*!
90 91 92 93 94
   * \brief Add mod to be linked with the generated module
   * \param mod The module to be linked.
   */
  void AddLinkModule(std::unique_ptr<llvm::Module>&& mod);
  /*!
95 96 97 98 99
   * \brief Create Value for expression e
   * \param e The expression to be created value for.
   * \return created value.
   */
  llvm::Value* MakeValue(const Expr& e) {
100
    return VisitExpr(e);
101 102
  }
  // Short hande code to get a constant int 32
103 104
  llvm::Constant* ConstInt32(int64_t value) const {
    return llvm::ConstantInt::getSigned(t_int32_, value);
105 106
  }
  // override codegen
107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
  llvm::Value* VisitExpr_(const Variable* op) override;
  llvm::Value* VisitExpr_(const Cast* op) override;
  llvm::Value* VisitExpr_(const IntImm* op) override;
  llvm::Value* VisitExpr_(const UIntImm* op) override;
  llvm::Value* VisitExpr_(const FloatImm* op) override;
  llvm::Value* VisitExpr_(const StringImm* op) override;
  llvm::Value* VisitExpr_(const Add* op) override;
  llvm::Value* VisitExpr_(const Sub* op) override;
  llvm::Value* VisitExpr_(const Mul* op) override;
  llvm::Value* VisitExpr_(const Div* op) override;
  llvm::Value* VisitExpr_(const Mod* op) override;
  llvm::Value* VisitExpr_(const Min* op) override;
  llvm::Value* VisitExpr_(const Max* op) override;
  llvm::Value* VisitExpr_(const LT* op) override;
  llvm::Value* VisitExpr_(const LE* op) override;
  llvm::Value* VisitExpr_(const GT* op) override;
  llvm::Value* VisitExpr_(const GE* op) override;
  llvm::Value* VisitExpr_(const EQ* op) override;
  llvm::Value* VisitExpr_(const NE* op) override;
  llvm::Value* VisitExpr_(const And* op) override;
  llvm::Value* VisitExpr_(const Or* op) override;
  llvm::Value* VisitExpr_(const Not* op) override;
  llvm::Value* VisitExpr_(const Select* op) override;
  llvm::Value* VisitExpr_(const Let* op) override;
  llvm::Value* VisitExpr_(const Load* op) override;
  llvm::Value* VisitExpr_(const Call* op) override;
  llvm::Value* VisitExpr_(const Ramp* op) override;
134
  llvm::Value* VisitExpr_(const Shuffle* op) override;
135
  llvm::Value* VisitExpr_(const Broadcast* op) override;
136
  // stmt
137 138 139 140 141 142 143 144 145 146
  void VisitStmt_(const Store* op) override;
  void VisitStmt_(const For* op) override;
  void VisitStmt_(const IfThenElse* op) override;
  void VisitStmt_(const Allocate* op) override;
  void VisitStmt_(const AttrStmt* op) override;
  void VisitStmt_(const AssertStmt* op) override;
  void VisitStmt_(const LetStmt* op) override;
  void VisitStmt_(const Block* op) override;
  void VisitStmt_(const Evaluate* op) override;
  void VisitStmt_(const ProducerConsumer* op) override;
147

148
 protected:
149 150 151
  /*! \brief The storage information */
  struct StorageInfo {
    /*! \brief The storage scope */
152
    runtime::StorageScope scope;
153 154 155
    /*! \brief The alignment of allocation */
    int alignment{0};
  };
156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175
  /*!
   * \brief Execute falloca at the beginning of the
   *  currrent function and obtain its return value.
   *
   *  This is a helper function to make sure that
   *  alloca always happen in the beginning of the function.
   *
   * \param falloca The allocation function to be executed.
   * \tparam F The function to be executed.
   * \return The result.
   */
  template<typename F>
  inline llvm::AllocaInst* WithFunctionEntry(F falloca) {
    llvm::BasicBlock* current = builder_->GetInsertBlock();
    llvm::BasicBlock* entry = &(function_->getEntryBlock());
    builder_->SetInsertPoint(entry, entry->begin());
    llvm::AllocaInst* res = falloca();
    builder_->SetInsertPoint(current);
    return res;
  }
176 177 178 179 180 181 182 183 184 185 186 187 188 189
  // create intrinstic given call
  virtual llvm::Value* CreateIntrinsic(const Call* op);
  // create extern function call
  virtual llvm::Value* CreateCallExtern(const Call* op);
  // Get the corresponding thread index
  virtual llvm::Value* GetThreadIndex(const IterVar& iv);
  // Get the corresponding thread index
  virtual llvm::Value* CreateStorageSync(const Call* op);
  // apply optimization on the module.
  virtual void InitPassManagerBuilder(llvm::PassManagerBuilder* builder);
  // Scalarize by iterating elements of e.
  // f is a callback that takes index and v.
  virtual void Scalarize(const Expr& e,
                         std::function<void(int i, llvm::Value* v)> f);
190 191 192 193 194 195 196
  // Initialize target
  virtual void InitTarget(llvm::TargetMachine* tm);
  // Add module startup function if needed.
  virtual void AddStartupFunction() {}
  // apply optimization on the module.
  virtual void Optimize();
  // Get the maximim storage align bits of buffer pointer given storage scope.
197
  virtual int NativeVectorBits(const runtime::StorageScope& storage_scope) const;
198 199 200
  // Get correct address space depending on the backend
  virtual unsigned GetGlobalAddressSpace();

201 202 203 204 205
  void AddFunctionInternal(const LoweredFunc& f, bool ret_void);
  // Create extern call
  llvm::CallInst* CreateCallExtern(llvm::Type* ret,
                                   const std::string& name,
                                   const std::vector<llvm::Value*>& value);
206 207 208 209 210
  /*!
   * \param t The original type.
   * \return LLVM type of t
   */
  llvm::Type* LLVMType(const Type& t) const;
211 212 213 214 215 216
  // initialize the function state.
  void InitFuncState();
  // Get alignment given index.
  void GetAlignment(
      Type t, const Variable* buf_var, const Expr& index,
      int* p_alignment, int* p_native_bits);
217 218
  // Get constant string
  llvm::Value* GetConstString(const std::string& str);
219 220 221
  // do a scalarize call with f
  llvm::Value* CreateScalarizedCall(
      const Call* op, llvm::Function* f, const std::vector<llvm::Value*>& args);
222 223
  // handle module import
  void HandleImport(const std::string& code);
224 225 226 227 228 229 230 231 232 233 234 235 236
  // cast operatpr
  llvm::Value* CreateCast(Type from, Type to, llvm::Value* value);
  // comparison op
  llvm::Value* GetVarValue(const Variable* v) const;
  llvm::Value* CreateLT(Type t, llvm::Value* a, llvm::Value* b);
  llvm::Value* CreateLE(Type t, llvm::Value* a, llvm::Value* b);
  llvm::Value* CreateGT(Type t, llvm::Value* a, llvm::Value* b);
  llvm::Value* CreateGE(Type t, llvm::Value* a, llvm::Value* b);
  llvm::Value* CreateAdd(Type t, llvm::Value* a, llvm::Value* b);
  llvm::Value* CreateSub(Type t, llvm::Value* a, llvm::Value* b);
  llvm::Value* CreateMul(Type t, llvm::Value* a, llvm::Value* b);
  llvm::Value* CreateBroadcast(llvm::Value* value, int lanes);
  llvm::Value* CreateBufferPtr(Type t, llvm::Value* buffer, llvm::Value* index);
237
  llvm::Value* CreateBufferVecPtr(Type t, llvm::Value* buffer, llvm::Value* index);
238 239 240 241 242 243 244 245 246 247 248 249
  // Vector concatenation.
  llvm::Value* CreateVecSlice(llvm::Value* vec, int begin, int extent);
  llvm::Value* CreateVecFlip(llvm::Value* vec);
  llvm::Value* CreateVecConcat(std::vector<llvm::Value*> vecs);
  llvm::Value* CreateVecPad(llvm::Value* vec, int target_lanes);
  // Create serial for
  void CreateSerialFor(llvm::Value* begin,
                       llvm::Value* end,
                       llvm::Value* stride,
                       const VarExpr& loop_var, const Stmt& body);
  // add alias information.
  void AddAliasInfo(llvm::Instruction* load, const Variable* buffer, Expr index, Type type);
250 251 252 253 254 255 256 257
  // The IRBuilder.
  using IRBuilder = llvm::IRBuilder<llvm::ConstantFolder, llvm::IRBuilderDefaultInserter>;
  // The current function
  llvm::Function* function_;
  // Internal builder
  std::unique_ptr<IRBuilder> builder_;
  // The module to be returned;
  std::unique_ptr<llvm::Module> module_;
258
  std::unique_ptr<llvm::DataLayout> data_layout_;
259 260
  // Internal metabuilder
  std::unique_ptr<llvm::MDBuilder> md_builder_;
261 262
  // llvm target machine
  llvm::TargetMachine* target_machine_{nullptr};
263 264 265 266
  // llvm context
  llvm::LLVMContext* ctx_{nullptr};
  // helpful data types
  llvm::Type* t_void_{nullptr};
267
  llvm::PointerType* t_void_p_{nullptr};
268 269 270 271 272 273 274
  llvm::Type* t_int_{nullptr};
  llvm::Type* t_char_{nullptr};
  llvm::Type* t_int8_{nullptr};
  llvm::Type* t_int16_{nullptr};
  llvm::Type* t_int32_{nullptr};
  llvm::Type* t_int64_{nullptr};
  llvm::Type* t_float64_{nullptr};
275
  // meta data
276 277
  llvm::MDNode* md_very_likely_branch_{nullptr};
  llvm::MDNode* md_tbaa_root_{nullptr};
278
  llvm::MDNode* md_tbaa_alias_set_{nullptr};
279 280
  // modules to be linked.
  std::vector<std::unique_ptr<llvm::Module> > link_modules_;
281 282
  /*! \brief native vector bits of current targetx*/
  int native_vector_bits_{0};
283
  /*! \brief the storage scope of allocation */
284
  std::unordered_map<const Variable*, StorageInfo> alloc_storage_info_;
285 286 287 288
  // The definition of local variable.
  std::unordered_map<const Variable*, llvm::Value*> var_map_;
  // global strings
  std::unordered_map<std::string, llvm::Constant*> str_map_;
289 290
  // Whether current function is restricted
  bool is_restricted_{true};
291 292
  // The analyzer information
  std::unique_ptr<arith::Analyzer> analyzer_;
293 294
  // set of var that are not restricted(can alias)
  std::unordered_set<const Variable*> alias_var_set_;
295 296
  // set of volatile buffer.
  std::unordered_set<const Variable*> volatile_buf_;
297
  /*! \brief Helper struct for debug infos. */
298 299 300 301 302
  struct DebugInfo {
    std::unique_ptr<llvm::DIBuilder> di_builder_;
    llvm::DICompileUnit* compilation_unit_{nullptr};
    llvm::DIFile* file_{nullptr};
  };
303 304 305 306
  /*!
   * \brief Create a new DebugInfo struct from the given Module that
   *  initializes file and compilation_unit_ to TVM defaults.
   */
307
  static std::unique_ptr<DebugInfo> CreateDebugInfo(llvm::Module* module);
308 309 310 311 312
};
}  // namespace codegen
}  // namespace tvm
#endif  // LLVM_VERSION
#endif  // TVM_CODEGEN_LLVM_CODEGEN_LLVM_H_