memory_manager.cc 4.53 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
21
 * \file tvm/runtime/vm/memory_manager.cc
22 23 24 25 26 27 28 29 30 31 32 33
 * \brief Allocate and manage memory for the runtime.
 */
#include <utility>
#include <memory>
#include "memory_manager.h"
#include "naive_allocator.h"
#include "pooled_allocator.h"

namespace tvm {
namespace runtime {
namespace vm {

34 35
static void BufferDeleter(Object* obj) {
  auto* ptr = static_cast<NDArray::Container*>(obj);
36 37 38 39 40 41 42 43
  CHECK(ptr->manager_ctx != nullptr);
  Buffer* buffer = reinterpret_cast<Buffer*>(ptr->manager_ctx);
  MemoryManager::Global()->GetAllocator(buffer->ctx)->
      Free(*(buffer));
  delete buffer;
  delete ptr;
}

44 45
void StorageObj::Deleter(Object* obj) {
  auto* ptr = static_cast<NDArray::Container*>(obj);
46 47 48 49 50 51 52 53 54 55 56 57 58 59
  // When invoking AllocNDArray we don't own the underlying allocation
  // and should not delete the buffer, but instead let it be reclaimed
  // by the storage object's destructor.
  //
  // We did bump the reference count by 1 to keep alive the StorageObj
  // allocation in case this NDArray is the sole owner.
  //
  // We decrement the object allowing for the buffer to release our
  // reference count from allocation.
  StorageObj* storage = reinterpret_cast<StorageObj*>(ptr->manager_ctx);
  storage->DecRef();
  delete ptr;
}

60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
inline void VerifyDataType(DLDataType dtype) {
  CHECK_GE(dtype.lanes, 1);
  if (dtype.code == kDLFloat) {
    CHECK_EQ(dtype.bits % 8, 0);
  } else {
    // allow uint1 as a special flag for bool.
    if (dtype.bits == 1 && dtype.code == kDLUInt) return;
    CHECK_EQ(dtype.bits % 8, 0);
  }
  CHECK_EQ(dtype.bits & (dtype.bits - 1), 0);
}

inline size_t GetDataAlignment(const DLTensor& arr) {
  size_t align = (arr.dtype.bits / 8) * arr.dtype.lanes;
  if (align < kAllocAlignment) return kAllocAlignment;
  return align;
}

78 79 80 81
NDArray StorageObj::AllocNDArray(size_t offset, std::vector<int64_t> shape, DLDataType dtype) {
  // TODO(@jroesch): generalize later to non-overlapping allocations.
  CHECK_EQ(offset, 0u);
  VerifyDataType(dtype);
82 83

  // crtical zone: allocate header, cannot throw
84
  NDArray::Container* container = new NDArray::Container(nullptr, shape, dtype, this->buffer.ctx);
85 86

  container->SetDeleter(StorageObj::Deleter);
87 88 89 90
  size_t needed_size = GetDataSize(container->dl_tensor);
  this->IncRef();
  container->manager_ctx = reinterpret_cast<void*>(this);
  container->dl_tensor.data = this->buffer.data;
91 92 93 94 95 96 97 98
  NDArray ret(GetObjectPtr<Object>(container));

  // RAII in effect, now run the check.
  // TODO(@jroesch): generalize later to non-overlapping allocations.
  CHECK(needed_size == this->buffer.size)
    << "size mistmatch required " << needed_size << " found " << this->buffer.size;

  return ret;
99 100
}

101 102 103 104 105 106 107 108
MemoryManager* MemoryManager::Global() {
  static MemoryManager memory_manager;
  return &memory_manager;
}

Allocator* MemoryManager::GetAllocator(TVMContext ctx) {
  std::lock_guard<std::mutex> lock(mu_);
  if (allocators_.find(ctx) == allocators_.end()) {
109 110
    DLOG(INFO) << "New allocator for " << DeviceName(ctx.device_type) << "("
               << ctx.device_id << ")";
111 112 113 114 115 116 117 118 119
    std::unique_ptr<Allocator> alloc(new NaiveAllocator(ctx));
    allocators_.emplace(ctx, std::move(alloc));
  }
  return allocators_.at(ctx).get();
}

NDArray Allocator::Empty(std::vector<int64_t> shape, DLDataType dtype, DLContext ctx) {
  VerifyDataType(dtype);
  NDArray::Container* container = new NDArray::Container(nullptr, shape, dtype, ctx);
120
  container->SetDeleter(BufferDeleter);
121 122 123 124 125 126
  size_t size = GetDataSize(container->dl_tensor);
  size_t alignment = GetDataAlignment(container->dl_tensor);
  Buffer *buffer = new Buffer;
  *buffer = this->Alloc(size, alignment, dtype);
  container->manager_ctx = reinterpret_cast<void*>(buffer);
  container->dl_tensor.data = buffer->data;
127
  return NDArray(GetObjectPtr<Object>(container));
128 129 130 131 132
}

}  // namespace vm
}  // namespace runtime
}  // namespace tvm