cudnn_utils.cc 4.45 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
/*!
 *  Copyright (c) 2017 by Contributors
 * \file Use external cudnn utils function
 */
#include "cudnn_utils.h"
#include <dmlc/thread_local.h>
#include <tvm/runtime/registry.h>


namespace tvm {
namespace contrib {

// CuDNN Data Type
cudnnDataType_t CuDNNDataType::DLTypeToCuDNNType(const DLDataType &dtype) {
  switch (dtype.code) {
35
      case kDLInt:
36 37 38 39 40 41
        if (dtype.bits == 8 && dtype.lanes == 1) return CUDNN_DATA_INT8;
        else if (dtype.bits == 32 && dtype.lanes == 1) return CUDNN_DATA_INT32;
        else if (dtype.bits == 8 && dtype.lanes == 4) return CUDNN_DATA_INT8x4;
        else
          LOG(FATAL) << "Unsupported type";
        break;
42
      case kDLUInt:
43 44
        LOG(FATAL) << "Unsupported type";
        break;
45
      case kDLFloat:
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
        if (dtype.bits == 32 && dtype.lanes == 1) return CUDNN_DATA_FLOAT;
        else if (dtype.bits == 64 && dtype.lanes == 1) return CUDNN_DATA_DOUBLE;
        else if (dtype.bits == 16 && dtype.lanes == 1) return CUDNN_DATA_HALF;
        else
          LOG(FATAL) << "Unsupported type";
        break;
    }
    return CUDNN_DATA_FLOAT;
}

template<>
const void* CuDNNDataType::GetConst<0>(cudnnDataType_t type) {
  static const int int_v = 0;
  static const float float_v = 0;
  static const double double_v = 0;
  if (type == CUDNN_DATA_FLOAT || type == CUDNN_DATA_HALF) {
    return static_cast<const void*>(&float_v);
  }
  if (type == CUDNN_DATA_DOUBLE) {
    return static_cast<const void*>(&double_v);
  }
  if (type == CUDNN_DATA_INT8 || type == CUDNN_DATA_INT32 || type == CUDNN_DATA_INT8x4) {
    return static_cast<const void*>(&int_v);
  }
  return nullptr;
}

template<>
const void* CuDNNDataType::GetConst<1>(cudnnDataType_t type) {
  static const int int_v = 1;
  static const float float_v = 1.f;
  static const double double_v = 1.f;
  if (type == CUDNN_DATA_FLOAT || type == CUDNN_DATA_HALF) {
    return static_cast<const void*>(&float_v);
  }
  if (type == CUDNN_DATA_DOUBLE) {
    return static_cast<const void*>(&double_v);
  }
  if (type == CUDNN_DATA_INT8 || type == CUDNN_DATA_INT32 || type == CUDNN_DATA_INT8x4) {
    return static_cast<const void*>(&int_v);
  }
  return nullptr;
}

// CuDNNThreadEntry

CuDNNThreadEntry::CuDNNThreadEntry() {
  auto stream = runtime::CUDAThreadEntry::ThreadLocal()->stream;
  auto func = runtime::Registry::Get("device_api.gpu");
  void *ret = (*func)();
  cuda_api = static_cast<runtime::DeviceAPI*>(ret);
  CUDNN_CALL(cudnnCreate(&handle));
  CUDNN_CALL(cudnnSetStream(handle, stream));
  conv_entry.cuda_api = cuda_api;
}

CuDNNThreadEntry::~CuDNNThreadEntry() {
  CUDNN_CALL(cudnnDestroy(handle));
}

typedef dmlc::ThreadLocalStore<CuDNNThreadEntry> CuDNNThreadStore;

CuDNNThreadEntry* CuDNNThreadEntry::ThreadLocal() {
  return CuDNNThreadStore::Get();
}

// ConvEntry

ConvEntry::ConvEntry() {
  CUDNN_CALL(cudnnCreateConvolutionDescriptor(&conv_desc));
  CUDNN_CALL(cudnnCreateFilterDescriptor(&filter_desc));
  CUDNN_CALL(cudnnCreateTensorDescriptor(&input_desc));
  CUDNN_CALL(cudnnCreateTensorDescriptor(&output_desc));
}

ConvEntry::~ConvEntry() {
  CUDNN_CALL(cudnnDestroyFilterDescriptor(filter_desc));
  CUDNN_CALL(cudnnDestroyConvolutionDescriptor(conv_desc));
  CUDNN_CALL(cudnnDestroyTensorDescriptor(input_desc));
  CUDNN_CALL(cudnnDestroyTensorDescriptor(output_desc));
  CleanWorkspace();
}

void ConvEntry::UpdateWorkspace(const size_t wsize) {
  if (workspace_size < wsize) {
    if (workspace != nullptr) {
      CleanWorkspace();
    }
    workspace_size = wsize;
    workspace = cuda_api->AllocWorkspace(ctx, workspace_size);
  }
}

void ConvEntry::CleanWorkspace() {
  if (workspace) cuda_api->FreeWorkspace(ctx, workspace);
  workspace_size = 0;
}

}  // namespace contrib
}  // namespace tvm