verify_gpu_code.cc 6.89 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
/*!
 *  Copyright (c) 2018 by Contributors
 * \file verify_gpu_code.cc
 * \brief Verify the correctness of a GPU IR.
 *        It will check the whether the amount of memory usage or the number of threads
 *        in a block exceeds the limit
 */

#include <tvm/api_registry.h>
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>

namespace tvm {
namespace ir {

class GPUCodeVerifier : public IRVisitor {
 public:
  bool Verify(tvm::Stmt stmt,
              int64_t max_local_memory_per_block,
              int64_t max_shared_memory_per_block,
40
              int64_t max_threads_per_block,
41 42 43 44 45
              int64_t max_thread_x,
              int64_t max_thread_y,
              int64_t max_thread_z) {
    max_local_memory_per_block_ = static_cast<size_t>(max_local_memory_per_block);
    max_shared_memory_per_block_ = static_cast<size_t>(max_shared_memory_per_block);
46
    max_threads_per_block_ = static_cast<size_t>(max_threads_per_block);
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
    max_thread_x_ = static_cast<size_t>(max_thread_x);
    max_thread_y_ = static_cast<size_t>(max_thread_y);
    max_thread_z_ = static_cast<size_t>(max_thread_z);

    Reset_();

    this->Visit(stmt);

    return valid_;
  }

  void Visit_(const ProducerConsumer *op) {
    if (nest_level_ == 0) {
      // enter a new kernel, reset statistics
      Reset_();
    }

    if (op->is_producer) {
      nest_level_++;
      IRVisitor::Visit_(op);
      nest_level_--;
    } else {
      IRVisitor::Visit_(op);
    }

    if (nest_level_ == 0) {
      // exit a kernel, check the validity
74
      valid_ &= thread_per_block_ <= max_threads_per_block_;
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94

      valid_ &= local_memory_per_block_ <= max_local_memory_per_block_;
      valid_ &= shared_memory_per_block_ <= max_shared_memory_per_block_;
    }
  }

  void Visit_(const Allocate *op) {
    IRVisitor::Visit_(op);
    // visit an allocation of a buffer in shared memory, record its size
    if (visited_local_buffers_.count(op->buffer_var.get()) != 0) {
      size_t size = static_cast<size_t>(op->constant_allocation_size());
      local_memory_per_block_ += size * op->type.bytes();
    } else if (visited_shared_buffers_.count(op->buffer_var.get()) != 0) {
      size_t size = static_cast<size_t>(op->constant_allocation_size());
      shared_memory_per_block_ += size * op->type.bytes();
    }
  }

  void Visit_(const AttrStmt *op) {
    if (op->attr_key == attr::storage_scope) {
95 96
      std::string op_value = op->value.as<StringImm>()->value;
      if (op_value == "local") {
97
        visited_local_buffers_.insert(op->node.as<tvm::Variable>());
98
      } else if (op_value == "shared") {
99 100 101 102 103 104 105 106 107 108
        visited_shared_buffers_.insert(op->node.as<tvm::Variable>());
      }
    } else if (op->attr_key == attr::thread_extent) {
      VarExpr var = op->node.as<tvm::IterVarNode>()->var;
      const auto *extent = op->value.as<IntImm>();
      CHECK(extent);

      // record the number of threads in a block
      std::string name = var.get()->name_hint;
      if (name == "threadIdx.x" || name == "threadIdx.y" || name == "threadIdx.z") {
109
        size_t length = static_cast<size_t>(extent->value);
110 111 112 113 114 115
        if (!visited_threads_.count(name)) {
          visited_threads_.insert(name);
          thread_per_block_ *= length;

          if (name == "threadIdx.x") {
            valid_ &= length <= max_thread_x_;
116
            thread_x_extent_ = length;
117 118
          } else if (name == "threadIdx.y") {
            valid_ &= length <= max_thread_y_;
119
            thread_y_extent_ = length;
120 121
          } else if (name == "threadIdx.z") {
            valid_ &= length <= max_thread_z_;
122 123 124 125 126 127 128 129 130 131
            thread_z_extent_ = length;
          }
        } else {
          // the thread should be bound to axes with the same length
          if (name == "threadIdx.x") {
            valid_ &= length == thread_x_extent_;
          } else if (name == "threadIdx.y") {
            valid_ &= length == thread_y_extent_;
          } else if (name == "threadIdx.z") {
            valid_ &= length == thread_z_extent_;
132 133 134 135 136 137 138 139 140 141 142 143 144 145
          }
        }
      }
    }
    IRVisitor::Visit_(op);
  }

 private:
  int nest_level_{0};

  std::unordered_set<const tvm::Variable *> visited_local_buffers_;
  std::unordered_set<const tvm::Variable *> visited_shared_buffers_;
  std::unordered_set<std::string> visited_threads_;

146 147
  size_t thread_x_extent_, thread_y_extent_, thread_z_extent_;

148 149 150 151 152 153
  size_t local_memory_per_block_;
  size_t shared_memory_per_block_;
  size_t thread_per_block_;

  size_t max_local_memory_per_block_;
  size_t max_shared_memory_per_block_;
154
  size_t max_threads_per_block_;
155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173
  size_t max_thread_x_, max_thread_y_, max_thread_z_;

  bool valid_{true};

  void Reset_() {
    visited_local_buffers_.clear();
    visited_shared_buffers_.clear();
    local_memory_per_block_ = 0;
    shared_memory_per_block_ = 0;

    visited_threads_.clear();
    thread_per_block_ = 1;
  }
};

bool VerifyGPUCode(Stmt stmt,
                   Map<std::string, Expr> constraints) {
  GPUCodeVerifier verifier;

174 175 176 177 178 179 180 181
  int64_t max_local_memory_per_block = INT64_MAX;
  int64_t max_shared_memory_per_block = INT64_MAX;
  int64_t max_threads_per_block = INT64_MAX;
  int64_t max_thread_x = INT64_MAX;
  int64_t max_thread_y = INT64_MAX;
  int64_t max_thread_z = INT64_MAX;

  for (auto iter : constraints) {
182
    const IntImm* val = iter.second.as<IntImm>();
183
    if (iter.first == "max_local_memory_per_block")
184
      max_local_memory_per_block = val->value;
185
    else if (iter.first == "max_shared_memory_per_block")
186
      max_shared_memory_per_block = val->value;
187
    else if (iter.first == "max_threads_per_block")
188
      max_threads_per_block = val->value;
189
    else if (iter.first == "max_thread_x")
190
      max_thread_x = val->value;
191
    else if (iter.first == "max_thread_y")
192
      max_thread_y = val->value;
193
    else if (iter.first == "max_thread_z")
194
      max_thread_z = val->value;
195 196 197
    else
      LOG(FATAL) << "Invalid check item: " << iter.first;
  }
198 199 200 201

  return verifier.Verify(stmt,
                         max_local_memory_per_block,
                         max_shared_memory_per_block,
202
                         max_threads_per_block,
203 204 205 206 207 208 209
                         max_thread_x,
                         max_thread_y,
                         max_thread_z);
}

}  // namespace ir
}  // namespace tvm