/*! * Copyright (c) 2017 by Contributors * \file Use external mps utils function */ #ifndef TVM_CONTRIB_MPS_MPS_UTILS_H_ #define TVM_CONTRIB_MPS_MPS_UTILS_H_ #import <MetalPerformanceShaders/MetalPerformanceShaders.h> #include <dmlc/logging.h> #include <dmlc/thread_local.h> #include <tvm/runtime/device_api.h> #include <tvm/runtime/registry.h> #include <tvm/runtime/util.h> #include <vector> #include "../../runtime/metal/metal_common.h" namespace tvm { namespace contrib { /*! breif Convert DLTensor type to MPS type */ struct MPSType { static MPSDataType DLTypeToMPSType(const DLDataType &dtype); }; // struct MPSType struct MetalThreadEntry { MetalThreadEntry(); ~MetalThreadEntry(); MPSImage *AllocMPSImage(id<MTLDevice> dev, MPSImageDescriptor *desc); MPSTemporaryImage *AllocTempImage(id<MTLCommandBuffer> cb, MPSImageDescriptor *desc); runtime::metal::MetalWorkspace *metal_api{nullptr}; static MetalThreadEntry *ThreadLocal(); std::vector<MPSImage *> img_table; }; // MetalThreadEntry } // namespace contrib } // namespace tvm #endif // TVM_CONTRIB_MPS_MPS_UTILS_H_