/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * Copyright (c) 2018 by Contributors * \file verify_memory.cc * \brief Pass to check if memory accesses are legal. */ #include <tvm/ir.h> #include <tvm/ir_visitor.h> #include <tvm/ir_pass.h> namespace tvm { namespace ir { namespace { /*! * \brief Verify if memory accesses are legal. * * In the case that tgt is cuda, if workload is not bound with * threads, CPU code is generated that tries to access GPU memory, * which is illegal. * * This pass performs such verification by checking if all Producer/Consumer * with memory accesses are bound with threads when device type is GPU. */ class MemoryAccessVerifier final : protected IRVisitor { public: /// Special member functions //@{ explicit MemoryAccessVerifier(LoweredFunc f, int device_type) : func_(f), dev_type_(device_type) {} virtual ~MemoryAccessVerifier() = default; MemoryAccessVerifier(const MemoryAccessVerifier &) = delete; MemoryAccessVerifier(MemoryAccessVerifier &&) = delete; MemoryAccessVerifier &operator=(const MemoryAccessVerifier &) = delete; MemoryAccessVerifier &operator=(MemoryAccessVerifier &&) = delete; //@} /// Interface to perform memory access verification void Run() { if (!IsGPUDevice(dev_type_) && !IsFPGADevice(dev_type_)) return; IRVisitor::Visit(func_->body); } /// Verification result bool Failed() const { return failure_; } protected: /// Visitor implementation //@{ void Visit(const NodeRef &n) final { if (Failed()) return; IRVisitor::Visit(n); } void Visit_(const LetStmt *op) final { // Book keep definitions defs_[op->var.get()] = op->value; return IRVisitor::Visit_(op); } void Visit_(const AttrStmt *op) final { if (!InThreadEnv() && (op->attr_key == attr::thread_extent || op->attr_key == attr::pipeline_exec_scope)) { EnterThreadEnv(); IRVisitor::Visit_(op); ExitThreadEnv(); } else { IRVisitor::Visit_(op); } } void Visit_(const ProducerConsumer *op) final { EnterProducerConsumer(op); IRVisitor::Visit_(op); ExitProducerConsumer(); } void Visit_(const Load *op) final { HandleLoadStoreToVariable(op->buffer_var); return IRVisitor::Visit_(op); } void Visit_(const Store *op) final { HandleLoadStoreToVariable(op->buffer_var); return IRVisitor::Visit_(op); } //@} /// Check if the value of a Variable comes from function argument. bool IsFromFunctionArgs(const Variable *var) const { const Variable *V = var; while (true) { CHECK(V) << "Invalid Variable\n"; // Variable is from function args. Return true. if (V == func_->args[0].node_.get()) return true; // The value is expected to come from a tvm_struct_get Call. // Get the first argument of tvm_struct_get, and continue. const auto &iter = defs_.find(V); if (iter == defs_.end()) return false; const Call *C = iter->second.as<const Call>(); if (!C || C->name != intrinsic::tvm_struct_get) return false; V = C->args[0].as<Variable>(); } return false; } /// Handle memory access to a Variable void HandleLoadStoreToVariable(const VarExpr &var) { // We skip the access within thread env. if (InThreadEnv()) return; // We only check access within a producer/consumer. // Because for load/store out side of producer/consumer, // they don't have to be in thread env to stay legal (e.g. Load of args). if (!InProducerConsumer()) return; // We only handle the variable from function argument. // If it does not come from args, then it could be allocated internally, // it may possibly be in host or device address space. // We do not handle this case, and skip it conservatively. if (!IsFromFunctionArgs(var.get())) return; // The verification fails in this case. SetFailure(); } /// Status getter/setter //@{ bool InThreadEnv() const { return in_thread_env_; } void EnterThreadEnv() { in_thread_env_ = true; } void ExitThreadEnv() { in_thread_env_ = false; } bool InProducerConsumer() const { return pc_ != nullptr; } const ProducerConsumer *GetCurrentProducerConsumer() const { return pc_; } void EnterProducerConsumer(const ProducerConsumer *pc) { this->pc_ = pc; } void ExitProducerConsumer() { pc_ = nullptr; } void SetFailure() { failure_ = true; } //@} /// Check if a given DLDeviceType/TVMDeviceExtType value denotes GPU device. static bool IsGPUDevice(int dev_type) { return kDLGPU == dev_type || kDLOpenCL == dev_type || kDLVulkan == dev_type || kDLMetal == dev_type || kDLROCM == dev_type || kOpenGL == dev_type; } /// Check if a given DLDeviceType/TVMDeviceExtType value denotes FPGA device. static bool IsFPGADevice(int dev_type) { return kDLSDAccel == dev_type || kDLAOCL == dev_type; } private: /// Status of visitor //@{ bool in_thread_env_{false}; const ProducerConsumer *pc_{nullptr}; bool failure_{false}; ///< If the verification fails (i.e. has illegal access) //@} LoweredFunc func_{nullptr}; ///< Function to be verified. int dev_type_{kDLCPU}; ///< Device type std::unordered_map<const Variable *, Expr> defs_; ///< Variable definitions }; } // namespace /// Interface of VerifyMemory pass bool VerifyMemory(LoweredFunc func, int device_type) { MemoryAccessVerifier v(func, device_type); v.Run(); return !v.Failed(); } } // namespace ir } // namespace tvm