metal_common.h 4.33 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
/*!
 * \file metal_common.h
 * \brief Metal common header
 */
#ifndef TVM_RUNTIME_METAL_METAL_COMMON_H_
#define TVM_RUNTIME_METAL_METAL_COMMON_H_

#import <Metal/MTLBuffer.h>
#import <Metal/MTLCommandQueue.h>
#import <Metal/MTLCommandBuffer.h>
#import <Metal/MTLBlitCommandEncoder.h>
#import <Metal/MTLDevice.h>
#import <Metal/MTLLibrary.h>

#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/packed_func.h>
36
#include <tvm/runtime/device_api.h>
37 38 39 40
#include <dmlc/logging.h>
#include <mutex>
#include <string>
#include <vector>
41
#include <memory>
42
#include "../workspace_pool.h"
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61

namespace tvm {
namespace runtime {
namespace metal {
/*!
 * \brief Process global Metal workspace.
 */
class MetalWorkspace final : public DeviceAPI {
 public:
  // the devices
  std::vector<id<MTLDevice> > devices;
  // the queues
  std::vector<id<MTLCommandQueue> > queues;
  // Warp size constant
  std::vector<int> warp_size;
  // Whether it is initialized.
  bool initialized_{false};
  // the mutex for initialization
  std::mutex mutex;
62 63
  // Destructor
  ~MetalWorkspace();
64 65
  // Get command queue for given context.
  id<MTLCommandQueue> GetCommandQueue(TVMContext ctx) {
66
    CHECK_EQ(ctx.device_type, kDLMetal);
67 68 69 70 71 72
    CHECK(ctx.device_id >= 0  && static_cast<size_t>(ctx.device_id) < queues.size())
        << "Invalid Metal device_id=" << ctx.device_id;
    return queues[ctx.device_id];
  }
  // Get device for given context
  id<MTLDevice> GetDevice(TVMContext ctx) {
73
    CHECK_EQ(ctx.device_type, kDLMetal);
74 75 76 77 78 79 80 81
    CHECK(ctx.device_id >= 0  && static_cast<size_t>(ctx.device_id) < devices.size())
        << "Invalid Metal device_id=" << ctx.device_id;
    return devices[ctx.device_id];
  }
  // Initialize workspace
  // Return false if already initialized, otherwise return true.
  void Init();
  // override device API
82 83
  void SetDevice(TVMContext ctx) final;
  void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final;
84 85 86
  void* AllocDataSpace(TVMContext ctx,
                       size_t nbytes,
                       size_t alignment,
87
                       DLDataType type_hint) final;
88 89 90 91 92 93 94 95
  void FreeDataSpace(TVMContext ctx, void* ptr) final;
  void CopyDataFromTo(const void* from,
                      size_t from_size,
                      void* to,
                      size_t to_size,
                      size_t size,
                      TVMContext ctx_from,
                      TVMContext ctx_to,
96
                      DLDataType type_hint,
97 98
                      TVMStreamHandle stream) final;
  void StreamSync(TVMContext ctx, TVMStreamHandle stream) final;
99
  void* AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) final;
100
  void FreeWorkspace(TVMContext ctx, void* data) final;
101
  // get the global workspace
102
  static const std::shared_ptr<MetalWorkspace>& Global();
103 104 105 106 107 108 109 110 111
};

/*! \brief Thread local workspace */
class MetalThreadEntry {
 public:
  /*! \brief The current context */
  TVMContext context;
  /*! \brief The shared buffer used for copy. */
  std::vector<id<MTLBuffer> > temp_buffer_;
112 113 114 115
  /*! \brief workspace pool */
  WorkspacePool pool;
  // constructor
  MetalThreadEntry()
116
      : pool(static_cast<DLDeviceType>(kDLMetal), MetalWorkspace::Global()) {
117
    context.device_id = 0;
118
    context.device_type = static_cast<DLDeviceType>(kDLMetal);
119
  }
120
  ~MetalThreadEntry();
121 122 123 124 125 126 127 128 129
  // Get temp buffer with at least size under ctx.
  id<MTLBuffer> GetTempBuffer(TVMContext ctx, size_t size);
  // get the global workspace
  static MetalThreadEntry* ThreadLocal();
};
}  // namespace metal
}  // namespace runtime
}  // namespace tvm
#endif  // TVM_RUNTIME_METAL_METAL_COMMON_H_