codegen_spirv.h 5.41 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23
/*!
 * \file ir_builder.h
 * \brief Utility for building SPIRV code block
 */
24 25
#ifndef TVM_TARGET_SPIRV_CODEGEN_SPIRV_H_
#define TVM_TARGET_SPIRV_CODEGEN_SPIRV_H_
26

27
#include <tvm/arith/analyzer.h>
28 29 30
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/lowered_func.h>
31 32

#include <vector>
33 34
#include <memory>
#include <unordered_map>
35

36
#include "ir_builder.h"
37 38 39 40 41
#include "../../runtime/thread_storage_scope.h"

namespace tvm {
namespace codegen {

42
using namespace tir;
43 44 45 46 47

/*!
 * \brief Code generator into SPIRV
 */
class CodeGenSPIRV:
48
      public ExprFunctor<spirv::Value(const PrimExpr&)>,
49 50 51 52 53 54 55 56 57 58 59 60 61
      public StmtFunctor<void(const Stmt&)> {
 public:
  /*!
   * \brief Compile and add function f to the current module.
   * \param f The function to be added.
   * \return The final spirv module.
   */
  virtual std::vector<uint32_t> BuildFunction(const LoweredFunc& f);
  /*!
   * \brief Create Value for expression e
   * \param e The expression to be created value for.
   * \return created value.
   */
62
  spirv::Value MakeValue(const PrimExpr& e) {
63 64 65
    return VisitExpr(e);
  }
  // override codegen
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
  spirv::Value VisitExpr_(const VarNode* op) override;
  spirv::Value VisitExpr_(const CastNode* op) override;
  spirv::Value VisitExpr_(const IntImmNode* op) override;
  spirv::Value VisitExpr_(const FloatImmNode* op) override;
  spirv::Value VisitExpr_(const StringImmNode* op) override;
  spirv::Value VisitExpr_(const AddNode* op) override;
  spirv::Value VisitExpr_(const SubNode* op) override;
  spirv::Value VisitExpr_(const MulNode* op) override;
  spirv::Value VisitExpr_(const DivNode* op) override;
  spirv::Value VisitExpr_(const ModNode* op) override;
  spirv::Value VisitExpr_(const MinNode* op) override;
  spirv::Value VisitExpr_(const MaxNode* op) override;
  spirv::Value VisitExpr_(const LTNode* op) override;
  spirv::Value VisitExpr_(const LENode* op) override;
  spirv::Value VisitExpr_(const GTNode* op) override;
  spirv::Value VisitExpr_(const GENode* op) override;
  spirv::Value VisitExpr_(const EQNode* op) override;
  spirv::Value VisitExpr_(const NENode* op) override;
  spirv::Value VisitExpr_(const AndNode* op) override;
  spirv::Value VisitExpr_(const OrNode* op) override;
  spirv::Value VisitExpr_(const NotNode* op) override;
  spirv::Value VisitExpr_(const SelectNode* op) override;
  spirv::Value VisitExpr_(const LetNode* op) override;
  spirv::Value VisitExpr_(const CallNode* op) override;
  spirv::Value VisitExpr_(const RampNode* op) override;
  spirv::Value VisitExpr_(const BroadcastNode* op) override;
  spirv::Value VisitExpr_(const LoadNode* op) override;
93
  // stmt
94 95 96 97 98 99 100
  void VisitStmt_(const StoreNode* op) override;
  void VisitStmt_(const ForNode* op) override;
  void VisitStmt_(const IfThenElseNode* op) override;
  void VisitStmt_(const AllocateNode* op) override;
  void VisitStmt_(const AttrStmtNode* op) override;
  void VisitStmt_(const AssertStmtNode* op) override;
  void VisitStmt_(const LetStmtNode* op) override;
101
  void VisitStmt_(const SeqStmtNode* op) override;
102 103
  void VisitStmt_(const EvaluateNode* op) override;
  void VisitStmt_(const ProducerConsumerNode* op) override;
104 105 106 107 108 109 110 111 112 113 114

 protected:
  /*! \brief The storage information */
  struct StorageInfo {
    /*! \brief The storage scope */
    runtime::StorageScope scope;
    /*! \brief Whether it is volatile */
    bool is_volatile{false};
    /*! \brief Whether it is volatile */
    bool content_fixed{false};
    /*! \brief Current content type */
115
    DataType content_type{DataType::Handle()};
116 117

    // Update content type if it hasn't beenupdated.
118
    void UpdateContentType(DataType type) {
119 120 121 122 123 124 125 126 127 128 129 130
      if (content_fixed) {
        CHECK_EQ(type, content_type)
            << "Cannot use two different content type in GLSL model";
      } else {
        this->content_type = type;
        content_fixed = true;
      }
    }
  };
  // Reset the state so it works for a new function.
  void InitFuncState();
  // Get the thread index
131
  spirv::Value GetThreadIndex(const IterVar& iv, const PrimExpr& extent);
132
  spirv::Value CreateStorageSync(const CallNode* op);
133
  void Scalarize(const PrimExpr& e,
134 135 136 137 138 139 140 141
                 std::function<void(int i, spirv::Value v)> f);
  // The builder
  std::unique_ptr<spirv::IRBuilder> builder_;
  // Work group size of three
  uint32_t workgroup_size_[3];
  // Likely branch
  uint32_t weight_likely_branch_{128};
  // the storage scope of allocation
142
  std::unordered_map<const VarNode*, StorageInfo> storage_info_;
143
  // The definition of local variable.
144
  std::unordered_map<const VarNode*, spirv::Value> var_map_;
145 146
  // The analyzer.
  std::unique_ptr<arith::Analyzer> analyzer_;
147 148 149 150 151 152
};

}  // namespace codegen
}  // namespace tvm


153
#endif  // TVM_TARGET_SPIRV_CODEGEN_SPIRV_H_