mps_utils.h 1.12 KB
Newer Older
1 2 3 4 5 6 7 8
/*!
 *  Copyright (c) 2017 by Contributors
 * \file Use external mps utils function
 */

#ifndef TVM_CONTRIB_MPS_MPS_UTILS_H_
#define TVM_CONTRIB_MPS_MPS_UTILS_H_

Leyuan Wang committed
9
#import <MetalPerformanceShaders/MetalPerformanceShaders.h>
10
#include <dmlc/logging.h>
Leyuan Wang committed
11
#include <dmlc/thread_local.h>
12
#include <tvm/runtime/device_api.h>
Leyuan Wang committed
13 14 15
#include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h>
#include <vector>
16 17 18 19 20 21 22 23 24 25 26 27 28
#include "../../runtime/metal/metal_common.h"

namespace tvm {
namespace contrib {

/*! breif Convert DLTensor type to MPS type */
struct MPSType {
  static MPSDataType DLTypeToMPSType(const DLDataType &dtype);
};  // struct MPSType

struct MetalThreadEntry {
  MetalThreadEntry();
  ~MetalThreadEntry();
Leyuan Wang committed
29 30 31 32 33 34
  MPSImage *AllocMPSImage(id<MTLDevice> dev, MPSImageDescriptor *desc);
  MPSTemporaryImage *AllocTempImage(id<MTLCommandBuffer> cb,
                                    MPSImageDescriptor *desc);
  runtime::metal::MetalWorkspace *metal_api{nullptr};
  static MetalThreadEntry *ThreadLocal();
  std::vector<MPSImage *> img_table;
35 36 37 38 39 40
};  // MetalThreadEntry

}  // namespace contrib
}  // namespace tvm

#endif  // TVM_CONTRIB_MPS_MPS_UTILS_H_