codegen_llvm.h 11.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25 26 27 28 29
/*!
 *  Copyright (c) 2017 by Contributors
 * \file codegen_llvm.h
 * \brief Common base class for generating into LLVM IR
 */
#ifndef TVM_CODEGEN_LLVM_CODEGEN_LLVM_H_
#define TVM_CODEGEN_LLVM_CODEGEN_LLVM_H_
#ifdef TVM_LLVM_VERSION

#include <tvm/ir.h>
30
#include <tvm/ir_functor_ext.h>
31
#include <tvm/codegen.h>
32
#include <tvm/arithmetic.h>
33
#include <memory>
34
#include <utility>
35 36
#include <vector>
#include <string>
37 38
#include <unordered_map>
#include <unordered_set>
39
#include "llvm_common.h"
40
#include "../../runtime/thread_storage_scope.h"
41 42 43 44 45 46 47 48 49

namespace tvm {
namespace codegen {

using namespace ir;

/*!
 * \brief A base class to generate a LLVM.
 */
50 51 52
class CodeGenLLVM :
      public ExprFunctor<llvm::Value* (const Expr&)>,
      public StmtFunctor<void(const Stmt&)> {
53 54
 public:
  /*!
55 56 57 58 59 60
   * \brief Create new code generator based on target machine.
   * \param tm The target machine
   * \return The created llvm generator.
   */
  static std::unique_ptr<CodeGenLLVM> Create(llvm::TargetMachine* tm);
  /*!
61 62
   * \brief Initialize the code generator with given context
   * \param module_name The name of the module.
63
   * \param tm Target machine model
64
   * \param ctx The context.
65
   * \param system_lib Whether to insert system library registration.
66 67
   * \param dynamic_lookup Whether dynamically lookup runtime function
   *                       or use the runtime function table passed by caller.
68
   */
69 70 71 72 73
  virtual void Init(const std::string& module_name,
                    llvm::TargetMachine* tm,
                    llvm::LLVMContext* ctx,
                    bool system_lib,
                    bool dynamic_lookup);
74 75 76 77
  /*!
   * \brief Compile and add function f to the current module.
   * \param f The function to be added.
   */
78
  virtual void AddFunction(const LoweredFunc& f);
79
  /*!
80 81 82
   * \brief Add main function as the entry name
   * \param entry_func_name The name of entry function to be added.
   */
83
  virtual void AddMainFunction(const std::string& entry_func_name);
84
  /*!
85 86 87
   * \brief Finish current pass of codegen, get the module.
   * \return the created module.
   */
88
  virtual std::unique_ptr<llvm::Module> Finish();
89
  /*!
90 91 92 93 94
   * \brief Add mod to be linked with the generated module
   * \param mod The module to be linked.
   */
  void AddLinkModule(std::unique_ptr<llvm::Module>&& mod);
  /*!
95 96 97 98 99
   * \brief Create Value for expression e
   * \param e The expression to be created value for.
   * \return created value.
   */
  llvm::Value* MakeValue(const Expr& e) {
100
    return VisitExpr(e);
101 102
  }
  // Short hande code to get a constant int 32
103 104
  llvm::Constant* ConstInt32(int64_t value) const {
    return llvm::ConstantInt::getSigned(t_int32_, value);
105 106
  }
  // override codegen
107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
  llvm::Value* VisitExpr_(const Variable* op) override;
  llvm::Value* VisitExpr_(const Cast* op) override;
  llvm::Value* VisitExpr_(const IntImm* op) override;
  llvm::Value* VisitExpr_(const UIntImm* op) override;
  llvm::Value* VisitExpr_(const FloatImm* op) override;
  llvm::Value* VisitExpr_(const StringImm* op) override;
  llvm::Value* VisitExpr_(const Add* op) override;
  llvm::Value* VisitExpr_(const Sub* op) override;
  llvm::Value* VisitExpr_(const Mul* op) override;
  llvm::Value* VisitExpr_(const Div* op) override;
  llvm::Value* VisitExpr_(const Mod* op) override;
  llvm::Value* VisitExpr_(const Min* op) override;
  llvm::Value* VisitExpr_(const Max* op) override;
  llvm::Value* VisitExpr_(const LT* op) override;
  llvm::Value* VisitExpr_(const LE* op) override;
  llvm::Value* VisitExpr_(const GT* op) override;
  llvm::Value* VisitExpr_(const GE* op) override;
  llvm::Value* VisitExpr_(const EQ* op) override;
  llvm::Value* VisitExpr_(const NE* op) override;
  llvm::Value* VisitExpr_(const And* op) override;
  llvm::Value* VisitExpr_(const Or* op) override;
  llvm::Value* VisitExpr_(const Not* op) override;
  llvm::Value* VisitExpr_(const Select* op) override;
  llvm::Value* VisitExpr_(const Let* op) override;
  llvm::Value* VisitExpr_(const Load* op) override;
  llvm::Value* VisitExpr_(const Call* op) override;
  llvm::Value* VisitExpr_(const Ramp* op) override;
  llvm::Value* VisitExpr_(const Broadcast* op) override;
135
  // stmt
136 137 138 139 140 141 142 143 144 145
  void VisitStmt_(const Store* op) override;
  void VisitStmt_(const For* op) override;
  void VisitStmt_(const IfThenElse* op) override;
  void VisitStmt_(const Allocate* op) override;
  void VisitStmt_(const AttrStmt* op) override;
  void VisitStmt_(const AssertStmt* op) override;
  void VisitStmt_(const LetStmt* op) override;
  void VisitStmt_(const Block* op) override;
  void VisitStmt_(const Evaluate* op) override;
  void VisitStmt_(const ProducerConsumer* op) override;
146

147
 protected:
148 149 150
  /*! \brief The storage information */
  struct StorageInfo {
    /*! \brief The storage scope */
151
    runtime::StorageScope scope;
152 153 154
    /*! \brief The alignment of allocation */
    int alignment{0};
  };
155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174
  /*!
   * \brief Execute falloca at the beginning of the
   *  currrent function and obtain its return value.
   *
   *  This is a helper function to make sure that
   *  alloca always happen in the beginning of the function.
   *
   * \param falloca The allocation function to be executed.
   * \tparam F The function to be executed.
   * \return The result.
   */
  template<typename F>
  inline llvm::AllocaInst* WithFunctionEntry(F falloca) {
    llvm::BasicBlock* current = builder_->GetInsertBlock();
    llvm::BasicBlock* entry = &(function_->getEntryBlock());
    builder_->SetInsertPoint(entry, entry->begin());
    llvm::AllocaInst* res = falloca();
    builder_->SetInsertPoint(current);
    return res;
  }
175 176 177 178 179 180 181 182 183 184 185 186 187 188
  // create intrinstic given call
  virtual llvm::Value* CreateIntrinsic(const Call* op);
  // create extern function call
  virtual llvm::Value* CreateCallExtern(const Call* op);
  // Get the corresponding thread index
  virtual llvm::Value* GetThreadIndex(const IterVar& iv);
  // Get the corresponding thread index
  virtual llvm::Value* CreateStorageSync(const Call* op);
  // apply optimization on the module.
  virtual void InitPassManagerBuilder(llvm::PassManagerBuilder* builder);
  // Scalarize by iterating elements of e.
  // f is a callback that takes index and v.
  virtual void Scalarize(const Expr& e,
                         std::function<void(int i, llvm::Value* v)> f);
189 190 191 192 193 194 195
  // Initialize target
  virtual void InitTarget(llvm::TargetMachine* tm);
  // Add module startup function if needed.
  virtual void AddStartupFunction() {}
  // apply optimization on the module.
  virtual void Optimize();
  // Get the maximim storage align bits of buffer pointer given storage scope.
196
  virtual int NativeVectorBits(const runtime::StorageScope& storage_scope) const;
197 198 199
  // Get correct address space depending on the backend
  virtual unsigned GetGlobalAddressSpace();

200 201 202 203 204
  void AddFunctionInternal(const LoweredFunc& f, bool ret_void);
  // Create extern call
  llvm::CallInst* CreateCallExtern(llvm::Type* ret,
                                   const std::string& name,
                                   const std::vector<llvm::Value*>& value);
205 206 207 208 209
  /*!
   * \param t The original type.
   * \return LLVM type of t
   */
  llvm::Type* LLVMType(const Type& t) const;
210 211 212 213 214 215
  // initialize the function state.
  void InitFuncState();
  // Get alignment given index.
  void GetAlignment(
      Type t, const Variable* buf_var, const Expr& index,
      int* p_alignment, int* p_native_bits);
216 217
  // Get constant string
  llvm::Value* GetConstString(const std::string& str);
218 219 220
  // do a scalarize call with f
  llvm::Value* CreateScalarizedCall(
      const Call* op, llvm::Function* f, const std::vector<llvm::Value*>& args);
221 222
  // handle module import
  void HandleImport(const std::string& code);
223 224 225 226 227 228 229 230 231 232 233 234 235
  // cast operatpr
  llvm::Value* CreateCast(Type from, Type to, llvm::Value* value);
  // comparison op
  llvm::Value* GetVarValue(const Variable* v) const;
  llvm::Value* CreateLT(Type t, llvm::Value* a, llvm::Value* b);
  llvm::Value* CreateLE(Type t, llvm::Value* a, llvm::Value* b);
  llvm::Value* CreateGT(Type t, llvm::Value* a, llvm::Value* b);
  llvm::Value* CreateGE(Type t, llvm::Value* a, llvm::Value* b);
  llvm::Value* CreateAdd(Type t, llvm::Value* a, llvm::Value* b);
  llvm::Value* CreateSub(Type t, llvm::Value* a, llvm::Value* b);
  llvm::Value* CreateMul(Type t, llvm::Value* a, llvm::Value* b);
  llvm::Value* CreateBroadcast(llvm::Value* value, int lanes);
  llvm::Value* CreateBufferPtr(Type t, llvm::Value* buffer, llvm::Value* index);
236
  llvm::Value* CreateBufferVecPtr(Type t, llvm::Value* buffer, llvm::Value* index);
237 238 239 240 241 242 243 244 245 246 247 248
  // Vector concatenation.
  llvm::Value* CreateVecSlice(llvm::Value* vec, int begin, int extent);
  llvm::Value* CreateVecFlip(llvm::Value* vec);
  llvm::Value* CreateVecConcat(std::vector<llvm::Value*> vecs);
  llvm::Value* CreateVecPad(llvm::Value* vec, int target_lanes);
  // Create serial for
  void CreateSerialFor(llvm::Value* begin,
                       llvm::Value* end,
                       llvm::Value* stride,
                       const VarExpr& loop_var, const Stmt& body);
  // add alias information.
  void AddAliasInfo(llvm::Instruction* load, const Variable* buffer, Expr index, Type type);
249 250 251 252 253 254 255 256
  // The IRBuilder.
  using IRBuilder = llvm::IRBuilder<llvm::ConstantFolder, llvm::IRBuilderDefaultInserter>;
  // The current function
  llvm::Function* function_;
  // Internal builder
  std::unique_ptr<IRBuilder> builder_;
  // The module to be returned;
  std::unique_ptr<llvm::Module> module_;
257
  std::unique_ptr<llvm::DataLayout> data_layout_;
258 259
  // Internal metabuilder
  std::unique_ptr<llvm::MDBuilder> md_builder_;
260 261
  // llvm target machine
  llvm::TargetMachine* target_machine_{nullptr};
262 263 264 265
  // llvm context
  llvm::LLVMContext* ctx_{nullptr};
  // helpful data types
  llvm::Type* t_void_{nullptr};
266
  llvm::PointerType* t_void_p_{nullptr};
267 268 269 270 271 272 273
  llvm::Type* t_int_{nullptr};
  llvm::Type* t_char_{nullptr};
  llvm::Type* t_int8_{nullptr};
  llvm::Type* t_int16_{nullptr};
  llvm::Type* t_int32_{nullptr};
  llvm::Type* t_int64_{nullptr};
  llvm::Type* t_float64_{nullptr};
274
  // meta data
275 276
  llvm::MDNode* md_very_likely_branch_{nullptr};
  llvm::MDNode* md_tbaa_root_{nullptr};
277
  llvm::MDNode* md_tbaa_alias_set_{nullptr};
278 279
  // modules to be linked.
  std::vector<std::unique_ptr<llvm::Module> > link_modules_;
280 281
  /*! \brief native vector bits of current targetx*/
  int native_vector_bits_{0};
282
  /*! \brief the storage scope of allocation */
283
  std::unordered_map<const Variable*, StorageInfo> alloc_storage_info_;
284 285 286 287
  // The definition of local variable.
  std::unordered_map<const Variable*, llvm::Value*> var_map_;
  // global strings
  std::unordered_map<std::string, llvm::Constant*> str_map_;
288 289
  // Whether current function is restricted
  bool is_restricted_{true};
290 291
  // The analyzer information
  std::unique_ptr<arith::Analyzer> analyzer_;
292 293
  // set of var that are not restricted(can alias)
  std::unordered_set<const Variable*> alias_var_set_;
294 295
  // set of volatile buffer.
  std::unordered_set<const Variable*> volatile_buf_;
296 297 298 299 300
};
}  // namespace codegen
}  // namespace tvm
#endif  // LLVM_VERSION
#endif  // TVM_CODEGEN_LLVM_CODEGEN_LLVM_H_