mps_utils.h 1.88 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23
/*!
 * \file Use external mps utils function
 */

24 25
#ifndef TVM_RUNTIME_CONTRIB_MPS_MPS_UTILS_H_
#define TVM_RUNTIME_CONTRIB_MPS_MPS_UTILS_H_
26

Leyuan Wang committed
27
#import <MetalPerformanceShaders/MetalPerformanceShaders.h>
28
#include <dmlc/logging.h>
Leyuan Wang committed
29
#include <dmlc/thread_local.h>
30
#include <tvm/runtime/device_api.h>
Leyuan Wang committed
31 32 33
#include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h>
#include <vector>
34
#include "../../metal/metal_common.h"
35 36 37 38 39 40 41 42 43 44 45 46

namespace tvm {
namespace contrib {

/*! breif Convert DLTensor type to MPS type */
struct MPSType {
  static MPSDataType DLTypeToMPSType(const DLDataType &dtype);
};  // struct MPSType

struct MetalThreadEntry {
  MetalThreadEntry();
  ~MetalThreadEntry();
Leyuan Wang committed
47 48 49 50 51 52
  MPSImage *AllocMPSImage(id<MTLDevice> dev, MPSImageDescriptor *desc);
  MPSTemporaryImage *AllocTempImage(id<MTLCommandBuffer> cb,
                                    MPSImageDescriptor *desc);
  runtime::metal::MetalWorkspace *metal_api{nullptr};
  static MetalThreadEntry *ThreadLocal();
  std::vector<MPSImage *> img_table;
53 54 55 56 57
};  // MetalThreadEntry

}  // namespace contrib
}  // namespace tvm

58
#endif  // TVM_RUNTIME_CONTRIB_MPS_MPS_UTILS_H_