/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ #include <tvm/runtime/module.h> #include <tvm/runtime/packed_func.h> #include <tvm/runtime/registry.h> #include <vta/dpi/module.h> #include <vta/dpi/tsim.h> #if defined(_WIN32) #include <windows.h> #else #include <dlfcn.h> #endif #include <mutex> #include <queue> #include <thread> #include <condition_variable> namespace vta { namespace dpi { using namespace tvm::runtime; typedef void* DeviceHandle; struct HostRequest { uint8_t opcode; uint8_t addr; uint32_t value; }; struct HostResponse { uint32_t value; }; struct MemResponse { uint8_t valid; uint64_t value; }; template <typename T> class ThreadSafeQueue { public: void Push(const T item) { std::lock_guard<std::mutex> lock(mutex_); queue_.push(std::move(item)); cond_.notify_one(); } void WaitPop(T* item) { std::unique_lock<std::mutex> lock(mutex_); cond_.wait(lock, [this]{return !queue_.empty();}); *item = std::move(queue_.front()); queue_.pop(); } bool TryPop(T* item, bool pop) { std::lock_guard<std::mutex> lock(mutex_); if (queue_.empty()) return false; *item = std::move(queue_.front()); if (pop) queue_.pop(); return true; } private: mutable std::mutex mutex_; std::queue<T> queue_; std::condition_variable cond_; }; class HostDevice { public: void PushRequest(uint8_t opcode, uint8_t addr, uint32_t value); bool TryPopRequest(HostRequest* r, bool pop); void PushResponse(uint32_t value); void WaitPopResponse(HostResponse* r); void Exit(); uint8_t GetExitStatus(); private: uint8_t exit_{0}; mutable std::mutex mutex_; ThreadSafeQueue<HostRequest> req_; ThreadSafeQueue<HostResponse> resp_; }; class MemDevice { public: void SetRequest(uint8_t opcode, uint64_t addr, uint32_t len); MemResponse ReadData(uint8_t ready); void WriteData(uint64_t value); private: uint64_t* raddr_{0}; uint64_t* waddr_{0}; uint32_t rlen_{0}; uint32_t wlen_{0}; std::mutex mutex_; }; void HostDevice::PushRequest(uint8_t opcode, uint8_t addr, uint32_t value) { HostRequest r; r.opcode = opcode; r.addr = addr; r.value = value; req_.Push(r); } bool HostDevice::TryPopRequest(HostRequest* r, bool pop) { r->opcode = 0xad; r->addr = 0xad; r->value = 0xbad; return req_.TryPop(r, pop); } void HostDevice::PushResponse(uint32_t value) { HostResponse r; r.value = value; resp_.Push(r); } void HostDevice::WaitPopResponse(HostResponse* r) { resp_.WaitPop(r); } void HostDevice::Exit() { std::unique_lock<std::mutex> lock(mutex_); exit_ = 1; } uint8_t HostDevice::GetExitStatus() { std::unique_lock<std::mutex> lock(mutex_); return exit_; } void MemDevice::SetRequest(uint8_t opcode, uint64_t addr, uint32_t len) { std::lock_guard<std::mutex> lock(mutex_); if (opcode == 1) { wlen_ = len + 1; waddr_ = reinterpret_cast<uint64_t*>(addr); } else { rlen_ = len + 1; raddr_ = reinterpret_cast<uint64_t*>(addr); } } MemResponse MemDevice::ReadData(uint8_t ready) { std::lock_guard<std::mutex> lock(mutex_); MemResponse r; r.valid = rlen_ > 0; r.value = rlen_ > 0 ? *raddr_ : 0xdeadbeefdeadbeef; if (ready == 1 && rlen_ > 0) { raddr_++; rlen_ -= 1; } return r; } void MemDevice::WriteData(uint64_t value) { std::lock_guard<std::mutex> lock(mutex_); if (wlen_ > 0) { *waddr_ = value; waddr_++; wlen_ -= 1; } } class DPIModule final : public DPIModuleNode { public: ~DPIModule() { if (lib_handle_) Unload(); } const char* type_key() const final { return "vta-tsim"; } PackedFunc GetFunction( const std::string& name, const std::shared_ptr<ModuleNode>& sptr_to_self) final { if (name == "WriteReg") { return TypedPackedFunc<void(int, int)>( [this](int addr, int value){ this->WriteReg(addr, value); }); } else { LOG(FATAL) << "Member " << name << "does not exists"; return nullptr; } } void Init(const std::string& name) { Load(name); VTADPIInitFunc finit = reinterpret_cast<VTADPIInitFunc>( GetSymbol("VTADPIInit")); CHECK(finit != nullptr); finit(this, VTAHostDPI, VTAMemDPI); fvsim_ = reinterpret_cast<VTADPISimFunc>(GetSymbol("VTADPISim")); CHECK(fvsim_ != nullptr); } void Launch(uint64_t max_cycles) { auto frun = [this, max_cycles]() { (*fvsim_)(max_cycles); }; vsim_thread_ = std::thread(frun); } void WriteReg(int addr, uint32_t value) { host_device_.PushRequest(1, addr, value); } uint32_t ReadReg(int addr) { uint32_t value; HostResponse* r = new HostResponse; host_device_.PushRequest(0, addr, 0); host_device_.WaitPopResponse(r); value = r->value; delete r; return value; } void Finish() { host_device_.Exit(); vsim_thread_.join(); } protected: VTADPISimFunc fvsim_; HostDevice host_device_; MemDevice mem_device_; std::thread vsim_thread_; void HostDPI(dpi8_t* exit, dpi8_t* req_valid, dpi8_t* req_opcode, dpi8_t* req_addr, dpi32_t* req_value, dpi8_t req_deq, dpi8_t resp_valid, dpi32_t resp_value) { HostRequest* r = new HostRequest; *exit = host_device_.GetExitStatus(); *req_valid = host_device_.TryPopRequest(r, req_deq); *req_opcode = r->opcode; *req_addr = r->addr; *req_value = r->value; if (resp_valid) { host_device_.PushResponse(resp_value); } delete r; } void MemDPI( dpi8_t req_valid, dpi8_t req_opcode, dpi8_t req_len, dpi64_t req_addr, dpi8_t wr_valid, dpi64_t wr_value, dpi8_t* rd_valid, dpi64_t* rd_value, dpi8_t rd_ready) { MemResponse r = mem_device_.ReadData(rd_ready); *rd_valid = r.valid; *rd_value = r.value; if (wr_valid) { mem_device_.WriteData(wr_value); } if (req_valid) { mem_device_.SetRequest(req_opcode, req_addr, req_len); } } static void VTAHostDPI( VTAContextHandle self, dpi8_t* exit, dpi8_t* req_valid, dpi8_t* req_opcode, dpi8_t* req_addr, dpi32_t* req_value, dpi8_t req_deq, dpi8_t resp_valid, dpi32_t resp_value) { static_cast<DPIModule*>(self)->HostDPI( exit, req_valid, req_opcode, req_addr, req_value, req_deq, resp_valid, resp_value); } static void VTAMemDPI( VTAContextHandle self, dpi8_t req_valid, dpi8_t req_opcode, dpi8_t req_len, dpi64_t req_addr, dpi8_t wr_valid, dpi64_t wr_value, dpi8_t* rd_valid, dpi64_t* rd_value, dpi8_t rd_ready) { static_cast<DPIModule*>(self)->MemDPI( req_valid, req_opcode, req_len, req_addr, wr_valid, wr_value, rd_valid, rd_value, rd_ready); } private: // Platform dependent handling. #if defined(_WIN32) // library handle HMODULE lib_handle_{nullptr}; // Load the library void Load(const std::string& name) { // use wstring version that is needed by LLVM. std::wstring wname(name.begin(), name.end()); lib_handle_ = LoadLibraryW(wname.c_str()); CHECK(lib_handle_ != nullptr) << "Failed to load dynamic shared library " << name; } void* GetSymbol(const char* name) { return reinterpret_cast<void*>( GetProcAddress(lib_handle_, (LPCSTR)name)); // NOLINT(*) } void Unload() { FreeLibrary(lib_handle_); } #else // Library handle void* lib_handle_{nullptr}; // load the library void Load(const std::string& name) { lib_handle_ = dlopen(name.c_str(), RTLD_LAZY | RTLD_LOCAL); CHECK(lib_handle_ != nullptr) << "Failed to load dynamic shared library " << name << " " << dlerror(); } void* GetSymbol(const char* name) { return dlsym(lib_handle_, name); } void Unload() { dlclose(lib_handle_); } #endif }; Module DPIModuleNode::Load(std::string dll_name) { std::shared_ptr<DPIModule> n = std::make_shared<DPIModule>(); n->Init(dll_name); return Module(n); } TVM_REGISTER_GLOBAL("module.loadfile_vta-tsim") .set_body([](TVMArgs args, TVMRetValue* rv) { *rv = DPIModuleNode::Load(args[0]); }); } // namespace dpi } // namespace vta