/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file codegen_metal.cc */ #include <vector> #include <string> #include <algorithm> #include "codegen_metal.h" #include "../build_common.h" #include "../../runtime/metal/metal_module.h" #include "../../runtime/thread_storage_scope.h" namespace tvm { namespace codegen { void CodeGenMetal::InitFuncState(LoweredFunc f) { CodeGenC::InitFuncState(f); // analyze the data; for (Var arg : f->args) { if (arg.dtype().is_handle()) { alloc_storage_scope_[arg.get()] = "global"; } } } CodeGenMetal::CodeGenMetal() { decl_stream << "#include <metal_stdlib>\n"; decl_stream << "using namespace metal;\n\n"; decl_stream << "union __TVMArgUnion {\n" << " int v_int;\n" << "};\n\n"; } void CodeGenMetal::AddFunction(LoweredFunc f) { // clear previous generated state. this->InitFuncState(f); // skip the first underscore, so SSA variable starts from _1 GetUniqueName("_"); // add to alloc buffer type. for (const auto & kv : f->handle_data_type) { RegisterHandleType(kv.first.get(), kv.second.dtype()); } // Function header. this->stream << "kernel void " << f->name << "(\n"; // Buffer arguments size_t num_buffer = 0; for (size_t i = 0; i < f->args.size(); ++i, ++num_buffer) { Var v = f->args[i]; if (!v.dtype().is_handle()) break; stream << " "; std::string vid = AllocVarID(v.get()); auto it = alloc_storage_scope_.find(v.get()); CHECK(it != alloc_storage_scope_.end()); PrintStorageScope(it->second, stream); stream << ' '; if (handle_data_type_.count(v.get())) { PrintType(handle_data_type_.at(v.get()), stream); stream << "*"; } else { PrintType(v.dtype(), stream); } stream << ' ' << vid << " [[ buffer(" << i << ") ]],\n"; } // Setup normal arguments. size_t nargs = f->args.size() - num_buffer; std::string varg = GetUniqueName("arg"); if (nargs != 0) { std::string arg_buf_type = f->name + "_args_t"; stream << " constant " << arg_buf_type << "& " << varg << " [[ buffer(" << num_buffer << ") ]],\n"; // declare the struct decl_stream << "struct " << arg_buf_type << " {\n"; for (size_t i = num_buffer; i < f->args.size(); ++i) { Var v = f->args[i]; CHECK(!v.dtype().is_handle()); std::string vid = AllocVarID(v.get()); std::ostringstream vref; if (v.dtype().bits() == 32) { decl_stream << " "; PrintType(v.dtype(), decl_stream); decl_stream << " " << vid << ";\n"; vref << varg << "." << vid; } else { // For non 32bit type, ref through arg union. decl_stream << " __TVMArgUnion " << vid << ";\n"; vref << varg << "." << vid << ".v_"; PrintType(v.dtype(), vref); } var_idmap_[v.get()] = vref.str(); } decl_stream << "};\n\n"; } // Setup the thread group info. CHECK_EQ(GetUniqueName("threadIdx"), "threadIdx"); CHECK_EQ(GetUniqueName("blockIdx"), "blockIdx"); int work_dim = 0; for (IterVar iv : f->thread_axis) { runtime::ThreadScope scope = runtime::ThreadScope::make(iv->thread_tag); work_dim = std::max(work_dim, scope.dim_index + 1); } if (work_dim != 0) { // use ushort by default for now stream << " "; PrintType(DataType::UInt(thread_index_bits_, work_dim), stream); stream << " blockIdx [[threadgroup_position_in_grid]],\n"; stream << " "; PrintType(DataType::UInt(thread_index_bits_, work_dim), stream); stream << " threadIdx [[thread_position_in_threadgroup]]\n"; } // bind thread axis for (IterVar iv : f->thread_axis) { CHECK(!var_idmap_.count(iv->var.get())); std::string vname = iv->thread_tag; if (work_dim <= 1) { vname = vname.substr(0, iv->thread_tag.length() - 2); } var_idmap_[iv->var.get()] = CastFromTo(vname, DataType::UInt(thread_index_bits_), iv->var.dtype()); } // the function scope. stream << ") {\n"; int func_scope = this->BeginScope(); this->PrintStmt(f->body); this->EndScope(func_scope); this->PrintIndent(); this->stream << "}\n\n"; } void CodeGenMetal::BindThreadIndex(const IterVar& iv) { CHECK(!var_idmap_.count(iv->var.get())); var_idmap_[iv->var.get()] = CastFromTo(iv->thread_tag, DataType::UInt(thread_index_bits_), iv->var.dtype()); } void CodeGenMetal::PrintType(DataType t, std::ostream& os) { // NOLINT(*) int lanes = t.lanes(); if (t.is_handle()) { CHECK_EQ(lanes, 1) << "do not yet support vector types"; os << "void*"; return; } if (t == DataType::Bool()) { os << "bool"; return; } bool fail = false; if (t.is_float()) { switch (t.bits()) { case 16: os << "half"; break; case 32: os << "float"; break; default: fail = true; break; } if (!fail && lanes == 1) return; if (!fail && (lanes >= 2 && lanes <= 4)) { os << lanes; return; } } else if (t.is_uint() || t.is_int()) { if (t.is_uint()) { os << 'u'; } if (t.bits() == 8 && t.lanes() == 4) { // directly 4 8 bit int in integer. os << "int"; return; } switch (t.bits()) { case 8: os << "char"; break; case 16: os << "short"; break; case 32: os << "int"; break; case 1: os << "bool"; break; default: fail = true; break; } if (!fail && lanes == 1) return; if (!fail && (lanes >= 2 && lanes <= 4)) { os << lanes; return; } } LOG(FATAL) << "Cannot convert type " << t << " to Metal type"; } void CodeGenMetal::PrintStorageSync(const CallNode* op) { const std::string& sync = op->args[0].as<StringImmNode>()->value; if (sync == "warp") { this->PrintIndent(); this->stream << "simdgroup_barrier(mem_flags::mem_threadgroup);\n"; } else if (sync == "shared") { this->PrintIndent(); this->stream << "threadgroup_barrier(mem_flags::mem_threadgroup);\n"; } else if (sync == "global") { LOG(FATAL) << "global barrier not supported"; } } void CodeGenMetal::PrintVecElemLoad(const std::string& vec, DataType t, int i, std::ostream& os) { // NOLINT(*) os << vec << "[" << i << "]"; } void CodeGenMetal::PrintVecElemStore(const std::string& vec, DataType t, int i, const std::string& value) { this->PrintIndent(); stream << vec << "[" << i << "]" << " = " << value << ";\n"; } void CodeGenMetal::PrintStorageScope( const std::string& scope, std::ostream& os) { // NOLINT(*) if (scope == "global") { os << "device"; } else if (scope == "shared") { os << "threadgroup"; } else { os << "thread"; } } void CodeGenMetal::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) std::string v = PrintExpr(op->value); PrintType(op->dtype, os); os << "("; for (int i = 0; i < op->lanes; ++i) { if (i != 0) os << ", "; os << v; } os << ')'; } void CodeGenMetal::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) if (op->is_intrinsic(CallNode::reinterpret)) { // generate as_type<TYPE>(ARG) os << "(as_type<"; this->PrintType(op->dtype, os); os << ">("; this->PrintExpr(op->args[0], os); os << "))"; } else { CodeGenC::VisitExpr_(op, os); } } runtime::Module BuildMetal(Array<LoweredFunc> funcs) { using tvm::runtime::Registry; bool output_ssa = false; CodeGenMetal cg; cg.Init(output_ssa); for (LoweredFunc f : funcs) { cg.AddFunction(f); } std::string code = cg.Finish(); std::string fmt = "metal"; std::string source = ""; if (const auto* f = Registry::Get("tvm_callback_metal_compile")) { source = code; code = (*f)(code).operator std::string(); fmt = "metallib"; } return MetalModuleCreate(code, fmt, ExtractFuncInfo(funcs), source); } TVM_REGISTER_GLOBAL("codegen.build_metal") .set_body([](TVMArgs args, TVMRetValue* rv) { *rv = BuildMetal(args[0]); }); } // namespace codegen } // namespace tvm