/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file codegen_llvm.h
 * \brief Common base class for generating into LLVM IR
 */
#ifndef TVM_TARGET_LLVM_CODEGEN_LLVM_H_
#define TVM_TARGET_LLVM_CODEGEN_LLVM_H_
#ifdef TVM_LLVM_VERSION

#include <tvm/arith/analyzer.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/target/codegen.h>
#include <memory>
#include <utility>
#include <vector>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include "llvm_common.h"
#include "../../runtime/thread_storage_scope.h"
#include "../../arith/compute_expr.h"
#include "../../tir/pass/ir_util.h"

namespace tvm {
namespace codegen {

using namespace tir;

/*!
 * \brief A base class to generate a LLVM.
 */
class CodeGenLLVM :
      public ExprFunctor<llvm::Value* (const PrimExpr&)>,
      public StmtFunctor<void(const Stmt&)> {
 public:
  /*!
   * \brief Create new code generator based on target machine.
   * \param tm The target machine
   * \return The created llvm generator.
   */
  static std::unique_ptr<CodeGenLLVM> Create(llvm::TargetMachine* tm);
  /*!
   * \brief Initialize the code generator with given context
   * \param module_name The name of the module.
   * \param tm Target machine model
   * \param ctx The context.
   * \param system_lib Whether to insert system library registration.
   * \param dynamic_lookup Whether dynamically lookup runtime function
   *                       or use the runtime function table passed by caller.
   */
  virtual void Init(const std::string& module_name,
                    llvm::TargetMachine* tm,
                    llvm::LLVMContext* ctx,
                    bool system_lib,
                    bool dynamic_lookup);
  /*!
   * \brief Compile and add function f to the current module.
   * \param f The function to be added.
   */
  virtual void AddFunction(const LoweredFunc& f);
  /*!
   * \brief Add main function as the entry name
   * \param entry_func_name The name of entry function to be added.
   */
  virtual void AddMainFunction(const std::string& entry_func_name);
  /*!
   * \brief Finish current pass of codegen, get the module.
   * \return the created module.
   */
  virtual std::unique_ptr<llvm::Module> Finish();
  /*!
   * \brief Add mod to be linked with the generated module
   * \param mod The module to be linked.
   */
  void AddLinkModule(std::unique_ptr<llvm::Module>&& mod);
  /*!
   * \brief Create Value for expression e
   * \param e The expression to be created value for.
   * \return created value.
   */
  llvm::Value* MakeValue(const PrimExpr& e) {
    return VisitExpr(e);
  }
  // Short hande code to get a constant int 32
  llvm::Constant* ConstInt32(int64_t value) const {
    return llvm::ConstantInt::getSigned(t_int32_, value);
  }
  // override codegen
  llvm::Value* VisitExpr_(const VarNode* op) override;
  llvm::Value* VisitExpr_(const CastNode* op) override;
  llvm::Value* VisitExpr_(const IntImmNode* op) override;
  llvm::Value* VisitExpr_(const FloatImmNode* op) override;
  llvm::Value* VisitExpr_(const StringImmNode* op) override;
  llvm::Value* VisitExpr_(const AddNode* op) override;
  llvm::Value* VisitExpr_(const SubNode* op) override;
  llvm::Value* VisitExpr_(const MulNode* op) override;
  llvm::Value* VisitExpr_(const DivNode* op) override;
  llvm::Value* VisitExpr_(const ModNode* op) override;
  llvm::Value* VisitExpr_(const MinNode* op) override;
  llvm::Value* VisitExpr_(const MaxNode* op) override;
  llvm::Value* VisitExpr_(const LTNode* op) override;
  llvm::Value* VisitExpr_(const LENode* op) override;
  llvm::Value* VisitExpr_(const GTNode* op) override;
  llvm::Value* VisitExpr_(const GENode* op) override;
  llvm::Value* VisitExpr_(const EQNode* op) override;
  llvm::Value* VisitExpr_(const NENode* op) override;
  llvm::Value* VisitExpr_(const AndNode* op) override;
  llvm::Value* VisitExpr_(const OrNode* op) override;
  llvm::Value* VisitExpr_(const NotNode* op) override;
  llvm::Value* VisitExpr_(const SelectNode* op) override;
  llvm::Value* VisitExpr_(const LetNode* op) override;
  llvm::Value* VisitExpr_(const LoadNode* op) override;
  llvm::Value* VisitExpr_(const CallNode* op) override;
  llvm::Value* VisitExpr_(const RampNode* op) override;
  llvm::Value* VisitExpr_(const ShuffleNode* op) override;
  llvm::Value* VisitExpr_(const BroadcastNode* op) override;
  // stmt
  void VisitStmt_(const StoreNode* op) override;
  void VisitStmt_(const ForNode* op) override;
  void VisitStmt_(const IfThenElseNode* op) override;
  void VisitStmt_(const AllocateNode* op) override;
  void VisitStmt_(const AttrStmtNode* op) override;
  void VisitStmt_(const AssertStmtNode* op) override;
  void VisitStmt_(const LetStmtNode* op) override;
  void VisitStmt_(const SeqStmtNode* op) override;
  void VisitStmt_(const EvaluateNode* op) override;
  void VisitStmt_(const ProducerConsumerNode* op) override;

 protected:
  /*! \brief The storage information */
  struct StorageInfo {
    /*! \brief The storage scope */
    runtime::StorageScope scope;
    /*! \brief The alignment of allocation */
    int alignment{0};
  };
  /*!
   * \brief Execute falloca at the beginning of the
   *  currrent function and obtain its return value.
   *
   *  This is a helper function to make sure that
   *  alloca always happen in the beginning of the function.
   *
   * \param falloca The allocation function to be executed.
   * \tparam F The function to be executed.
   * \return The result.
   */
  template<typename F>
  inline llvm::AllocaInst* WithFunctionEntry(F falloca) {
    llvm::BasicBlock* current = builder_->GetInsertBlock();
    llvm::BasicBlock* entry = &(function_->getEntryBlock());
    builder_->SetInsertPoint(entry, entry->begin());
    llvm::AllocaInst* res = falloca();
    builder_->SetInsertPoint(current);
    return res;
  }
  // create intrinstic given call
  virtual llvm::Value* CreateIntrinsic(const CallNode* op);
  // create extern function call
  virtual llvm::Value* CreateCallExtern(const CallNode* op);
  // Get the corresponding thread index
  virtual llvm::Value* GetThreadIndex(const IterVar& iv);
  // Get the corresponding thread index
  virtual llvm::Value* CreateStorageSync(const CallNode* op);
  // apply optimization on the module.
  virtual void InitPassManagerBuilder(llvm::PassManagerBuilder* builder);
  // Scalarize by iterating elements of e.
  // f is a callback that takes index and v.
  virtual void Scalarize(const PrimExpr& e,
                         std::function<void(int i, llvm::Value* v)> f);
  // Initialize target
  virtual void InitTarget(llvm::TargetMachine* tm);
  // Add module startup function if needed.
  virtual void AddStartupFunction() {}
  // apply optimization on the module.
  virtual void Optimize();
  // Get the maximim storage align bits of buffer pointer given storage scope.
  virtual int NativeVectorBits(const runtime::StorageScope& storage_scope) const;
  // Get correct address space depending on the backend
  virtual unsigned GetGlobalAddressSpace();

  void AddFunctionInternal(const LoweredFunc& f, bool ret_void);
  // Create extern call
  llvm::CallInst* CreateCallExtern(llvm::Type* ret,
                                   const std::string& name,
                                   const std::vector<llvm::Value*>& value);
  /*!
   * \param t The original type.
   * \return LLVM type of t
   */
  llvm::Type* LLVMType(const DataType& t) const;
  // initialize the function state.
  void InitFuncState();
  // Get alignment given index.
  void GetAlignment(
      DataType t, const VarNode* buf_var, const PrimExpr& index,
      int* p_alignment, int* p_native_bits);
  // Get constant string
  llvm::Value* GetConstString(const std::string& str);
  // do a scalarize call with f
  llvm::Value* CreateScalarizedCall(
      const CallNode* op, llvm::Function* f, const std::vector<llvm::Value*>& args);
  // handle module import
  void HandleImport(const std::string& code);
  // cast operatpr
  llvm::Value* CreateCast(DataType from, DataType to, llvm::Value* value);
  // comparison op
  llvm::Value* GetVarValue(const VarNode* v) const;
  llvm::Value* CreateLT(DataType t, llvm::Value* a, llvm::Value* b);
  llvm::Value* CreateLE(DataType t, llvm::Value* a, llvm::Value* b);
  llvm::Value* CreateGT(DataType t, llvm::Value* a, llvm::Value* b);
  llvm::Value* CreateGE(DataType t, llvm::Value* a, llvm::Value* b);
  llvm::Value* CreateAdd(DataType t, llvm::Value* a, llvm::Value* b);
  llvm::Value* CreateSub(DataType t, llvm::Value* a, llvm::Value* b);
  llvm::Value* CreateMul(DataType t, llvm::Value* a, llvm::Value* b);
  llvm::Value* CreateBroadcast(llvm::Value* value, int lanes);
  llvm::Value* CreateBufferPtr(DataType t, llvm::Value* buffer, llvm::Value* index);
  llvm::Value* CreateBufferVecPtr(DataType t, llvm::Value* buffer, llvm::Value* index);
  // Vector concatenation.
  llvm::Value* CreateVecSlice(llvm::Value* vec, int begin, int extent);
  llvm::Value* CreateVecFlip(llvm::Value* vec);
  llvm::Value* CreateVecConcat(std::vector<llvm::Value*> vecs);
  llvm::Value* CreateVecPad(llvm::Value* vec, int target_lanes);
  // Create serial for
  void CreateSerialFor(llvm::Value* begin,
                       llvm::Value* end,
                       llvm::Value* stride,
                       const Var& loop_var, const Stmt& body);
  // add alias information.
  void AddAliasInfo(llvm::Instruction* load, const VarNode* buffer, PrimExpr index, DataType type);
  // The IRBuilder.
  using IRBuilder = llvm::IRBuilder<llvm::ConstantFolder, llvm::IRBuilderDefaultInserter>;
  // The current function
  llvm::Function* function_;
  // Internal builder
  std::unique_ptr<IRBuilder> builder_;
  // The module to be returned;
  std::unique_ptr<llvm::Module> module_;
  std::unique_ptr<llvm::DataLayout> data_layout_;
  // Internal metabuilder
  std::unique_ptr<llvm::MDBuilder> md_builder_;
  // llvm target machine
  llvm::TargetMachine* target_machine_{nullptr};
  // llvm context
  llvm::LLVMContext* ctx_{nullptr};
  // helpful data types
  llvm::Type* t_void_{nullptr};
  llvm::PointerType* t_void_p_{nullptr};
  llvm::Type* t_int_{nullptr};
  llvm::Type* t_char_{nullptr};
  llvm::Type* t_int8_{nullptr};
  llvm::Type* t_int16_{nullptr};
  llvm::Type* t_int32_{nullptr};
  llvm::Type* t_int64_{nullptr};
  llvm::Type* t_float64_{nullptr};
  // meta data
  llvm::MDNode* md_very_likely_branch_{nullptr};
  llvm::MDNode* md_tbaa_root_{nullptr};
  llvm::MDNode* md_tbaa_alias_set_{nullptr};
  // modules to be linked.
  std::vector<std::unique_ptr<llvm::Module> > link_modules_;
  /*! \brief native vector bits of current targetx*/
  int native_vector_bits_{0};
  /*! \brief the storage scope of allocation */
  std::unordered_map<const VarNode*, StorageInfo> alloc_storage_info_;
  // The definition of local variable.
  std::unordered_map<const VarNode*, llvm::Value*> var_map_;
  // global strings
  std::unordered_map<std::string, llvm::Constant*> str_map_;
  // Whether current function is restricted
  bool is_restricted_{true};
  // The analyzer information
  std::unique_ptr<arith::Analyzer> analyzer_;
  // set of var that are not restricted(can alias)
  std::unordered_set<const VarNode*> alias_var_set_;
  // set of volatile buffer.
  std::unordered_set<const VarNode*> volatile_buf_;
  /*! \brief Helper struct for debug infos. */
  struct DebugInfo {
    std::unique_ptr<llvm::DIBuilder> di_builder_;
    llvm::DICompileUnit* compilation_unit_{nullptr};
    llvm::DIFile* file_{nullptr};
  };
  /*!
   * \brief Create a new DebugInfo struct from the given Module that
   *  initializes file and compilation_unit_ to TVM defaults.
   */
  static std::unique_ptr<DebugInfo> CreateDebugInfo(llvm::Module* module);
};
}  // namespace codegen
}  // namespace tvm
#endif  // LLVM_VERSION
#endif  // TVM_TARGET_LLVM_CODEGEN_LLVM_H_