verify_gpu_code.cc 6.99 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25 26
/*!
 * \file verify_gpu_code.cc
 * \brief Verify the correctness of a GPU IR.
 *        It will check the whether the amount of memory usage or the number of threads
 *        in a block exceeds the limit
 */

27 28
#include <tvm/runtime/registry.h>

29
#include <tvm/ir.h>
30
#include <tvm/ir_functor_ext.h>
31 32 33 34

namespace tvm {
namespace ir {

35
class GPUCodeVerifier : public StmtVisitor {
36 37 38 39
 public:
  bool Verify(tvm::Stmt stmt,
              int64_t max_local_memory_per_block,
              int64_t max_shared_memory_per_block,
40
              int64_t max_threads_per_block,
41 42 43 44 45
              int64_t max_thread_x,
              int64_t max_thread_y,
              int64_t max_thread_z) {
    max_local_memory_per_block_ = static_cast<size_t>(max_local_memory_per_block);
    max_shared_memory_per_block_ = static_cast<size_t>(max_shared_memory_per_block);
46
    max_threads_per_block_ = static_cast<size_t>(max_threads_per_block);
47 48 49 50 51 52
    max_thread_x_ = static_cast<size_t>(max_thread_x);
    max_thread_y_ = static_cast<size_t>(max_thread_y);
    max_thread_z_ = static_cast<size_t>(max_thread_z);

    Reset_();

53
    this->VisitStmt(stmt);
54 55 56 57

    return valid_;
  }

58
  void VisitStmt_(const ProducerConsumerNode* op) final {
59 60 61 62 63 64 65
    if (nest_level_ == 0) {
      // enter a new kernel, reset statistics
      Reset_();
    }

    if (op->is_producer) {
      nest_level_++;
66
      StmtVisitor::VisitStmt_(op);
67 68
      nest_level_--;
    } else {
69
      StmtVisitor::VisitStmt_(op);
70 71 72 73
    }

    if (nest_level_ == 0) {
      // exit a kernel, check the validity
74
      valid_ &= thread_per_block_ <= max_threads_per_block_;
75 76 77 78 79 80

      valid_ &= local_memory_per_block_ <= max_local_memory_per_block_;
      valid_ &= shared_memory_per_block_ <= max_shared_memory_per_block_;
    }
  }

81
  void VisitStmt_(const AllocateNode* op) final {
82
    StmtVisitor::VisitStmt_(op);
83 84 85
    // visit an allocation of a buffer in shared memory, record its size
    if (visited_local_buffers_.count(op->buffer_var.get()) != 0) {
      size_t size = static_cast<size_t>(op->constant_allocation_size());
86
      local_memory_per_block_ += size * op->dtype.bytes() * op->dtype.lanes();
87 88
    } else if (visited_shared_buffers_.count(op->buffer_var.get()) != 0) {
      size_t size = static_cast<size_t>(op->constant_allocation_size());
89
      shared_memory_per_block_ += size * op->dtype.bytes() * op->dtype.lanes();
90 91 92
    }
  }

93
  void VisitStmt_(const AttrStmtNode* op) final {
94
    if (op->attr_key == attr::storage_scope) {
95
      std::string op_value = op->value.as<StringImmNode>()->value;
96
      if (op_value == "local") {
97
        visited_local_buffers_.insert(op->node.as<tvm::VarNode>());
98
      } else if (op_value == "shared") {
99
        visited_shared_buffers_.insert(op->node.as<tvm::VarNode>());
100 101
      }
    } else if (op->attr_key == attr::thread_extent) {
102
      Var var = op->node.as<tvm::IterVarNode>()->var;
103
      const auto *extent = op->value.as<IntImmNode>();
104 105 106 107 108
      CHECK(extent);

      // record the number of threads in a block
      std::string name = var.get()->name_hint;
      if (name == "threadIdx.x" || name == "threadIdx.y" || name == "threadIdx.z") {
109
        size_t length = static_cast<size_t>(extent->value);
110 111 112 113 114 115
        if (!visited_threads_.count(name)) {
          visited_threads_.insert(name);
          thread_per_block_ *= length;

          if (name == "threadIdx.x") {
            valid_ &= length <= max_thread_x_;
116
            thread_x_extent_ = length;
117 118
          } else if (name == "threadIdx.y") {
            valid_ &= length <= max_thread_y_;
119
            thread_y_extent_ = length;
120 121
          } else if (name == "threadIdx.z") {
            valid_ &= length <= max_thread_z_;
122 123 124 125 126 127 128 129 130 131
            thread_z_extent_ = length;
          }
        } else {
          // the thread should be bound to axes with the same length
          if (name == "threadIdx.x") {
            valid_ &= length == thread_x_extent_;
          } else if (name == "threadIdx.y") {
            valid_ &= length == thread_y_extent_;
          } else if (name == "threadIdx.z") {
            valid_ &= length == thread_z_extent_;
132 133 134 135
          }
        }
      }
    }
136
    StmtVisitor::VisitStmt_(op);
137 138 139 140 141
  }

 private:
  int nest_level_{0};

142 143
  std::unordered_set<const tvm::VarNode *> visited_local_buffers_;
  std::unordered_set<const tvm::VarNode *> visited_shared_buffers_;
144 145
  std::unordered_set<std::string> visited_threads_;

146 147
  size_t thread_x_extent_, thread_y_extent_, thread_z_extent_;

148 149 150 151 152 153
  size_t local_memory_per_block_;
  size_t shared_memory_per_block_;
  size_t thread_per_block_;

  size_t max_local_memory_per_block_;
  size_t max_shared_memory_per_block_;
154
  size_t max_threads_per_block_;
155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
  size_t max_thread_x_, max_thread_y_, max_thread_z_;

  bool valid_{true};

  void Reset_() {
    visited_local_buffers_.clear();
    visited_shared_buffers_.clear();
    local_memory_per_block_ = 0;
    shared_memory_per_block_ = 0;

    visited_threads_.clear();
    thread_per_block_ = 1;
  }
};

bool VerifyGPUCode(Stmt stmt,
171
                   Map<std::string, PrimExpr> constraints) {
172 173
  GPUCodeVerifier verifier;

174 175 176 177 178 179 180 181
  int64_t max_local_memory_per_block = INT64_MAX;
  int64_t max_shared_memory_per_block = INT64_MAX;
  int64_t max_threads_per_block = INT64_MAX;
  int64_t max_thread_x = INT64_MAX;
  int64_t max_thread_y = INT64_MAX;
  int64_t max_thread_z = INT64_MAX;

  for (auto iter : constraints) {
182
    const IntImmNode* val = iter.second.as<IntImmNode>();
183
    if (iter.first == "max_local_memory_per_block")
184
      max_local_memory_per_block = val->value;
185
    else if (iter.first == "max_shared_memory_per_block")
186
      max_shared_memory_per_block = val->value;
187
    else if (iter.first == "max_threads_per_block")
188
      max_threads_per_block = val->value;
189
    else if (iter.first == "max_thread_x")
190
      max_thread_x = val->value;
191
    else if (iter.first == "max_thread_y")
192
      max_thread_y = val->value;
193
    else if (iter.first == "max_thread_z")
194
      max_thread_z = val->value;
195 196 197
    else
      LOG(FATAL) << "Invalid check item: " << iter.first;
  }
198 199 200 201

  return verifier.Verify(stmt,
                         max_local_memory_per_block,
                         max_shared_memory_per_block,
202
                         max_threads_per_block,
203 204 205 206 207 208 209
                         max_thread_x,
                         max_thread_y,
                         max_thread_z);
}

}  // namespace ir
}  // namespace tvm