verify_gpu_code.cc 5.28 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166
/*!
 *  Copyright (c) 2018 by Contributors
 * \file verify_gpu_code.cc
 * \brief Verify the correctness of a GPU IR.
 *        It will check the whether the amount of memory usage or the number of threads
 *        in a block exceeds the limit
 */

#include <tvm/api_registry.h>
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>

namespace tvm {
namespace ir {

class GPUCodeVerifier : public IRVisitor {
 public:
  bool Verify(tvm::Stmt stmt,
              int64_t max_local_memory_per_block,
              int64_t max_shared_memory_per_block,
              int64_t max_thread_per_block,
              int64_t max_thread_x,
              int64_t max_thread_y,
              int64_t max_thread_z) {
    max_local_memory_per_block_ = static_cast<size_t>(max_local_memory_per_block);
    max_shared_memory_per_block_ = static_cast<size_t>(max_shared_memory_per_block);
    max_thread_per_block_ = static_cast<size_t>(max_thread_per_block);
    max_thread_x_ = static_cast<size_t>(max_thread_x);
    max_thread_y_ = static_cast<size_t>(max_thread_y);
    max_thread_z_ = static_cast<size_t>(max_thread_z);

    Reset_();

    this->Visit(stmt);

    return valid_;
  }

  void Visit_(const ProducerConsumer *op) {
    if (nest_level_ == 0) {
      // enter a new kernel, reset statistics
      Reset_();
    }

    if (op->is_producer) {
      nest_level_++;
      IRVisitor::Visit_(op);
      nest_level_--;
    } else {
      IRVisitor::Visit_(op);
    }

    if (nest_level_ == 0) {
      // exit a kernel, check the validity
      valid_ &= thread_per_block_ <= max_thread_per_block_;

      valid_ &= local_memory_per_block_ <= max_local_memory_per_block_;
      valid_ &= shared_memory_per_block_ <= max_shared_memory_per_block_;
    }
  }

  void Visit_(const Allocate *op) {
    IRVisitor::Visit_(op);
    // visit an allocation of a buffer in shared memory, record its size
    if (visited_local_buffers_.count(op->buffer_var.get()) != 0) {
      size_t size = static_cast<size_t>(op->constant_allocation_size());
      local_memory_per_block_ += size * op->type.bytes();
    } else if (visited_shared_buffers_.count(op->buffer_var.get()) != 0) {
      size_t size = static_cast<size_t>(op->constant_allocation_size());
      shared_memory_per_block_ += size * op->type.bytes();
    }
  }

  void Visit_(const AttrStmt *op) {
    if (op->attr_key == attr::storage_scope) {
      if (op->value.as<StringImm>()->value == "local") {
        visited_local_buffers_.insert(op->node.as<tvm::Variable>());
      } else if (op->value.as<StringImm>()->value == "shared") {
        visited_shared_buffers_.insert(op->node.as<tvm::Variable>());
      }
    } else if (op->attr_key == attr::thread_extent) {
      VarExpr var = op->node.as<tvm::IterVarNode>()->var;
      const auto *extent = op->value.as<IntImm>();
      CHECK(extent);

      // record the number of threads in a block
      std::string name = var.get()->name_hint;
      if (name == "threadIdx.x" || name == "threadIdx.y" || name == "threadIdx.z") {
        if (!visited_threads_.count(name)) {
          visited_threads_.insert(name);
          size_t length = static_cast<size_t>(extent->value);
          thread_per_block_ *= length;

          if (name == "threadIdx.x") {
            valid_ &= length <= max_thread_x_;
          } else if (name == "threadIdx.y") {
            valid_ &= length <= max_thread_y_;
          } else if (name == "threadIdx.z") {
            valid_ &= length <= max_thread_z_;
          }
        }
      }
    }
    IRVisitor::Visit_(op);
  }

 private:
  int nest_level_{0};

  std::unordered_set<const tvm::Variable *> visited_local_buffers_;
  std::unordered_set<const tvm::Variable *> visited_shared_buffers_;
  std::unordered_set<std::string> visited_threads_;

  size_t local_memory_per_block_;
  size_t shared_memory_per_block_;
  size_t thread_per_block_;

  size_t max_local_memory_per_block_;
  size_t max_shared_memory_per_block_;
  size_t max_thread_per_block_;
  size_t max_thread_x_, max_thread_y_, max_thread_z_;

  bool valid_{true};

  void Reset_() {
    visited_local_buffers_.clear();
    visited_shared_buffers_.clear();
    local_memory_per_block_ = 0;
    shared_memory_per_block_ = 0;

    visited_threads_.clear();
    thread_per_block_ = 1;
  }
};

bool VerifyGPUCode(Stmt stmt,
                   Map<std::string, Expr> constraints) {
  GPUCodeVerifier verifier;

  auto get_int = [&constraints](std::string key, int64_t def) {
    auto iter = constraints.find(key);
    if (iter != constraints.end()) {
      return ((*iter).second).as<IntImm>()->value;
    } else {
      return def;
    }
  };

  int64_t max_local_memory_per_block = get_int("max_local_memory_per_block", INT64_MAX);
  int64_t max_shared_memory_per_block = get_int("max_shared_memory_per_block", INT64_MAX);
  int64_t max_thread_per_block = get_int("max_thread_per_block", INT64_MAX);
  int64_t max_thread_x = get_int("max_thread_x", INT64_MAX);
  int64_t max_thread_y = get_int("max_thread_y", INT64_MAX);
  int64_t max_thread_z = get_int("max_thread_z", INT64_MAX);

  return verifier.Verify(stmt,
                         max_local_memory_per_block,
                         max_shared_memory_per_block,
                         max_thread_per_block,
                         max_thread_x,
                         max_thread_y,
                         max_thread_z);
}

}  // namespace ir
}  // namespace tvm