/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 *  Copyright (c) 2019 by Contributors
 * \file runtime/pooled_allocator.h
 */
#ifndef TVM_RUNTIME_VM_POOLED_ALLOCATOR_H_
#define TVM_RUNTIME_VM_POOLED_ALLOCATOR_H_

#include <tvm/runtime/device_api.h>
#include <atomic>
#include <mutex>
#include <unordered_map>
#include <vector>

#include "memory_manager.h"

namespace tvm {
namespace runtime {
namespace vm {

class PooledAllocator final : public Allocator {
 public:
  static constexpr size_t kDefaultPageSize = 4096;

  explicit PooledAllocator(TVMContext ctx, size_t page_size = kDefaultPageSize)
      : Allocator(), page_size_(page_size), used_memory_(0), ctx_(ctx) {}

  ~PooledAllocator() { ReleaseAll(); }

  Buffer Alloc(size_t nbytes, size_t alignment, TVMType type_hint) override {
    std::lock_guard<std::mutex> lock(mu_);
    size_t size = ((nbytes + page_size_ - 1) / page_size_) * page_size_;
    auto&& it = memory_pool_.find(size);
    if (it != memory_pool_.end() && !it->second.empty()) {
      auto&& pool = it->second;
      auto ret = pool.back();
      pool.pop_back();
      return ret;
    }
    Buffer buf;
    buf.ctx = ctx_;
    buf.size = size;
    buf.data = DeviceAPI::Get(ctx_)->AllocDataSpace(ctx_, size, alignment, type_hint);
    used_memory_.fetch_add(size, std::memory_order_relaxed);
    DLOG(INFO) << "allocate " << size << " B, used memory " << used_memory_ << " B";
    return buf;
  }

  void Free(const Buffer& buffer) override {
    std::lock_guard<std::mutex> lock(mu_);
    if (memory_pool_.find(buffer.size) == memory_pool_.end()) {
      memory_pool_.emplace(buffer.size, std::vector<Buffer>{});
    }
    memory_pool_.at(buffer.size).push_back(buffer);
    DLOG(INFO) << "reclaim buffer " << buffer.size;
  }

  size_t UsedMemory() const override { return used_memory_.load(std::memory_order_relaxed); }

 private:
  void ReleaseAll() {
    std::lock_guard<std::mutex> lock(mu_);
    for (auto const& it : memory_pool_) {
      auto const& pool = it.second;
      for (auto const& buf : pool) {
        DeviceAPI::Get(buf.ctx)->FreeDataSpace(buf.ctx, buf.data);
      }
    }
    memory_pool_.clear();
    used_memory_ = 0;
    DLOG(INFO) << "release all buffers";
  }

 private:
  size_t page_size_;
  std::atomic<size_t> used_memory_;
  std::unordered_map<size_t, std::vector<Buffer> > memory_pool_;
  std::mutex mu_;
  TVMContext ctx_;
};

}  // namespace vm
}  // namespace runtime
}  // namespace tvm

#endif  // TVM_RUNTIME_VM_POOLED_ALLOCATOR_H_