/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * Copyright (c) 2018 by Contributors * \file pynq_driver.c * \brief VTA driver for Pynq board. */ #include <vta/driver.h> #include <thread> #include "pynq_driver.h" void* VTAMemAlloc(size_t size, int cached) { return cma_alloc(size, cached); } void VTAMemFree(void* buf) { cma_free(buf); } vta_phy_addr_t VTAMemGetPhyAddr(void* buf) { return cma_get_phy_addr(buf); } void VTAFlushCache(vta_phy_addr_t buf, int size) { xlnkFlushCache(reinterpret_cast<void*>(buf), size); } void VTAInvalidateCache(vta_phy_addr_t buf, int size) { xlnkInvalidateCache(reinterpret_cast<void*>(buf), size); } void *VTAMapRegister(uint32_t addr, size_t length) { // Align the base address with the pages uint32_t virt_base = addr & ~(getpagesize() - 1); // Calculate base address offset w.r.t the base address uint32_t virt_offset = addr - virt_base; // Open file and mmap uint32_t mmap_file = open(VTA_PYNQ_DEV_MEM_PATH, O_RDWR|O_SYNC); return mmap(NULL, (length+virt_offset), PROT_READ|PROT_WRITE, MAP_SHARED, mmap_file, virt_base); } void VTAUnmapRegister(void *vta, size_t length) { // Unmap memory int status = munmap(vta, length); assert(status == 0); } void VTAWriteMappedReg(void* base_addr, uint32_t offset, uint32_t val) { *((volatile uint32_t *) (reinterpret_cast<char *>(base_addr) + offset)) = val; } uint32_t VTAReadMappedReg(void* base_addr, uint32_t offset) { return *((volatile uint32_t *) (reinterpret_cast<char *>(base_addr) + offset)); } class VTADevice { public: VTADevice() { // VTA stage handles vta_fetch_handle_ = VTAMapRegister(VTA_FETCH_ADDR, VTA_RANGE); vta_load_handle_ = VTAMapRegister(VTA_LOAD_ADDR, VTA_RANGE); vta_compute_handle_ = VTAMapRegister(VTA_COMPUTE_ADDR, VTA_RANGE); vta_store_handle_ = VTAMapRegister(VTA_STORE_ADDR, VTA_RANGE); } ~VTADevice() { // Close VTA stage handle VTAUnmapRegister(vta_fetch_handle_, VTA_RANGE); VTAUnmapRegister(vta_load_handle_, VTA_RANGE); VTAUnmapRegister(vta_compute_handle_, VTA_RANGE); VTAUnmapRegister(vta_store_handle_, VTA_RANGE); } int Run(vta_phy_addr_t insn_phy_addr, uint32_t insn_count, uint32_t wait_cycles) { // NOTE: Register address map is derived from the auto-generated // driver files available under hardware/build/vivado/<design>/export/driver // FETCH @ 0x10 : Data signal of insn_count_V VTAWriteMappedReg(vta_fetch_handle_, 0x10, insn_count); // FETCH @ 0x18 : Data signal of insns_V VTAWriteMappedReg(vta_fetch_handle_, 0x18, insn_phy_addr); // LOAD @ 0x10 : Data signal of inputs_V VTAWriteMappedReg(vta_load_handle_, 0x10, 0); // LOAD @ 0x18 : Data signal of weight_V VTAWriteMappedReg(vta_load_handle_, 0x18, 0); // COMPUTE @ 0x20 : Data signal of uops_V VTAWriteMappedReg(vta_compute_handle_, 0x20, 0); // COMPUTE @ 0x28 : Data signal of biases_V VTAWriteMappedReg(vta_compute_handle_, 0x28, 0); // STORE @ 0x10 : Data signal of outputs_V VTAWriteMappedReg(vta_store_handle_, 0x10, 0); // VTA start VTAWriteMappedReg(vta_fetch_handle_, 0x0, VTA_START); VTAWriteMappedReg(vta_load_handle_, 0x0, VTA_AUTORESTART); VTAWriteMappedReg(vta_compute_handle_, 0x0, VTA_AUTORESTART); VTAWriteMappedReg(vta_store_handle_, 0x0, VTA_AUTORESTART); // Loop until the VTA is done unsigned t, flag = 0; for (t = 0; t < wait_cycles; ++t) { flag = VTAReadMappedReg(vta_compute_handle_, 0x18); if (flag == VTA_DONE) break; std::this_thread::yield(); } // Report error if timeout return t < wait_cycles ? 0 : 1; } private: // VTA handles (register maps) void* vta_fetch_handle_{nullptr}; void* vta_load_handle_{nullptr}; void* vta_compute_handle_{nullptr}; void* vta_store_handle_{nullptr}; }; VTADeviceHandle VTADeviceAlloc() { return new VTADevice(); } void VTADeviceFree(VTADeviceHandle handle) { delete static_cast<VTADevice*>(handle); } int VTADeviceRun(VTADeviceHandle handle, vta_phy_addr_t insn_phy_addr, uint32_t insn_count, uint32_t wait_cycles) { return static_cast<VTADevice*>(handle)->Run( insn_phy_addr, insn_count, wait_cycles); }