Unverified Commit 4683c3f5 by Thierry Moreau Committed by GitHub

[VTA] HW sources refactor (#5188)

* refactor

* path udpate
parent 0e2701b2
...@@ -7,3 +7,6 @@ ...@@ -7,3 +7,6 @@
[submodule "3rdparty/rang"] [submodule "3rdparty/rang"]
path = 3rdparty/rang path = 3rdparty/rang
url = https://github.com/agauniyal/rang url = https://github.com/agauniyal/rang
[submodule "3rdparty/vta-hw"]
path = 3rdparty/vta-hw
url = https://github.com/apache/incubator-tvm-vta
Subproject commit db65157208ec8fabb7b548c94596211b9db04190
...@@ -29,7 +29,7 @@ ifndef DLPACK_PATH ...@@ -29,7 +29,7 @@ ifndef DLPACK_PATH
endif endif
ifndef VTA_HW_PATH ifndef VTA_HW_PATH
VTA_HW_PATH = $(ROOTDIR)/vta/vta-hw VTA_HW_PATH = $(ROOTDIR)/3rdparty/vta-hw
endif endif
INCLUDE_FLAGS = -Iinclude -I$(DLPACK_PATH)/include -I$(DMLC_CORE_PATH)/include INCLUDE_FLAGS = -Iinclude -I$(DLPACK_PATH)/include -I$(DMLC_CORE_PATH)/include
......
...@@ -20,7 +20,7 @@ find_program(PYTHON NAMES python python3 python3.6) ...@@ -20,7 +20,7 @@ find_program(PYTHON NAMES python python3 python3.6)
# Throw error if VTA_HW_PATH is not set # Throw error if VTA_HW_PATH is not set
if(NOT DEFINED ENV{VTA_HW_PATH}) if(NOT DEFINED ENV{VTA_HW_PATH})
set(VTA_HW_PATH ${CMAKE_CURRENT_SOURCE_DIR}/vta/vta-hw) set(VTA_HW_PATH ${CMAKE_CURRENT_SOURCE_DIR}/3rdparty/vta-hw)
else() else()
set(VTA_HW_PATH $ENV{VTA_HW_PATH}) set(VTA_HW_PATH $ENV{VTA_HW_PATH})
endif() endif()
......
...@@ -21,7 +21,7 @@ VTA Configuration ...@@ -21,7 +21,7 @@ VTA Configuration
The VTA stack incorporates both a hardware accelerator stack and The VTA stack incorporates both a hardware accelerator stack and
a TVM based software stack. a TVM based software stack.
VTA incorporates flexibility out of the box: by modifying the VTA incorporates flexibility out of the box: by modifying the
``vta/vta-hw/config/vta_config.json`` high-level configuration file, ``3rdparty/vta-hw/config/vta_config.json`` high-level configuration file,
the user can change the shape of the tensor intrinsic, the user can change the shape of the tensor intrinsic,
clock frequency, pipelining, data type width, and on-chip buffer sizes. clock frequency, pipelining, data type width, and on-chip buffer sizes.
......
...@@ -53,17 +53,17 @@ HLS Hardware Source Organization ...@@ -53,17 +53,17 @@ HLS Hardware Source Organization
The VTA design is currently specified in Vivado HLS C++, which is only supported The VTA design is currently specified in Vivado HLS C++, which is only supported
by Xilinx toolchains. by Xilinx toolchains.
The VTA hardware sources are contained under ``vta/vta-hw/hardware/xilinx/sources``: The VTA hardware sources are contained under ``3rdparty/vta-hw/hardware/xilinx/sources``:
- ``vta.cc`` contains the definitions for each VTA module, as well as a top - ``vta.cc`` contains the definitions for each VTA module, as well as a top
level behavioral model for the top-level VTA design. level behavioral model for the top-level VTA design.
- ``vta.h`` contains type definitions using Xilinx ``ap_int`` types, and - ``vta.h`` contains type definitions using Xilinx ``ap_int`` types, and
function prototypes declarations. function prototypes declarations.
In addition preprocessor macros are defined under ``vta/vta-hw/include/vta/hw_spec.h``. In addition preprocessor macros are defined under ``3rdparty/vta-hw/include/vta/hw_spec.h``.
Much of these macro definitions are derived from the parameters listed in the Much of these macro definitions are derived from the parameters listed in the
``vta/vta-hw/config/vta_config.json`` file. ``3rdparty/vta-hw/config/vta_config.json`` file.
The json file is processed by ``vta/vta-hw/config/vta_config.py`` to produce a string of The json file is processed by ``3rdparty/vta-hw/config/vta_config.py`` to produce a string of
compile flags that define the preprocessor macros. compile flags that define the preprocessor macros.
That string is used by the makefile in order to set those high-level That string is used by the makefile in order to set those high-level
parameters in both the HLS hardware synthesis compiler, and the C++ parameters in both the HLS hardware synthesis compiler, and the C++
...@@ -220,7 +220,7 @@ Microarchitectural Overview ...@@ -220,7 +220,7 @@ Microarchitectural Overview
--------------------------- ---------------------------
We describe the modules that compose the VTA design. We describe the modules that compose the VTA design.
The module definitions are contained in ``vta/vta-hw/hardware/xilinx/sources/vta.cc``. The module definitions are contained in ``3rdparty/vta-hw/hardware/xilinx/sources/vta.cc``.
Fetch Module Fetch Module
~~~~~~~~~~~~ ~~~~~~~~~~~~
......
...@@ -32,7 +32,7 @@ For a quick and easy start, checkout the [Docker Guide](https://tvm.apache.org/d ...@@ -32,7 +32,7 @@ For a quick and easy start, checkout the [Docker Guide](https://tvm.apache.org/d
You'll need to set the following paths to use VTA: You'll need to set the following paths to use VTA:
```bash ```bash
export TVM_PATH=<path to TVM root> export TVM_PATH=<path to TVM root>
export VTA_HW_PATH=$TVM_PATH/vta/vta-hw export VTA_HW_PATH=$TVM_PATH/3rdparty/vta-hw
``` ```
The VTA functional simulation library needs to be enabled when building TVM. The VTA functional simulation library needs to be enabled when building TVM.
...@@ -66,7 +66,7 @@ You are invited to try out our [VTA programming tutorials](https://tvm.apache.or ...@@ -66,7 +66,7 @@ You are invited to try out our [VTA programming tutorials](https://tvm.apache.or
### Advanced Configuration (optional) ### Advanced Configuration (optional)
VTA is a generic configurable deep learning accelerator. VTA is a generic configurable deep learning accelerator.
The configuration is specified by `vta_config.json` under `vta/vta-hw/config`. The configuration is specified by `vta_config.json` under `3rdparty/vta-hw/config`.
This file provides an architectural specification of the VTA accelerator to parameterize the TVM compiler stack and the VTA hardware stack. This file provides an architectural specification of the VTA accelerator to parameterize the TVM compiler stack and the VTA hardware stack.
The VTA configuration file also specifies the TVM compiler target. The VTA configuration file also specifies the TVM compiler target.
...@@ -76,7 +76,7 @@ To do so, ...@@ -76,7 +76,7 @@ To do so,
```bash ```bash
cd <tvm root> cd <tvm root>
vim vta/vta-hw/config/vta_config.json vim 3rdparty/vta-hw/config/vta_config.json
# edit vta_config.json # edit vta_config.json
make make
``` ```
...@@ -134,7 +134,7 @@ mkdir build ...@@ -134,7 +134,7 @@ mkdir build
cp cmake/config.cmake build/. cp cmake/config.cmake build/.
echo 'set(USE_VTA_FPGA ON)' >> build/config.cmake echo 'set(USE_VTA_FPGA ON)' >> build/config.cmake
# Copy pynq specific configuration # Copy pynq specific configuration
cp vta/vta-hw/config/pynq_sample.json vta/vta-hw/config/vta_config.json cp 3rdparty/vta-hw/config/pynq_sample.json 3rdparty/vta-hw/config/vta_config.json
cd build cd build
cmake .. cmake ..
make runtime vta -j2 make runtime vta -j2
...@@ -168,7 +168,7 @@ In addition, you'll need to edit the `vta_config.json` file on the host to indic ...@@ -168,7 +168,7 @@ In addition, you'll need to edit the `vta_config.json` file on the host to indic
```bash ```bash
# On the Host-side # On the Host-side
cd <tvm root> cd <tvm root>
cp vta/vta-hw/config/pynq_sample.json vta/vta-hw/config/vta_config.json cp 3rdparty/vta-hw/config/pynq_sample.json 3rdparty/vta-hw/config/vta_config.json
``` ```
This time again, we will run the 2D convolution testbench. This time again, we will run the 2D convolution testbench.
...@@ -359,11 +359,11 @@ For this custom VTA bitstream compilation exercise, we'll change the frequency o ...@@ -359,11 +359,11 @@ For this custom VTA bitstream compilation exercise, we'll change the frequency o
* Set the `HW_FREQ` field to `142`. The Pynq board supports 100, 142, 167 and 200MHz clocks. Note that the higher the frequency, the harder it will be to close timing. Increasing the frequency can lead to timing violation and thus faulty hardware execution. * Set the `HW_FREQ` field to `142`. The Pynq board supports 100, 142, 167 and 200MHz clocks. Note that the higher the frequency, the harder it will be to close timing. Increasing the frequency can lead to timing violation and thus faulty hardware execution.
* Set the `HW_CLK_TARGET` to `6`. This parameters refers to the target clock period in nano seconds for HLS - a lower clock period leads to more aggressive pipelining to achieve timing closure at higher frequencies. Technically a 142MHz clock would require a 7ns target, but we intentionally lower the clock target to 6ns to more aggressively pipeline our design. * Set the `HW_CLK_TARGET` to `6`. This parameters refers to the target clock period in nano seconds for HLS - a lower clock period leads to more aggressive pipelining to achieve timing closure at higher frequencies. Technically a 142MHz clock would require a 7ns target, but we intentionally lower the clock target to 6ns to more aggressively pipeline our design.
Bitstream generation is driven by a top-level `Makefile` under `<tvm root>/vta/vta-hw/hardware/xilinx/`. Bitstream generation is driven by a top-level `Makefile` under `<tvm root>/3rdparty/vta-hw/hardware/xilinx/`.
If you just want to simulate the VTA design in software emulation to make sure that it is functional, enter: If you just want to simulate the VTA design in software emulation to make sure that it is functional, enter:
```bash ```bash
cd <tvm root>/vta/vta-hw/hardware/xilinx cd <tvm root>/3rdparty/vta-hw/hardware/xilinx
make ip MODE=sim make ip MODE=sim
``` ```
...@@ -371,7 +371,7 @@ If you just want to generate the HLS-based VTA IP cores without launching the en ...@@ -371,7 +371,7 @@ If you just want to generate the HLS-based VTA IP cores without launching the en
```bash ```bash
make ip make ip
``` ```
You'll be able to view the HLS synthesis reports under `<tvm root>/vta/vta-hw/build/hardware/xilinx/hls/` `<configuration>/<block>/solution0/syn/report/<block>_csynth.rpt` You'll be able to view the HLS synthesis reports under `<tvm root>/3rdparty/vta-hw/build/hardware/xilinx/hls/` `<configuration>/<block>/solution0/syn/report/<block>_csynth.rpt`
> Note: The `<configuration>` name is a string that summarizes the VTA configuration parameters listed in the `vta_config.json`. The `<block>` name refers to the specific module (or HLS function) that compose the high-level VTA pipeline. > Note: The `<configuration>` name is a string that summarizes the VTA configuration parameters listed in the `vta_config.json`. The `<block>` name refers to the specific module (or HLS function) that compose the high-level VTA pipeline.
Finally to run the full hardware compilation and generate the VTA bitstream, run: Finally to run the full hardware compilation and generate the VTA bitstream, run:
...@@ -383,20 +383,20 @@ make ...@@ -383,20 +383,20 @@ make
This process is lengthy, and can take around up to an hour to complete depending on your machine's specs. This process is lengthy, and can take around up to an hour to complete depending on your machine's specs.
We recommend setting the `VTA_HW_COMP_THREADS` variable in the Makefile to take full advantage of all the cores on your development machine. We recommend setting the `VTA_HW_COMP_THREADS` variable in the Makefile to take full advantage of all the cores on your development machine.
Once the compilation completes, the generated bitstream can be found under `<tvm root>/vta/vta-hw/build/hardware/xilinx/vivado/<configuration>/export/vta.bit`. Once the compilation completes, the generated bitstream can be found under `<tvm root>/3rdparty/vta-hw/build/hardware/xilinx/vivado/<configuration>/export/vta.bit`.
### Chisel-based Custom VTA Bitstream Compilation for DE10-Nano ### Chisel-based Custom VTA Bitstream Compilation for DE10-Nano
Similar to the HLS-based design, high-level hardware parameters in Chisel-based design are listed in the VTA configuration file [Configs.scala](https://github.com/apache/incubator-tvm/blob/master/vta/vta-hw/hardware/chisel/src/main/scala/core/Configs.scala), and they can be customized by the user. Similar to the HLS-based design, high-level hardware parameters in Chisel-based design are listed in the VTA configuration file [Configs.scala](https://github.com/apache/incubator-tvm/blob/master/3rdparty/vta-hw/hardware/chisel/src/main/scala/core/Configs.scala), and they can be customized by the user.
For Intel FPGA, bitstream generation is driven by a top-level `Makefile` under `<tvmroot>/vta/vta-hw/hardware/intel`. For Intel FPGA, bitstream generation is driven by a top-level `Makefile` under `<tvm root>/3rdparty/vta-hw/hardware/intel`.
If you just want to generate the Chisel-based VTA IP core for the DE10-Nano board without compiling the design for the FPGA hardware, enter: If you just want to generate the Chisel-based VTA IP core for the DE10-Nano board without compiling the design for the FPGA hardware, enter:
```bash ```bash
cd <tvmroot>/vta/vta-hw/hardware/intel cd <tvm root>/3rdparty/vta-hw/hardware/intel
make ip make ip
``` ```
Then you'll be able to locate the generated verilog file at `<tvmroot>/vta/vta-hw/build/hardware/intel/chisel/<configuration>/VTA.DefaultDe10Config.v`. Then you'll be able to locate the generated verilog file at `<tvm root>/3rdparty/vta-hw/build/hardware/intel/chisel/<configuration>/VTA.DefaultDe10Config.v`.
If you would like to run the full hardware compilation for the `de10nano` board: If you would like to run the full hardware compilation for the `de10nano` board:
```bash ```bash
...@@ -405,14 +405,14 @@ make ...@@ -405,14 +405,14 @@ make
This process might be a bit lengthy, and might take up to half an hour to complete depending on the performance of your PC. The Quartus Prime software would automatically detect the number of cores available on your PC and try to utilize all of them to perform such process. This process might be a bit lengthy, and might take up to half an hour to complete depending on the performance of your PC. The Quartus Prime software would automatically detect the number of cores available on your PC and try to utilize all of them to perform such process.
Once the compilation completes, the generated bistream can be found under `<tvmroot>vtay/vta-hw/build/hardware/intel/quartus/<configuration>/export/vta.rbf`. You can also open the Quartus project file (.qpf) available at `<tvmroot>/vta/vta-hw/build/hardware/intel/quartus/<configuration>/de10_nano_top.qpf` to look around the generated reports. Once the compilation completes, the generated bistream can be found under `<tvm root>/3rdparty/vta-hw/build/hardware/intel/quartus/<configuration>/export/vta.rbf`. You can also open the Quartus project file (.qpf) available at `<tvm root>/3rdparty/vta-hw/build/hardware/intel/quartus/<configuration>/de10_nano_top.qpf` to look around the generated reports.
### Use the Custom Bitstream ### Use the Custom Bitstream
We can program the new VTA FPGA bitstream by setting the bitstream path of the `vta.program_fpga()` function in the tutorial examples, or in the `test_program_rpc.py` script. We can program the new VTA FPGA bitstream by setting the bitstream path of the `vta.program_fpga()` function in the tutorial examples, or in the `test_program_rpc.py` script.
```python ```python
vta.program_fpga(remote, bitstream="<tvm root>/vta/vta-hw/build/hardware/xilinx/vivado/<configuration>/export/vta.bit") vta.program_fpga(remote, bitstream="<tvm root>/3rdparty/vta-hw/build/hardware/xilinx/vivado/<configuration>/export/vta.bit")
``` ```
Instead of downloading a pre-built bitstream from the VTA bitstream repository, TVM will instead use the new bitstream you just generated, which is a VTA design clocked at a higher frequency. Instead of downloading a pre-built bitstream from the VTA bitstream repository, TVM will instead use the new bitstream you just generated, which is a VTA design clocked at a higher frequency.
......
...@@ -15,4 +15,5 @@ ...@@ -15,4 +15,5 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
export VTA_HW_PATH=`pwd`/3rdparty/vta-hw
cd $1 && cmake .. && make $2 && cd .. cd $1 && cmake .. && make $2 && cd ..
...@@ -20,6 +20,8 @@ set -e ...@@ -20,6 +20,8 @@ set -e
set -u set -u
export LD_LIBRARY_PATH="lib:${LD_LIBRARY_PATH:-}" export LD_LIBRARY_PATH="lib:${LD_LIBRARY_PATH:-}"
# NOTE: important to use abspath, when VTA is enabled.
export VTA_HW_PATH=`pwd`/3rdparty/vta-hw
# Remove existing testcases # Remove existing testcases
rm -f build/*_test rm -f build/*_test
......
...@@ -21,7 +21,7 @@ set -u ...@@ -21,7 +21,7 @@ set -u
export TVM_PATH=`pwd` export TVM_PATH=`pwd`
export PYTHONPATH=${TVM_PATH}/python:${TVM_PATH}/vta/python:${TVM_PATH}/topi/python export PYTHONPATH=${TVM_PATH}/python:${TVM_PATH}/vta/python:${TVM_PATH}/topi/python
export VTA_HW_PATH=`pwd`/vta/vta-hw export VTA_HW_PATH=`pwd`/3rdparty/vta-hw
# cleanup pycache # cleanup pycache
find . -type f -path "*.pyc" | xargs rm -f find . -type f -path "*.pyc" | xargs rm -f
......
...@@ -21,7 +21,7 @@ set -u ...@@ -21,7 +21,7 @@ set -u
export TVM_PATH=`pwd` export TVM_PATH=`pwd`
export PYTHONPATH=${TVM_PATH}/python:${TVM_PATH}/vta/python:${TVM_PATH}/topi/python export PYTHONPATH=${TVM_PATH}/python:${TVM_PATH}/vta/python:${TVM_PATH}/topi/python
export VTA_HW_PATH=`pwd`/vta/vta-hw export VTA_HW_PATH=`pwd`/3rdparty/vta-hw
# cleanup pycache # cleanup pycache
find . -type f -path "*.pyc" | xargs rm -f find . -type f -path "*.pyc" | xargs rm -f
......
...@@ -28,7 +28,7 @@ from . import intrin ...@@ -28,7 +28,7 @@ from . import intrin
def get_vta_hw_path(): def get_vta_hw_path():
"""Get the VTA HW path.""" """Get the VTA HW path."""
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
vta_hw_default = os.path.abspath(os.path.join(curr_path, "../../vta-hw")) vta_hw_default = os.path.abspath(os.path.join(curr_path, "../../../3rdparty/vta-hw"))
VTA_HW_PATH = os.getenv('VTA_HW_PATH', vta_hw_default) VTA_HW_PATH = os.getenv('VTA_HW_PATH', vta_hw_default)
return os.path.abspath(VTA_HW_PATH) return os.path.abspath(VTA_HW_PATH)
......
...@@ -181,7 +181,7 @@ def compile_network(env, target, model, start_pack, stop_pack): ...@@ -181,7 +181,7 @@ def compile_network(env, target, model, start_pack, stop_pack):
tracker_host = os.environ.get("TVM_TRACKER_HOST", '0.0.0.0') tracker_host = os.environ.get("TVM_TRACKER_HOST", '0.0.0.0')
tracker_port = int(os.environ.get("TVM_TRACKER_PORT", 9190)) tracker_port = int(os.environ.get("TVM_TRACKER_PORT", 9190))
# Load VTA parameters from the vta/vta-hw/config/vta_config.json file # Load VTA parameters from the 3rdparty/vta-hw/config/vta_config.json file
env = vta.get_env() env = vta.get_env()
# This target is used for cross compilation. You can query it by :code:`gcc -v` on your device. # This target is used for cross compilation. You can query it by :code:`gcc -v` on your device.
......
...@@ -68,7 +68,7 @@ assert tvm.runtime.enabled("rpc") ...@@ -68,7 +68,7 @@ assert tvm.runtime.enabled("rpc")
# ------------------------------------- # -------------------------------------
# Execute on CPU vs. VTA, and define the model. # Execute on CPU vs. VTA, and define the model.
# Load VTA parameters from the vta/vta-hw/config/vta_config.json file # Load VTA parameters from the 3rdparty/vta-hw/config/vta_config.json file
env = vta.get_env() env = vta.get_env()
# Set ``device=arm_cpu`` to run inference on the CPU # Set ``device=arm_cpu`` to run inference on the CPU
......
...@@ -111,7 +111,7 @@ names = [x.strip() for x in content] ...@@ -111,7 +111,7 @@ names = [x.strip() for x in content]
# -------------------------------------- # --------------------------------------
# Execute on CPU vs. VTA, and define the model. # Execute on CPU vs. VTA, and define the model.
# Load VTA parameters from the vta/vta-hw/config/vta_config.json file # Load VTA parameters from the 3rdparty/vta-hw/config/vta_config.json file
env = vta.get_env() env = vta.get_env()
# Set ``device=arm_cpu`` to run inference on the CPU # Set ``device=arm_cpu`` to run inference on the CPU
# or ``device=vta`` to run inference on the FPGA. # or ``device=vta`` to run inference on the FPGA.
......
...@@ -43,7 +43,7 @@ from tvm import rpc ...@@ -43,7 +43,7 @@ from tvm import rpc
from tvm.contrib import util from tvm.contrib import util
from vta.testing import simulator from vta.testing import simulator
# Load VTA parameters from the vta/vta-hw/config/vta_config.json file # Load VTA parameters from the 3rdparty/vta-hw/config/vta_config.json file
env = vta.get_env() env = vta.get_env()
# We read the Pynq RPC host IP address and port number from the OS environment # We read the Pynq RPC host IP address and port number from the OS environment
......
...@@ -47,7 +47,7 @@ from tvm import rpc ...@@ -47,7 +47,7 @@ from tvm import rpc
from tvm.contrib import util from tvm.contrib import util
from vta.testing import simulator from vta.testing import simulator
# Load VTA parameters from the vta/vta-hw/config/vta_config.json file # Load VTA parameters from the 3rdparty/vta-hw/config/vta_config.json file
env = vta.get_env() env = vta.get_env()
# We read the Pynq RPC host IP address and port number from the OS environment # We read the Pynq RPC host IP address and port number from the OS environment
......
...@@ -46,7 +46,7 @@ from tvm import rpc ...@@ -46,7 +46,7 @@ from tvm import rpc
from tvm.contrib import util from tvm.contrib import util
from vta.testing import simulator from vta.testing import simulator
# Load VTA parameters from the vta/vta-hw/config/vta_config.json file # Load VTA parameters from the 3rdparty/vta-hw/config/vta_config.json file
env = vta.get_env() env = vta.get_env()
# We read the Pynq RPC host IP address and port number from the OS environment # We read the Pynq RPC host IP address and port number from the OS environment
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
cmake_minimum_required(VERSION 3.2)
project(tsim C CXX)
if(NOT DEFINED ENV{TVM_PATH})
message(ERROR "Make sure to set TVM_PATH in your environment")
endif()
if(NOT DEFINED ENV{VTA_HW_PATH})
message(ERROR "Make sure to set VTA_HW_PATH in your environment")
endif()
include_directories("$ENV{TVM_PATH}/include")
include_directories("$ENV{TVM_PATH}/3rdparty/dlpack/include")
include_directories("$ENV{TVM_PATH}/3rdparty/dmlc-core/include")
include_directories("$ENV{VTA_HW_PATH}/src/dpi")
set(CMAKE_C_FLAGS "-O2 -Wall -fPIC -fvisibility=hidden")
set(CMAKE_CXX_FLAGS "-O2 -Wall -fPIC -fvisibility=hidden -std=c++11")
if (CMAKE_CXX_COMPILER_ID MATCHES "GNU" AND
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0)
set(CMAKE_CXX_FLAGS "-faligned-new ${CMAKE_CXX_FLAGS}")
endif()
file(GLOB TSIM_SW_SRC src/driver.cc)
list(APPEND TSIM_SW_SRC $ENV{VTA_HW_PATH}/src/vmem/virtual_memory.cc)
list(APPEND TSIM_SW_SRC $ENV{VTA_HW_PATH}/src/dpi/module.cc)
add_library(sw SHARED ${TSIM_SW_SRC})
target_include_directories(sw PRIVATE $ENV{VTA_HW_PATH}/include $ENV{VTA_HW_PATH}/src)
if(APPLE)
set_target_properties(sw PROPERTIES LINK_FLAGS "-undefined dynamic_lookup")
endif(APPLE)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
export PYTHONPATH:=$(abspath .)/python:$(PYTHONPATH)
BUILD_NAME = build
build_dir = $(abspath .)/$(BUILD_NAME)
default: chisel driver serial parallel
serial:
python3 tests/python/chisel_accel.py serial
parallel:
python3 tests/python/chisel_accel.py parallel
driver: | $(build_dir)
cd $(build_dir) && cmake .. && make
$(build_dir):
mkdir -p $@
chisel:
make -C hardware/chisel
clean:
-rm -rf $(build_dir)
make -C hardware/chisel clean
<!--- Licensed to the Apache Software Foundation (ASF) under one -->
<!--- or more contributor license agreements. See the NOTICE file -->
<!--- distributed with this work for additional information -->
<!--- regarding copyright ownership. The ASF licenses this file -->
<!--- to you under the Apache License, Version 2.0 (the -->
<!--- "License"); you may not use this file except in compliance -->
<!--- with the License. You may obtain a copy of the License at -->
<!--- http://www.apache.org/licenses/LICENSE-2.0 -->
<!--- Unless required by applicable law or agreed to in writing, -->
<!--- software distributed under the License is distributed on an -->
<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
<!--- KIND, either express or implied. See the License for the -->
<!--- specific language governing permissions and limitations -->
<!--- under the License. -->
VTA TSIM Application
======================
Prior to this application, please take a look at `<vta-hw-root>/apps/tsim_example` for installation
This is an application that performs Bit Serial Multiplication for GEMM utilizing TSIM.
**Bit Serial Multiplication for GEMM:**
General Matrix Multiplications (GEMM), are mostly calculated by repeatly calculating the dot product for each pair of vectors.
The dot product is calculated by summing every product of the vector pair.
We approach this operation with slicing and shifting, like how basic multiplication works, each vector elements before we accumulate them.
We can sufficiently reduce the cycles required to perform a gemm given that the data bit width is small. This GEMM application uses TSIM for future accerlerator prototypes.
* Test Chisel3 backend with bit serial GEMM
* Go to `<vta-hw-root>/apps/gemm`
* Run `make`
* If you have already compiled chisel backend (i.e. ran `make`)
* Bit Serial test with another input set, run `make serial`
* Bit parallel test with another input set, run `make parallel`
* Some steps for creating your own custom TSIM application
* Go to `<vta-hw-root>/apps/gemm`
* Create custom circuit within `./hardware/chisel/src/scala.main/accel/Compute.scala`
* Map the according Registers in `./hardware/chisel/src/scala.main/accel/RegFile.scala`
* Create your test script
* Map the registers in `./src/driver.cc` and link it with both `RegFile.scala` and the test script
* Understanding of `<vta-hw-root>/apps/tsim_example`, which performs add by one to a vector, is highly encouraged to create a more complex application
* Some pointers
* Chisel3 tests in `<vta-hw-root>/apps/gemm/tests/python`
* Chisel3 accelerator backend `<vta-hw-root>/apps/gemm/hardware/chisel`
* Software C++ driver (backend) that handles the accelerator `<vta-hw-root>/apps/gemm/src/driver.cc`
* Software Python driver (frontend) that handles the accelerator `<vta-hw-root>/apps/gemm/python/accel`
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
ifeq (, $(shell which verilator))
$(error "No Verilator in $(PATH), consider doing apt-get install verilator")
endif
# Change VERILATOR_INC_DIR if Verilator is installed on a different location
ifeq (, $(VERILATOR_INC_DIR))
ifeq (, $(wildcard /usr/local/share/verilator/include/*))
ifeq (, $(wildcard /usr/share/verilator/include/*))
$(error "Verilator include directory is not set properly")
else
VERILATOR_INC_DIR := /usr/share/verilator/include
endif
else
VERILATOR_INC_DIR := /usr/local/share/verilator/include
endif
endif
TOP = TestAccel
BUILD_NAME = build
USE_TRACE = 1
LIBNAME = libhw
vta_dir = $(abspath ../../../../)
tvm_dir = $(abspath ../../../../../../)
build_dir = $(abspath .)/$(BUILD_NAME)
verilator_build_dir = $(build_dir)/verilator
chisel_build_dir = $(build_dir)/chisel
verilator_opt = --cc
verilator_opt += +define+RANDOMIZE_GARBAGE_ASSIGN
verilator_opt += +define+RANDOMIZE_REG_INIT
verilator_opt += +define+RANDOMIZE_MEM_INIT
verilator_opt += --x-assign unique
verilator_opt += --output-split 20000
verilator_opt += --output-split-cfuncs 20000
verilator_opt += --top-module ${TOP}
verilator_opt += -Mdir ${verilator_build_dir}
verilator_opt += -I$(chisel_build_dir)
cxx_flags = -O2 -Wall -fPIC -shared
cxx_flags += -fvisibility=hidden -std=c++11
cxx_flags += -DVL_TSIM_NAME=V$(TOP)
cxx_flags += -DVL_PRINTF=printf
cxx_flags += -DVL_USER_FINISH
cxx_flags += -DVM_COVERAGE=0
cxx_flags += -DVM_SC=0
cxx_flags += -Wno-sign-compare
cxx_flags += -include V$(TOP).h
cxx_flags += -I$(verilator_build_dir)
cxx_flags += -I$(VERILATOR_INC_DIR)
cxx_flags += -I$(VERILATOR_INC_DIR)/vltstd
cxx_flags += -I$(vta_dir)/include
cxx_flags += -I$(tvm_dir)/include
cxx_flags += -I$(tvm_dir)/3rdparty/dlpack/include
cxx_files = $(VERILATOR_INC_DIR)/verilated.cpp
cxx_files += $(VERILATOR_INC_DIR)/verilated_dpi.cpp
cxx_files += $(wildcard $(verilator_build_dir)/*.cpp)
cxx_files += $(vta_dir)/hardware/dpi/tsim_device.cc
ifneq ($(USE_TRACE), 0)
verilator_opt += --trace
cxx_flags += -DVM_TRACE=1
cxx_flags += -DTSIM_TRACE_FILE=$(verilator_build_dir)/$(TOP).vcd
cxx_files += $(VERILATOR_INC_DIR)/verilated_vcd_c.cpp
else
cxx_flags += -DVM_TRACE=0
endif
# The following is to be consistent with cmake
ifeq ($(shell uname), Darwin)
lib_path = $(build_dir)/$(LIBNAME).dylib
else
lib_path = $(build_dir)/$(LIBNAME).so
endif
default: lib
lib: $(lib_path)
$(lib_path): $(verilator_build_dir)/V$(TOP).cpp
g++ $(cxx_flags) $(cxx_files) -o $@
verilator: $(verilator_build_dir)/V$(TOP).cpp
$(verilator_build_dir)/V$(TOP).cpp: $(chisel_build_dir)/$(TOP).v
verilator $(verilator_opt) $<
verilog: $(chisel_build_dir)/$(TOP).v
$(chisel_build_dir)/$(TOP).v: install_vta_package
sbt 'test:runMain test.Elaborate --target-dir $(chisel_build_dir) --top-name $(TOP)'
install_vta_package:
cd $(vta_dir)/hardware/chisel && sbt publishLocal
clean:
-rm -rf $(build_dir) target project/target project/project
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
name := "accel"
version := "0.1.0-SNAPSHOT"
organization := "edu.washington.cs"
def scalacOptionsVersion(scalaVersion: String): Seq[String] = {
Seq() ++ {
// If we're building with Scala > 2.11, enable the compile option
// switch to support our anonymous Bundle definitions:
// https://github.com/scala/bug/issues/10047
CrossVersion.partialVersion(scalaVersion) match {
case Some((2, scalaMajor: Long)) if scalaMajor < 12 => Seq()
case _ => Seq(
"-Xsource:2.11",
"-language:reflectiveCalls",
"-language:implicitConversions",
"-deprecation",
"-Xlint",
"-Ywarn-unused",
)
}
}
}
def javacOptionsVersion(scalaVersion: String): Seq[String] = {
Seq() ++ {
// Scala 2.12 requires Java 8. We continue to generate
// Java 7 compatible code for Scala 2.11
// for compatibility with old clients.
CrossVersion.partialVersion(scalaVersion) match {
case Some((2, scalaMajor: Long)) if scalaMajor < 12 =>
Seq("-source", "1.7", "-target", "1.7")
case _ =>
Seq("-source", "1.8", "-target", "1.8")
}
}
}
scalaVersion := "2.11.12"
resolvers ++= Seq(
Resolver.sonatypeRepo("snapshots"),
Resolver.sonatypeRepo("releases"))
libraryDependencies ++= Seq(
"edu.berkeley.cs" %% "chisel3" % "3.1.7",
"edu.washington.cs" %% "vta" % "0.1.0-SNAPSHOT",
)
scalacOptions ++= scalacOptionsVersion(scalaVersion.value)
javacOptions ++= javacOptionsVersion(scalaVersion.value)
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
sbt.version = 1.3.2
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
logLevel := Level.Warn
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package accel
import chisel3._
import vta.dpi._
/** Add-by-one accelerator.
*
* ___________ ___________
* | | | |
* | HostDPI | <--> | RegFile | <->|
* |_________| |_________| |
* |
* ___________ ___________ |
* | | | | |
* | MemDPI | <--> | Compute | <->|
* |_________| |_________|
*
*/
case class AccelConfig() {
val nCtrl = 1
val nECnt = 1
val nVals = 4
val nPtrs = 3
val regBits = 32
val ptrBits = 2*regBits
}
class Accel extends Module {
val io = IO(new Bundle {
val host = new VTAHostDPIClient
val mem = new VTAMemDPIMaster
})
implicit val config = AccelConfig()
val rf = Module(new RegFile)
val ce = Module(new Compute)
rf.io.host <> io.host
io.mem <> ce.io.mem
ce.io.launch := rf.io.launch
rf.io.finish := ce.io.finish
rf.io.ecnt <> ce.io.ecnt
ce.io.vals <> rf.io.vals
ce.io.ptrs <> rf.io.ptrs
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package accel
import chisel3._
import chisel3.util._
import vta.dpi._
import vta.core._
import vta.util.config._
import vta.shell._
class TestConfig extends Config(new CoreConfig ++ new PynqConfig)
/** Compute
*
* Bit Slice GEMM:
*
* 1. Wait for launch to be asserted
* 2. Issue 1 read request for 8-bit value at inp1_baddr address (read matrix)
* 3. Wait for the value
* 4. Increment read-address for next value
* 5. Repeat until all inp1 data have been read
* 6. Issue 1 read request for 8-bit value at inp2_baddr address (read vector)
* 7. Wait for the value
* 8. Increment read-address for next value
* 9. Repeat until all inp2 data have been read
* 10. Wait for output to be calculated
* 11. Issue a write request for 8-byte value at out_baddr address
* 12. Increment write-address for next value to write
* 13. Check if counter (cntout) is equal to length to asser finish,
otherwise go to step 11
*/
class Compute(implicit config: AccelConfig) extends Module {
val io = IO(new Bundle {
val launch = Input(Bool())
val finish = Output(Bool())
val ecnt = Vec(config.nECnt, ValidIO(UInt(config.regBits.W)))
val vals = Input(Vec(config.nVals, UInt(config.regBits.W)))
val ptrs = Input(Vec(config.nPtrs, UInt(config.ptrBits.W)))
val mem = new VTAMemDPIMaster
})
implicit val p: Parameters = new TestConfig
val sIdle :: sReadAReq :: sReadAData :: sReadADone ::sReadBReq :: sReadBData :: sReadBDone :: sInpDone ::sWait:: sWriteReq :: sWriteData :: sWriteDone :: Nil = Enum(12)
val state = RegInit(sIdle)
val shift = io.vals(0)
val length = io.vals(1)
val rstAccum = io.vals(2)
val startDot = io.vals(3)
val cycles = RegInit(0.U(config.regBits.W))
val mvc = Module(new MatrixVectorMultiplication)
val reg1 = Reg(chiselTypeOf(mvc.io.wgt.data.bits))
val reg2 = Reg(chiselTypeOf(mvc.io.inp.data.bits))
val cntwgt = Reg(UInt(config.regBits.W))
val cntinp = Reg(UInt(config.regBits.W))
val cntout = Reg(UInt(config.regBits.W))
val raddr1 = Reg(UInt(config.ptrBits.W))
val raddr2 = Reg(UInt(config.ptrBits.W))
val waddr = Reg(UInt(config.ptrBits.W))
val accum = Module(new Accmulator(size = p(CoreKey).blockOut, accBits = p(CoreKey).accBits))
switch (state) {
is (sIdle) {
when (io.launch) {
state := sReadAReq
}
}
// Read
is (sReadAReq) {
state := sReadAData
}
is (sReadAData) {
when (io.mem.rd.valid) {
state := sReadADone
}
}
is (sReadADone) {
when (cntwgt === (length * length) - 1.U) {
state := sReadBReq
} .otherwise {
state := sReadAReq
}
}
is (sReadBReq) {
state := sReadBData
}
is (sReadBData) {
when (io.mem.rd.valid) {
state := sReadBDone
}
}
is (sReadBDone) {
when (cntinp === length-1.U) {
state := sInpDone
} .otherwise {
state := sReadBReq
}
}
// Both input is processed
is (sInpDone) {
state := sWait
}
// Wait for computation
is (sWait) {
when (accum.io.ready) {
state := sWriteReq
}
}
// Write
is (sWriteReq) {
state := sWriteData
}
is (sWriteData) {
state := sWriteDone
}
is (sWriteDone) {
when (cntout === (length - 1.U)) {
state := sIdle
} .otherwise {
state := sWriteReq
}
}
}
val last = state === sWriteDone && cntout === (length - 1.U)
// cycle counter
when (state === sIdle) {
cycles := 0.U
} .otherwise {
cycles := cycles + 1.U
}
io.ecnt(0).valid := last
io.ecnt(0).bits := cycles
// calculate next address
when (state === sIdle) {
raddr1 := io.ptrs(0)
raddr2 := io.ptrs(1)
waddr := io.ptrs(2)
} .elsewhen (state === sReadADone) { // increment input array by 1-byte
raddr1 := raddr1 + 1.U
} .elsewhen (state === sReadBDone) { // increment input array by 1-byte
raddr2 := raddr2 + 1.U
} .elsewhen (state === sWriteDone) {
waddr := waddr + 4.U // writing 4 bytes
}
// create request
io.mem.req.valid := state === sReadAReq | state === sReadBReq | state === sWriteReq
io.mem.req.opcode := state === sWriteReq
io.mem.req.len := 0.U // one-word-per-request
io.mem.req.addr := Mux(state === sReadAReq | state === sReadBReq, Mux(state === sReadAReq, raddr1, raddr2), waddr)
// read
when (state === sReadAData && io.mem.rd.valid) {
reg1(cntwgt/length)(cntwgt%length) := io.mem.rd.bits(7, 0)
}
when (state === sReadBData && io.mem.rd.valid) {
reg2(0)(cntinp) := io.mem.rd.bits(7, 0)
}
io.mem.rd.ready := state === sReadAData | state === sReadBData
mvc.io.inp.data.valid := state === sInpDone // 2 inputs have been processed
mvc.io.wgt.data.valid := state === sInpDone // 2 inputs have been processed
mvc.io.wgt.data.bits <> reg1
mvc.io.inp.data.bits <> reg2
// Modify when shift operation is supported
mvc.io.reset := false.B
mvc.io.acc_i.data.valid := true.B
for (i <- 0 until p(CoreKey).blockOut) {
mvc.io.acc_i.data.bits(0)(i) := 0.U
}
accum.io.in := mvc.io.acc_o.data.bits
accum.io.shift := shift
accum.io.clear := rstAccum
accum.io.valid := mvc.io.acc_o.data.valid
// write
io.mem.wr.valid := state === sWriteData
io.mem.wr.bits := accum.io.sum(cntout)
// count read/write
when (state === sIdle) {
cntwgt := 0.U
cntinp := 0.U
cntout := 0.U
} .elsewhen (state === sReadADone) {
cntwgt := cntwgt + 1.U
} .elsewhen (state === sReadBDone) {
cntinp := cntinp + 1.U
} .elsewhen (state === sWriteDone) {
cntout := cntout + 1.U
}
io.finish := last // data has been added
}
// Shift operation until supported in MVM
class Accmulator(size: Int = 16, accBits: Int = 32) extends Module {
val io = IO(new Bundle {
val clear = Input(Bool())
val valid = Input(Bool())
val ready = Output(Bool())
val in = Input(Vec(1, Vec(size, (UInt(accBits.W)))))
val shift = Input(UInt(8.W))
val sum = Output(Vec(size, (UInt(accBits.W))))
})
val reg = RegInit(VecInit(Seq.fill(size)(0.U(accBits.W))))
for (i <- 0 until size) {
when (io.clear) {
reg(i) := 0.U
} .elsewhen(io.valid) {
reg(i) := reg(i) + (io.in(0)(i) << io.shift)
}
}
io.ready := RegNext(io.valid)
io.sum := reg
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package accel
import chisel3._
import chisel3.util._
import vta.dpi._
/** Register File.
*
* Six 32-bit register file.
*
* -------------------------------
* Register description | addr
* -------------------------|-----
* Control status register | 0x00
* Cycle counter | 0x04
* Shift value | 0x08
* Vector length | 0x0c
* Reset Accumulator | 0x10
* Input1 pointer | 0x18
* Input2 pointer | 0x20
* Output pointer | 0x28
* -------------------------------
* ------------------------------
* Control status register | bit
* ------------------------------
* Launch | 0
* Finish | 1
* ------------------------------
*/
class RegFile(implicit config: AccelConfig) extends Module {
val io = IO(new Bundle {
val launch = Output(Bool())
val finish = Input(Bool())
val ecnt = Vec(config.nECnt, Flipped(ValidIO(UInt(config.regBits.W))))
val vals = Output(Vec(config.nVals, UInt(config.regBits.W)))
val ptrs = Output(Vec(config.nPtrs, UInt(config.ptrBits.W)))
val host = new VTAHostDPIClient
})
val sIdle :: sRead :: Nil = Enum(2)
val state = RegInit(sIdle)
switch (state) {
is (sIdle) {
when (io.host.req.valid && !io.host.req.opcode) {
state := sRead
}
}
is (sRead) {
state := sIdle
}
}
io.host.req.deq := state === sIdle & io.host.req.valid
val nTotal = config.nCtrl + config.nECnt + config.nVals + (2*config.nPtrs)
val reg = Seq.fill(nTotal)(RegInit(0.U.asTypeOf(chiselTypeOf(io.host.req.value))))
val addr = Seq.tabulate(nTotal)(_ * 4)
val reg_map = (addr zip reg) map { case (a, r) => a.U -> r }
val eo = config.nCtrl
val vo = eo + config.nECnt
val po = vo + config.nVals
when (io.finish) {
reg(0) := "b_10".U
} .elsewhen (state === sIdle && io.host.req.valid &&
io.host.req.opcode && addr(0).U === io.host.req.addr) {
reg(0) := io.host.req.value
}
for (i <- 0 until config.nECnt) {
when (io.ecnt(i).valid) {
reg(eo + i) := io.ecnt(i).bits
} .elsewhen (state === sIdle && io.host.req.valid &&
io.host.req.opcode && addr(eo + i).U === io.host.req.addr) {
reg(eo + i) := io.host.req.value
}
}
for (i <- 0 until (config.nVals + (2*config.nPtrs))) {
when (state === sIdle && io.host.req.valid &&
io.host.req.opcode && addr(vo + i).U === io.host.req.addr) {
reg(vo + i) := io.host.req.value
}
}
val rdata = RegInit(0.U.asTypeOf(chiselTypeOf(io.host.req.value)))
when (state === sIdle && io.host.req.valid && !io.host.req.opcode) {
rdata := MuxLookup(io.host.req.addr, 0.U, reg_map)
}
io.host.resp.valid := state === sRead
io.host.resp.bits := rdata
io.launch := reg(0)(0)
for (i <- 0 until config.nVals) {
io.vals(i) := reg(vo + i)
}
for (i <- 0 until config.nPtrs) {
io.ptrs(i) := Cat(reg(po + 2*i + 1), reg(po + 2*i))
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package test
import chisel3._
import chisel3.experimental.MultiIOModule
import vta.dpi._
import accel._
/** VTA simulation shell.
*
* Instantiate Host and Memory DPI modules.
*
*/
class VTASimShell extends MultiIOModule {
val host = IO(new VTAHostDPIMaster)
val mem = IO(new VTAMemDPIClient)
val sim_clock = IO(Input(Clock()))
val sim_wait = IO(Output(Bool()))
val mod_sim = Module(new VTASimDPI)
val mod_host = Module(new VTAHostDPI)
val mod_mem = Module(new VTAMemDPI)
mod_mem.io.clock := clock
mod_mem.io.reset := reset
mod_mem.io.dpi <> mem
mod_host.io.clock := clock
mod_host.io.reset := reset
host <> mod_host.io.dpi
mod_sim.io.clock := sim_clock
mod_sim.io.reset := reset
sim_wait := mod_sim.io.dpi_wait
}
/** Test accelerator.
*
* Instantiate and connect the simulation-shell and the accelerator.
*
*/
class TestAccel extends MultiIOModule {
val sim_clock = IO(Input(Clock()))
val sim_wait = IO(Output(Bool()))
val sim_shell = Module(new VTASimShell)
val vta_accel = Module(new Accel)
sim_shell.sim_clock := sim_clock
sim_wait := sim_shell.sim_wait
sim_shell.mem <> vta_accel.io.mem
vta_accel.io.host <> sim_shell.host
}
/** Generate TestAccel as top module */
object Elaborate extends App {
chisel3.Driver.execute(args, () => new TestAccel)
}
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from . import tsim
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import te
import ctypes
import os.path as osp
from sys import platform
def get_ext():
"""Return shared library extension"""
return ".dylib" if platform == "darwin" else ".so"
def load_dll(dll):
"""Load shared library
Parameters
------------
dll : str
Path for shared library
Returns
------------
The shared library
"""
try:
return [ctypes.CDLL(dll, ctypes.RTLD_GLOBAL)]
except OSError:
return []
def load_sw():
"""Load all software shared libraries"""
cur_path = osp.dirname(osp.abspath(osp.expanduser(__file__)))
sw_libname = "libsw" + get_ext()
sw_lib = osp.join(cur_path, "..", "build", sw_libname)
load_dll(sw_lib)
def init(hw_backend):
"""Init hardware and software shared library for accelerator
Parameters
------------
hw_backend : str
Hardware backend can be verilog or chisel
"""
cur_path = osp.dirname(osp.abspath(osp.expanduser(__file__)))
hw_libname = "libhw" + get_ext()
if hw_backend in ("verilog", "chisel"):
hw_lib = osp.join(cur_path, "..", "hardware", hw_backend, "build", hw_libname)
load_sw()
m = tvm.runtime.load_module(hw_lib, "vta-tsim")
f = tvm.get_global_func("tvm.vta.tsim.init")
f(m)
def load_module():
"""Return driver function"""
load_sw()
return tvm.get_global_func("tvm.vta.driver")
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>
#include <vta/dpi/module.h>
#include "vmem/virtual_memory.h"
namespace vta {
namespace driver {
using vta::dpi::DPIModuleNode;
using tvm::runtime::Module;
class DPILoader {
public:
~DPILoader() {
dpi_->SimResume();
dpi_->SimFinish();
}
void Init(Module module) {
mod_ = module;
dpi_ = this->Get();
dpi_->SimLaunch();
dpi_->SimWait();
}
DPIModuleNode* Get() {
return static_cast<DPIModuleNode*>(mod_.operator->());
}
static DPILoader* Global() {
static DPILoader inst;
return &inst;
}
// TVM module
Module mod_;
// DPI Module
DPIModuleNode* dpi_{nullptr};
};
class Device {
public:
Device() {
loader_ = DPILoader::Global();
}
uint32_t Run(DLTensor* inp1, DLTensor* inp2, uint32_t shiftVal, DLTensor* out, uint32_t reset) {
uint32_t cycles;
uint32_t length = inp2->shape[0];
// 1 matrix 1 vector input
size_t size1 = (inp1->dtype.bits >> 3) * length * length;
size_t size2 = (inp2->dtype.bits >> 3) * length;
// 1 vector output
size_t size3 = (32 >> 3) * length;
inp1_ = this->MemAlloc(size1);
inp2_ = this->MemAlloc(size2);
out_ = this->MemAlloc(size3);
this->MemCopyFromHost(inp1_, inp1->data, size1);
this->MemCopyFromHost(inp2_, inp2->data, size2);
this->Init();
this->Launch(length, shiftVal, reset);
cycles = this->WaitForCompletion();
this->MemCopyToHost(out->data, out_, size3);
this->MemFree(inp1_);
this->MemFree(inp2_);
this->MemFree(out_);
return cycles;
}
private:
void Init() {
dpi_ = loader_->Get();
dpi_->SimResume();
}
void* MemAlloc(size_t size) {
void * addr = vta::vmem::VirtualMemoryManager::Global()->Alloc(size);
return reinterpret_cast<void*>(vta::vmem::VirtualMemoryManager::Global()->GetPhyAddr(addr));
}
void MemFree(void* buf) {
void * addr = vta::vmem::VirtualMemoryManager::Global()->GetAddr(reinterpret_cast<uint64_t>(buf));
vta::vmem::VirtualMemoryManager::Global()->Free(addr);
}
vta_phy_addr_t MemGetPhyAddr(void* buf) {
return reinterpret_cast<uint64_t>(reinterpret_cast<uint64_t*>(buf));
}
void MemCopyFromHost(void* dst, const void* src, size_t size) {
vta::vmem::VirtualMemoryManager::Global()->MemCopyFromHost(dst, src, size);
}
void MemCopyToHost(void* dst, const void* src, size_t size) {
vta::vmem::VirtualMemoryManager::Global()->MemCopyToHost(dst, src, size);
}
void Launch(uint32_t length, uint32_t shiftVal, uint32_t reset) {
dpi_->WriteReg(0x08, shiftVal);
dpi_->WriteReg(0x0c, length); // tensor size
dpi_->WriteReg(0x18, this->MemGetPhyAddr(inp1_));
dpi_->WriteReg(0x20, this->MemGetPhyAddr(inp2_));
dpi_->WriteReg(0x28, this->MemGetPhyAddr(out_));
dpi_->WriteReg(0x00, 0x1); // launch
dpi_->WriteReg(0x00, 0x0);
if (reset == 1) {
dpi_->WriteReg(0x10, 0x1); // reset accumulator
dpi_->WriteReg(0x10, 0x0);
}
}
uint32_t WaitForCompletion() {
uint32_t i, val;
for (i = 0; i < wait_cycles_; i++) {
val = dpi_->ReadReg(0x00);
if (val == 2) break; // finish
}
val = dpi_->ReadReg(0x04);
dpi_->SimWait();
return val;
}
// wait cycles
uint32_t wait_cycles_{100000000};
// DPI loader
DPILoader* loader_{nullptr};
// DPI Module
DPIModuleNode* dpi_{nullptr};
// input vm ptr
void* inp1_{nullptr};
void* inp2_{nullptr};
// output vm ptr
void* out_{nullptr};
};
using tvm::runtime::TVMRetValue;
using tvm::runtime::TVMArgs;
TVM_REGISTER_GLOBAL("tvm.vta.tsim.init")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Module m = args[0];
DPILoader::Global()->Init(m);
});
TVM_REGISTER_GLOBAL("tvm.vta.driver")
.set_body([](TVMArgs args, TVMRetValue* rv) {
DLTensor* A = args[0];
DLTensor* B = args[1];
DLTensor* C = args[3];
Device dev_;
uint32_t cycles = dev_.Run(A, B, static_cast<int>(args[2]), C, static_cast<int>(args[4]));
*rv = static_cast<int>(cycles);
});
} // namespace driver
} // namespace vta
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import te
import numpy as np
import tsim
import sys
""" Vector Bit Slice and Pack Function
Parameters
----------
A : Vector to be sliced and packed
slice_width : slice width
Returns
---------
C: 2d matrix where each cloumn (because of bit packing) represents each bit slice of A
"""
def slice(A, slice_width):
assert np.log2(slice_width) % 1 == 0, "only power of 2 is supported"
dtype = type(A[0])
row = 0
# currently only supports uint
if dtype is np.uint8: row = 8 // slice_width
elif dtype is np.uint16: row = 16 // slice_width
elif dtype is np.uint32: row = 32 // slice_width
elif dtype is np.uint64: row = 64 // slice_width
else: raise ValueError("datatype currently not supported")
if (row >= 8):
dtype = 'uint' + str(row)
else:
dtype = 'uint8'
C = np.zeros((row, len(A))).astype(dtype) # sliced and transform
# create mask
slice_mask = 2**(slice_width)-1
# slice and pack
for x in range(len(A)):
for y in range(row):
C[y][x] = (np.uint64(A[x]) >> np.uint64(slice_width * y)) & np.uint64(slice_mask)
return C
def slice_mat(A, slice_width):
assert np.log2(slice_width) % 1 == 0, "only power of 2 is supported"
dtype = type(A[0][0])
row = 0
# currently only supports uint
if dtype is np.uint8: row = 8 // slice_width
elif dtype is np.uint16: row = 16 // slice_width
elif dtype is np.uint32: row = 32 // slice_width
elif dtype is np.uint64: row = 64 // slice_width
else: raise ValueError("datatype currently not supported")
if (row >= 8):
dtype = 'uint' + str(row)
else:
dtype = 'uint8'
# 3d array (bits, row, clmn)
C = np.zeros((row, A.shape[0], A.shape[1])).astype(dtype) # sliced and transform
# create mask
slice_mask = 2**(slice_width)-1
# slice and pack
for z in range(A.shape[0]):
C[:, z, :] = slice(A[z], slice_width)
return C
""" Matrix Multiplication Function
Parameters
----------
A : Matrix A
B: Matrix B
i_width : weight slice width
w_width : activation slice width
Returns
---------
C: result of A * B
"""
# A is a n*m matrix, B is a m*p matrix(not transposed yet)
def matrix_multiply(A, B, i_width, w_width):
assert A.shape[1] == B.shape[0], "can't perform multiplication"
BT = B.transpose()
cycles = 0
B_sliced = slice_mat(BT, w_width)
C = np.zeros((A.shape[0], B.shape[1])).astype('uint64')
for i in range(A.shape[0]):
A_sliced = slice(A[i], i_width)
test = test_accel(A_sliced, B_sliced, i_width, w_width)
C[i] = test[0]
cycles += test[1]
np.testing.assert_array_equal(C[i], compute(A_sliced, B_sliced, i_width, w_width))
print("PASS row " + str(i))
np.testing.assert_array_equal(C, np.matmul(A.astype('uint64'),B))
print("result: ")
print(C)
print("TEST PASSED, cycles: " + str(cycles))
return C
""" Software Verification Function
Parameter Dimesions
---------
A (bits, y) and B (bits, y, x) (transposed)
Takes 1 vector and 1 matrix input (sliced and packed)
Returns
---------
Resulting vector
"""
def compute(A, B, i_width, w_width):
assert A.shape[1] == B.shape[1], "sliced shape not match"
# reset hardware accumulator
accum = np.zeros(A.shape[1])
for x in range(A.shape[0]):
for y in range(B.shape[0]):
accum += np.matmul(A[x].astype('uint64'), B[y].transpose()) << np.uint64(x*i_width + y*w_width)
# get value from accumulator
return accum
"""Testing Function for Matrix Vector Multiplication"""
def test_accel(A, B, i_width, w_width):
assert A.shape[1] == B.shape[2], "sliced shape not match"
dtype = A.dtype
ctx = tvm.cpu(0)
f = tsim.load_module()
a_arr = []
b_arr = []
for i in range(A.shape[0]):
list_a = np.zeros(A.shape[1]).astype(dtype)
for j in range(A.shape[1]):
list_a[j] = A[i][j]
a_arr.append(tvm.nd.array(list_a.astype(dtype), ctx))
for i in range(B.shape[0]):
# transpose
list_b = np.zeros((B.shape[2], B.shape[1])).astype(dtype)
for j in range(B.shape[2]):
for k in range(B.shape[1]):
list_b[j][k] = B[i][j][k]
b_arr.append(tvm.nd.array(list_b.astype(dtype), ctx))
cycles = 0
accum = tvm.nd.array(np.zeros(A.shape[1]).astype("uint32"), ctx)
for i in range(len(a_arr)):
for j in range(len(b_arr)):
shift = np.uint8(i*i_width + j*w_width)
if i == 0 and j == 0:
cycles += f(b_arr[j], a_arr[i], shift, accum, np.uint32(1)) # reset accumulator
else:
cycles += f(b_arr[j], a_arr[i], shift, accum, np.uint32(0)) # no reset
return (accum.asnumpy(), cycles)
""" Matrix Generator
Parameters
----------
dtype : String, datatype generated (supports only uint)
i_width : weight bit slices(needs to be less than actual bit width)
w_width : activation bit slices(needs to be less than actual bit width)
"""
def top_test(dtype, i_width, w_width):
# only supports positive values (up to 2**(bits-1))
rmax = 127
# (m,16) * (16,16) GEMM
rrow = np.random.randint(7) + 1
clmn = 16
A = np.random.randint(rmax, size=(rrow,clmn)).astype(dtype)
B = np.random.randint(rmax, size=(clmn,clmn)).astype(dtype)
print("A: " + str(A))
print("B: " + str(B))
# perform GEMM
matrix_multiply(A, B, i_width, w_width)
if __name__ == "__main__":
tsim.init("chisel")
for i in range(1):
# reg1 and reg2 bits in hardware/chisel/src/main/Compute.scala must be modified for slices greater than 8 bits
if sys.argv[1] == 'serial':
# generates a random uint8 GEMM with 2-bit(8/4) input and 4-bit(8/2) weight
top_test("uint8", 4, 2)
elif sys.argv[1] == 'parallel':
# generates a random uint8 GEMM with 8-bit input and 8-bit weight (bit parallel)
top_test('uint8', 8, 8)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
cmake_minimum_required(VERSION 3.2)
project(tsim C CXX)
if(NOT DEFINED ENV{TVM_PATH})
message(ERROR "Make sure to set TVM_PATH in your environment")
endif()
if(NOT DEFINED ENV{VTA_HW_PATH})
message(ERROR "Make sure to set VTA_HW_PATH in your environment")
endif()
include_directories("$ENV{TVM_PATH}/include")
include_directories("$ENV{TVM_PATH}/3rdparty/dlpack/include")
include_directories("$ENV{TVM_PATH}/3rdparty/dmlc-core/include")
include_directories("$ENV{VTA_HW_PATH}/src/dpi")
set(CMAKE_C_FLAGS "-O2 -Wall -fPIC -fvisibility=hidden")
set(CMAKE_CXX_FLAGS "-O2 -Wall -fPIC -fvisibility=hidden -std=c++11")
if (CMAKE_CXX_COMPILER_ID MATCHES "GNU" AND
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0)
set(CMAKE_CXX_FLAGS "-faligned-new ${CMAKE_CXX_FLAGS}")
endif()
file(GLOB TSIM_SW_SRC src/driver.cc)
list(APPEND TSIM_SW_SRC $ENV{VTA_HW_PATH}/src/vmem/virtual_memory.cc)
list(APPEND TSIM_SW_SRC $ENV{VTA_HW_PATH}/src/dpi/module.cc)
add_library(sw SHARED ${TSIM_SW_SRC})
target_include_directories(sw PRIVATE $ENV{VTA_HW_PATH}/include $ENV{VTA_HW_PATH}/src)
if(APPLE)
set_target_properties(sw PROPERTIES LINK_FLAGS "-undefined dynamic_lookup")
endif(APPLE)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
export PYTHONPATH:=$(PWD)/python:$(PYTHONPATH)
BUILD_NAME = build
build_dir = $(abspath .)/$(BUILD_NAME)
default: run_verilog
run_verilog: verilog driver
python3 tests/python/verilog_accel.py
run_chisel: chisel driver
python3 tests/python/chisel_accel.py
driver: | $(build_dir)
cd $(build_dir) && cmake .. && make
$(build_dir):
mkdir -p $@
verilog:
make -C hardware/verilog
chisel:
make -C hardware/chisel
clean:
-rm -rf $(build_dir)
make -C hardware/chisel clean
make -C hardware/verilog clean
<!--- Licensed to the Apache Software Foundation (ASF) under one -->
<!--- or more contributor license agreements. See the NOTICE file -->
<!--- distributed with this work for additional information -->
<!--- regarding copyright ownership. The ASF licenses this file -->
<!--- to you under the Apache License, Version 2.0 (the -->
<!--- "License"); you may not use this file except in compliance -->
<!--- with the License. You may obtain a copy of the License at -->
<!--- http://www.apache.org/licenses/LICENSE-2.0 -->
<!--- Unless required by applicable law or agreed to in writing, -->
<!--- software distributed under the License is distributed on an -->
<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
<!--- KIND, either express or implied. See the License for the -->
<!--- specific language governing permissions and limitations -->
<!--- under the License. -->
VTA TSIM Installation
======================
*TSIM* is a cycle-accurate hardware simulation environment that can be invoked and managed directly from TVM. It aims to enable cycle accurate simulation of deep learning accelerators including VTA.
This simulation environment can be used in both OSX and Linux.
There are two dependencies required to make *TSIM* works: [Verilator](https://www.veripool.org/wiki/verilator) and [sbt](https://www.scala-sbt.org/) for accelerators designed in [Chisel3](https://github.com/freechipsproject/chisel3).
## OSX Dependencies
Install `sbt` and `verilator` using [Homebrew](https://brew.sh/).
```bash
brew install verilator sbt
```
## Linux Dependencies
Add `sbt` to package manager (Ubuntu).
```bash
echo "deb https://dl.bintray.com/sbt/debian /" | sudo tee -a /etc/apt/sources.list.d/sbt.list
sudo apt-key adv --keyserver hkp://keyserver.ubuntu.com:80 --recv 2EE0EA64E40A89B84B2DF73499E82A75642AC823
sudo apt-get update
```
Install `sbt` and `verilator`.
```bash
sudo apt install verilator sbt
```
Verilator version check
```bash
verilator --version
```
the supported version of Verilator should be at least 4.012,
if homebrew (OSX) or package-manager (Linux) does not support that version,
please install Verilator 4.012 or later from binary or source base on following
instruction of Verilator wiki.
https://www.veripool.org/projects/verilator/wiki/Installing
## Setup in TVM
1. Install `verilator` and `sbt` as described above
2. Get tvm `git clone https://github.com/apache/incubator-tvm.git tvm --recursive`
3. Build [tvm](https://docs.tvm.ai/install/from_source.html#build-the-shared-library)
## How to run VTA TSIM examples
There are two sample VTA accelerators, add-a-constant, designed in Chisel3 and Verilog to show how *TSIM* works.
The default target language for these two implementations is Verilog. The following instructions show
how to run both of them:
* Test Verilog backend
* Go to `<vta-hw-root>/apps/tsim_example`
* Run `make`
* Test Chisel3 backend
* Go to `<vta-hw-root>/apps/tsim_example`
* Run `make run_chisel`
* Some pointers
* Verilog and Chisel3 tests in `<vta-hw-root>/apps/tsim_example/tests/python`
* Verilog accelerator backend `<vta-hw-root>/apps/tsim_example/hardware/verilog`
* Chisel3 accelerator backend `<vta-hw-root>/apps/tsim_example/hardware/chisel`
* Software C++ driver (backend) that handles the accelerator `<vta-hw-root>/apps/tsim_example/src/driver.cc`
* Software Python driver (frontend) that handles the accelerator `<vta-hw-root>/apps/tsim_example/python/accel`
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
ifeq (, $(shell which verilator))
$(error "No Verilator in $(PATH), consider doing apt-get install verilator")
endif
# Change VERILATOR_INC_DIR if Verilator is installed on a different location
ifeq (, $(VERILATOR_INC_DIR))
ifeq (, $(wildcard /usr/local/share/verilator/include/*))
ifeq (, $(wildcard /usr/share/verilator/include/*))
$(error "Verilator include directory is not set properly")
else
VERILATOR_INC_DIR := /usr/share/verilator/include
endif
else
VERILATOR_INC_DIR := /usr/local/share/verilator/include
endif
endif
TOP = TestAccel
BUILD_NAME = build
USE_TRACE = 0
LIBNAME = libhw
vta_dir = $(abspath ../../../../)
tvm_dir = $(abspath ../../../../../../)
build_dir = $(abspath .)/$(BUILD_NAME)
verilator_build_dir = $(build_dir)/verilator
chisel_build_dir = $(build_dir)/chisel
verilator_opt = --cc
verilator_opt += +define+RANDOMIZE_GARBAGE_ASSIGN
verilator_opt += +define+RANDOMIZE_REG_INIT
verilator_opt += +define+RANDOMIZE_MEM_INIT
verilator_opt += --x-assign unique
verilator_opt += --output-split 20000
verilator_opt += --output-split-cfuncs 20000
verilator_opt += --top-module ${TOP}
verilator_opt += -Mdir ${verilator_build_dir}
verilator_opt += -I$(chisel_build_dir)
cxx_flags = -O2 -Wall -fPIC -shared
cxx_flags += -fvisibility=hidden -std=c++11
cxx_flags += -DVL_TSIM_NAME=V$(TOP)
cxx_flags += -DVL_PRINTF=printf
cxx_flags += -DVL_USER_FINISH
cxx_flags += -DVM_COVERAGE=0
cxx_flags += -DVM_SC=0
cxx_flags += -Wno-sign-compare
cxx_flags += -include V$(TOP).h
cxx_flags += -I$(verilator_build_dir)
cxx_flags += -I$(VERILATOR_INC_DIR)
cxx_flags += -I$(VERILATOR_INC_DIR)/vltstd
cxx_flags += -I$(vta_dir)/include
cxx_flags += -I$(tvm_dir)/include
cxx_flags += -I$(tvm_dir)/3rdparty/dlpack/include
cxx_files = $(VERILATOR_INC_DIR)/verilated.cpp
cxx_files += $(VERILATOR_INC_DIR)/verilated_dpi.cpp
cxx_files += $(wildcard $(verilator_build_dir)/*.cpp)
cxx_files += $(vta_dir)/hardware/dpi/tsim_device.cc
ifneq ($(USE_TRACE), 0)
verilator_opt += --trace
cxx_flags += -DVM_TRACE=1
cxx_flags += -DTSIM_TRACE_FILE=$(verilator_build_dir)/$(TOP).vcd
cxx_files += $(VERILATOR_INC_DIR)/verilated_vcd_c.cpp
else
cxx_flags += -DVM_TRACE=0
endif
# The following is to be consistent with cmake
ifeq ($(shell uname), Darwin)
lib_path = $(build_dir)/$(LIBNAME).dylib
else
lib_path = $(build_dir)/$(LIBNAME).so
endif
default: lint lib
lint:
cp $(vta_dir)/hardware/chisel/scalastyle-config.xml .
sbt scalastyle
lib: $(lib_path)
$(lib_path): $(verilator_build_dir)/V$(TOP).cpp
g++ $(cxx_flags) $(cxx_files) -o $@
verilator: $(verilator_build_dir)/V$(TOP).cpp
$(verilator_build_dir)/V$(TOP).cpp: $(chisel_build_dir)/$(TOP).v
verilator $(verilator_opt) $<
verilog: $(chisel_build_dir)/$(TOP).v
$(chisel_build_dir)/$(TOP).v: install_vta_package
sbt 'test:runMain test.Elaborate --target-dir $(chisel_build_dir) --top-name $(TOP)'
install_vta_package:
cd $(vta_dir)/hardware/chisel && sbt publishLocal
clean:
-rm -rf $(build_dir) target project/target project/project
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
name := "accel"
version := "0.1.0-SNAPSHOT"
organization := "edu.washington.cs"
def scalacOptionsVersion(scalaVersion: String): Seq[String] = {
Seq() ++ {
// If we're building with Scala > 2.11, enable the compile option
// switch to support our anonymous Bundle definitions:
// https://github.com/scala/bug/issues/10047
CrossVersion.partialVersion(scalaVersion) match {
case Some((2, scalaMajor: Long)) if scalaMajor < 12 => Seq()
case _ => Seq(
"-Xsource:2.11",
"-language:reflectiveCalls",
"-language:implicitConversions",
"-deprecation",
"-Xlint",
"-Ywarn-unused",
)
}
}
}
def javacOptionsVersion(scalaVersion: String): Seq[String] = {
Seq() ++ {
// Scala 2.12 requires Java 8. We continue to generate
// Java 7 compatible code for Scala 2.11
// for compatibility with old clients.
CrossVersion.partialVersion(scalaVersion) match {
case Some((2, scalaMajor: Long)) if scalaMajor < 12 =>
Seq("-source", "1.7", "-target", "1.7")
case _ =>
Seq("-source", "1.8", "-target", "1.8")
}
}
}
scalaVersion := "2.11.12"
resolvers ++= Seq(
Resolver.sonatypeRepo("snapshots"),
Resolver.sonatypeRepo("releases"))
libraryDependencies ++= Seq(
"edu.berkeley.cs" %% "chisel3" % "3.1.7",
"edu.washington.cs" %% "vta" % "0.1.0-SNAPSHOT",
)
scalacOptions ++= scalacOptionsVersion(scalaVersion.value)
javacOptions ++= javacOptionsVersion(scalaVersion.value)
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
sbt.version = 1.3.2
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
logLevel := Level.Warn
addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "1.0.0")
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package accel
import chisel3._
import vta.dpi._
/** Add-by-one accelerator.
*
* ___________ ___________
* | | | |
* | HostDPI | <--> | RegFile | <->|
* |_________| |_________| |
* |
* ___________ ___________ |
* | | | | |
* | MemDPI | <--> | Compute | <->|
* |_________| |_________|
*
*/
case class AccelConfig() {
val nCtrl = 1
val nECnt = 1
val nVals = 2
val nPtrs = 2
val regBits = 32
val ptrBits = 2 * regBits
}
class Accel extends Module {
val io = IO(new Bundle {
val host = new VTAHostDPIClient
val mem = new VTAMemDPIMaster
})
implicit val config = AccelConfig()
val rf = Module(new RegFile)
val ce = Module(new Compute)
rf.io.host <> io.host
io.mem <> ce.io.mem
ce.io.launch := rf.io.launch
rf.io.finish := ce.io.finish
rf.io.ecnt <> ce.io.ecnt
ce.io.vals <> rf.io.vals
ce.io.ptrs <> rf.io.ptrs
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package accel
import chisel3._
import chisel3.util._
import vta.dpi._
/** Compute
*
* Add-by-one procedure:
*
* 1. Wait for launch to be asserted
* 2. Issue a read request for 8-byte value at inp_baddr address
* 3. Wait for the value
* 4. Issue a write request for 8-byte value at out_baddr address
* 5. Increment read-address and write-address for next value
* 6. Check if counter (cnt) is equal to length to assert finish,
* otherwise go to step 2.
*/
class Compute(implicit config: AccelConfig) extends Module {
val io = IO(new Bundle {
val launch = Input(Bool())
val finish = Output(Bool())
val ecnt = Vec(config.nECnt, ValidIO(UInt(config.regBits.W)))
val vals = Input(Vec(config.nVals, UInt(config.regBits.W)))
val ptrs = Input(Vec(config.nPtrs, UInt(config.ptrBits.W)))
val mem = new VTAMemDPIMaster
})
val sIdle :: sReadReq :: sReadData :: sWriteReq :: sWriteData :: Nil = Enum(5)
val state = RegInit(sIdle)
val const = io.vals(0)
val length = io.vals(1)
val cycles = RegInit(0.U(config.regBits.W))
val reg = Reg(chiselTypeOf(io.mem.rd.bits))
val cnt = Reg(UInt(config.regBits.W))
val raddr = Reg(UInt(config.ptrBits.W))
val waddr = Reg(UInt(config.ptrBits.W))
switch(state) {
is(sIdle) {
when(io.launch) {
state := sReadReq
}
}
is(sReadReq) {
state := sReadData
}
is(sReadData) {
when(io.mem.rd.valid) {
state := sWriteReq
}
}
is(sWriteReq) {
state := sWriteData
}
is(sWriteData) {
when(cnt === (length - 1.U)) {
state := sIdle
}.otherwise {
state := sReadReq
}
}
}
val last = state === sWriteData && cnt === (length - 1.U)
// cycle counter
when(state === sIdle) {
cycles := 0.U
}.otherwise {
cycles := cycles + 1.U
}
io.ecnt(0).valid := last
io.ecnt(0).bits := cycles
// calculate next address
when(state === sIdle) {
raddr := io.ptrs(0)
waddr := io.ptrs(1)
}.elsewhen(state === sWriteData) { // increment by 8-bytes
raddr := raddr + 8.U
waddr := waddr + 8.U
}
// create request
io.mem.req.valid := state === sReadReq | state === sWriteReq
io.mem.req.opcode := state === sWriteReq
io.mem.req.len := 0.U // one-word-per-request
io.mem.req.addr := Mux(state === sReadReq, raddr, waddr)
// read
when(state === sReadData && io.mem.rd.valid) {
reg := io.mem.rd.bits + const
}
io.mem.rd.ready := state === sReadData
// write
io.mem.wr.valid := state === sWriteData
io.mem.wr.bits := reg
// count read/write
when(state === sIdle) {
cnt := 0.U
}.elsewhen(state === sWriteData) {
cnt := cnt + 1.U
}
// done when read/write are equal to length
io.finish := last
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package accel
import chisel3._
import chisel3.util._
import vta.dpi._
/** Register File.
*
* Six 32-bit register file.
*
* -------------------------------
* Register description | addr
* -------------------------|-----
* Control status register | 0x00
* Cycle counter | 0x04
* Constant value | 0x08
* Vector length | 0x0c
* Input pointer lsb | 0x10
* Input pointer msb | 0x14
* Output pointer lsb | 0x18
* Output pointer msb | 0x1c
* -------------------------------
*
* ------------------------------
* Control status register | bit
* ------------------------------
* Launch | 0
* Finish | 1
* ------------------------------
*/
class RegFile(implicit config: AccelConfig) extends Module {
val io = IO(new Bundle {
val launch = Output(Bool())
val finish = Input(Bool())
val ecnt = Vec(config.nECnt, Flipped(ValidIO(UInt(config.regBits.W))))
val vals = Output(Vec(config.nVals, UInt(config.regBits.W)))
val ptrs = Output(Vec(config.nPtrs, UInt(config.ptrBits.W)))
val host = new VTAHostDPIClient
})
val sIdle :: sRead :: Nil = Enum(2)
val state = RegInit(sIdle)
switch(state) {
is(sIdle) {
when(io.host.req.valid && !io.host.req.opcode) {
state := sRead
}
}
is(sRead) {
state := sIdle
}
}
io.host.req.deq := state === sIdle & io.host.req.valid
val nTotal = config.nCtrl + config.nECnt + config.nVals + (2 * config.nPtrs)
val reg =
Seq.fill(nTotal)(RegInit(0.U.asTypeOf(chiselTypeOf(io.host.req.value))))
val addr = Seq.tabulate(nTotal)(_ * 4)
val reg_map = (addr zip reg) map { case (a, r) => a.U -> r }
val eo = config.nCtrl
val vo = eo + config.nECnt
val po = vo + config.nVals
when(io.finish) {
reg(0) := "b_10".U
}.elsewhen(state === sIdle && io.host.req.valid &&
io.host.req.opcode && addr(0).U === io.host.req.addr) {
reg(0) := io.host.req.value
}
for (i <- 0 until config.nECnt) {
when(io.ecnt(i).valid) {
reg(eo + i) := io.ecnt(i).bits
}.elsewhen(state === sIdle && io.host.req.valid &&
io.host.req.opcode && addr(eo + i).U === io.host.req.addr) {
reg(eo + i) := io.host.req.value
}
}
for (i <- 0 until (config.nVals + (2 * config.nPtrs))) {
when(state === sIdle && io.host.req.valid &&
io.host.req.opcode && addr(vo + i).U === io.host.req.addr) {
reg(vo + i) := io.host.req.value
}
}
val rdata = RegInit(0.U.asTypeOf(chiselTypeOf(io.host.req.value)))
when(state === sIdle && io.host.req.valid && !io.host.req.opcode) {
rdata := MuxLookup(io.host.req.addr, 0.U, reg_map)
}
io.host.resp.valid := state === sRead
io.host.resp.bits := rdata
io.launch := reg(0)(0)
for (i <- 0 until config.nVals) {
io.vals(i) := reg(vo + i)
}
for (i <- 0 until config.nPtrs) {
io.ptrs(i) := Cat(reg(po + (2 * i) + 1), reg(po + (2 * i)))
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package test
import chisel3._
import chisel3.experimental.MultiIOModule
import vta.dpi._
import accel._
/** VTA simulation shell.
*
* Instantiate Host and Memory DPI modules.
*
*/
class VTASimShell extends MultiIOModule {
val host = IO(new VTAHostDPIMaster)
val mem = IO(new VTAMemDPIClient)
val sim_clock = IO(Input(Clock()))
val sim_wait = IO(Output(Bool()))
val mod_sim = Module(new VTASimDPI)
val mod_host = Module(new VTAHostDPI)
val mod_mem = Module(new VTAMemDPI)
mod_mem.io.clock := clock
mod_mem.io.reset := reset
mod_mem.io.dpi <> mem
mod_host.io.clock := clock
mod_host.io.reset := reset
host <> mod_host.io.dpi
mod_sim.io.clock := sim_clock
mod_sim.io.reset := reset
sim_wait := mod_sim.io.dpi_wait
}
/** Test accelerator.
*
* Instantiate and connect the simulation-shell and the accelerator.
*
*/
class TestAccel extends MultiIOModule {
val sim_clock = IO(Input(Clock()))
val sim_wait = IO(Output(Bool()))
val sim_shell = Module(new VTASimShell)
val vta_accel = Module(new Accel)
sim_shell.sim_clock := sim_clock
sim_wait := sim_shell.sim_wait
sim_shell.mem <> vta_accel.io.mem
vta_accel.io.host <> sim_shell.host
}
/** Generate TestAccel as top module */
object Elaborate extends App {
chisel3.Driver.execute(args, () => new TestAccel)
}
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
ifeq (, $(shell which verilator))
$(error "No Verilator in $(PATH), consider doing apt-get install verilator")
endif
# Change VERILATOR_INC_DIR if Verilator is installed on a different location
ifeq (, $(VERILATOR_INC_DIR))
ifeq (, $(wildcard /usr/local/share/verilator/include/*))
ifeq (, $(wildcard /usr/share/verilator/include/*))
$(error "Verilator include directory is not set properly")
else
VERILATOR_INC_DIR := /usr/share/verilator/include
endif
else
VERILATOR_INC_DIR := /usr/local/share/verilator/include
endif
endif
TOP = TestAccel
BUILD_NAME = build
USE_TRACE = 0
LIBNAME = libhw
vta_dir = $(abspath ../../../../)
tvm_dir = $(abspath ../../../../../../)
build_dir = $(abspath .)/$(BUILD_NAME)
verilator_opt = --cc
verilator_opt += +define+RANDOMIZE_GARBAGE_ASSIGN
verilator_opt += +define+RANDOMIZE_REG_INIT
verilator_opt += +define+RANDOMIZE_MEM_INIT
verilator_opt += --x-assign unique
verilator_opt += --output-split 20000
verilator_opt += --output-split-cfuncs 20000
verilator_opt += --top-module ${TOP}
verilator_opt += -Mdir ${build_dir}
cxx_flags = -O2 -Wall -fPIC -shared
cxx_flags += -fvisibility=hidden -std=c++11
cxx_flags += -DVL_TSIM_NAME=V$(TOP)
cxx_flags += -DVL_PRINTF=printf
cxx_flags += -DVL_USER_FINISH
cxx_flags += -DVM_COVERAGE=0
cxx_flags += -DVM_SC=0
cxx_flags += -Wno-sign-compare
cxx_flags += -include V$(TOP).h
cxx_flags += -I$(build_dir)
cxx_flags += -I$(VERILATOR_INC_DIR)
cxx_flags += -I$(VERILATOR_INC_DIR)/vltstd
cxx_flags += -I$(vta_dir)/include
cxx_flags += -I$(tvm_dir)/include
cxx_flags += -I$(tvm_dir)/3rdparty/dlpack/include
cxx_files = $(VERILATOR_INC_DIR)/verilated.cpp
cxx_files += $(VERILATOR_INC_DIR)/verilated_dpi.cpp
cxx_files += $(wildcard $(build_dir)/*.cpp)
cxx_files += $(vta_dir)/hardware/dpi/tsim_device.cc
v_files = $(wildcard $(abspath .)/src/*.v $(vta_dir)/hardware/chisel/src/main/resources/verilog/*.v)
ifneq ($(USE_TRACE), 0)
verilator_opt += --trace
cxx_flags += -DVM_TRACE=1
cxx_flags += -DTSIM_TRACE_FILE=$(build_dir)/$(TOP).vcd
cxx_files += $(VERILATOR_INC_DIR)/verilated_vcd_c.cpp
else
cxx_flags += -DVM_TRACE=0
endif
# The following is to be consistent with cmake
ifeq ($(shell uname), Darwin)
lib_path = $(build_dir)/$(LIBNAME).dylib
else
lib_path = $(build_dir)/$(LIBNAME).so
endif
default: lib
lib: $(lib_path)
$(lib_path): $(build_dir)/V$(TOP).cpp
g++ $(cxx_flags) $(cxx_files) -o $@
verilator: $(build_dir)/V$(TOP).cpp
$(build_dir)/V$(TOP).cpp: $(v_files) | $(build_dir)
verilator $(verilator_opt) $(v_files)
$(build_dir):
mkdir -p $@
clean:
-rm -rf $(build_dir)
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/** Add-by-one accelerator.
*
* ___________ ___________
* | | | |
* | HostDPI | <--> | RegFile | <->|
* |_________| |_________| |
* |
* ___________ ___________ |
* | | | | |
* | MemDPI | <--> | Compute | <->|
* |_________| |_________|
*
*/
module Accel #
( parameter HOST_ADDR_BITS = 8,
parameter HOST_DATA_BITS = 32,
parameter MEM_LEN_BITS = 8,
parameter MEM_ADDR_BITS = 64,
parameter MEM_DATA_BITS = 64
)
(
input clock,
input reset,
input host_req_valid,
input host_req_opcode,
input [HOST_ADDR_BITS-1:0] host_req_addr,
input [HOST_DATA_BITS-1:0] host_req_value,
output host_req_deq,
output host_resp_valid,
output [HOST_DATA_BITS-1:0] host_resp_bits,
output mem_req_valid,
output mem_req_opcode,
output [MEM_LEN_BITS-1:0] mem_req_len,
output [MEM_ADDR_BITS-1:0] mem_req_addr,
output mem_wr_valid,
output [MEM_DATA_BITS-1:0] mem_wr_bits,
input mem_rd_valid,
input [MEM_DATA_BITS-1:0] mem_rd_bits,
output mem_rd_ready
);
logic launch;
logic finish;
logic event_counter_valid;
logic [HOST_DATA_BITS-1:0] event_counter_value;
logic [HOST_DATA_BITS-1:0] constant;
logic [HOST_DATA_BITS-1:0] length;
logic [MEM_ADDR_BITS-1:0] inp_baddr;
logic [MEM_ADDR_BITS-1:0] out_baddr;
RegFile #
(
.MEM_ADDR_BITS(MEM_ADDR_BITS),
.HOST_ADDR_BITS(HOST_ADDR_BITS),
.HOST_DATA_BITS(HOST_DATA_BITS)
)
rf
(
.clock (clock),
.reset (reset),
.host_req_valid (host_req_valid),
.host_req_opcode (host_req_opcode),
.host_req_addr (host_req_addr),
.host_req_value (host_req_value),
.host_req_deq (host_req_deq),
.host_resp_valid (host_resp_valid),
.host_resp_bits (host_resp_bits),
.launch (launch),
.finish (finish),
.event_counter_valid (event_counter_valid),
.event_counter_value (event_counter_value),
.constant (constant),
.length (length),
.inp_baddr (inp_baddr),
.out_baddr (out_baddr)
);
Compute #
(
.MEM_LEN_BITS(MEM_LEN_BITS),
.MEM_ADDR_BITS(MEM_ADDR_BITS),
.MEM_DATA_BITS(MEM_DATA_BITS),
.HOST_DATA_BITS(HOST_DATA_BITS)
)
comp
(
.clock (clock),
.reset (reset),
.mem_req_valid (mem_req_valid),
.mem_req_opcode (mem_req_opcode),
.mem_req_len (mem_req_len),
.mem_req_addr (mem_req_addr),
.mem_wr_valid (mem_wr_valid),
.mem_wr_bits (mem_wr_bits),
.mem_rd_valid (mem_rd_valid),
.mem_rd_bits (mem_rd_bits),
.mem_rd_ready (mem_rd_ready),
.launch (launch),
.finish (finish),
.event_counter_valid (event_counter_valid),
.event_counter_value (event_counter_value),
.constant (constant),
.length (length),
.inp_baddr (inp_baddr),
.out_baddr (out_baddr)
);
endmodule
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/** Compute
*
* Add-by-one procedure:
*
* 1. Wait for launch to be asserted
* 2. Issue a read request for 8-byte value at inp_baddr address
* 3. Wait for the value
* 4. Issue a write request for 8-byte value at out_baddr address
* 5. Increment read-address and write-address for next value
* 6. Check if counter (cnt) is equal to length to assert finish,
* otherwise go to step 2.
*/
module Compute #
(
parameter MEM_LEN_BITS = 8,
parameter MEM_ADDR_BITS = 64,
parameter MEM_DATA_BITS = 64,
parameter HOST_DATA_BITS = 32
)
(
input clock,
input reset,
output mem_req_valid,
output mem_req_opcode,
output [MEM_LEN_BITS-1:0] mem_req_len,
output [MEM_ADDR_BITS-1:0] mem_req_addr,
output mem_wr_valid,
output [MEM_DATA_BITS-1:0] mem_wr_bits,
input mem_rd_valid,
input [MEM_DATA_BITS-1:0] mem_rd_bits,
output mem_rd_ready,
input launch,
output finish,
output event_counter_valid,
output [HOST_DATA_BITS-1:0] event_counter_value,
input [HOST_DATA_BITS-1:0] constant,
input [HOST_DATA_BITS-1:0] length,
input [MEM_ADDR_BITS-1:0] inp_baddr,
input [MEM_ADDR_BITS-1:0] out_baddr
);
typedef enum logic [2:0] {IDLE,
READ_REQ,
READ_DATA,
WRITE_REQ,
WRITE_DATA} state_t;
state_t state_n, state_r;
logic [31:0] cnt;
logic [MEM_DATA_BITS-1:0] data;
logic [MEM_ADDR_BITS-1:0] raddr;
logic [MEM_ADDR_BITS-1:0] waddr;
always_ff @(posedge clock) begin
if (reset) begin
state_r <= IDLE;
end else begin
state_r <= state_n;
end
end
always_comb begin
state_n = IDLE;
case (state_r)
IDLE: begin
if (launch) begin
state_n = READ_REQ;
end
end
READ_REQ: begin
state_n = READ_DATA;
end
READ_DATA: begin
if (mem_rd_valid) begin
state_n = WRITE_REQ;
end else begin
state_n = READ_DATA;
end
end
WRITE_REQ: begin
state_n = WRITE_DATA;
end
WRITE_DATA: begin
if (cnt == (length - 1'b1)) begin
state_n = IDLE;
end else begin
state_n = READ_REQ;
end
end
default: begin
end
endcase
end
logic last;
assign last = (state_r == WRITE_DATA) & (cnt == (length - 1'b1));
// cycle counter
logic [HOST_DATA_BITS-1:0] cycle_counter;
always_ff @(posedge clock) begin
if (reset | state_r == IDLE) begin
cycle_counter <= '0;
end else begin
cycle_counter <= cycle_counter + 1'b1;
end
end
assign event_counter_valid = last;
assign event_counter_value = cycle_counter;
// calculate next address
always_ff @(posedge clock) begin
if (reset | state_r == IDLE) begin
raddr <= inp_baddr;
waddr <= out_baddr;
end else if (state_r == WRITE_DATA) begin
raddr <= raddr + 'd8;
waddr <= waddr + 'd8;
end
end
// create request
assign mem_req_valid = (state_r == READ_REQ) | (state_r == WRITE_REQ);
assign mem_req_opcode = state_r == WRITE_REQ;
assign mem_req_len = 'd0; // one-word-per-request
assign mem_req_addr = (state_r == READ_REQ)? raddr : waddr;
// read
always_ff @(posedge clock) begin
if ((state_r == READ_DATA) & mem_rd_valid) begin
data <= mem_rd_bits + {32'd0, constant};
end
end
assign mem_rd_ready = state_r == READ_DATA;
// write
assign mem_wr_valid = state_r == WRITE_DATA;
assign mem_wr_bits = data;
// count read/write
always_ff @(posedge clock) begin
if (reset | state_r == IDLE) begin
cnt <= 'd0;
end else if (state_r == WRITE_DATA) begin
cnt <= cnt + 1'b1;
end
end
// done when read/write are equal to length
assign finish = last;
endmodule
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/** Register File.
*
* Six 32-bit register file.
*
* -------------------------------
* Register description | addr
* -------------------------|-----
* Control status register | 0x00
* Cycle counter | 0x04
* Constant value | 0x08
* Vector length | 0x0c
* Input pointer lsb | 0x10
* Input pointer msb | 0x14
* Output pointer lsb | 0x18
* Output pointer msb | 0x1c
* -------------------------------
* ------------------------------
* Control status register | bit
* ------------------------------
* Launch | 0
* Finish | 1
* ------------------------------
*/
module RegFile #
(parameter MEM_ADDR_BITS = 64,
parameter HOST_ADDR_BITS = 8,
parameter HOST_DATA_BITS = 32
)
(
input clock,
input reset,
input host_req_valid,
input host_req_opcode,
input [HOST_ADDR_BITS-1:0] host_req_addr,
input [HOST_DATA_BITS-1:0] host_req_value,
output host_req_deq,
output host_resp_valid,
output [HOST_DATA_BITS-1:0] host_resp_bits,
output launch,
input finish,
input event_counter_valid,
input [HOST_DATA_BITS-1:0] event_counter_value,
output [HOST_DATA_BITS-1:0] constant,
output [HOST_DATA_BITS-1:0] length,
output [MEM_ADDR_BITS-1:0] inp_baddr,
output [MEM_ADDR_BITS-1:0] out_baddr
);
localparam NUM_REG = 8;
typedef enum logic {IDLE, READ} state_t;
state_t state_n, state_r;
always_ff @(posedge clock) begin
if (reset) begin
state_r <= IDLE;
end else begin
state_r <= state_n;
end
end
always_comb begin
state_n = IDLE;
case (state_r)
IDLE: begin
if (host_req_valid & ~host_req_opcode) begin
state_n = READ;
end
end
READ: begin
state_n = IDLE;
end
endcase
end
assign host_req_deq = (state_r == IDLE) ? host_req_valid : 1'b0;
logic [HOST_DATA_BITS-1:0] rf [NUM_REG-1:0];
genvar i;
for (i = 0; i < NUM_REG; i++) begin
logic wen = (state_r == IDLE)? host_req_valid & host_req_opcode & i*4 == host_req_addr : 1'b0;
if (i == 0) begin
always_ff @(posedge clock) begin
if (reset) begin
rf[i] <= 'd0;
end else if (finish) begin
rf[i] <= 'd2;
end else if (wen) begin
rf[i] <= host_req_value;
end
end
end else if (i == 1) begin
always_ff @(posedge clock) begin
if (reset) begin
rf[i] <= 'd0;
end else if (event_counter_valid) begin
rf[i] <= event_counter_value;
end else if (wen) begin
rf[i] <= host_req_value;
end
end
end else begin
always_ff @(posedge clock) begin
if (reset) begin
rf[i] <= 'd0;
end else if (wen) begin
rf[i] <= host_req_value;
end
end
end
end
logic [HOST_DATA_BITS-1:0] rdata;
always_ff @(posedge clock) begin
if (reset) begin
rdata <= 'd0;
end else if ((state_r == IDLE) & host_req_valid & ~host_req_opcode) begin
if (host_req_addr == 'h00) begin
rdata <= rf[0];
end else if (host_req_addr == 'h04) begin
rdata <= rf[1];
end else if (host_req_addr == 'h08) begin
rdata <= rf[2];
end else if (host_req_addr == 'h0c) begin
rdata <= rf[3];
end else if (host_req_addr == 'h10) begin
rdata <= rf[4];
end else if (host_req_addr == 'h14) begin
rdata <= rf[5];
end else if (host_req_addr == 'h18) begin
rdata <= rf[6];
end else if (host_req_addr == 'h1c) begin
rdata <= rf[7];
end else begin
rdata <= 'd0;
end
end
end
assign host_resp_valid = (state_r == READ);
assign host_resp_bits = rdata;
assign launch = rf[0][0];
assign constant = rf[2];
assign length = rf[3];
assign inp_baddr = {rf[5], rf[4]};
assign out_baddr = {rf[7], rf[6]};
endmodule
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/** Test accelerator.
*
* Instantiate host/memory DPI modules and connect them to the accelerator.
*
*/
module TestAccel
(
input clock,
input reset,
input sim_clock,
output sim_wait
);
localparam HOST_ADDR_BITS = 8;
localparam HOST_DATA_BITS = 32;
logic host_req_valid;
logic host_req_opcode;
logic [HOST_ADDR_BITS-1:0] host_req_addr;
logic [HOST_DATA_BITS-1:0] host_req_value;
logic host_req_deq;
logic host_resp_valid;
logic [HOST_DATA_BITS-1:0] host_resp_bits;
localparam MEM_LEN_BITS = 8;
localparam MEM_ADDR_BITS = 64;
localparam MEM_DATA_BITS = 64;
logic mem_req_valid;
logic mem_req_opcode;
logic [MEM_LEN_BITS-1:0] mem_req_len;
logic [MEM_ADDR_BITS-1:0] mem_req_addr;
logic mem_wr_valid;
logic [MEM_DATA_BITS-1:0] mem_wr_bits;
logic mem_rd_valid;
logic [MEM_DATA_BITS-1:0] mem_rd_bits;
logic mem_rd_ready;
VTASimDPI sim
(
.clock (sim_clock),
.reset (reset),
.dpi_wait (sim_wait)
);
VTAHostDPI host
(
.clock (clock),
.reset (reset),
.dpi_req_valid (host_req_valid),
.dpi_req_opcode (host_req_opcode),
.dpi_req_addr (host_req_addr),
.dpi_req_value (host_req_value),
.dpi_req_deq (host_req_deq),
.dpi_resp_valid (host_resp_valid),
.dpi_resp_bits (host_resp_bits)
);
VTAMemDPI mem
(
.clock (clock),
.reset (reset),
.dpi_req_valid (mem_req_valid),
.dpi_req_opcode (mem_req_opcode),
.dpi_req_len (mem_req_len),
.dpi_req_addr (mem_req_addr),
.dpi_wr_valid (mem_wr_valid),
.dpi_wr_bits (mem_wr_bits),
.dpi_rd_valid (mem_rd_valid),
.dpi_rd_bits (mem_rd_bits),
.dpi_rd_ready (mem_rd_ready)
);
Accel #
(
.HOST_ADDR_BITS(HOST_ADDR_BITS),
.HOST_DATA_BITS(HOST_DATA_BITS),
.MEM_LEN_BITS(MEM_LEN_BITS),
.MEM_ADDR_BITS(MEM_ADDR_BITS),
.MEM_DATA_BITS(MEM_DATA_BITS)
)
accel
(
.clock (clock),
.reset (reset),
.host_req_valid (host_req_valid),
.host_req_opcode (host_req_opcode),
.host_req_addr (host_req_addr),
.host_req_value (host_req_value),
.host_req_deq (host_req_deq),
.host_resp_valid (host_resp_valid),
.host_resp_bits (host_resp_bits),
.mem_req_valid (mem_req_valid),
.mem_req_opcode (mem_req_opcode),
.mem_req_len (mem_req_len),
.mem_req_addr (mem_req_addr),
.mem_wr_valid (mem_wr_valid),
.mem_wr_bits (mem_wr_bits),
.mem_rd_valid (mem_rd_valid),
.mem_rd_bits (mem_rd_bits),
.mem_rd_ready (mem_rd_ready)
);
endmodule
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from . import tsim
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import te
import ctypes
import os.path as osp
from sys import platform
def get_ext():
"""Return shared library extension"""
return ".dylib" if platform == "darwin" else ".so"
def load_dll(dll):
"""Load shared library
Parameters
------------
dll : str
Path for shared library
Returns
------------
The shared library
"""
try:
return [ctypes.CDLL(dll, ctypes.RTLD_GLOBAL)]
except OSError:
return []
def load_sw():
"""Load all software shared libraries"""
cur_path = osp.dirname(osp.abspath(osp.expanduser(__file__)))
sw_libname = "libsw" + get_ext()
sw_lib = osp.join(cur_path, "..", "build", sw_libname)
load_dll(sw_lib)
def init(hw_backend):
"""Init hardware and software shared library for accelerator
Parameters
------------
hw_backend : str
Hardware backend can be verilog or chisel
"""
cur_path = osp.dirname(osp.abspath(osp.expanduser(__file__)))
hw_libname = "libhw" + get_ext()
if hw_backend in ("verilog", "chisel"):
hw_lib = osp.join(cur_path, "..", "hardware", hw_backend, "build", hw_libname)
load_sw()
m = tvm.runtime.load_module(hw_lib, "vta-tsim")
f = tvm.get_global_func("tvm.vta.tsim.init")
f(m)
def load_module():
"""Return driver function"""
load_sw()
return tvm.get_global_func("tvm.vta.driver")
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>
#include <vta/dpi/module.h>
#include "vmem/virtual_memory.h"
namespace vta {
namespace driver {
using vta::dpi::DPIModuleNode;
using tvm::runtime::Module;
class DPILoader {
public:
~DPILoader() {
dpi_->SimResume();
dpi_->SimFinish();
}
void Init(Module module) {
mod_ = module;
dpi_ = this->Get();
dpi_->SimLaunch();
dpi_->SimWait();
}
DPIModuleNode* Get() {
return static_cast<DPIModuleNode*>(mod_.operator->());
}
static DPILoader* Global() {
static DPILoader inst;
return &inst;
}
// TVM module
Module mod_;
// DPI Module
DPIModuleNode* dpi_{nullptr};
};
class Device {
public:
Device() {
loader_ = DPILoader::Global();
}
uint32_t Run(uint32_t c, DLTensor* a, DLTensor* b) {
uint32_t cycles;
uint32_t len = a->shape[0];
size_t size = (a->dtype.bits >> 3) * len;
a_ = this->MemAlloc(size);
b_ = this->MemAlloc(size);
this->MemCopyFromHost(a_, a->data, size);
this->Init();
this->Launch(c, len);
cycles = this->WaitForCompletion();
this->MemCopyToHost(b->data, b_, size);
this->MemFree(a_);
this->MemFree(b_);
return cycles;
}
private:
void Init() {
dpi_ = loader_->Get();
dpi_->SimResume();
}
void* MemAlloc(size_t size) {
void * addr = vta::vmem::VirtualMemoryManager::Global()->Alloc(size);
return reinterpret_cast<void*>(vta::vmem::VirtualMemoryManager::Global()->GetPhyAddr(addr));
}
void MemFree(void* buf) {
void * addr = vta::vmem::VirtualMemoryManager::Global()->GetAddr(reinterpret_cast<uint64_t>(buf));
vta::vmem::VirtualMemoryManager::Global()->Free(addr);
}
vta_phy_addr_t MemGetPhyAddr(void* buf) {
return reinterpret_cast<uint64_t>(reinterpret_cast<uint64_t*>(buf));
}
void MemCopyFromHost(void* dst, const void* src, size_t size) {
vta::vmem::VirtualMemoryManager::Global()->MemCopyFromHost(dst, src, size);
}
void MemCopyToHost(void* dst, const void* src, size_t size) {
vta::vmem::VirtualMemoryManager::Global()->MemCopyToHost(dst, src, size);
}
void Launch(uint32_t c, uint32_t len) {
dpi_->WriteReg(0x08, c);
dpi_->WriteReg(0x0c, len);
dpi_->WriteReg(0x10, this->MemGetPhyAddr(a_));
dpi_->WriteReg(0x14, 0);
dpi_->WriteReg(0x18, this->MemGetPhyAddr(b_));
dpi_->WriteReg(0x1c, 0);
dpi_->WriteReg(0x00, 0x1); // launch
}
uint32_t WaitForCompletion() {
uint32_t i, val;
for (i = 0; i < wait_cycles_; i++) {
val = dpi_->ReadReg(0x00);
if (val == 2) break; // finish
}
val = dpi_->ReadReg(0x04);
dpi_->SimWait();
return val;
}
// wait cycles
uint32_t wait_cycles_{100000000};
// DPI loader
DPILoader* loader_{nullptr};
// DPI Module
DPIModuleNode* dpi_{nullptr};
// input vm ptr
void* a_{nullptr};
// output vm ptr
void* b_{nullptr};
};
using tvm::runtime::TVMRetValue;
using tvm::runtime::TVMArgs;
TVM_REGISTER_GLOBAL("tvm.vta.tsim.init")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Module m = args[0];
DPILoader::Global()->Init(m);
});
TVM_REGISTER_GLOBAL("tvm.vta.driver")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Device dev_;
DLTensor* A = args[0];
DLTensor* B = args[1];
uint32_t c = static_cast<int>(args[2]);
uint32_t cycles = dev_.Run(c, A, B);
*rv = static_cast<int>(cycles);
});
} // namespace driver
} // namespace vta
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import te
import numpy as np
import tsim
def test_accel():
rmax = 64
dtype = "uint64"
n = np.random.randint(1, rmax)
c = np.random.randint(0, rmax)
ctx = tvm.cpu(0)
a = tvm.nd.array(np.random.randint(rmax, size=n).astype(dtype), ctx)
b = tvm.nd.array(np.zeros(n).astype(dtype), ctx)
f = tsim.load_module()
cycles = f(a, b, c)
msg = "cycles:{0:4} n:{1:2} c:{2:2}".format(cycles, n, c)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + c, err_msg = "[FAIL] " + msg)
print("[PASS] " + msg)
if __name__ == "__main__":
tsim.init("chisel")
for i in range(10):
test_accel()
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import te
import numpy as np
import tsim
def test_accel():
rmax = 64
dtype = "uint64"
n = np.random.randint(1, rmax)
c = np.random.randint(0, rmax)
ctx = tvm.cpu(0)
a = tvm.nd.array(np.random.randint(rmax, size=n).astype(dtype), ctx)
b = tvm.nd.array(np.zeros(n).astype(dtype), ctx)
f = tsim.load_module()
cycles = f(a, b, c)
msg = "cycles:{0:4} n:{1:2} c:{2:2}".format(cycles, n, c)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + c, err_msg = "[FAIL] " + msg)
print("[PASS] " + msg)
if __name__ == "__main__":
tsim.init("verilog")
for i in range(10):
test_accel()
<!--- Licensed to the Apache Software Foundation (ASF) under one -->
<!--- or more contributor license agreements. See the NOTICE file -->
<!--- distributed with this work for additional information -->
<!--- regarding copyright ownership. The ASF licenses this file -->
<!--- to you under the Apache License, Version 2.0 (the -->
<!--- "License"); you may not use this file except in compliance -->
<!--- with the License. You may obtain a copy of the License at -->
<!--- http://www.apache.org/licenses/LICENSE-2.0 -->
<!--- Unless required by applicable law or agreed to in writing, -->
<!--- software distributed under the License is distributed on an -->
<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
<!--- KIND, either express or implied. See the License for the -->
<!--- specific language governing permissions and limitations -->
<!--- under the License. -->
# VTA Configuration
Each VTA runtime/hardware configuration is specified by vta_config.json file.
You can copy the vta_config.json to tvm project root and modify the configuration
before you type make.
The config is going to affect the behavior of python package as well as
the hardware runtime build.
{
"TARGET" : "de10nano",
"HW_VER" : "0.0.1",
"LOG_INP_WIDTH" : 3,
"LOG_WGT_WIDTH" : 3,
"LOG_ACC_WIDTH" : 5,
"LOG_BATCH" : 0,
"LOG_BLOCK" : 4,
"LOG_UOP_BUFF_SIZE" : 15,
"LOG_INP_BUFF_SIZE" : 15,
"LOG_WGT_BUFF_SIZE" : 18,
"LOG_ACC_BUFF_SIZE" : 17
}
{
"TARGET" : "sim",
"HW_VER" : "0.0.1",
"LOG_INP_WIDTH" : 3,
"LOG_WGT_WIDTH" : 3,
"LOG_ACC_WIDTH" : 5,
"LOG_BATCH" : 0,
"LOG_BLOCK" : 4,
"LOG_UOP_BUFF_SIZE" : 15,
"LOG_INP_BUFF_SIZE" : 15,
"LOG_WGT_BUFF_SIZE" : 18,
"LOG_ACC_BUFF_SIZE" : 17
}
{
"TARGET" : "pynq",
"HW_VER" : "0.0.1",
"LOG_INP_WIDTH" : 3,
"LOG_WGT_WIDTH" : 3,
"LOG_ACC_WIDTH" : 5,
"LOG_BATCH" : 0,
"LOG_BLOCK" : 4,
"LOG_UOP_BUFF_SIZE" : 15,
"LOG_INP_BUFF_SIZE" : 15,
"LOG_WGT_BUFF_SIZE" : 18,
"LOG_ACC_BUFF_SIZE" : 17
}
{
"TARGET" : "tsim",
"HW_VER" : "0.0.1",
"LOG_INP_WIDTH" : 3,
"LOG_WGT_WIDTH" : 3,
"LOG_ACC_WIDTH" : 5,
"LOG_BATCH" : 0,
"LOG_BLOCK" : 4,
"LOG_UOP_BUFF_SIZE" : 15,
"LOG_INP_BUFF_SIZE" : 15,
"LOG_WGT_BUFF_SIZE" : 18,
"LOG_ACC_BUFF_SIZE" : 17
}
{
"TARGET" : "ultra96",
"HW_VER" : "0.0.1",
"LOG_INP_WIDTH" : 3,
"LOG_WGT_WIDTH" : 3,
"LOG_ACC_WIDTH" : 5,
"LOG_BATCH" : 0,
"LOG_BLOCK" : 4,
"LOG_UOP_BUFF_SIZE" : 15,
"LOG_INP_BUFF_SIZE" : 15,
"LOG_WGT_BUFF_SIZE" : 18,
"LOG_ACC_BUFF_SIZE" : 17
}
{
"TARGET" : "sim",
"HW_VER" : "0.0.1",
"LOG_INP_WIDTH" : 3,
"LOG_WGT_WIDTH" : 3,
"LOG_ACC_WIDTH" : 5,
"LOG_BATCH" : 0,
"LOG_BLOCK" : 4,
"LOG_UOP_BUFF_SIZE" : 15,
"LOG_INP_BUFF_SIZE" : 15,
"LOG_WGT_BUFF_SIZE" : 18,
"LOG_ACC_BUFF_SIZE" : 17
}
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""VTA config tool"""
import os
import sys
import json
import argparse
def pkg_config(cfg):
"""Returns PkgConfig pkg config object."""
pkg_config_py = os.path.join(
os.path.dirname(os.path.abspath(os.path.expanduser(__file__))),
"pkg_config.py"
)
libpkg = {"__file__": pkg_config_py}
exec(compile(open(pkg_config_py, "rb").read(), pkg_config_py, "exec"), libpkg, libpkg)
PkgConfig = libpkg["PkgConfig"]
return PkgConfig(cfg)
def main():
"""Main funciton"""
parser = argparse.ArgumentParser()
parser.add_argument("--use-cfg", type=str, default="",
help="path to the config json")
parser.add_argument("--cflags", action="store_true",
help="print the cflags")
parser.add_argument("--defs", action="store_true",
help="print the macro defs")
parser.add_argument("--sources", action="store_true",
help="print the source file paths")
parser.add_argument("--update", action="store_true",
help="Print out the json option.")
parser.add_argument("--ldflags", action="store_true",
help="print the ldflags")
parser.add_argument("--cfg-json", action="store_true",
help="print all the config json")
parser.add_argument("--save-cfg-json", type=str, default="",
help="save config json to file")
parser.add_argument("--target", action="store_true",
help="print the target")
parser.add_argument("--cfg-str", action="store_true",
help="print the configuration string")
parser.add_argument("--get-inp-mem-banks", action="store_true",
help="returns number of input memory banks")
parser.add_argument("--get-inp-mem-width", action="store_true",
help="returns input memory read/write port width")
parser.add_argument("--get-inp-mem-depth", action="store_true",
help="returns input memory depth")
parser.add_argument("--get-inp-mem-axi-ratio", action="store_true",
help="returns ratio between input element width and axi width")
parser.add_argument("--get-wgt-mem-banks", action="store_true",
help="returns number of weight memory banks")
parser.add_argument("--get-wgt-mem-width", action="store_true",
help="returns weight memory read/write port width")
parser.add_argument("--get-wgt-mem-depth", action="store_true",
help="returns weight memory depth")
parser.add_argument("--get-wgt-mem-axi-ratio", action="store_true",
help="returns ratio between weight element width and axi width")
parser.add_argument("--get-out-mem-banks", action="store_true",
help="returns number of output memory banks")
parser.add_argument("--get-out-mem-width", action="store_true",
help="returns output memory read/write port width")
parser.add_argument("--get-out-mem-depth", action="store_true",
help="returns output memory depth")
parser.add_argument("--get-out-mem-axi-ratio", action="store_true",
help="returns ratio between output element width and axi width")
parser.add_argument("--get-axi-cache-bits", action="store_true",
help="returns AXI system ARCACHE/AWCACHE hardcoded bit value")
parser.add_argument("--get-axi-prot-bits", action="store_true",
help="returns AXI system ARPROT/AWPROT hardcoded bit value")
parser.add_argument("--get-ip-reg-map-range", action="store_true",
help="returns ip register map address range")
parser.add_argument("--get-fetch-base-addr", action="store_true",
help="returns fetch module base address")
parser.add_argument("--get-load-base-addr", action="store_true",
help="returns load module base address")
parser.add_argument("--get-compute-base-addr", action="store_true",
help="returns compute module base address")
parser.add_argument("--get-store-base-addr", action="store_true",
help="returns store module base address")
parser.add_argument("--get-fpga-dev", action="store_true",
help="returns FPGA device target")
parser.add_argument("--get-fpga-family", action="store_true",
help="returns FPGA device family")
parser.add_argument("--get-fpga-freq", action="store_true",
help="returns FPGA frequency")
parser.add_argument("--get-fpga-per", action="store_true",
help="returns HLS target clock period")
args = parser.parse_args()
if len(sys.argv) == 1:
parser.print_help()
return
# Path to vta config
curr_path = os.path.dirname(
os.path.abspath(os.path.expanduser(__file__)))
path_list = [
"vta_config.json", os.path.join(curr_path, "vta_config.json")
]
if args.use_cfg:
path_list = [args.use_cfg]
ok_path_list = [p for p in path_list if os.path.exists(p)]
if not ok_path_list:
raise RuntimeError("Cannot find config in %s" % str(path_list))
cfg = json.load(open(ok_path_list[0]))
pkg = pkg_config(cfg)
if args.target:
print(pkg.TARGET)
if args.defs:
print(" ".join(pkg.macro_defs))
if args.sources:
print(" ".join(pkg.lib_source))
if args.cflags:
cflags_str = " ".join(pkg.cflags)
if pkg.TARGET == "pynq":
cflags_str += " -DVTA_TARGET_PYNQ"
elif pkg.TARGET == "de10nano":
cflags_str += " -DVTA_TARGET_DE10_NANO"
elif pkg.TARGET == "ultra96":
cflags_str += " -DVTA_TARGET_ULTRA96"
print(cflags_str)
if args.ldflags:
print(" ".join(pkg.ldflags))
if args.cfg_json:
print(pkg.cfg_json)
if args.save_cfg_json:
with open(args.save_cfg_json, "w") as fo:
fo.write(pkg.cfg_json)
if args.cfg_str:
print(pkg.TARGET + "_" + pkg.bitstream)
if args.get_inp_mem_banks:
print(pkg.inp_mem_banks)
if args.get_inp_mem_width:
print(pkg.inp_mem_width)
if args.get_inp_mem_depth:
print(pkg.inp_mem_depth)
if args.get_inp_mem_axi_ratio:
print(pkg.inp_mem_axi_ratio)
if args.get_wgt_mem_banks:
print(pkg.wgt_mem_banks)
if args.get_wgt_mem_width:
print(pkg.wgt_mem_width)
if args.get_wgt_mem_depth:
print(pkg.wgt_mem_depth)
if args.get_wgt_mem_axi_ratio:
print(pkg.wgt_mem_axi_ratio)
if args.get_out_mem_banks:
print(pkg.out_mem_banks)
if args.get_out_mem_width:
print(pkg.out_mem_width)
if args.get_out_mem_depth:
print(pkg.out_mem_depth)
if args.get_out_mem_axi_ratio:
print(pkg.out_mem_axi_ratio)
if args.get_axi_cache_bits:
print(pkg.axi_cache_bits)
if args.get_axi_prot_bits:
print(pkg.axi_prot_bits)
if args.get_ip_reg_map_range:
print(pkg.ip_reg_map_range)
if args.get_fetch_base_addr:
print(pkg.fetch_base_addr)
if args.get_load_base_addr:
print(pkg.load_base_addr)
if args.get_compute_base_addr:
print(pkg.compute_base_addr)
if args.get_store_base_addr:
print(pkg.store_base_addr)
if args.get_fpga_dev:
print(pkg.fpga_device)
if args.get_fpga_family:
print(pkg.fpga_family)
if args.get_fpga_freq:
print(pkg.fpga_freq)
if args.get_fpga_per:
print(pkg.fpga_per)
if __name__ == "__main__":
main()
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
ifeq (, $(shell which verilator))
$(error "No Verilator in $(PATH), consider doing apt-get install verilator")
endif
# Change VERILATOR_INC_DIR if Verilator is installed on a different location
ifeq (, $(VERILATOR_INC_DIR))
ifeq (, $(wildcard /usr/local/share/verilator/include/*))
ifeq (, $(wildcard /usr/share/verilator/include/*))
$(error "Verilator include directory is not set properly")
else
VERILATOR_INC_DIR := /usr/share/verilator/include
endif
else
VERILATOR_INC_DIR := /usr/local/share/verilator/include
endif
endif
CONFIG = DefaultDe10Config
TOP = VTA
TOP_TEST = Test
BUILD_NAME = build
# Set USE_TRACE = 1 to generate a trace during simulation.
USE_TRACE = 0
# With USE_TRACE = 1, default trace format is VCD.
# Set USE_TRACE_FST = 1 to use the FST format.
# Note that although FST is around two orders of magnitude smaller than VCD
# it is also currently much slower to produce (verilator limitation). But if
# you are low on disk space it may be your only option.
USE_TRACE_FST = 0
# With USE_TRACE = 1, USE_TRACE_DETAILED = 1 will generate traces that also
# include non-interface internal signal names starting with an underscore.
# This will significantly increase the trace size and should only be used
# on a per need basis for difficult debug problems.
USE_TRACE_DETAILED = 0
USE_THREADS = 0
VTA_LIBNAME = libvta_hw
UNITTEST_NAME = all
CXX = g++
# A debug build with DEBUG = 1 is useful to trace the simulation with a
# debugger.
DEBUG = 0
# With DEBUG = 1, SANITIZE = 1 turns on address sanitizing to verify that
# the verilator build is sane. To be used if you know what you are doing.
SANITIZE = 0
CXX_MAJOR := $(shell $(CXX) -dumpversion | sed 's/\..*//')
CXX_HAS_ALIGN_NEW := $(shell [ $(CXX_MAJOR) -ge 7 ] && echo true)
config_test = $(TOP_TEST)$(CONFIG)
ifndef TVM_PATH
TVM_PATH := $(abspath ../../../../)
endif
ifndef VTA_HW_PATH
VTA_HW_PATH := $(abspath ../../)
endif
verilator_build_dir = $(VTA_HW_PATH)/$(BUILD_NAME)/verilator
chisel_build_dir = $(VTA_HW_PATH)/$(BUILD_NAME)/chisel
verilator_opt = --cc
verilator_opt += +define+RANDOMIZE_GARBAGE_ASSIGN
verilator_opt += +define+RANDOMIZE_REG_INIT
verilator_opt += +define+RANDOMIZE_MEM_INIT
verilator_opt += --x-assign unique
verilator_opt += --output-split 20000
verilator_opt += --output-split-cfuncs 20000
verilator_opt += --top-module ${TOP_TEST}
verilator_opt += -Mdir ${verilator_build_dir}
verilator_opt += -I$(chisel_build_dir)
ifeq ($(DEBUG), 0)
cxx_flags = -O2 -Wall -fvisibility=hidden
else
cxx_flags = -O0 -g -Wall
endif
cxx_flags += -std=c++11 -Wno-maybe-uninitialized
ifeq ($(CXX_HAS_ALIGN_NEW),true)
cxx_flags += -faligned-new
endif
cxx_flags += -DVL_TSIM_NAME=V$(TOP_TEST)
cxx_flags += -DVL_PRINTF=printf
cxx_flags += -DVL_USER_FINISH
cxx_flags += -DVM_COVERAGE=0
cxx_flags += -DVM_SC=0
cxx_flags += -Wno-sign-compare
cxx_flags += -include V$(TOP_TEST).h
cxx_flags += -I$(verilator_build_dir)
cxx_flags += -I$(VERILATOR_INC_DIR)
cxx_flags += -I$(VERILATOR_INC_DIR)/vltstd
cxx_flags += -I$(VTA_HW_PATH)/include
cxx_flags += -I$(TVM_PATH)/include
cxx_flags += -I$(TVM_PATH)/3rdparty/dlpack/include
ld_flags = -fPIC -shared
ifeq ($(SANITIZE), 1)
ifeq ($(DEBUG), 1)
cxx_flags += -fno-omit-frame-pointer -fsanitize=address -fsanitize-recover=address
ld_flags += -fno-omit-frame-pointer -fsanitize=address -fsanitize-recover=address
endif
endif
cxx_objs = $(verilator_build_dir)/verilated.o $(verilator_build_dir)/verilated_dpi.o $(verilator_build_dir)/tsim_device.o
ifneq ($(USE_TRACE), 0)
cxx_flags += -DVM_TRACE=1
ifeq ($(USE_TRACE_FST), 1)
cxx_flags += -DVM_TRACE_FST
verilator_opt += --trace-fst
else
verilator_opt += --trace
endif
ifeq ($(USE_TRACE_DETAILED), 1)
verilator_opt += --trace-underscore --trace-structs
endif
ifeq ($(USE_TRACE_FST), 1)
cxx_flags += -DTSIM_TRACE_FILE=$(verilator_build_dir)/$(TOP_TEST).fst
cxx_objs += $(verilator_build_dir)/verilated_fst_c.o
else
cxx_flags += -DTSIM_TRACE_FILE=$(verilator_build_dir)/$(TOP_TEST).vcd
cxx_objs += $(verilator_build_dir)/verilated_vcd_c.o
endif
else
cxx_flags += -DVM_TRACE=0
endif
ifneq ($(USE_THREADS), 0)
verilator_opt += --threads $(USE_THREADS)
cxx_flags += -DVL_THREADED
cxx_objs += $(verilator_build_dir)/verilated_threads.o
endif
VPATH = $(VERILATOR_INC_DIR):$(verilator_build_dir):$(VTA_HW_PATH)/hardware/dpi
# The following is to be consistent with cmake
ifeq ($(shell uname), Darwin)
lib_path = $(VTA_HW_PATH)/$(BUILD_NAME)/$(VTA_LIBNAME).dylib
cxx_flags += -isysroot /Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk
else
lib_path = $(VTA_HW_PATH)/$(BUILD_NAME)/$(VTA_LIBNAME).so
endif
default: lint lib
lint:
sbt scalastyle
lib: $(lib_path)
$(verilator_build_dir)/%.o: %.cpp
$(CXX) -fPIC $(cxx_flags) -c $^ -o $@
$(verilator_build_dir)/tsim_device.o: tsim_device.cc
$(CXX) -fPIC $(cxx_flags) -c $^ -o $@
$(lib_path): $(verilator_build_dir)/V$(TOP_TEST).cpp $(cxx_objs)
for f in $(shell find $(verilator_build_dir)/*.cpp); do \
$(CXX) -fPIC $(cxx_flags) -c $${f} -o $${f}.o ; \
done
$(CXX) $(ld_flags) $(cxx_flags) $(cxx_objs) $(patsubst %.cpp,%.cpp.o,$(shell find $(verilator_build_dir)/*.cpp)) -o $@
verilator: $(verilator_build_dir)/V$(TOP_TEST).cpp
$(verilator_build_dir)/V$(TOP_TEST).cpp: $(chisel_build_dir)/$(TOP_TEST).$(CONFIG).v
verilator $(verilator_opt) $<
verilog: $(chisel_build_dir)/$(TOP).$(CONFIG).v
$(chisel_build_dir)/$(TOP).$(CONFIG).v:
sbt 'runMain vta.$(CONFIG) --target-dir $(chisel_build_dir) --top-name $(TOP).$(CONFIG)'
verilog_test: $(chisel_build_dir)/$(TOP_TEST).$(CONFIG).v
$(chisel_build_dir)/$(TOP_TEST).$(CONFIG).v:
sbt 'runMain vta.$(config_test) --target-dir $(chisel_build_dir) --top-name $(TOP_TEST).$(CONFIG)'
unittest:
sbt 'test:runMain unittest.Launcher $(UNITTEST_NAME)'
clean:
-rm -rf target project/target project/project test_run_dir
cleanall:
-rm -rf $(VTA_HW_PATH)/$(BUILD_NAME)/chisel
-rm -rf $(VTA_HW_PATH)/$(BUILD_NAME)/libvta_hw.so
-rm -rf $(VTA_HW_PATH)/$(BUILD_NAME)/libvta_hw.dylib
-rm -rf $(VTA_HW_PATH)/$(BUILD_NAME)/verilator
<!--- Licensed to the Apache Software Foundation (ASF) under one -->
<!--- or more contributor license agreements. See the NOTICE file -->
<!--- distributed with this work for additional information -->
<!--- regarding copyright ownership. The ASF licenses this file -->
<!--- to you under the Apache License, Version 2.0 (the -->
<!--- "License"); you may not use this file except in compliance -->
<!--- with the License. You may obtain a copy of the License at -->
<!--- http://www.apache.org/licenses/LICENSE-2.0 -->
<!--- Unless required by applicable law or agreed to in writing, -->
<!--- software distributed under the License is distributed on an -->
<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
<!--- KIND, either express or implied. See the License for the -->
<!--- specific language governing permissions and limitations -->
<!--- under the License. -->
VTA in Chisel
===================================================
For contributors who wants to test a chisel module:
- You can add your test files in `src/test/scala/unitttest`
- Add your test name and tests to the `test` object in `src/test/scala/unitttest/Launcher.scala`
- Check out the provided sample test `mvm` which tests the MatrixVectorComputation module
in `src/main/scala/core/TensorGemm.scala`
- Running unit tests: `make test test_name=your_own test_name`
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
name := "vta"
version := "0.1.0-SNAPSHOT"
organization := "edu.washington.cs"
def scalacOptionsVersion(scalaVersion: String): Seq[String] = {
Seq() ++ {
// If we're building with Scala > 2.11, enable the compile option
// switch to support our anonymous Bundle definitions:
// https://github.com/scala/bug/issues/10047
CrossVersion.partialVersion(scalaVersion) match {
case Some((2, scalaMajor: Long)) if scalaMajor < 12 => Seq()
case _ => Seq(
"-Xsource:2.11",
"-language:reflectiveCalls",
"-language:implicitConversions",
"-deprecation",
"-Xlint",
"-Ywarn-unused",
)
}
}
}
def javacOptionsVersion(scalaVersion: String): Seq[String] = {
Seq() ++ {
// Scala 2.12 requires Java 8. We continue to generate
// Java 7 compatible code for Scala 2.11
// for compatibility with old clients.
CrossVersion.partialVersion(scalaVersion) match {
case Some((2, scalaMajor: Long)) if scalaMajor < 12 =>
Seq("-source", "1.7", "-target", "1.7")
case _ =>
Seq("-source", "1.8", "-target", "1.8")
}
}
}
scalaVersion := "2.11.12"
resolvers ++= Seq(
Resolver.sonatypeRepo("snapshots"),
Resolver.sonatypeRepo("releases"))
val defaultVersions = Map(
"chisel3" -> "3.1.7",
"chisel-iotesters" -> "1.2.4"
)
libraryDependencies ++= Seq("chisel3","chisel-iotesters").map {
dep: String => "edu.berkeley.cs" %% dep % sys.props.getOrElse(dep + "Version", defaultVersions(dep)) }
scalacOptions ++= scalacOptionsVersion(scalaVersion.value)
javacOptions ++= javacOptionsVersion(scalaVersion.value)
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
sbt.version = 1.3.2
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
logLevel := Level.Warn
addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "1.0.0")
<scalastyle>
<name>Scalastyle standard configuration</name>
<check level="error" class="org.scalastyle.file.FileTabChecker" enabled="true"></check>
<check level="error" class="org.scalastyle.file.FileLengthChecker" enabled="true">
<parameters>
<parameter name="maxFileLength"><![CDATA[800]]></parameter>
</parameters>
</check>
<check level="error" class="org.scalastyle.file.HeaderMatchesChecker" enabled="true">
<parameters>
<parameter name="header"><![CDATA[/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
]]></parameter>
</parameters>
</check>
<check level="error" class="org.scalastyle.scalariform.SpacesAfterPlusChecker" enabled="true"></check>
<check level="error" class="org.scalastyle.file.WhitespaceEndOfLineChecker" enabled="true"></check>
<check level="error" class="org.scalastyle.scalariform.SpacesBeforePlusChecker" enabled="true"></check>
<check level="error" class="org.scalastyle.file.FileLineLengthChecker" enabled="true">
<parameters>
<parameter name="maxLineLength"><![CDATA[120]]></parameter>
<parameter name="tabSize"><![CDATA[2]]></parameter>
</parameters>
</check>
<check level="error" class="org.scalastyle.scalariform.ClassNamesChecker" enabled="true">
<parameters>
<parameter name="regex"><![CDATA[[A-Z][A-Za-z]*]]></parameter>
</parameters>
</check>
<check level="error" class="org.scalastyle.scalariform.ObjectNamesChecker" enabled="true">
<parameters>
<parameter name="regex"><![CDATA[[A-Z][A-Za-z]*]]></parameter>
</parameters>
</check>
<check level="error" class="org.scalastyle.scalariform.PackageObjectNamesChecker" enabled="true">
<parameters>
<parameter name="regex"><![CDATA[^[a-z][A-Za-z]*$]]></parameter>
</parameters>
</check>
<check level="error" class="org.scalastyle.scalariform.EqualsHashCodeChecker" enabled="true"></check>
<check level="error" class="org.scalastyle.scalariform.IllegalImportsChecker" enabled="true">
<parameters>
<parameter name="illegalImports"><![CDATA[sun._,java.awt._]]></parameter>
</parameters>
</check>
<check level="error" class="org.scalastyle.scalariform.ParameterNumberChecker" enabled="true">
<parameters>
<parameter name="maxParameters"><![CDATA[8]]></parameter>
</parameters>
</check>
<check level="error" class="org.scalastyle.scalariform.MagicNumberChecker" enabled="false">
<parameters>
<parameter name="ignore"><![CDATA[-1,0,1,2,3,4,8,16,32,64,128]]></parameter>
</parameters>
</check>
<check level="error" class="org.scalastyle.scalariform.NoWhitespaceBeforeLeftBracketChecker" enabled="true"></check>
<check level="error" class="org.scalastyle.scalariform.NoWhitespaceAfterLeftBracketChecker" enabled="true"></check>
<check level="error" class="org.scalastyle.scalariform.ReturnChecker" enabled="true"></check>
<check level="error" class="org.scalastyle.scalariform.NullChecker" enabled="true"></check>
<check level="error" class="org.scalastyle.scalariform.NoCloneChecker" enabled="true"></check>
<check level="error" class="org.scalastyle.scalariform.NoFinalizeChecker" enabled="true"></check>
<check level="error" class="org.scalastyle.scalariform.CovariantEqualsChecker" enabled="true"></check>
<check level="error" class="org.scalastyle.scalariform.StructuralTypeChecker" enabled="true"></check>
<check level="error" class="org.scalastyle.file.RegexChecker" enabled="true">
<parameters>
<parameter name="regex"><![CDATA[println]]></parameter>
</parameters>
</check>
<check level="error" class="org.scalastyle.scalariform.NumberOfTypesChecker" enabled="true">
<parameters>
<parameter name="maxTypes"><![CDATA[30]]></parameter>
</parameters>
</check>
<check level="error" class="org.scalastyle.scalariform.CyclomaticComplexityChecker" enabled="true">
<parameters>
<parameter name="maximum"><![CDATA[10]]></parameter>
</parameters>
</check>
<check level="error" class="org.scalastyle.scalariform.UppercaseLChecker" enabled="true"></check>
<check level="error" class="org.scalastyle.scalariform.SimplifyBooleanExpressionChecker" enabled="true"></check>
<check level="error" class="org.scalastyle.scalariform.IfBraceChecker" enabled="false">
<parameters>
<parameter name="singleLineAllowed"><![CDATA[true]]></parameter>
<parameter name="doubleLineAllowed"><![CDATA[false]]></parameter>
</parameters>
</check>
<check level="error" class="org.scalastyle.scalariform.MethodLengthChecker" enabled="true">
<parameters>
<parameter name="maxLength"><![CDATA[50]]></parameter>
</parameters>
</check>
<check level="error" class="org.scalastyle.scalariform.MethodNamesChecker" enabled="false">
<parameters>
<parameter name="regex"><![CDATA[^[a-z][A-Za-z0-9]*$]]></parameter>
</parameters>
</check>
<check level="error" class="org.scalastyle.scalariform.NumberOfMethodsInTypeChecker" enabled="true">
<parameters>
<parameter name="maxMethods"><![CDATA[30]]></parameter>
</parameters>
</check>
<check level="error" class="org.scalastyle.scalariform.PublicMethodsHaveTypeChecker" enabled="false"></check>
<check level="error" class="org.scalastyle.file.NewLineAtEofChecker" enabled="true"></check>
<check level="error" class="org.scalastyle.file.NoNewLineAtEofChecker" enabled="false"></check>
<check level="error" class="org.scalastyle.file.IndentationChecker" enabled="true">
<parameters>
<parameter name="tabSize">2</parameter>
<parameter name="methodParamIndentSize">2</parameter>
<parameter name="classParamIndentSize">4</parameter>
</parameters>
</check>
</scalastyle>
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module VTAHostDPI #
( parameter ADDR_BITS = 8,
parameter DATA_BITS = 32
)
(
input clock,
input reset,
output logic dpi_req_valid,
output logic dpi_req_opcode,
output logic [ADDR_BITS-1:0] dpi_req_addr,
output logic [DATA_BITS-1:0] dpi_req_value,
input dpi_req_deq,
input dpi_resp_valid,
input [DATA_BITS-1:0] dpi_resp_bits
);
import "DPI-C" function void VTAHostDPI
(
output byte unsigned req_valid,
output byte unsigned req_opcode,
output byte unsigned req_addr,
output int unsigned req_value,
input byte unsigned req_deq,
input byte unsigned resp_valid,
input int unsigned resp_value
);
typedef logic dpi1_t;
typedef logic [7:0] dpi8_t;
typedef logic [31:0] dpi32_t;
dpi1_t __reset;
dpi8_t __req_valid;
dpi8_t __req_opcode;
dpi8_t __req_addr;
dpi32_t __req_value;
dpi8_t __req_deq;
dpi8_t __resp_valid;
dpi32_t __resp_bits;
// reset
always_ff @(posedge clock) begin
__reset <= reset;
end
// delaying outputs by one-cycle
// since verilator does not support delays
always_ff @(posedge clock) begin
dpi_req_valid <= dpi1_t ' (__req_valid);
dpi_req_opcode <= dpi1_t ' (__req_opcode);
dpi_req_addr <= __req_addr;
dpi_req_value <= __req_value;
end
assign __req_deq = dpi8_t ' (dpi_req_deq);
assign __resp_valid = dpi8_t ' (dpi_resp_valid);
assign __resp_bits = dpi_resp_bits;
// evaluate DPI function
always_ff @(posedge clock) begin
if (reset | __reset) begin
__req_valid = 0;
__req_opcode = 0;
__req_addr = 0;
__req_value = 0;
end
else begin
VTAHostDPI(
__req_valid,
__req_opcode,
__req_addr,
__req_value,
__req_deq,
__resp_valid,
__resp_bits);
end
end
endmodule
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module VTAMemDPI #
( parameter LEN_BITS = 8,
parameter ADDR_BITS = 64,
parameter DATA_BITS = 64
)
(
input clock,
input reset,
input dpi_req_valid,
input dpi_req_opcode,
input [LEN_BITS-1:0] dpi_req_len,
input [ADDR_BITS-1:0] dpi_req_addr,
input dpi_wr_valid,
input [DATA_BITS-1:0] dpi_wr_bits,
output logic dpi_rd_valid,
output logic [DATA_BITS-1:0] dpi_rd_bits,
input dpi_rd_ready
);
import "DPI-C" function void VTAMemDPI
(
input byte unsigned req_valid,
input byte unsigned req_opcode,
input byte unsigned req_len,
input longint unsigned req_addr,
input byte unsigned wr_valid,
input longint unsigned wr_value,
output byte unsigned rd_valid,
output longint unsigned rd_value,
input byte unsigned rd_ready
);
typedef logic dpi1_t;
typedef logic [7:0] dpi8_t;
typedef logic [31:0] dpi32_t;
typedef logic [63:0] dpi64_t;
dpi1_t __reset;
dpi8_t __req_valid;
dpi8_t __req_opcode;
dpi8_t __req_len;
dpi64_t __req_addr;
dpi8_t __wr_valid;
dpi64_t __wr_value;
dpi8_t __rd_valid;
dpi64_t __rd_value;
dpi8_t __rd_ready;
always_ff @(posedge clock) begin
__reset <= reset;
end
// delaying outputs by one-cycle
// since verilator does not support delays
always_ff @(posedge clock) begin
dpi_rd_valid <= dpi1_t ' (__rd_valid);
dpi_rd_bits <= __rd_value;
end
assign __req_valid = dpi8_t ' (dpi_req_valid);
assign __req_opcode = dpi8_t ' (dpi_req_opcode);
assign __req_len = dpi_req_len;
assign __req_addr = dpi_req_addr;
assign __wr_valid = dpi8_t ' (dpi_wr_valid);
assign __wr_value = dpi_wr_bits;
assign __rd_ready = dpi8_t ' (dpi_rd_ready);
// evaluate DPI function
always_ff @(posedge clock) begin
if (reset | __reset) begin
__rd_valid = 0;
__rd_value = 0;
end
else begin
VTAMemDPI(
__req_valid,
__req_opcode,
__req_len,
__req_addr,
__wr_valid,
__wr_value,
__rd_valid,
__rd_value,
__rd_ready);
end
end
endmodule
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module VTASimDPI
(
input clock,
input reset,
output logic dpi_wait
);
import "DPI-C" function void VTASimDPI
(
output byte unsigned sim_wait,
output byte unsigned sim_exit
);
typedef logic dpi1_t;
typedef logic [7:0] dpi8_t;
dpi1_t __reset;
dpi8_t __wait;
dpi8_t __exit;
// reset
always_ff @(posedge clock) begin
__reset <= reset;
end
// evaluate DPI function
always_ff @(posedge clock) begin
if (reset | __reset) begin
__wait = 0;
__exit = 0;
end
else begin
VTASimDPI(
__wait,
__exit);
end
end
logic wait_reg;
always_ff @(posedge clock) begin
if (reset | __reset) begin
wait_reg <= 1'b0;
end else if (__wait == 1) begin
wait_reg <= 1'b1;
end else begin
wait_reg <= 1'b0;
end
end
assign dpi_wait = wait_reg;
always_ff @(posedge clock) begin
if (__exit == 1) begin
$finish;
end
end
endmodule
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package vta.core
import chisel3._
import chisel3.util._
import vta.util.config._
import vta.shell._
/** Compute.
*
* The compute unit is in charge of the following:
* - Loading micro-ops from memory (loadUop module)
* - Loading biases (acc) from memory (tensorAcc module)
* - Compute ALU instructions (tensorAlu module)
* - Compute GEMM instructions (tensorGemm module)
*/
class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module {
val mp = p(ShellKey).memParams
val io = IO(new Bundle {
val i_post = Vec(2, Input(Bool()))
val o_post = Vec(2, Output(Bool()))
val inst = Flipped(Decoupled(UInt(INST_BITS.W)))
val uop_baddr = Input(UInt(mp.addrBits.W))
val acc_baddr = Input(UInt(mp.addrBits.W))
val vme_rd = Vec(2, new VMEReadMaster)
val inp = new TensorMaster(tensorType = "inp")
val wgt = new TensorMaster(tensorType = "wgt")
val out = new TensorMaster(tensorType = "out")
val finish = Output(Bool())
val acc_wr_event = Output(Bool())
})
val sIdle :: sSync :: sExe :: Nil = Enum(3)
val state = RegInit(sIdle)
val s = Seq.tabulate(2)(_ =>
Module(new Semaphore(counterBits = 8, counterInitValue = 0)))
val loadUop = Module(new LoadUop)
val tensorAcc = Module(new TensorLoad(tensorType = "acc"))
val tensorGemm = Module(new TensorGemm)
val tensorAlu = Module(new TensorAlu)
val inst_q = Module(new Queue(UInt(INST_BITS.W), p(CoreKey).instQueueEntries))
// decode
val dec = Module(new ComputeDecode)
dec.io.inst := inst_q.io.deq.bits
val inst_type =
Cat(dec.io.isFinish,
dec.io.isAlu,
dec.io.isGemm,
dec.io.isLoadAcc,
dec.io.isLoadUop).asUInt
val sprev = inst_q.io.deq.valid & Mux(dec.io.pop_prev, s(0).io.sready, true.B)
val snext = inst_q.io.deq.valid & Mux(dec.io.pop_next, s(1).io.sready, true.B)
val start = snext & sprev
val done =
MuxLookup(
inst_type,
false.B, // default
Array(
"h_01".U -> loadUop.io.done,
"h_02".U -> tensorAcc.io.done,
"h_04".U -> tensorGemm.io.done,
"h_08".U -> tensorAlu.io.done,
"h_10".U -> true.B // Finish
)
)
// control
switch(state) {
is(sIdle) {
when(start) {
when(dec.io.isSync) {
state := sSync
}.elsewhen(inst_type.orR) {
state := sExe
}
}
}
is(sSync) {
state := sIdle
}
is(sExe) {
when(done) {
state := sIdle
}
}
}
// instructions
inst_q.io.enq <> io.inst
inst_q.io.deq.ready := (state === sExe & done) | (state === sSync)
// uop
loadUop.io.start := state === sIdle & start & dec.io.isLoadUop
loadUop.io.inst := inst_q.io.deq.bits
loadUop.io.baddr := io.uop_baddr
io.vme_rd(0) <> loadUop.io.vme_rd
loadUop.io.uop.idx <> Mux(dec.io.isGemm, tensorGemm.io.uop.idx, tensorAlu.io.uop.idx)
// acc
tensorAcc.io.start := state === sIdle & start & dec.io.isLoadAcc
tensorAcc.io.inst := inst_q.io.deq.bits
tensorAcc.io.baddr := io.acc_baddr
tensorAcc.io.tensor.rd.idx <> Mux(dec.io.isGemm, tensorGemm.io.acc.rd.idx, tensorAlu.io.acc.rd.idx)
tensorAcc.io.tensor.wr <> Mux(dec.io.isGemm, tensorGemm.io.acc.wr, tensorAlu.io.acc.wr)
io.vme_rd(1) <> tensorAcc.io.vme_rd
io.acc_wr_event := tensorAcc.io.tensor.wr.valid
// gemm
tensorGemm.io.start := state === sIdle & start & dec.io.isGemm
tensorGemm.io.inst := inst_q.io.deq.bits
tensorGemm.io.uop.data.valid := loadUop.io.uop.data.valid & dec.io.isGemm
tensorGemm.io.uop.data.bits <> loadUop.io.uop.data.bits
tensorGemm.io.inp <> io.inp
tensorGemm.io.wgt <> io.wgt
tensorGemm.io.acc.rd.data.valid := tensorAcc.io.tensor.rd.data.valid & dec.io.isGemm
tensorGemm.io.acc.rd.data.bits <> tensorAcc.io.tensor.rd.data.bits
tensorGemm.io.out.rd.data.valid := io.out.rd.data.valid & dec.io.isGemm
tensorGemm.io.out.rd.data.bits <> io.out.rd.data.bits
// alu
tensorAlu.io.start := state === sIdle & start & dec.io.isAlu
tensorAlu.io.inst := inst_q.io.deq.bits
tensorAlu.io.uop.data.valid := loadUop.io.uop.data.valid & dec.io.isAlu
tensorAlu.io.uop.data.bits <> loadUop.io.uop.data.bits
tensorAlu.io.acc.rd.data.valid := tensorAcc.io.tensor.rd.data.valid & dec.io.isAlu
tensorAlu.io.acc.rd.data.bits <> tensorAcc.io.tensor.rd.data.bits
tensorAlu.io.out.rd.data.valid := io.out.rd.data.valid & dec.io.isAlu
tensorAlu.io.out.rd.data.bits <> io.out.rd.data.bits
// out
io.out.rd.idx <> Mux(dec.io.isGemm,
tensorGemm.io.out.rd.idx,
tensorAlu.io.out.rd.idx)
io.out.wr <> Mux(dec.io.isGemm, tensorGemm.io.out.wr, tensorAlu.io.out.wr)
// semaphore
s(0).io.spost := io.i_post(0)
s(1).io.spost := io.i_post(1)
s(0).io.swait := dec.io.pop_prev & (state === sIdle & start)
s(1).io.swait := dec.io.pop_next & (state === sIdle & start)
io.o_post(0) := dec.io.push_prev & ((state === sExe & done) | (state === sSync))
io.o_post(1) := dec.io.push_next & ((state === sExe & done) | (state === sSync))
// finish
io.finish := state === sExe & done & dec.io.isFinish
// debug
if (debug) {
// start
when(state === sIdle && start) {
when(dec.io.isSync) {
printf("[Compute] start sync\n")
}.elsewhen(dec.io.isLoadUop) {
printf("[Compute] start load uop\n")
}.elsewhen(dec.io.isLoadAcc) {
printf("[Compute] start load acc\n")
}.elsewhen(dec.io.isGemm) {
printf("[Compute] start gemm\n")
}.elsewhen(dec.io.isAlu) {
printf("[Compute] start alu\n")
}.elsewhen(dec.io.isFinish) {
printf("[Compute] start finish\n")
}
}
// done
when(state === sSync) {
printf("[Compute] done sync\n")
}
when(state === sExe) {
when(done) {
when(dec.io.isLoadUop) {
printf("[Compute] done load uop\n")
}.elsewhen(dec.io.isLoadAcc) {
printf("[Compute] done load acc\n")
}.elsewhen(dec.io.isGemm) {
printf("[Compute] done gemm\n")
}.elsewhen(dec.io.isAlu) {
printf("[Compute] done alu\n")
}.elsewhen(dec.io.isFinish) {
printf("[Compute] done finish\n")
}
}
}
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package vta.core
import vta.util.config._
/** CoreConfig.
*
* This is one supported configuration for VTA. This file will
* be eventually filled out with class configurations that can be
* mixed/matched with Shell configurations for different backends.
*/
class CoreConfig extends Config((site, here, up) => {
case CoreKey =>
CoreParams(
batch = 1,
blockOut = 16,
blockIn = 16,
inpBits = 8,
wgtBits = 8,
uopBits = 32,
accBits = 32,
outBits = 8,
uopMemDepth = 2048,
inpMemDepth = 2048,
wgtMemDepth = 1024,
accMemDepth = 2048,
outMemDepth = 2048,
instQueueEntries = 512
)
})
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment