rocm_device_api.cc 4.88 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12
/*!
 *  Copyright (c) 2017 by Contributors
 * \file rocm_device_api.cc
 * \brief GPU specific API
 */
#include <tvm/runtime/device_api.h>

#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <tvm/runtime/registry.h>
#include <hip/hip_runtime_api.h>
#include <hsa/hsa.h>
13
#include "rocm_common.h"
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44

namespace tvm {
namespace runtime {

class ROCMDeviceAPI final : public DeviceAPI {
 public:
  void SetDevice(TVMContext ctx) final {
    ROCM_CALL(hipSetDevice(ctx.device_id));
  }
  void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final {
    int value = 0;
    switch (kind) {
      case kExist: {
        if (hsa_init() == HSA_STATUS_SUCCESS) {
          int dev;
          ROCM_CALL(hipGetDeviceCount(&dev));
          value = dev > ctx.device_id ? 1 : 0;
          hsa_shut_down();
        } else {
          value = 0;
        }
        break;
      }
      case kMaxThreadsPerBlock: {
        value = 1024;
        break;
      }
      case kWarpSize: {
        value = 64;
        break;
      }
45
      case kMaxSharedMemoryPerBlock: return;
46
      case kComputeVersion: {
47 48 49 50
        hipDeviceProp_t prop;
        ROCM_CALL(hipGetDeviceProperties(&prop, ctx.device_id));
        *rv = prop.gcnArch;
        return;
51
      }
52
      case kDeviceName: return;
53 54
      case kMaxClockRate: return;
      case kMultiProcessorCount: return;
55
      case kMaxThreadDimensions: return;
56 57 58
    }
    *rv = value;
  }
59 60 61 62
  void* AllocDataSpace(TVMContext ctx,
                       size_t nbytes,
                       size_t alignment,
                       TVMType type_hint) final {
63 64 65 66
    ROCM_CALL(hipSetDevice(ctx.device_id));
    CHECK_EQ(256 % alignment, 0U)
        << "ROCM space is aligned at 256 bytes";
    void *ret;
67
    ROCM_CALL(hipMalloc(&ret, nbytes));
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
    return ret;
  }

  void FreeDataSpace(TVMContext ctx, void* ptr) final {
    ROCM_CALL(hipSetDevice(ctx.device_id));
    ROCM_CALL(hipFree(ptr));
  }

  void CopyDataFromTo(const void* from,
                      size_t from_offset,
                      void* to,
                      size_t to_offset,
                      size_t size,
                      TVMContext ctx_from,
                      TVMContext ctx_to,
83
                      TVMType type_hint,
84 85 86 87
                      TVMStreamHandle stream) final {
    hipStream_t hip_stream = static_cast<hipStream_t>(stream);
    from = static_cast<const char*>(from) + from_offset;
    to = static_cast<char*>(to) + to_offset;
88
    if (ctx_from.device_type == kDLROCM && ctx_to.device_type == kDLROCM) {
89 90 91 92 93 94 95 96
      ROCM_CALL(hipSetDevice(ctx_from.device_id));
      if (ctx_from.device_id == ctx_to.device_id) {
        GPUCopy(from, to, size, hipMemcpyDeviceToDevice, hip_stream);
      } else {
        hipMemcpyPeerAsync(to, ctx_to.device_id,
                            from, ctx_from.device_id,
                            size, hip_stream);
      }
97
    } else if (ctx_from.device_type == kDLROCM && ctx_to.device_type == kDLCPU) {
98 99
      ROCM_CALL(hipSetDevice(ctx_from.device_id));
      GPUCopy(from, to, size, hipMemcpyDeviceToHost, hip_stream);
100
    } else if (ctx_from.device_type == kDLCPU && ctx_to.device_type == kDLROCM) {
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
      ROCM_CALL(hipSetDevice(ctx_to.device_id));
      GPUCopy(from, to, size, hipMemcpyHostToDevice, hip_stream);
    } else {
      LOG(FATAL) << "expect copy from/to GPU or between GPU";
    }
  }

  void StreamSync(TVMContext ctx, TVMStreamHandle stream) final {
    ROCM_CALL(hipSetDevice(ctx.device_id));
    ROCM_CALL(hipStreamSynchronize(static_cast<hipStream_t>(stream)));
  }

  void SetStream(TVMContext ctx, TVMStreamHandle stream) final {
    ROCMThreadEntry::ThreadLocal()
        ->stream = static_cast<hipStream_t>(stream);
  }

118
  void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final {
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
    return ROCMThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
  }

  void FreeWorkspace(TVMContext ctx, void* data) final {
    ROCMThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data);
  }

  static const std::shared_ptr<ROCMDeviceAPI>& Global() {
    static std::shared_ptr<ROCMDeviceAPI> inst =
        std::make_shared<ROCMDeviceAPI>();
    return inst;
  }

 private:
  static void GPUCopy(const void* from,
                      void* to,
                      size_t size,
                      hipMemcpyKind kind,
                      hipStream_t stream) {
    if (stream != 0) {
      ROCM_CALL(hipMemcpyAsync(to, from, size, kind, stream));
    } else {
      ROCM_CALL(hipMemcpy(to, from, size, kind));
    }
  }
};

typedef dmlc::ThreadLocalStore<ROCMThreadEntry> ROCMThreadStore;

ROCMThreadEntry::ROCMThreadEntry()
149
    : pool(kDLROCM, ROCMDeviceAPI::Global()) {
150 151 152 153 154 155 156 157 158 159 160 161 162 163
}

ROCMThreadEntry* ROCMThreadEntry::ThreadLocal() {
  return ROCMThreadStore::Get();
}

TVM_REGISTER_GLOBAL("device_api.rocm")
.set_body([](TVMArgs args, TVMRetValue* rv) {
    DeviceAPI* ptr = ROCMDeviceAPI::Global().get();
    *rv = static_cast<void*>(ptr);
  });

}  // namespace runtime
}  // namespace tvm