micro_device_api.cc 6.51 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file micro_device_api.cc
 */

#include <tvm/runtime/registry.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/c_runtime_api.h>
#include "../workspace_pool.h"
#include "micro_session.h"

namespace tvm {
namespace runtime {
/*!
 * \brief device API for uTVM micro devices
 */
class MicroDeviceAPI final : public DeviceAPI {
 public:
  /*! \brief constructor */
  MicroDeviceAPI() { }

  void SetDevice(TVMContext ctx) final {}

  void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final {
    if (kind == kExist) {
      *rv = 1;
    }
  }

  void* AllocDataSpace(TVMContext ctx,
                       size_t nbytes,
                       size_t alignment,
51
                       DLDataType type_hint) final {
52
    ObjectPtr<MicroSession>& session = MicroSession::Current();
53 54 55 56 57 58 59 60 61 62 63
    void* data = session->AllocateInSection(SectionKind::kHeap, nbytes).cast_to<void*>();
    CHECK(data != nullptr) << "unable to allocate " << nbytes << " bytes on device heap";
    MicroDevSpace* dev_space = new MicroDevSpace();
    dev_space->data = data;
    dev_space->session = session;
    return static_cast<void*>(dev_space);
  }

  void FreeDataSpace(TVMContext ctx, void* ptr) final {
    MicroDevSpace* dev_space = static_cast<MicroDevSpace*>(ptr);
    dev_space->session->FreeInSection(
64
      SectionKind::kHeap, DevPtr(reinterpret_cast<std::uintptr_t>(dev_space->data)));
65 66 67 68 69 70 71 72 73 74
    delete dev_space;
  }

  void CopyDataFromTo(const void* from,
                      size_t from_offset,
                      void* to,
                      size_t to_offset,
                      size_t size,
                      TVMContext ctx_from,
                      TVMContext ctx_to,
75
                      DLDataType type_hint,
76 77 78 79 80 81 82 83
                      TVMStreamHandle stream) final {
    std::tuple<int, int> type_from_to(ctx_from.device_type, ctx_to.device_type);
    if (type_from_to == std::make_tuple(kDLMicroDev, kDLMicroDev)) {
      // Copying from the device to the device.

      MicroDevSpace* from_space = static_cast<MicroDevSpace*>(const_cast<void*>(from));
      MicroDevSpace* to_space = static_cast<MicroDevSpace*>(const_cast<void*>(to));
      CHECK(from_space->session == to_space->session)
84 85 86
          << "attempt to copy data between different micro sessions ("
          << from_space->session.get()
          << " != " << to_space->session.get() << ")";
87 88
      CHECK(ctx_from.device_id == ctx_to.device_id)
        << "can only copy between the same micro device";
89
      ObjectPtr<MicroSession>& session = from_space->session;
90 91
      const std::shared_ptr<LowLevelDevice>& lld = session->low_level_device();

92 93
      DevPtr from_dev_addr = GetDevLoc(from_space, from_offset);
      DevPtr to_dev_addr = GetDevLoc(to_space, to_offset);
94 95

      std::vector<uint8_t> buffer(size);
96 97
      lld->Read(from_dev_addr, static_cast<void*>(buffer.data()), size);
      lld->Write(to_dev_addr, static_cast<void*>(buffer.data()), size);
98 99 100 101
    } else if (type_from_to == std::make_tuple(kDLMicroDev, kDLCPU)) {
      // Reading from the device.

      MicroDevSpace* from_space = static_cast<MicroDevSpace*>(const_cast<void*>(from));
102
      ObjectPtr<MicroSession>& session = from_space->session;
103 104
      const std::shared_ptr<LowLevelDevice>& lld = session->low_level_device();

105
      DevPtr from_dev_addr = GetDevLoc(from_space, from_offset);
106
      void* to_host_ptr = GetHostLoc(to, to_offset);
107
      lld->Read(from_dev_addr, to_host_ptr, size);
108 109 110 111
    } else if (type_from_to == std::make_tuple(kDLCPU, kDLMicroDev)) {
      // Writing to the device.

      MicroDevSpace* to_space = static_cast<MicroDevSpace*>(const_cast<void*>(to));
112
      ObjectPtr<MicroSession>& session = to_space->session;
113 114 115
      const std::shared_ptr<LowLevelDevice>& lld = session->low_level_device();

      void* from_host_ptr = GetHostLoc(from, from_offset);
116 117
      DevPtr to_dev_addr = GetDevLoc(to_space, to_offset);
      lld->Write(to_dev_addr, from_host_ptr, size);
118 119 120 121 122 123 124 125
    } else {
      LOG(FATAL) << "Expect copy from/to micro device or between micro device\n";
    }
  }

  void StreamSync(TVMContext ctx, TVMStreamHandle stream) final {
  }

126
  void* AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) final {
127
    ObjectPtr<MicroSession>& session = MicroSession::Current();
128 129 130 131 132 133 134 135 136 137 138

    void* data = session->AllocateInSection(SectionKind::kWorkspace, size).cast_to<void*>();
    CHECK(data != nullptr) << "unable to allocate " << size << " bytes on device workspace";
    MicroDevSpace* dev_space = new MicroDevSpace();
    dev_space->data = data;
    dev_space->session = session;
    return static_cast<void*>(dev_space);
  }

  void FreeWorkspace(TVMContext ctx, void* data) final {
    MicroDevSpace* dev_space = static_cast<MicroDevSpace*>(data);
139
    ObjectPtr<MicroSession>& session = dev_space->session;
140
    session->FreeInSection(SectionKind::kWorkspace,
141
                           DevPtr(reinterpret_cast<std::uintptr_t>(dev_space->data)));
142 143 144 145 146 147 148 149 150 151 152 153 154
    delete dev_space;
  }

  /*!
   * \brief obtain a global singleton of MicroDeviceAPI
   * \return global shared pointer to MicroDeviceAPI
   */
  static const std::shared_ptr<MicroDeviceAPI>& Global() {
    static std::shared_ptr<MicroDeviceAPI> inst = std::make_shared<MicroDeviceAPI>();
    return inst;
  }

 private:
155 156
  DevPtr GetDevLoc(MicroDevSpace* dev_space, size_t offset) {
    return DevPtr(reinterpret_cast<std::uintptr_t>(dev_space->data) + offset);
157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
  }

  void* GetHostLoc(const void* ptr, size_t offset) {
    return reinterpret_cast<void*>(reinterpret_cast<std::uintptr_t>(ptr) + offset);
  }
};

// register device that can be obtained from Python frontend
TVM_REGISTER_GLOBAL("device_api.micro_dev")
.set_body([](TVMArgs args, TVMRetValue* rv) {
    DeviceAPI* ptr = MicroDeviceAPI::Global().get();
    *rv = static_cast<void*>(ptr);
    });
}  // namespace runtime
}  // namespace tvm