cuda_device_api.cc 7.92 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24
/*!
 *  Copyright (c) 2017 by Contributors
 * \file cuda_device_api.cc
 * \brief GPU specific API
 */
25
#include <tvm/runtime/device_api.h>
26

27
#include <dmlc/thread_local.h>
28
#include <tvm/runtime/registry.h>
29
#include <cuda.h>
30
#include <cuda_runtime.h>
31
#include "cuda_common.h"
32 33 34 35

namespace tvm {
namespace runtime {

36
class CUDADeviceAPI final : public DeviceAPI {
37
 public:
38 39
  void SetDevice(TVMContext ctx) final {
    CUDA_CALL(cudaSetDevice(ctx.device_id));
40
  }
41
  void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final {
42
    int value = 0;
43 44 45 46
    switch (kind) {
      case kExist:
        value = (
            cudaDeviceGetAttribute(
47
                &value, cudaDevAttrMaxThreadsPerBlock, ctx.device_id)
48 49
            == cudaSuccess);
        break;
50
      case kMaxThreadsPerBlock: {
51
        CUDA_CALL(cudaDeviceGetAttribute(
52
            &value, cudaDevAttrMaxThreadsPerBlock, ctx.device_id));
53 54 55 56
        break;
      }
      case kWarpSize: {
        CUDA_CALL(cudaDeviceGetAttribute(
57
            &value, cudaDevAttrWarpSize, ctx.device_id));
58 59
        break;
      }
60 61 62 63 64
      case kMaxSharedMemoryPerBlock: {
        CUDA_CALL(cudaDeviceGetAttribute(
            &value, cudaDevAttrMaxSharedMemoryPerBlock, ctx.device_id));
        break;
      }
65 66 67 68 69 70 71 72 73 74 75
      case kComputeVersion: {
        std::ostringstream os;
        CUDA_CALL(cudaDeviceGetAttribute(
            &value, cudaDevAttrComputeCapabilityMajor, ctx.device_id));
        os << value << ".";
        CUDA_CALL(cudaDeviceGetAttribute(
            &value, cudaDevAttrComputeCapabilityMinor, ctx.device_id));
        os << value;
        *rv = os.str();
        return;
      }
76
      case kDeviceName: {
77 78 79 80
        std::string name(256, 0);
        CUDA_DRIVER_CALL(cuDeviceGetName(&name[0], name.size(), ctx.device_id));
        name.resize(strlen(name.c_str()));
        *rv = std::move(name);
81 82
        return;
      }
83 84 85 86 87 88 89 90 91 92
      case kMaxClockRate: {
        CUDA_CALL(cudaDeviceGetAttribute(
            &value, cudaDevAttrClockRate, ctx.device_id));
        break;
      }
      case kMultiProcessorCount: {
        CUDA_CALL(cudaDeviceGetAttribute(
            &value, cudaDevAttrMultiProcessorCount, ctx.device_id));
        break;
      }
93 94 95 96 97 98 99 100 101 102 103 104 105 106
      case kMaxThreadDimensions: {
        int dims[3];
        CUDA_CALL(cudaDeviceGetAttribute(
            &dims[0], cudaDevAttrMaxBlockDimX, ctx.device_id));
        CUDA_CALL(cudaDeviceGetAttribute(
            &dims[1], cudaDevAttrMaxBlockDimY, ctx.device_id));
        CUDA_CALL(cudaDeviceGetAttribute(
            &dims[2], cudaDevAttrMaxBlockDimZ, ctx.device_id));

        std::stringstream ss;  // use json string to return multiple int values;
        ss << "[" << dims[0] <<", " << dims[1] << ", " << dims[2] << "]";
        *rv = ss.str();
        return;
      }
107 108 109
    }
    *rv = value;
  }
110 111 112 113
  void* AllocDataSpace(TVMContext ctx,
                       size_t nbytes,
                       size_t alignment,
                       TVMType type_hint) final {
114 115 116 117
    CUDA_CALL(cudaSetDevice(ctx.device_id));
    CHECK_EQ(256 % alignment, 0U)
        << "CUDA space is aligned at 256 bytes";
    void *ret;
118
    CUDA_CALL(cudaMalloc(&ret, nbytes));
119 120 121 122 123 124 125 126 127
    return ret;
  }

  void FreeDataSpace(TVMContext ctx, void* ptr) final {
    CUDA_CALL(cudaSetDevice(ctx.device_id));
    CUDA_CALL(cudaFree(ptr));
  }

  void CopyDataFromTo(const void* from,
128
                      size_t from_offset,
129
                      void* to,
130
                      size_t to_offset,
131 132 133
                      size_t size,
                      TVMContext ctx_from,
                      TVMContext ctx_to,
134
                      TVMType type_hint,
135 136
                      TVMStreamHandle stream) final {
    cudaStream_t cu_stream = static_cast<cudaStream_t>(stream);
137 138
    from = static_cast<const char*>(from) + from_offset;
    to = static_cast<char*>(to) + to_offset;
139
    if (ctx_from.device_type == kDLGPU && ctx_to.device_type == kDLGPU) {
140 141 142 143 144 145 146 147
      CUDA_CALL(cudaSetDevice(ctx_from.device_id));
      if (ctx_from.device_id == ctx_to.device_id) {
        GPUCopy(from, to, size, cudaMemcpyDeviceToDevice, cu_stream);
      } else {
        cudaMemcpyPeerAsync(to, ctx_to.device_id,
                            from, ctx_from.device_id,
                            size, cu_stream);
      }
148
    } else if (ctx_from.device_type == kDLGPU && ctx_to.device_type == kDLCPU) {
149 150
      CUDA_CALL(cudaSetDevice(ctx_from.device_id));
      GPUCopy(from, to, size, cudaMemcpyDeviceToHost, cu_stream);
151
    } else if (ctx_from.device_type == kDLCPU && ctx_to.device_type == kDLGPU) {
152 153 154 155 156 157 158
      CUDA_CALL(cudaSetDevice(ctx_to.device_id));
      GPUCopy(from, to, size, cudaMemcpyHostToDevice, cu_stream);
    } else {
      LOG(FATAL) << "expect copy from/to GPU or between GPU";
    }
  }

159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
  TVMStreamHandle CreateStream(TVMContext ctx) {
    CUDA_CALL(cudaSetDevice(ctx.device_id));
    cudaStream_t retval;
    CUDA_CALL(cudaStreamCreate(&retval));
    return static_cast<TVMStreamHandle>(retval);
  }

  void FreeStream(TVMContext ctx, TVMStreamHandle stream) {
    CUDA_CALL(cudaSetDevice(ctx.device_id));
    cudaStream_t cu_stream = static_cast<cudaStream_t>(stream);
    CUDA_CALL(cudaStreamDestroy(cu_stream));
  }

  void SyncStreamFromTo(TVMContext ctx, TVMStreamHandle event_src, TVMStreamHandle event_dst) {
    CUDA_CALL(cudaSetDevice(ctx.device_id));
    cudaStream_t src_stream = static_cast<cudaStream_t>(event_src);
    cudaStream_t dst_stream = static_cast<cudaStream_t>(event_dst);
    cudaEvent_t evt;
    CUDA_CALL(cudaEventCreate(&evt));
    CUDA_CALL(cudaEventRecord(evt, src_stream));
    CUDA_CALL(cudaStreamWaitEvent(dst_stream, evt, 0));
    CUDA_CALL(cudaEventDestroy(evt));
  }

183 184 185 186 187
  void StreamSync(TVMContext ctx, TVMStreamHandle stream) final {
    CUDA_CALL(cudaSetDevice(ctx.device_id));
    CUDA_CALL(cudaStreamSynchronize(static_cast<cudaStream_t>(stream)));
  }

188 189 190 191 192
  void SetStream(TVMContext ctx, TVMStreamHandle stream) final {
    CUDAThreadEntry::ThreadLocal()
        ->stream = static_cast<cudaStream_t>(stream);
  }

193
  void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final {
194 195 196 197 198 199 200 201 202 203 204 205 206
    return CUDAThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
  }

  void FreeWorkspace(TVMContext ctx, void* data) final {
    CUDAThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data);
  }

  static const std::shared_ptr<CUDADeviceAPI>& Global() {
    static std::shared_ptr<CUDADeviceAPI> inst =
        std::make_shared<CUDADeviceAPI>();
    return inst;
  }

207 208 209 210 211 212 213 214 215 216 217 218 219 220
 private:
  static void GPUCopy(const void* from,
                      void* to,
                      size_t size,
                      cudaMemcpyKind kind,
                      cudaStream_t stream) {
    if (stream != 0) {
      CUDA_CALL(cudaMemcpyAsync(to, from, size, kind, stream));
    } else {
      CUDA_CALL(cudaMemcpy(to, from, size, kind));
    }
  }
};

221 222
typedef dmlc::ThreadLocalStore<CUDAThreadEntry> CUDAThreadStore;

223
CUDAThreadEntry::CUDAThreadEntry()
224
    : pool(kDLGPU, CUDADeviceAPI::Global()) {
225 226
}

227 228 229 230
CUDAThreadEntry* CUDAThreadEntry::ThreadLocal() {
  return CUDAThreadStore::Get();
}

231
TVM_REGISTER_GLOBAL("device_api.gpu")
232
.set_body([](TVMArgs args, TVMRetValue* rv) {
233
    DeviceAPI* ptr = CUDADeviceAPI::Global().get();
234 235 236 237 238
    *rv = static_cast<void*>(ptr);
  });

}  // namespace runtime
}  // namespace tvm