/*! * Copyright (c) 2017 by Contributors * \file metal_common.h * \brief Metal common header */ #ifndef TVM_RUNTIME_METAL_METAL_COMMON_H_ #define TVM_RUNTIME_METAL_METAL_COMMON_H_ #import <Metal/MTLBuffer.h> #import <Metal/MTLCommandQueue.h> #import <Metal/MTLCommandBuffer.h> #import <Metal/MTLBlitCommandEncoder.h> #import <Metal/MTLDevice.h> #import <Metal/MTLLibrary.h> #include <tvm/runtime/config.h> #include <tvm/runtime/c_runtime_api.h> #include <tvm/runtime/packed_func.h> #include <tvm/runtime/device_api.h> #include <dmlc/logging.h> #include <mutex> #include <string> #include <vector> #include "../workspace_pool.h" namespace tvm { namespace runtime { namespace metal { /*! * \brief Process global Metal workspace. */ class MetalWorkspace final : public DeviceAPI { public: // the devices std::vector<id<MTLDevice> > devices; // the queues std::vector<id<MTLCommandQueue> > queues; // Warp size constant std::vector<int> warp_size; // Whether it is initialized. bool initialized_{false}; // the mutex for initialization std::mutex mutex; // Destructor ~MetalWorkspace(); // Get command queue for given context. id<MTLCommandQueue> GetCommandQueue(TVMContext ctx) { CHECK_EQ(ctx.device_type, kMetal); CHECK(ctx.device_id >= 0 && static_cast<size_t>(ctx.device_id) < queues.size()) << "Invalid Metal device_id=" << ctx.device_id; return queues[ctx.device_id]; } // Get device for given context id<MTLDevice> GetDevice(TVMContext ctx) { CHECK_EQ(ctx.device_type, kMetal); CHECK(ctx.device_id >= 0 && static_cast<size_t>(ctx.device_id) < devices.size()) << "Invalid Metal device_id=" << ctx.device_id; return devices[ctx.device_id]; } // Initialize workspace // Return false if already initialized, otherwise return true. void Init(); // override device API void SetDevice(TVMContext ctx) final; void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final; void* AllocDataSpace(TVMContext ctx, size_t size, size_t alignment) final; void FreeDataSpace(TVMContext ctx, void* ptr) final; void CopyDataFromTo(const void* from, size_t from_size, void* to, size_t to_size, size_t size, TVMContext ctx_from, TVMContext ctx_to, TVMStreamHandle stream) final; void StreamSync(TVMContext ctx, TVMStreamHandle stream) final; void* AllocWorkspace(TVMContext ctx, size_t size) final; void FreeWorkspace(TVMContext ctx, void* data) final; // get the global workspace static const std::shared_ptr<MetalWorkspace>& Global(); }; /*! \brief Thread local workspace */ class MetalThreadEntry { public: /*! \brief The current context */ TVMContext context; /*! \brief The shared buffer used for copy. */ std::vector<id<MTLBuffer> > temp_buffer_; /*! \brief workspace pool */ WorkspacePool pool; // constructor MetalThreadEntry() : pool(static_cast<DLDeviceType>(kMetal), MetalWorkspace::Global()) { context.device_id = 0; context.device_type = static_cast<DLDeviceType>(kMetal); } ~MetalThreadEntry(); // Get temp buffer with at least size under ctx. id<MTLBuffer> GetTempBuffer(TVMContext ctx, size_t size); // get the global workspace static MetalThreadEntry* ThreadLocal(); }; } // namespace metal } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_METAL_METAL_COMMON_H_