cuda_device_api.cc 7.06 KB
Newer Older
1 2 3 4 5
/*!
 *  Copyright (c) 2017 by Contributors
 * \file cuda_device_api.cc
 * \brief GPU specific API
 */
6
#include <tvm/runtime/device_api.h>
7

8
#include <dmlc/thread_local.h>
9 10
#include <tvm/runtime/registry.h>
#include <cuda_runtime.h>
11
#include "cuda_common.h"
12 13 14 15

namespace tvm {
namespace runtime {

16
class CUDADeviceAPI final : public DeviceAPI {
17
 public:
18 19
  void SetDevice(TVMContext ctx) final {
    CUDA_CALL(cudaSetDevice(ctx.device_id));
20
  }
21
  void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final {
22
    int value = 0;
23 24 25 26
    switch (kind) {
      case kExist:
        value = (
            cudaDeviceGetAttribute(
27
                &value, cudaDevAttrMaxThreadsPerBlock, ctx.device_id)
28 29
            == cudaSuccess);
        break;
30
      case kMaxThreadsPerBlock: {
31
        CUDA_CALL(cudaDeviceGetAttribute(
32
            &value, cudaDevAttrMaxThreadsPerBlock, ctx.device_id));
33 34 35 36
        break;
      }
      case kWarpSize: {
        CUDA_CALL(cudaDeviceGetAttribute(
37
            &value, cudaDevAttrWarpSize, ctx.device_id));
38 39
        break;
      }
40 41 42 43 44
      case kMaxSharedMemoryPerBlock: {
        CUDA_CALL(cudaDeviceGetAttribute(
            &value, cudaDevAttrMaxSharedMemoryPerBlock, ctx.device_id));
        break;
      }
45 46 47 48 49 50 51 52 53 54 55
      case kComputeVersion: {
        std::ostringstream os;
        CUDA_CALL(cudaDeviceGetAttribute(
            &value, cudaDevAttrComputeCapabilityMajor, ctx.device_id));
        os << value << ".";
        CUDA_CALL(cudaDeviceGetAttribute(
            &value, cudaDevAttrComputeCapabilityMinor, ctx.device_id));
        os << value;
        *rv = os.str();
        return;
      }
56 57 58 59 60 61
      case kDeviceName: {
        cudaDeviceProp props;
        CUDA_CALL(cudaGetDeviceProperties(&props, ctx.device_id));
        *rv = std::string(props.name);
        return;
      }
62 63 64 65 66 67 68 69 70 71
      case kMaxClockRate: {
        CUDA_CALL(cudaDeviceGetAttribute(
            &value, cudaDevAttrClockRate, ctx.device_id));
        break;
      }
      case kMultiProcessorCount: {
        CUDA_CALL(cudaDeviceGetAttribute(
            &value, cudaDevAttrMultiProcessorCount, ctx.device_id));
        break;
      }
72 73 74 75 76 77 78 79 80 81 82 83 84 85
      case kMaxThreadDimensions: {
        int dims[3];
        CUDA_CALL(cudaDeviceGetAttribute(
            &dims[0], cudaDevAttrMaxBlockDimX, ctx.device_id));
        CUDA_CALL(cudaDeviceGetAttribute(
            &dims[1], cudaDevAttrMaxBlockDimY, ctx.device_id));
        CUDA_CALL(cudaDeviceGetAttribute(
            &dims[2], cudaDevAttrMaxBlockDimZ, ctx.device_id));

        std::stringstream ss;  // use json string to return multiple int values;
        ss << "[" << dims[0] <<", " << dims[1] << ", " << dims[2] << "]";
        *rv = ss.str();
        return;
      }
86 87 88
    }
    *rv = value;
  }
89 90 91 92
  void* AllocDataSpace(TVMContext ctx,
                       size_t nbytes,
                       size_t alignment,
                       TVMType type_hint) final {
93 94 95 96
    CUDA_CALL(cudaSetDevice(ctx.device_id));
    CHECK_EQ(256 % alignment, 0U)
        << "CUDA space is aligned at 256 bytes";
    void *ret;
97
    CUDA_CALL(cudaMalloc(&ret, nbytes));
98 99 100 101 102 103 104 105 106
    return ret;
  }

  void FreeDataSpace(TVMContext ctx, void* ptr) final {
    CUDA_CALL(cudaSetDevice(ctx.device_id));
    CUDA_CALL(cudaFree(ptr));
  }

  void CopyDataFromTo(const void* from,
107
                      size_t from_offset,
108
                      void* to,
109
                      size_t to_offset,
110 111 112
                      size_t size,
                      TVMContext ctx_from,
                      TVMContext ctx_to,
113
                      TVMType type_hint,
114 115
                      TVMStreamHandle stream) final {
    cudaStream_t cu_stream = static_cast<cudaStream_t>(stream);
116 117
    from = static_cast<const char*>(from) + from_offset;
    to = static_cast<char*>(to) + to_offset;
118
    if (ctx_from.device_type == kDLGPU && ctx_to.device_type == kDLGPU) {
119 120 121 122 123 124 125 126
      CUDA_CALL(cudaSetDevice(ctx_from.device_id));
      if (ctx_from.device_id == ctx_to.device_id) {
        GPUCopy(from, to, size, cudaMemcpyDeviceToDevice, cu_stream);
      } else {
        cudaMemcpyPeerAsync(to, ctx_to.device_id,
                            from, ctx_from.device_id,
                            size, cu_stream);
      }
127
    } else if (ctx_from.device_type == kDLGPU && ctx_to.device_type == kDLCPU) {
128 129
      CUDA_CALL(cudaSetDevice(ctx_from.device_id));
      GPUCopy(from, to, size, cudaMemcpyDeviceToHost, cu_stream);
130
    } else if (ctx_from.device_type == kDLCPU && ctx_to.device_type == kDLGPU) {
131 132 133 134 135 136 137
      CUDA_CALL(cudaSetDevice(ctx_to.device_id));
      GPUCopy(from, to, size, cudaMemcpyHostToDevice, cu_stream);
    } else {
      LOG(FATAL) << "expect copy from/to GPU or between GPU";
    }
  }

138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
  TVMStreamHandle CreateStream(TVMContext ctx) {
    CUDA_CALL(cudaSetDevice(ctx.device_id));
    cudaStream_t retval;
    CUDA_CALL(cudaStreamCreate(&retval));
    return static_cast<TVMStreamHandle>(retval);
  }

  void FreeStream(TVMContext ctx, TVMStreamHandle stream) {
    CUDA_CALL(cudaSetDevice(ctx.device_id));
    cudaStream_t cu_stream = static_cast<cudaStream_t>(stream);
    CUDA_CALL(cudaStreamDestroy(cu_stream));
  }

  void SyncStreamFromTo(TVMContext ctx, TVMStreamHandle event_src, TVMStreamHandle event_dst) {
    CUDA_CALL(cudaSetDevice(ctx.device_id));
    cudaStream_t src_stream = static_cast<cudaStream_t>(event_src);
    cudaStream_t dst_stream = static_cast<cudaStream_t>(event_dst);
    cudaEvent_t evt;
    CUDA_CALL(cudaEventCreate(&evt));
    CUDA_CALL(cudaEventRecord(evt, src_stream));
    CUDA_CALL(cudaStreamWaitEvent(dst_stream, evt, 0));
    CUDA_CALL(cudaEventDestroy(evt));
  }

162 163 164 165 166
  void StreamSync(TVMContext ctx, TVMStreamHandle stream) final {
    CUDA_CALL(cudaSetDevice(ctx.device_id));
    CUDA_CALL(cudaStreamSynchronize(static_cast<cudaStream_t>(stream)));
  }

167 168 169 170 171
  void SetStream(TVMContext ctx, TVMStreamHandle stream) final {
    CUDAThreadEntry::ThreadLocal()
        ->stream = static_cast<cudaStream_t>(stream);
  }

172
  void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final {
173 174 175 176 177 178 179 180 181 182 183 184 185
    return CUDAThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
  }

  void FreeWorkspace(TVMContext ctx, void* data) final {
    CUDAThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data);
  }

  static const std::shared_ptr<CUDADeviceAPI>& Global() {
    static std::shared_ptr<CUDADeviceAPI> inst =
        std::make_shared<CUDADeviceAPI>();
    return inst;
  }

186 187 188 189 190 191 192 193 194 195 196 197 198 199
 private:
  static void GPUCopy(const void* from,
                      void* to,
                      size_t size,
                      cudaMemcpyKind kind,
                      cudaStream_t stream) {
    if (stream != 0) {
      CUDA_CALL(cudaMemcpyAsync(to, from, size, kind, stream));
    } else {
      CUDA_CALL(cudaMemcpy(to, from, size, kind));
    }
  }
};

200 201
typedef dmlc::ThreadLocalStore<CUDAThreadEntry> CUDAThreadStore;

202
CUDAThreadEntry::CUDAThreadEntry()
203
    : pool(kDLGPU, CUDADeviceAPI::Global()) {
204 205
}

206 207 208 209
CUDAThreadEntry* CUDAThreadEntry::ThreadLocal() {
  return CUDAThreadStore::Get();
}

210
TVM_REGISTER_GLOBAL("device_api.gpu")
211
.set_body([](TVMArgs args, TVMRetValue* rv) {
212
    DeviceAPI* ptr = CUDADeviceAPI::Global().get();
213 214 215 216 217
    *rv = static_cast<void*>(ptr);
  });

}  // namespace runtime
}  // namespace tvm