/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file cuda_device_api.cc * \brief GPU specific API */ #include <tvm/runtime/device_api.h> #include <dmlc/thread_local.h> #include <tvm/runtime/registry.h> #include <cuda.h> #include <cuda_runtime.h> #include <cstring> #include "cuda_common.h" namespace tvm { namespace runtime { class CUDADeviceAPI final : public DeviceAPI { public: void SetDevice(TVMContext ctx) final { CUDA_CALL(cudaSetDevice(ctx.device_id)); } void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final { int value = 0; switch (kind) { case kExist: value = ( cudaDeviceGetAttribute( &value, cudaDevAttrMaxThreadsPerBlock, ctx.device_id) == cudaSuccess); break; case kMaxThreadsPerBlock: { CUDA_CALL(cudaDeviceGetAttribute( &value, cudaDevAttrMaxThreadsPerBlock, ctx.device_id)); break; } case kWarpSize: { CUDA_CALL(cudaDeviceGetAttribute( &value, cudaDevAttrWarpSize, ctx.device_id)); break; } case kMaxSharedMemoryPerBlock: { CUDA_CALL(cudaDeviceGetAttribute( &value, cudaDevAttrMaxSharedMemoryPerBlock, ctx.device_id)); break; } case kComputeVersion: { std::ostringstream os; CUDA_CALL(cudaDeviceGetAttribute( &value, cudaDevAttrComputeCapabilityMajor, ctx.device_id)); os << value << "."; CUDA_CALL(cudaDeviceGetAttribute( &value, cudaDevAttrComputeCapabilityMinor, ctx.device_id)); os << value; *rv = os.str(); return; } case kDeviceName: { std::string name(256, 0); CUDA_DRIVER_CALL(cuDeviceGetName(&name[0], name.size(), ctx.device_id)); name.resize(strlen(name.c_str())); *rv = std::move(name); return; } case kMaxClockRate: { CUDA_CALL(cudaDeviceGetAttribute( &value, cudaDevAttrClockRate, ctx.device_id)); break; } case kMultiProcessorCount: { CUDA_CALL(cudaDeviceGetAttribute( &value, cudaDevAttrMultiProcessorCount, ctx.device_id)); break; } case kMaxThreadDimensions: { int dims[3]; CUDA_CALL(cudaDeviceGetAttribute( &dims[0], cudaDevAttrMaxBlockDimX, ctx.device_id)); CUDA_CALL(cudaDeviceGetAttribute( &dims[1], cudaDevAttrMaxBlockDimY, ctx.device_id)); CUDA_CALL(cudaDeviceGetAttribute( &dims[2], cudaDevAttrMaxBlockDimZ, ctx.device_id)); std::stringstream ss; // use json string to return multiple int values; ss << "[" << dims[0] <<", " << dims[1] << ", " << dims[2] << "]"; *rv = ss.str(); return; } case kGcnArch: return; } *rv = value; } void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, DLDataType type_hint) final { CUDA_CALL(cudaSetDevice(ctx.device_id)); CHECK_EQ(256 % alignment, 0U) << "CUDA space is aligned at 256 bytes"; void *ret; CUDA_CALL(cudaMalloc(&ret, nbytes)); return ret; } void FreeDataSpace(TVMContext ctx, void* ptr) final { CUDA_CALL(cudaSetDevice(ctx.device_id)); CUDA_CALL(cudaFree(ptr)); } void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, TVMContext ctx_from, TVMContext ctx_to, DLDataType type_hint, TVMStreamHandle stream) final { cudaStream_t cu_stream = static_cast<cudaStream_t>(stream); from = static_cast<const char*>(from) + from_offset; to = static_cast<char*>(to) + to_offset; if (ctx_from.device_type == kDLGPU && ctx_to.device_type == kDLGPU) { CUDA_CALL(cudaSetDevice(ctx_from.device_id)); if (ctx_from.device_id == ctx_to.device_id) { GPUCopy(from, to, size, cudaMemcpyDeviceToDevice, cu_stream); } else { cudaMemcpyPeerAsync(to, ctx_to.device_id, from, ctx_from.device_id, size, cu_stream); } } else if (ctx_from.device_type == kDLGPU && ctx_to.device_type == kDLCPU) { CUDA_CALL(cudaSetDevice(ctx_from.device_id)); GPUCopy(from, to, size, cudaMemcpyDeviceToHost, cu_stream); } else if (ctx_from.device_type == kDLCPU && ctx_to.device_type == kDLGPU) { CUDA_CALL(cudaSetDevice(ctx_to.device_id)); GPUCopy(from, to, size, cudaMemcpyHostToDevice, cu_stream); } else { LOG(FATAL) << "expect copy from/to GPU or between GPU"; } } TVMStreamHandle CreateStream(TVMContext ctx) { CUDA_CALL(cudaSetDevice(ctx.device_id)); cudaStream_t retval; CUDA_CALL(cudaStreamCreate(&retval)); return static_cast<TVMStreamHandle>(retval); } void FreeStream(TVMContext ctx, TVMStreamHandle stream) { CUDA_CALL(cudaSetDevice(ctx.device_id)); cudaStream_t cu_stream = static_cast<cudaStream_t>(stream); CUDA_CALL(cudaStreamDestroy(cu_stream)); } void SyncStreamFromTo(TVMContext ctx, TVMStreamHandle event_src, TVMStreamHandle event_dst) { CUDA_CALL(cudaSetDevice(ctx.device_id)); cudaStream_t src_stream = static_cast<cudaStream_t>(event_src); cudaStream_t dst_stream = static_cast<cudaStream_t>(event_dst); cudaEvent_t evt; CUDA_CALL(cudaEventCreate(&evt)); CUDA_CALL(cudaEventRecord(evt, src_stream)); CUDA_CALL(cudaStreamWaitEvent(dst_stream, evt, 0)); CUDA_CALL(cudaEventDestroy(evt)); } void StreamSync(TVMContext ctx, TVMStreamHandle stream) final { CUDA_CALL(cudaSetDevice(ctx.device_id)); CUDA_CALL(cudaStreamSynchronize(static_cast<cudaStream_t>(stream))); } void SetStream(TVMContext ctx, TVMStreamHandle stream) final { CUDAThreadEntry::ThreadLocal() ->stream = static_cast<cudaStream_t>(stream); } void* AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) final { return CUDAThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size); } void FreeWorkspace(TVMContext ctx, void* data) final { CUDAThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data); } static const std::shared_ptr<CUDADeviceAPI>& Global() { static std::shared_ptr<CUDADeviceAPI> inst = std::make_shared<CUDADeviceAPI>(); return inst; } private: static void GPUCopy(const void* from, void* to, size_t size, cudaMemcpyKind kind, cudaStream_t stream) { if (stream != 0) { CUDA_CALL(cudaMemcpyAsync(to, from, size, kind, stream)); } else { CUDA_CALL(cudaMemcpy(to, from, size, kind)); } } }; typedef dmlc::ThreadLocalStore<CUDAThreadEntry> CUDAThreadStore; CUDAThreadEntry::CUDAThreadEntry() : pool(kDLGPU, CUDADeviceAPI::Global()) { } CUDAThreadEntry* CUDAThreadEntry::ThreadLocal() { return CUDAThreadStore::Get(); } TVM_REGISTER_GLOBAL("device_api.gpu") .set_body([](TVMArgs args, TVMRetValue* rv) { DeviceAPI* ptr = CUDADeviceAPI::Global().get(); *rv = static_cast<void*>(ptr); }); } // namespace runtime } // namespace tvm