verify_memory.cc 6.16 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
/*!
 *  Copyright (c) 2018 by Contributors
 * \file verify_memory.cc
 * \brief Pass to check if memory accesses are legal.
 */
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_pass.h>

namespace tvm {
namespace ir {
namespace {

/*!
 * \brief Verify if memory accesses are legal.
 *
 *  In the case that tgt is cuda, if workload is not bound with
 *  threads, CPU code is generated that tries to access GPU memory,
 *  which is illegal.
 *
 *  This pass performs such verification by checking if all Producer/Consumer
 *  with memory accesses are bound with threads when device type is GPU.
 */
class MemoryAccessVerifier final : protected IRVisitor {
 public:
  /// Special member functions
  //@{
  explicit MemoryAccessVerifier(LoweredFunc f, int device_type)
      : func_(f), dev_type_(device_type) {}
  virtual ~MemoryAccessVerifier() = default;
  MemoryAccessVerifier(const MemoryAccessVerifier &) = delete;
  MemoryAccessVerifier(MemoryAccessVerifier &&) = delete;
  MemoryAccessVerifier &operator=(const MemoryAccessVerifier &) = delete;
  MemoryAccessVerifier &operator=(MemoryAccessVerifier &&) = delete;
  //@}

  /// Interface to perform memory access verification
  void Run() {
58
    if (!IsGPUDevice(dev_type_) && !IsFPGADevice(dev_type_)) return;
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164
    IRVisitor::Visit(func_->body);
  }

  /// Verification result
  bool Failed() const { return failure_; }

 protected:
  /// Visitor implementation
  //@{
  void Visit(const NodeRef &n) final {
    if (Failed()) return;
    IRVisitor::Visit(n);
  }

  void Visit_(const LetStmt *op) final {
    // Book keep definitions
    defs_[op->var.get()] = op->value;
    return IRVisitor::Visit_(op);
  }

  void Visit_(const AttrStmt *op) final {
    if (!InThreadEnv() && (op->attr_key == attr::thread_extent ||
                           op->attr_key == attr::pipeline_exec_scope)) {
      EnterThreadEnv();
      IRVisitor::Visit_(op);
      ExitThreadEnv();
    } else {
      IRVisitor::Visit_(op);
    }
  }

  void Visit_(const ProducerConsumer *op) final {
    EnterProducerConsumer(op);
    IRVisitor::Visit_(op);
    ExitProducerConsumer();
  }

  void Visit_(const Load *op) final {
    HandleLoadStoreToVariable(op->buffer_var);
    return IRVisitor::Visit_(op);
  }

  void Visit_(const Store *op) final {
    HandleLoadStoreToVariable(op->buffer_var);
    return IRVisitor::Visit_(op);
  }
  //@}

  /// Check if the value of a Variable comes from function argument.
  bool IsFromFunctionArgs(const Variable *var) const {
    const Variable *V = var;
    while (true) {
      CHECK(V) << "Invalid Variable\n";

      // Variable is from function args. Return true.
      if (V == func_->args[0].node_.get()) return true;

      // The value is expected to come from a tvm_struct_get Call.
      // Get the first argument of tvm_struct_get, and continue.
      const auto &iter = defs_.find(V);
      if (iter == defs_.end()) return false;
      const Call *C = iter->second.as<const Call>();
      if (!C || C->name != intrinsic::tvm_struct_get) return false;
      V = C->args[0].as<Variable>();
    }
    return false;
  }

  /// Handle memory access to a Variable
  void HandleLoadStoreToVariable(const VarExpr &var) {
    // We skip the access within thread env.
    if (InThreadEnv()) return;

    // We only check access within a producer/consumer.
    // Because for load/store out side of producer/consumer,
    // they don't have to be in thread env to stay legal (e.g. Load of args).
    if (!InProducerConsumer()) return;

    // We only handle the variable from function argument.
    // If it does not come from args, then it could be allocated internally,
    // it may possibly be in host or device address space.
    // We do not handle this case, and skip it conservatively.
    if (!IsFromFunctionArgs(var.get())) return;

    // The verification fails in this case.
    SetFailure();
  }

  /// Status getter/setter
  //@{
  bool InThreadEnv() const { return in_thread_env_; }
  void EnterThreadEnv() { in_thread_env_ = true; }
  void ExitThreadEnv() { in_thread_env_ = false; }
  bool InProducerConsumer() const { return pc_ != nullptr; }
  const ProducerConsumer *GetCurrentProducerConsumer() const { return pc_; }
  void EnterProducerConsumer(const ProducerConsumer *pc) { this->pc_ = pc; }
  void ExitProducerConsumer() { pc_ = nullptr; }
  void SetFailure() { failure_ = true; }
  //@}

  /// Check if a given DLDeviceType/TVMDeviceExtType value denotes GPU device.
  static bool IsGPUDevice(int dev_type) {
    return kDLGPU == dev_type || kDLOpenCL == dev_type ||
           kDLVulkan == dev_type || kDLMetal == dev_type ||
           kDLROCM == dev_type || kOpenGL == dev_type;
  }
165 166
  /// Check if a given DLDeviceType/TVMDeviceExtType value denotes FPGA device.
  static bool IsFPGADevice(int dev_type) {
167
    return kDLSDAccel == dev_type || kDLAOCL == dev_type;
168
  }
169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191

 private:
  /// Status of visitor
  //@{
  bool in_thread_env_{false};
  const ProducerConsumer *pc_{nullptr};
  bool failure_{false};  ///< If the verification fails (i.e. has illegal access)
  //@}
  LoweredFunc func_{nullptr};  ///< Function to be verified.
  int dev_type_{kDLCPU};       ///< Device type
  std::unordered_map<const Variable *, Expr> defs_;  ///< Variable definitions
};
}  // namespace

/// Interface of VerifyMemory pass
bool VerifyMemory(LoweredFunc func, int device_type) {
  MemoryAccessVerifier v(func, device_type);
  v.Run();
  return !v.Failed();
}

}  // namespace ir
}  // namespace tvm