/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file lower_device_storage_access.cc * \brief Lower the special device storage access. */ #include <tvm/tir/stmt_functor.h> #include <tvm/tir/transform.h> #include <tvm/tir/buffer.h> #include <tvm/target/target_info.h> #include <tvm/runtime/registry.h> #include <tvm/tir/ir_pass.h> #include "../pass/ir_util.h" #include "../../runtime/thread_storage_scope.h" namespace tvm { namespace tir { using runtime::StorageScope; using runtime::StorageRank; class StorageAccessInfoLower : public StmtExprMutator { public: Stmt VisitStmt_(const AllocateNode* op) final { // Lower allocate to device allocate when needed. Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as<AllocateNode>(); // For special memory, remove allocate, or use head expr auto it = storage_info_.find(op->buffer_var.get()); if (it != storage_info_.end() && it->second.info.defined()) { const MemoryInfo& info = it->second.info; ++it->second.alloc_count; CHECK_LE(it->second.alloc_count, 1) << "Double allocation of " << it->second.scope.to_string(); if (info->head_address.defined()) { return AllocateNode::make( op->buffer_var, op->dtype, op->extents, op->condition, op->body, info->head_address, "nop"); } return op->body; } else { return stmt; } } Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == attr::storage_scope) { const VarNode* buf = op->node.as<VarNode>(); StorageScope scope = StorageScope::make(op->value.as<StringImmNode>()->value); StorageEntry e; e.scope = scope; if (scope.tag.length() != 0) { e.info = GetMemoryInfo(op->value.as<StringImmNode>()->value); CHECK(e.info.defined()) << "Cannot find memory info of " << scope.to_string(); } storage_info_[buf] = e; return StmtExprMutator::VisitStmt_(op); } else { return StmtExprMutator::VisitStmt_(op); } } PrimExpr VisitExpr_(const CallNode* op) final { if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { return MakeAccessPtr(op); } else { return StmtExprMutator::VisitExpr_(op); } } private: // tvm_access_ptr PrimExpr MakeAccessPtr(const CallNode* op) { // Specially handle the buffer packed intrinsic PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as<CallNode>(); CHECK_EQ(op->args.size(), 5U); DataType dtype = op->args[0].dtype(); const VarNode* buffer = op->args[1].as<VarNode>(); Var buffer_var = Downcast<Var>(op->args[1]); PrimExpr offset = op->args[2]; auto it = storage_info_.find(buffer); if (it != storage_info_.end() && it->second.info.defined()) { return MakeTaggedAccessPtr( op->dtype, buffer_var, dtype, offset, it->second.info); } CHECK(op->dtype.is_handle()); // Change to address_of return AddressOffset(buffer_var, dtype, offset); } PrimExpr MakeTaggedAccessPtr(DataType ptr_type, Var buffer_var, DataType dtype, PrimExpr offset, const MemoryInfo& info) { if (ptr_type.is_handle()) { CHECK(info->head_address.defined()) << buffer_var << " is not adddressable."; return AddressOffset(buffer_var, dtype, offset); } int dtype_bits = dtype.bits() * dtype.lanes(); CHECK_EQ(info->unit_bits % dtype_bits, 0); return cast(ptr_type, tir::Simplify(offset / make_const( offset.dtype(), info->unit_bits / dtype_bits))); } // The storage entry. struct StorageEntry { // Whether it is tagged memory. StorageScope scope; // The memory info if any. MemoryInfo info; // Allocation counter int alloc_count{0}; }; // The storage scope of each buffer std::unordered_map<const VarNode*, StorageEntry> storage_info_; }; Stmt LowerStorageAccessInfo(Stmt stmt) { return StorageAccessInfoLower()(std::move(stmt)); } namespace transform { Pass LowerDeviceStorageAccessInfo() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); n->body = StorageAccessInfoLower()(std::move(n->body)); return f; }; return CreatePrimFuncPass( pass_func, 0, "tir.LowerDeviceStorageAccessInfo", {}); } TVM_REGISTER_GLOBAL("tir.transform.LowerDeviceStorageAccessInfo") .set_body_typed(LowerDeviceStorageAccessInfo); } // namespace transform } // namespace tir } // namespace tvm