metal_common.h 3.42 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
/*!
 *  Copyright (c) 2017 by Contributors
 * \file metal_common.h
 * \brief Metal common header
 */
#ifndef TVM_RUNTIME_METAL_METAL_COMMON_H_
#define TVM_RUNTIME_METAL_METAL_COMMON_H_

#import <Metal/MTLBuffer.h>
#import <Metal/MTLCommandQueue.h>
#import <Metal/MTLCommandBuffer.h>
#import <Metal/MTLBlitCommandEncoder.h>
#import <Metal/MTLDevice.h>
#import <Metal/MTLLibrary.h>

#include <tvm/runtime/config.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/packed_func.h>
19
#include <tvm/runtime/device_api.h>
20 21 22 23
#include <dmlc/logging.h>
#include <mutex>
#include <string>
#include <vector>
24
#include "../workspace_pool.h"
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43

namespace tvm {
namespace runtime {
namespace metal {
/*!
 * \brief Process global Metal workspace.
 */
class MetalWorkspace final : public DeviceAPI {
 public:
  // the devices
  std::vector<id<MTLDevice> > devices;
  // the queues
  std::vector<id<MTLCommandQueue> > queues;
  // Warp size constant
  std::vector<int> warp_size;
  // Whether it is initialized.
  bool initialized_{false};
  // the mutex for initialization
  std::mutex mutex;
44 45
  // Destructor
  ~MetalWorkspace();
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
  // Get command queue for given context.
  id<MTLCommandQueue> GetCommandQueue(TVMContext ctx) {
    CHECK_EQ(ctx.device_type, kMetal);
    CHECK(ctx.device_id >= 0  && static_cast<size_t>(ctx.device_id) < queues.size())
        << "Invalid Metal device_id=" << ctx.device_id;
    return queues[ctx.device_id];
  }
  // Get device for given context
  id<MTLDevice> GetDevice(TVMContext ctx) {
    CHECK_EQ(ctx.device_type, kMetal);
    CHECK(ctx.device_id >= 0  && static_cast<size_t>(ctx.device_id) < devices.size())
        << "Invalid Metal device_id=" << ctx.device_id;
    return devices[ctx.device_id];
  }
  // Initialize workspace
  // Return false if already initialized, otherwise return true.
  void Init();
  // override device API
64 65
  void SetDevice(TVMContext ctx) final;
  void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final;
66 67 68 69 70 71 72 73 74 75 76
  void* AllocDataSpace(TVMContext ctx, size_t size, size_t alignment) final;
  void FreeDataSpace(TVMContext ctx, void* ptr) final;
  void CopyDataFromTo(const void* from,
                      size_t from_size,
                      void* to,
                      size_t to_size,
                      size_t size,
                      TVMContext ctx_from,
                      TVMContext ctx_to,
                      TVMStreamHandle stream) final;
  void StreamSync(TVMContext ctx, TVMStreamHandle stream) final;
77 78
  void* AllocWorkspace(TVMContext ctx, size_t size) final;
  void FreeWorkspace(TVMContext ctx, void* data) final;
79
  // get the global workspace
80
  static const std::shared_ptr<MetalWorkspace>& Global();
81 82 83 84 85 86 87 88 89
};

/*! \brief Thread local workspace */
class MetalThreadEntry {
 public:
  /*! \brief The current context */
  TVMContext context;
  /*! \brief The shared buffer used for copy. */
  std::vector<id<MTLBuffer> > temp_buffer_;
90 91 92 93 94
  /*! \brief workspace pool */
  WorkspacePool pool;
  // constructor
  MetalThreadEntry()
      : pool(static_cast<DLDeviceType>(kMetal), MetalWorkspace::Global()) {
95 96 97
    context.device_id = 0;
    context.device_type = static_cast<DLDeviceType>(kMetal);
  }
98
  ~MetalThreadEntry();
99 100 101 102 103 104 105 106 107
  // Get temp buffer with at least size under ctx.
  id<MTLBuffer> GetTempBuffer(TVMContext ctx, size_t size);
  // get the global workspace
  static MetalThreadEntry* ThreadLocal();
};
}  // namespace metal
}  // namespace runtime
}  // namespace tvm
#endif  // TVM_RUNTIME_METAL_METAL_COMMON_H_