verify_memory.cc 6.37 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25
/*!
 * \file verify_memory.cc
 * \brief Pass to check if memory accesses are legal.
 */
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
26 27
#include <tvm/ir_functor_ext.h>

28 29 30 31 32 33 34 35 36 37 38 39 40 41 42

namespace tvm {
namespace ir {
namespace {

/*!
 * \brief Verify if memory accesses are legal.
 *
 *  In the case that tgt is cuda, if workload is not bound with
 *  threads, CPU code is generated that tries to access GPU memory,
 *  which is illegal.
 *
 *  This pass performs such verification by checking if all Producer/Consumer
 *  with memory accesses are bound with threads when device type is GPU.
 */
43
class MemoryAccessVerifier final : protected StmtExprVisitor {
44 45 46 47 48 49 50 51 52 53 54 55 56 57
 public:
  /// Special member functions
  //@{
  explicit MemoryAccessVerifier(LoweredFunc f, int device_type)
      : func_(f), dev_type_(device_type) {}
  virtual ~MemoryAccessVerifier() = default;
  MemoryAccessVerifier(const MemoryAccessVerifier &) = delete;
  MemoryAccessVerifier(MemoryAccessVerifier &&) = delete;
  MemoryAccessVerifier &operator=(const MemoryAccessVerifier &) = delete;
  MemoryAccessVerifier &operator=(MemoryAccessVerifier &&) = delete;
  //@}

  /// Interface to perform memory access verification
  void Run() {
58
    if (!IsGPUDevice(dev_type_) && !IsFPGADevice(dev_type_)) return;
59
    StmtExprVisitor::VisitStmt(func_->body);
60 61 62 63 64 65 66 67
  }

  /// Verification result
  bool Failed() const { return failure_; }

 protected:
  /// Visitor implementation
  //@{
68
  void VisitExpr(const PrimExpr &n) final {
69 70 71 72 73
    if (Failed()) return;
    StmtExprVisitor::VisitExpr(n);
  }

  void VisitStmt(const Stmt &n) final {
74
    if (Failed()) return;
75
    StmtExprVisitor::VisitStmt(n);
76 77
  }

78
  void VisitStmt_(const LetStmtNode* op) final {
79 80
    // Book keep definitions
    defs_[op->var.get()] = op->value;
81
    return StmtExprVisitor::VisitStmt_(op);
82 83
  }

84
  void VisitStmt_(const AttrStmtNode* op) final {
85 86 87
    if (!InThreadEnv() && (op->attr_key == attr::thread_extent ||
                           op->attr_key == attr::pipeline_exec_scope)) {
      EnterThreadEnv();
88
      StmtExprVisitor::VisitStmt_(op);
89 90
      ExitThreadEnv();
    } else {
91
      StmtExprVisitor::VisitStmt_(op);
92 93 94
    }
  }

95
  void VisitStmt_(const ProducerConsumerNode* op) final {
96
    EnterProducerConsumer(op);
97
    StmtExprVisitor::VisitStmt_(op);
98 99 100
    ExitProducerConsumer();
  }

101
  void VisitExpr_(const LoadNode* op) final {
102
    HandleLoadStoreToVariable(op->buffer_var);
103
    return StmtExprVisitor::VisitExpr_(op);
104 105
  }

106
  void VisitStmt_(const StoreNode* op) final {
107
    HandleLoadStoreToVariable(op->buffer_var);
108
    return StmtExprVisitor::VisitStmt_(op);
109 110 111 112
  }
  //@}

  /// Check if the value of a Variable comes from function argument.
113 114
  bool IsFromFunctionArgs(const VarNode *var) const {
    const VarNode *V = var;
115 116 117 118
    while (true) {
      CHECK(V) << "Invalid Variable\n";

      // Variable is from function args. Return true.
119
      if (V == func_->args[0].get()) return true;
120 121 122 123 124

      // The value is expected to come from a tvm_struct_get Call.
      // Get the first argument of tvm_struct_get, and continue.
      const auto &iter = defs_.find(V);
      if (iter == defs_.end()) return false;
125
      const CallNode *C = iter->second.as<const CallNode>();
126
      if (!C || C->name != intrinsic::tvm_struct_get) return false;
127
      V = C->args[0].as<VarNode>();
128 129 130 131 132
    }
    return false;
  }

  /// Handle memory access to a Variable
133
  void HandleLoadStoreToVariable(const Var &var) {
134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
    // We skip the access within thread env.
    if (InThreadEnv()) return;

    // We only check access within a producer/consumer.
    // Because for load/store out side of producer/consumer,
    // they don't have to be in thread env to stay legal (e.g. Load of args).
    if (!InProducerConsumer()) return;

    // We only handle the variable from function argument.
    // If it does not come from args, then it could be allocated internally,
    // it may possibly be in host or device address space.
    // We do not handle this case, and skip it conservatively.
    if (!IsFromFunctionArgs(var.get())) return;

    // The verification fails in this case.
    SetFailure();
  }

  /// Status getter/setter
  //@{
  bool InThreadEnv() const { return in_thread_env_; }
  void EnterThreadEnv() { in_thread_env_ = true; }
  void ExitThreadEnv() { in_thread_env_ = false; }
  bool InProducerConsumer() const { return pc_ != nullptr; }
158 159
  const ProducerConsumerNode *GetCurrentProducerConsumer() const { return pc_; }
  void EnterProducerConsumer(const ProducerConsumerNode *pc) { this->pc_ = pc; }
160 161 162 163 164 165 166 167 168 169
  void ExitProducerConsumer() { pc_ = nullptr; }
  void SetFailure() { failure_ = true; }
  //@}

  /// Check if a given DLDeviceType/TVMDeviceExtType value denotes GPU device.
  static bool IsGPUDevice(int dev_type) {
    return kDLGPU == dev_type || kDLOpenCL == dev_type ||
           kDLVulkan == dev_type || kDLMetal == dev_type ||
           kDLROCM == dev_type || kOpenGL == dev_type;
  }
170 171
  /// Check if a given DLDeviceType/TVMDeviceExtType value denotes FPGA device.
  static bool IsFPGADevice(int dev_type) {
172
    return kDLSDAccel == dev_type || kDLAOCL == dev_type;
173
  }
174 175 176 177 178

 private:
  /// Status of visitor
  //@{
  bool in_thread_env_{false};
179
  const ProducerConsumerNode *pc_{nullptr};
180 181 182 183
  bool failure_{false};  ///< If the verification fails (i.e. has illegal access)
  //@}
  LoweredFunc func_{nullptr};  ///< Function to be verified.
  int dev_type_{kDLCPU};       ///< Device type
184
  std::unordered_map<const VarNode *, PrimExpr> defs_;  ///< Variable definitions
185 186 187 188 189 190 191 192 193 194 195 196
};
}  // namespace

/// Interface of VerifyMemory pass
bool VerifyMemory(LoweredFunc func, int device_type) {
  MemoryAccessVerifier v(func, device_type);
  v.Run();
  return !v.Failed();
}

}  // namespace ir
}  // namespace tvm