verify_gpu_code.cc 6.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
/*!
 *  Copyright (c) 2018 by Contributors
 * \file verify_gpu_code.cc
 * \brief Verify the correctness of a GPU IR.
 *        It will check the whether the amount of memory usage or the number of threads
 *        in a block exceeds the limit
 */

#include <tvm/api_registry.h>
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>

namespace tvm {
namespace ir {

class GPUCodeVerifier : public IRVisitor {
 public:
  bool Verify(tvm::Stmt stmt,
              int64_t max_local_memory_per_block,
              int64_t max_shared_memory_per_block,
21
              int64_t max_threads_per_block,
22 23 24 25 26
              int64_t max_thread_x,
              int64_t max_thread_y,
              int64_t max_thread_z) {
    max_local_memory_per_block_ = static_cast<size_t>(max_local_memory_per_block);
    max_shared_memory_per_block_ = static_cast<size_t>(max_shared_memory_per_block);
27
    max_threads_per_block_ = static_cast<size_t>(max_threads_per_block);
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
    max_thread_x_ = static_cast<size_t>(max_thread_x);
    max_thread_y_ = static_cast<size_t>(max_thread_y);
    max_thread_z_ = static_cast<size_t>(max_thread_z);

    Reset_();

    this->Visit(stmt);

    return valid_;
  }

  void Visit_(const ProducerConsumer *op) {
    if (nest_level_ == 0) {
      // enter a new kernel, reset statistics
      Reset_();
    }

    if (op->is_producer) {
      nest_level_++;
      IRVisitor::Visit_(op);
      nest_level_--;
    } else {
      IRVisitor::Visit_(op);
    }

    if (nest_level_ == 0) {
      // exit a kernel, check the validity
55
      valid_ &= thread_per_block_ <= max_threads_per_block_;
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75

      valid_ &= local_memory_per_block_ <= max_local_memory_per_block_;
      valid_ &= shared_memory_per_block_ <= max_shared_memory_per_block_;
    }
  }

  void Visit_(const Allocate *op) {
    IRVisitor::Visit_(op);
    // visit an allocation of a buffer in shared memory, record its size
    if (visited_local_buffers_.count(op->buffer_var.get()) != 0) {
      size_t size = static_cast<size_t>(op->constant_allocation_size());
      local_memory_per_block_ += size * op->type.bytes();
    } else if (visited_shared_buffers_.count(op->buffer_var.get()) != 0) {
      size_t size = static_cast<size_t>(op->constant_allocation_size());
      shared_memory_per_block_ += size * op->type.bytes();
    }
  }

  void Visit_(const AttrStmt *op) {
    if (op->attr_key == attr::storage_scope) {
76 77
      std::string op_value = op->value.as<StringImm>()->value;
      if (op_value == "local") {
78
        visited_local_buffers_.insert(op->node.as<tvm::Variable>());
79
      } else if (op_value == "shared") {
80 81 82 83 84 85 86 87 88 89
        visited_shared_buffers_.insert(op->node.as<tvm::Variable>());
      }
    } else if (op->attr_key == attr::thread_extent) {
      VarExpr var = op->node.as<tvm::IterVarNode>()->var;
      const auto *extent = op->value.as<IntImm>();
      CHECK(extent);

      // record the number of threads in a block
      std::string name = var.get()->name_hint;
      if (name == "threadIdx.x" || name == "threadIdx.y" || name == "threadIdx.z") {
90
        size_t length = static_cast<size_t>(extent->value);
91 92 93 94 95 96
        if (!visited_threads_.count(name)) {
          visited_threads_.insert(name);
          thread_per_block_ *= length;

          if (name == "threadIdx.x") {
            valid_ &= length <= max_thread_x_;
97
            thread_x_extent_ = length;
98 99
          } else if (name == "threadIdx.y") {
            valid_ &= length <= max_thread_y_;
100
            thread_y_extent_ = length;
101 102
          } else if (name == "threadIdx.z") {
            valid_ &= length <= max_thread_z_;
103 104 105 106 107 108 109 110 111 112
            thread_z_extent_ = length;
          }
        } else {
          // the thread should be bound to axes with the same length
          if (name == "threadIdx.x") {
            valid_ &= length == thread_x_extent_;
          } else if (name == "threadIdx.y") {
            valid_ &= length == thread_y_extent_;
          } else if (name == "threadIdx.z") {
            valid_ &= length == thread_z_extent_;
113 114 115 116 117 118 119 120 121 122 123 124 125 126
          }
        }
      }
    }
    IRVisitor::Visit_(op);
  }

 private:
  int nest_level_{0};

  std::unordered_set<const tvm::Variable *> visited_local_buffers_;
  std::unordered_set<const tvm::Variable *> visited_shared_buffers_;
  std::unordered_set<std::string> visited_threads_;

127 128
  size_t thread_x_extent_, thread_y_extent_, thread_z_extent_;

129 130 131 132 133 134
  size_t local_memory_per_block_;
  size_t shared_memory_per_block_;
  size_t thread_per_block_;

  size_t max_local_memory_per_block_;
  size_t max_shared_memory_per_block_;
135
  size_t max_threads_per_block_;
136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154
  size_t max_thread_x_, max_thread_y_, max_thread_z_;

  bool valid_{true};

  void Reset_() {
    visited_local_buffers_.clear();
    visited_shared_buffers_.clear();
    local_memory_per_block_ = 0;
    shared_memory_per_block_ = 0;

    visited_threads_.clear();
    thread_per_block_ = 1;
  }
};

bool VerifyGPUCode(Stmt stmt,
                   Map<std::string, Expr> constraints) {
  GPUCodeVerifier verifier;

155 156 157 158 159 160 161 162
  int64_t max_local_memory_per_block = INT64_MAX;
  int64_t max_shared_memory_per_block = INT64_MAX;
  int64_t max_threads_per_block = INT64_MAX;
  int64_t max_thread_x = INT64_MAX;
  int64_t max_thread_y = INT64_MAX;
  int64_t max_thread_z = INT64_MAX;

  for (auto iter : constraints) {
163
    const IntImm* val = iter.second.as<IntImm>();
164
    if (iter.first == "max_local_memory_per_block")
165
      max_local_memory_per_block = val->value;
166
    else if (iter.first == "max_shared_memory_per_block")
167
      max_shared_memory_per_block = val->value;
168
    else if (iter.first == "max_threads_per_block")
169
      max_threads_per_block = val->value;
170
    else if (iter.first == "max_thread_x")
171
      max_thread_x = val->value;
172
    else if (iter.first == "max_thread_y")
173
      max_thread_y = val->value;
174
    else if (iter.first == "max_thread_z")
175
      max_thread_z = val->value;
176 177 178
    else
      LOG(FATAL) << "Invalid check item: " << iter.first;
  }
179 180 181 182

  return verifier.Verify(stmt,
                         max_local_memory_per_block,
                         max_shared_memory_per_block,
183
                         max_threads_per_block,
184 185 186 187 188 189 190
                         max_thread_x,
                         max_thread_y,
                         max_thread_z);
}

}  // namespace ir
}  // namespace tvm