codegen_spirv.h 4.49 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
/*!
 *  Copyright (c) 2018 by Contributors
 * \file ir_builder.h
 * \brief Utility for building SPIRV code block
 */
#ifndef TVM_CODEGEN_SPIRV_CODEGEN_SPIRV_H_
#define TVM_CODEGEN_SPIRV_CODEGEN_SPIRV_H_

#include <tvm/ir.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/lowered_func.h>

#include <vector>

#include "./ir_builder.h"
#include "../../runtime/thread_storage_scope.h"

namespace tvm {
namespace codegen {

using namespace ir;

/*!
 * \brief Code generator into SPIRV
 */
class CodeGenSPIRV:
      public ExprFunctor<spirv::Value(const Expr&)>,
      public StmtFunctor<void(const Stmt&)> {
 public:
  /*!
   * \brief Compile and add function f to the current module.
   * \param f The function to be added.
   * \return The final spirv module.
   */
  virtual std::vector<uint32_t> BuildFunction(const LoweredFunc& f);
  /*!
   * \brief Create Value for expression e
   * \param e The expression to be created value for.
   * \return created value.
   */
  spirv::Value MakeValue(const Expr& e) {
    return VisitExpr(e);
  }
  // override codegen
  spirv::Value VisitExpr_(const Variable* op) override;
  spirv::Value VisitExpr_(const Cast* op) override;
  spirv::Value VisitExpr_(const IntImm* op) override;
  spirv::Value VisitExpr_(const UIntImm* op) override;
  spirv::Value VisitExpr_(const FloatImm* op) override;
  spirv::Value VisitExpr_(const StringImm* op) override;
  spirv::Value VisitExpr_(const Add* op) override;
  spirv::Value VisitExpr_(const Sub* op) override;
  spirv::Value VisitExpr_(const Mul* op) override;
  spirv::Value VisitExpr_(const Div* op) override;
  spirv::Value VisitExpr_(const Mod* op) override;
  spirv::Value VisitExpr_(const Min* op) override;
  spirv::Value VisitExpr_(const Max* op) override;
  spirv::Value VisitExpr_(const LT* op) override;
  spirv::Value VisitExpr_(const LE* op) override;
  spirv::Value VisitExpr_(const GT* op) override;
  spirv::Value VisitExpr_(const GE* op) override;
  spirv::Value VisitExpr_(const EQ* op) override;
  spirv::Value VisitExpr_(const NE* op) override;
  spirv::Value VisitExpr_(const And* op) override;
  spirv::Value VisitExpr_(const Or* op) override;
  spirv::Value VisitExpr_(const Not* op) override;
  spirv::Value VisitExpr_(const Select* op) override;
  spirv::Value VisitExpr_(const Let* op) override;
  spirv::Value VisitExpr_(const Call* op) override;
  spirv::Value VisitExpr_(const Ramp* op) override;
  spirv::Value VisitExpr_(const Broadcast* op) override;
  spirv::Value VisitExpr_(const Load* op) override;
  // stmt
  void VisitStmt_(const Store* op) override;
  void VisitStmt_(const For* op) override;
  void VisitStmt_(const IfThenElse* op) override;
  void VisitStmt_(const Allocate* op) override;
  void VisitStmt_(const AttrStmt* op) override;
  void VisitStmt_(const AssertStmt* op) override;
  void VisitStmt_(const LetStmt* op) override;
  void VisitStmt_(const Block* op) override;
  void VisitStmt_(const Evaluate* op) override;
  void VisitStmt_(const ProducerConsumer* op) override;

 protected:
  /*! \brief The storage information */
  struct StorageInfo {
    /*! \brief The storage scope */
    runtime::StorageScope scope;
    /*! \brief Whether it is volatile */
    bool is_volatile{false};
    /*! \brief Whether it is volatile */
    bool content_fixed{false};
    /*! \brief Current content type */
    Type content_type{Handle()};

    // Update content type if it hasn't beenupdated.
    void UpdateContentType(Type type) {
      if (content_fixed) {
        CHECK_EQ(type, content_type)
            << "Cannot use two different content type in GLSL model";
      } else {
        this->content_type = type;
        content_fixed = true;
      }
    }
  };
  // Reset the state so it works for a new function.
  void InitFuncState();
  // Get the thread index
  spirv::Value GetThreadIndex(const IterVar& iv, const Expr& extent);
  spirv::Value CreateStorageSync(const Call* op);
  void Scalarize(const Expr& e,
                 std::function<void(int i, spirv::Value v)> f);
  // The builder
  std::unique_ptr<spirv::IRBuilder> builder_;
  // Work group size of three
  uint32_t workgroup_size_[3];
  // Likely branch
  uint32_t weight_likely_branch_{128};
  // the storage scope of allocation
  std::unordered_map<const Variable*, StorageInfo> storage_info_;
  // The definition of local variable.
  std::unordered_map<const Variable*, spirv::Value> var_map_;
  // The alignment information
  std::unordered_map<const Variable*, arith::ModularEntry> align_map_;
};

}  // namespace codegen
}  // namespace tvm


#endif  // TVM_CODEGEN_SPIRV_CODEGEN_SPIRV_H_