Commit 124f9b7f by Luis Vega Committed by Thierry Moreau

[VTA][TSIM] update app example (#3343)

* add initial support to cycle counter to accelerator

* remove prints from c

* add event counter support to chisel tsim example

* make it more readable

* use a config class

* update driver

* add individual Makefile to chisel

* add rule for installing vta package

* add makefile for verilog backend

* update drivers

* update

* rename

* update README

* put default sim back

* set counter to zero
parent 2c41fd2f
...@@ -34,6 +34,10 @@ if (CMAKE_CXX_COMPILER_ID MATCHES "GNU" AND ...@@ -34,6 +34,10 @@ if (CMAKE_CXX_COMPILER_ID MATCHES "GNU" AND
set(CMAKE_CXX_FLAGS "-faligned-new ${CMAKE_CXX_FLAGS}") set(CMAKE_CXX_FLAGS "-faligned-new ${CMAKE_CXX_FLAGS}")
endif() endif()
# Module rules file(GLOB TSIM_SW_SRC src/driver.cc)
include(cmake/modules/hw.cmake) add_library(sw SHARED ${TSIM_SW_SRC})
include(cmake/modules/sw.cmake) target_include_directories(sw PRIVATE ${VTA_DIR}/include)
if(APPLE)
set_target_properties(sw PROPERTIES LINK_FLAGS "-undefined dynamic_lookup")
endif(APPLE)
...@@ -17,20 +17,32 @@ ...@@ -17,20 +17,32 @@
export PYTHONPATH:=$(PWD)/python:$(PYTHONPATH) export PYTHONPATH:=$(PWD)/python:$(PYTHONPATH)
BUILD_DIR = $(shell python3 config/config.py --get-build-name) BUILD_NAME = build
build_dir = $(abspath .)/$(BUILD_NAME)
default: cmake run default: verilog driver run_verilog
run_chisel: chisel driver
python3 tests/python/chisel_accel.py
.PHONY: cmake .PHONY: cmake
cmake: | $(BUILD_DIR) driver: | $(build_dir)
cd $(BUILD_DIR) && cmake .. && make cd $(build_dir) && cmake .. && make
$(BUILD_DIR): $(build_dir):
mkdir -p $@ mkdir -p $@
run: verilog:
python3 tests/python/add_by_one.py | grep PASS make -C hardware/verilog
chisel:
make -C hardware/chisel
run_verilog:
python3 tests/python/verilog_accel.py
clean: clean:
-rm -rf $(BUILD_DIR) -rm -rf $(build_dir)
make -C hardware/chisel clean
make -C hardware/verilog clean
...@@ -49,29 +49,25 @@ sudo apt install verilator sbt ...@@ -49,29 +49,25 @@ sudo apt install verilator sbt
## Setup in TVM ## Setup in TVM
1. Install `verilator` and `sbt` as described above 1. Install `verilator` and `sbt` as described above
2. Change `TARGET` to `tsim` in `<tvm-root>/tvm/vta/config/vta_config.json` 2. Build [tvm](https://docs.tvm.ai/install/from_source.html#build-the-shared-library)
3. Build [tvm](https://docs.tvm.ai/install/from_source.html#build-the-shared-library)
## How to run VTA TSIM examples ## How to run VTA TSIM examples
There are two sample VTA accelerators (add-by-one) designed in Chisel3 and Verilog to show how *TSIM* works. There are two sample VTA accelerators, add-a-constant, designed in Chisel3 and Verilog to show how *TSIM* works.
The default `TARGET` language for these two implementations is Verilog. The following instructions show The default `TARGET` language for these two implementations is Verilog. The following instructions show
how to run both of them: how to run both of them:
* Verilog add-by-one * Test Verilog backend
* Go to `<tvm-root>/vta/apps/tsim_example` * Go to `<tvm-root>/vta/apps/tsim_example`
* Run `make` to build and run add-by-one test * Run `make`
* Chisel3 add-by-one * Test Chisel3 backend
* Open `<tvm-root>/vta/apps/tsim_example/python/tsim/config.json` * Open `<tvm-root>/vta/apps/tsim_example`
* Change `TARGET` from `verilog` to `chisel` * Run `make run_chisel`
* Go to `tvm/vta/apps/tsim_example`
* Run `make` to build and run add-by-one test
* Some pointers * Some pointers
* Add-by-one test `<tvm-root>/vta/apps/tsim_example/tests/python/add_by_one.py` * Verilog and Chisel3 tests in `<tvm-root>/vta/apps/tsim_example/tests/python`
* Add-by-one accelerator in Verilog `<tvm-root>/vta/apps/tsim_example/hardware/verilog` * Verilog accelerator backend `<tvm-root>/vta/apps/tsim_example/hardware/verilog`
* Add-by-one accelerator in Chisel3 `<tvm-root>/vta/apps/tsim_example/hardware/chisel` * Chisel3 accelerator backend `<tvm-root>/vta/apps/tsim_example/hardware/chisel`
* Software driver that handles the accelerator `<tvm-root>/vta/apps/tsim_example/src/driver.cc` * Software C++ driver (backend) that handles the accelerator `<tvm-root>/vta/apps/tsim_example/src/driver.cc`
* Build cmake script for software library`<tvm-root>/vta/apps/tsim_example/cmake/modules/sw.cmake` * Software Python driver (frontend) that handles the accelerator `<tvm-root>/vta/apps/tsim_example/python/accel`
* Build cmake script for hardware library`<tvm-root>/vta/apps/tsim_example/cmake/modules/hw.cmake`
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
if(MSVC)
message(STATUS "[TSIM_HW] build is skipped in Windows..")
else()
find_program(PYTHON NAMES python python3 python3.6)
find_program(VERILATOR NAMES verilator)
if (VERILATOR AND PYTHON)
if (TSIM_TOP_NAME STREQUAL "")
message(FATAL_ERROR "[TSIM_HW] TSIM_TOP_NAME should be defined")
endif()
if (TSIM_BUILD_NAME STREQUAL "")
message(FATAL_ERROR "[TSIM_HW] TSIM_BUILD_NAME should be defined")
endif()
set(TSIM_CONFIG ${PYTHON} ${CMAKE_CURRENT_SOURCE_DIR}/config/config.py)
execute_process(COMMAND ${TSIM_CONFIG} --get-target OUTPUT_VARIABLE TSIM_TARGET OUTPUT_STRIP_TRAILING_WHITESPACE)
execute_process(COMMAND ${TSIM_CONFIG} --get-top-name OUTPUT_VARIABLE TSIM_TOP_NAME OUTPUT_STRIP_TRAILING_WHITESPACE)
execute_process(COMMAND ${TSIM_CONFIG} --get-build-name OUTPUT_VARIABLE TSIM_BUILD_NAME OUTPUT_STRIP_TRAILING_WHITESPACE)
execute_process(COMMAND ${TSIM_CONFIG} --get-use-trace OUTPUT_VARIABLE TSIM_USE_TRACE OUTPUT_STRIP_TRAILING_WHITESPACE)
execute_process(COMMAND ${TSIM_CONFIG} --get-trace-name OUTPUT_VARIABLE TSIM_TRACE_NAME OUTPUT_STRIP_TRAILING_WHITESPACE)
set(TSIM_BUILD_DIR ${CMAKE_CURRENT_SOURCE_DIR}/${TSIM_BUILD_NAME})
if (TSIM_TARGET STREQUAL "chisel")
find_program(SBT NAMES sbt)
if (SBT)
# Install Chisel VTA package for DPI modules
set(VTA_CHISEL_DIR ${VTA_DIR}/hardware/chisel)
execute_process(WORKING_DIRECTORY ${VTA_CHISEL_DIR}
COMMAND ${SBT} publishLocal RESULT_VARIABLE RETCODE)
if (NOT RETCODE STREQUAL "0")
message(FATAL_ERROR "[TSIM_HW] sbt failed to install VTA scala package")
endif()
# Chisel - Scala to Verilog compilation
set(TSIM_CHISEL_DIR ${CMAKE_CURRENT_SOURCE_DIR}/hardware/chisel)
set(CHISEL_BUILD_DIR ${TSIM_BUILD_DIR}/chisel)
set(CHISEL_OPT "test:runMain test.Elaborate --target-dir ${CHISEL_BUILD_DIR} --top-name ${TSIM_TOP_NAME}")
execute_process(WORKING_DIRECTORY ${TSIM_CHISEL_DIR} COMMAND ${SBT} ${CHISEL_OPT} RESULT_VARIABLE RETCODE)
if (NOT RETCODE STREQUAL "0")
message(FATAL_ERROR "[TSIM_HW] sbt failed to compile from Chisel to Verilog.")
endif()
file(GLOB VERILATOR_RTL_SRC ${CHISEL_BUILD_DIR}/*.v)
else()
message(FATAL_ERROR "[TSIM_HW] sbt should be installed for Chisel")
endif() # sbt
elseif (TSIM_TARGET STREQUAL "verilog")
set(VTA_VERILOG_DIR ${VTA_DIR}/hardware/chisel/src/main/resources/verilog)
set(TSIM_VERILOG_DIR ${CMAKE_CURRENT_SOURCE_DIR}/hardware/verilog)
file(GLOB VERILATOR_RTL_SRC ${VTA_VERILOG_DIR}/*.v ${TSIM_VERILOG_DIR}/*.v)
else()
message(FATAL_ERROR "[TSIM_HW] target language can be only verilog or chisel...")
endif() # TSIM_TARGET
if (TSIM_TARGET STREQUAL "chisel" OR TSIM_TARGET STREQUAL "verilog")
# Check if tracing can be enabled
if (NOT TSIM_USE_TRACE STREQUAL "off")
message(STATUS "[TSIM_HW] Verilog enable tracing")
else()
message(STATUS "[TSIM_HW] Verilator disable tracing")
endif()
# Verilator - Verilog to C++ compilation
set(VERILATOR_BUILD_DIR ${TSIM_BUILD_DIR}/verilator)
set(VERILATOR_OPT +define+RANDOMIZE_GARBAGE_ASSIGN +define+RANDOMIZE_REG_INIT)
list(APPEND VERILATOR_OPT +define+RANDOMIZE_MEM_INIT --x-assign unique)
list(APPEND VERILATOR_OPT --output-split 20000 --output-split-cfuncs 20000)
list(APPEND VERILATOR_OPT --top-module ${TSIM_TOP_NAME} -Mdir ${VERILATOR_BUILD_DIR})
list(APPEND VERILATOR_OPT --cc ${VERILATOR_RTL_SRC})
if (NOT TSIM_USE_TRACE STREQUAL "off")
list(APPEND VERILATOR_OPT --trace)
endif()
execute_process(COMMAND ${VERILATOR} ${VERILATOR_OPT} RESULT_VARIABLE RETCODE)
if (NOT RETCODE STREQUAL "0")
message(FATAL_ERROR "[TSIM_HW] Verilator failed to compile Verilog to C++...")
endif()
# Build shared library (.so)
set(VTA_HW_DPI_DIR ${VTA_DIR}/hardware/dpi)
if (EXISTS /usr/local/share/verilator/include)
set(VERILATOR_INC_DIR /usr/local/share/verilator/include)
elseif (EXISTS /usr/share/verilator/include)
set(VERILATOR_INC_DIR /usr/share/verilator/include)
else()
message(FATAL_ERROR "[TSIM_HW] Verilator include directory not found")
endif()
set(VERILATOR_LIB_SRC ${VERILATOR_INC_DIR}/verilated.cpp ${VERILATOR_INC_DIR}/verilated_dpi.cpp)
if (NOT TSIM_USE_TRACE STREQUAL "off")
list(APPEND VERILATOR_LIB_SRC ${VERILATOR_INC_DIR}/verilated_vcd_c.cpp)
endif()
file(GLOB VERILATOR_GEN_SRC ${VERILATOR_BUILD_DIR}/*.cpp)
file(GLOB VERILATOR_SRC ${VTA_HW_DPI_DIR}/tsim_device.cc)
add_library(hw SHARED ${VERILATOR_LIB_SRC} ${VERILATOR_GEN_SRC} ${VERILATOR_SRC})
set(VERILATOR_DEF VL_USER_FINISH VL_TSIM_NAME=V${TSIM_TOP_NAME} VL_PRINTF=printf VM_COVERAGE=0 VM_SC=0)
if (NOT TSIM_USE_TRACE STREQUAL "off")
list(APPEND VERILATOR_DEF VM_TRACE=1 TSIM_TRACE_FILE=${TSIM_BUILD_DIR}/${TSIM_TRACE_NAME}.vcd)
else()
list(APPEND VERILATOR_DEF VM_TRACE=0)
endif()
target_compile_definitions(hw PRIVATE ${VERILATOR_DEF})
target_compile_options(hw PRIVATE -Wno-sign-compare -include V${TSIM_TOP_NAME}.h)
target_include_directories(hw PRIVATE ${VERILATOR_BUILD_DIR} ${VERILATOR_INC_DIR} ${VERILATOR_INC_DIR}/vltstd ${VTA_DIR}/include)
if(APPLE)
set_target_properties(hw PROPERTIES LINK_FLAGS "-undefined dynamic_lookup")
endif(APPLE)
endif() # TSIM_TARGET STREQUAL "chisel" OR TSIM_TARGET STREQUAL "verilog"
else()
message(STATUS "[TSIM_HW] could not find Python or Verilator, build is skipped...")
endif() # VERILATOR
endif() # MSVC
{
"TARGET" : "verilog",
"TOP_NAME" : "TestAccel",
"BUILD_NAME" : "build",
"USE_TRACE" : "off",
"TRACE_NAME" : "trace"
}
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os.path as osp
import sys
import json
import argparse
cur = osp.abspath(osp.dirname(__file__))
cfg = json.load(open(osp.join(cur, 'config.json')))
def main():
"""Main function"""
parser = argparse.ArgumentParser()
parser.add_argument("--get-target", action="store_true",
help="Get target language, i.e. verilog or chisel")
parser.add_argument("--get-top-name", action="store_true",
help="Get hardware design top name")
parser.add_argument("--get-build-name", action="store_true",
help="Get build folder name")
parser.add_argument("--get-use-trace", action="store_true",
help="Get use trace")
parser.add_argument("--get-trace-name", action="store_true",
help="Get trace filename")
args = parser.parse_args()
if len(sys.argv) == 1:
parser.print_help()
return
if args.get_target:
print(cfg['TARGET'])
if args.get_top_name:
print(cfg['TOP_NAME'])
if args.get_build_name:
print(cfg['BUILD_NAME'])
if args.get_use_trace:
print(cfg['USE_TRACE'])
if args.get_trace_name:
print(cfg['TRACE_NAME'])
if __name__ == "__main__":
main()
...@@ -15,5 +15,92 @@ ...@@ -15,5 +15,92 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
ifeq (, $(shell which verilator))
$(error "No Verilator in $(PATH), consider doing apt-get install verilator")
endif
# Change VERILATOR_INC_DIR if Verilator is installed on a different location
ifeq (, $(VERILATOR_INC_DIR))
ifeq (, $(wildcard /usr/local/share/verilator/include/*))
ifeq (, $(wildcard /usr/share/verilator/include/*))
$(error "Verilator include directory is not set properly")
else
VERILATOR_INC_DIR := /usr/share/verilator/include
endif
else
VERILATOR_INC_DIR := /usr/local/share/verilator/include
endif
endif
TOP = TestAccel
BUILD_NAME = build
USE_TRACE = 0
LIBNAME = libhw
vta_dir = $(abspath ../../../../)
tvm_dir = $(abspath ../../../../../)
build_dir = $(abspath .)/$(BUILD_NAME)
verilator_build_dir = $(build_dir)/verilator
chisel_build_dir = $(build_dir)/chisel
verilator_opt = --cc
verilator_opt += +define+RANDOMIZE_GARBAGE_ASSIGN
verilator_opt += +define+RANDOMIZE_REG_INIT
verilator_opt += +define+RANDOMIZE_MEM_INIT
verilator_opt += --x-assign unique
verilator_opt += --output-split 20000
verilator_opt += --output-split-cfuncs 20000
verilator_opt += --top-module ${TOP}
verilator_opt += -Mdir ${verilator_build_dir}
verilator_opt += -I$(chisel_build_dir)
cxx_flags = -O2 -Wall -fPIC -shared
cxx_flags += -fvisibility=hidden -std=c++11
cxx_flags += -DVL_TSIM_NAME=V$(TOP)
cxx_flags += -DVL_PRINTF=printf
cxx_flags += -DVL_USER_FINISH
cxx_flags += -DVM_COVERAGE=0
cxx_flags += -DVM_SC=0
cxx_flags += -Wno-sign-compare
cxx_flags += -include V$(TOP).h
cxx_flags += -I$(verilator_build_dir)
cxx_flags += -I$(VERILATOR_INC_DIR)
cxx_flags += -I$(VERILATOR_INC_DIR)/vltstd
cxx_flags += -I$(vta_dir)/include
cxx_flags += -I$(tvm_dir)/include
cxx_flags += -I$(tvm_dir)/3rdparty/dlpack/include
cxx_files = $(VERILATOR_INC_DIR)/verilated.cpp
cxx_files += $(VERILATOR_INC_DIR)/verilated_dpi.cpp
cxx_files += $(wildcard $(verilator_build_dir)/*.cpp)
cxx_files += $(vta_dir)/hardware/dpi/tsim_device.cc
ifneq ($(USE_TRACE), 0)
verilator_opt += --trace
cxx_flags += -DVM_TRACE=1
cxx_flags += -DTSIM_TRACE_FILE=$(verilator_build_dir)/$(TOP).vcd
cxx_files += $(VERILATOR_INC_DIR)/verilated_vcd_c.cpp
else
cxx_flags += -DVM_TRACE=0
endif
default: lib
lib: $(build_dir)/$(LIBNAME).so
$(build_dir)/$(LIBNAME).so: $(verilator_build_dir)/V$(TOP).cpp
echo $(cxx_files)
g++ $(cxx_flags) $(cxx_files) -o $@
verilator: $(verilator_build_dir)/V$(TOP).cpp
$(verilator_build_dir)/V$(TOP).cpp: $(chisel_build_dir)/$(TOP).v
verilator $(verilator_opt) $<
verilog: $(chisel_build_dir)/$(TOP).v
$(chisel_build_dir)/$(TOP).v: install_vta_package
sbt 'test:runMain test.Elaborate --target-dir $(chisel_build_dir) --top-name $(TOP)'
install_vta_package:
cd $(vta_dir)/hardware/chisel && sbt publishLocal
clean: clean:
-rm -rf target project/target project/project -rm -rf $(build_dir) target project/target project/project
...@@ -35,18 +35,28 @@ import vta.dpi._ ...@@ -35,18 +35,28 @@ import vta.dpi._
* |_________| |_________| * |_________| |_________|
* *
*/ */
case class AccelConfig() {
val nCtrl = 1
val nECnt = 1
val nVals = 2
val nPtrs = 2
val regBits = 32
val ptrBits = 2*regBits
}
class Accel extends Module { class Accel extends Module {
val io = IO(new Bundle { val io = IO(new Bundle {
val host = new VTAHostDPIClient val host = new VTAHostDPIClient
val mem = new VTAMemDPIMaster val mem = new VTAMemDPIMaster
}) })
implicit val config = AccelConfig()
val rf = Module(new RegFile) val rf = Module(new RegFile)
val ce = Module(new Compute) val ce = Module(new Compute)
rf.io.host <> io.host rf.io.host <> io.host
io.mem <> ce.io.mem io.mem <> ce.io.mem
ce.io.launch := rf.io.launch ce.io.launch := rf.io.launch
rf.io.finish := ce.io.finish rf.io.finish := ce.io.finish
ce.io.length := rf.io.length rf.io.ecnt <> ce.io.ecnt
ce.io.inp_baddr := rf.io.inp_baddr ce.io.vals <> rf.io.vals
ce.io.out_baddr := rf.io.out_baddr ce.io.ptrs <> rf.io.ptrs
} }
...@@ -35,21 +35,24 @@ import vta.dpi._ ...@@ -35,21 +35,24 @@ import vta.dpi._
* 6. Check if counter (cnt) is equal to length to assert finish, * 6. Check if counter (cnt) is equal to length to assert finish,
* otherwise go to step 2. * otherwise go to step 2.
*/ */
class Compute extends Module { class Compute(implicit config: AccelConfig) extends Module {
val io = IO(new Bundle { val io = IO(new Bundle {
val launch = Input(Bool()) val launch = Input(Bool())
val finish = Output(Bool()) val finish = Output(Bool())
val length = Input(UInt(32.W)) val ecnt = Vec(config.nECnt, ValidIO(UInt(config.regBits.W)))
val inp_baddr = Input(UInt(64.W)) val vals = Input(Vec(config.nVals, UInt(config.regBits.W)))
val out_baddr = Input(UInt(64.W)) val ptrs = Input(Vec(config.nPtrs, UInt(config.ptrBits.W)))
val mem = new VTAMemDPIMaster val mem = new VTAMemDPIMaster
}) })
val sIdle :: sReadReq :: sReadData :: sWriteReq :: sWriteData :: Nil = Enum(5) val sIdle :: sReadReq :: sReadData :: sWriteReq :: sWriteData :: Nil = Enum(5)
val state = RegInit(sIdle) val state = RegInit(sIdle)
val const = io.vals(0)
val length = io.vals(1)
val cycles = RegInit(0.U(config.regBits.W))
val reg = Reg(chiselTypeOf(io.mem.rd.bits)) val reg = Reg(chiselTypeOf(io.mem.rd.bits))
val cnt = Reg(chiselTypeOf(io.length)) val cnt = Reg(UInt(config.regBits.W))
val raddr = Reg(chiselTypeOf(io.inp_baddr)) val raddr = Reg(UInt(config.ptrBits.W))
val waddr = Reg(chiselTypeOf(io.out_baddr)) val waddr = Reg(UInt(config.ptrBits.W))
switch (state) { switch (state) {
is (sIdle) { is (sIdle) {
...@@ -69,7 +72,7 @@ class Compute extends Module { ...@@ -69,7 +72,7 @@ class Compute extends Module {
state := sWriteData state := sWriteData
} }
is (sWriteData) { is (sWriteData) {
when (cnt === (io.length - 1.U)) { when (cnt === (length - 1.U)) {
state := sIdle state := sIdle
} .otherwise { } .otherwise {
state := sReadReq state := sReadReq
...@@ -77,10 +80,22 @@ class Compute extends Module { ...@@ -77,10 +80,22 @@ class Compute extends Module {
} }
} }
val last = state === sWriteData && cnt === (length - 1.U)
// cycle counter
when (state === sIdle) {
cycles := 0.U
} .otherwise {
cycles := cycles + 1.U
}
io.ecnt(0).valid := last
io.ecnt(0).bits := cycles
// calculate next address // calculate next address
when (state === sIdle) { when (state === sIdle) {
raddr := io.inp_baddr raddr := io.ptrs(0)
waddr := io.out_baddr waddr := io.ptrs(1)
} .elsewhen (state === sWriteData) { // increment by 8-bytes } .elsewhen (state === sWriteData) { // increment by 8-bytes
raddr := raddr + 8.U raddr := raddr + 8.U
waddr := waddr + 8.U waddr := waddr + 8.U
...@@ -94,7 +109,7 @@ class Compute extends Module { ...@@ -94,7 +109,7 @@ class Compute extends Module {
// read // read
when (state === sReadData && io.mem.rd.valid) { when (state === sReadData && io.mem.rd.valid) {
reg := io.mem.rd.bits + 1.U reg := io.mem.rd.bits + const
} }
io.mem.rd.ready := state === sReadData io.mem.rd.ready := state === sReadData
...@@ -110,5 +125,5 @@ class Compute extends Module { ...@@ -110,5 +125,5 @@ class Compute extends Module {
} }
// done when read/write are equal to length // done when read/write are equal to length
io.finish := state === sWriteData && cnt === (io.length - 1.U) io.finish := last
} }
...@@ -31,11 +31,13 @@ import vta.dpi._ ...@@ -31,11 +31,13 @@ import vta.dpi._
* Register description | addr * Register description | addr
* -------------------------|----- * -------------------------|-----
* Control status register | 0x00 * Control status register | 0x00
* Length value register | 0x04 * Cycle counter | 0x04
* Input pointer lsb | 0x08 * Constant value | 0x08
* Input pointer msb | 0x0c * Vector length | 0x0c
* Output pointer lsb | 0x10 * Input pointer lsb | 0x10
* Output pointer msb | 0x14 * Input pointer msb | 0x14
* Output pointer lsb | 0x18
* Output pointer msb | 0x1c
* ------------------------------- * -------------------------------
* ------------------------------ * ------------------------------
...@@ -45,13 +47,13 @@ import vta.dpi._ ...@@ -45,13 +47,13 @@ import vta.dpi._
* Finish | 1 * Finish | 1
* ------------------------------ * ------------------------------
*/ */
class RegFile extends Module { class RegFile(implicit config: AccelConfig) extends Module {
val io = IO(new Bundle { val io = IO(new Bundle {
val launch = Output(Bool()) val launch = Output(Bool())
val finish = Input(Bool()) val finish = Input(Bool())
val length = Output(UInt(32.W)) val ecnt = Vec(config.nECnt, Flipped(ValidIO(UInt(config.regBits.W))))
val inp_baddr = Output(UInt(64.W)) val vals = Output(Vec(config.nVals, UInt(config.regBits.W)))
val out_baddr = Output(UInt(64.W)) val ptrs = Output(Vec(config.nPtrs, UInt(config.regBits.W)))
val host = new VTAHostDPIClient val host = new VTAHostDPIClient
}) })
val sIdle :: sRead :: Nil = Enum(2) val sIdle :: sRead :: Nil = Enum(2)
...@@ -70,23 +72,34 @@ class RegFile extends Module { ...@@ -70,23 +72,34 @@ class RegFile extends Module {
io.host.req.deq := state === sIdle & io.host.req.valid io.host.req.deq := state === sIdle & io.host.req.valid
val reg = Seq.fill(6)(RegInit(0.U.asTypeOf(chiselTypeOf(io.host.req.value)))) val nTotal = config.nCtrl + config.nECnt + config.nVals + (2*config.nPtrs)
val addr = Seq.tabulate(6)(_ * 4) val reg = Seq.fill(nTotal)(RegInit(0.U.asTypeOf(chiselTypeOf(io.host.req.value))))
val addr = Seq.tabulate(nTotal)(_ * 4)
val reg_map = (addr zip reg) map { case (a, r) => a.U -> r } val reg_map = (addr zip reg) map { case (a, r) => a.U -> r }
val eo = config.nCtrl
val vo = eo + config.nECnt
val po = vo + config.nVals
(reg zip addr).foreach { case(r, a) => when (io.finish) {
if (a == 0) { // control status register reg(0) := "b_10".U
when (io.finish) { } .elsewhen (state === sIdle && io.host.req.valid &&
r := "b_10".U io.host.req.opcode && addr(0).U === io.host.req.addr) {
} .elsewhen (state === sIdle && io.host.req.valid && reg(0) := io.host.req.value
io.host.req.opcode && a.U === io.host.req.addr) { }
r := io.host.req.value
} for (i <- 0 until config.nECnt) {
} else { when (io.ecnt(i).valid) {
when (state === sIdle && io.host.req.valid && reg(eo + i) := io.ecnt(i).bits
io.host.req.opcode && a.U === io.host.req.addr) { } .elsewhen (state === sIdle && io.host.req.valid &&
r := io.host.req.value io.host.req.opcode && addr(eo + i).U === io.host.req.addr) {
} reg(eo + i) := io.host.req.value
}
}
for (i <- 0 until (config.nVals + (2*config.nPtrs))) {
when (state === sIdle && io.host.req.valid &&
io.host.req.opcode && addr(vo + i).U === io.host.req.addr) {
reg(vo + i) := io.host.req.value
} }
} }
...@@ -99,7 +112,12 @@ class RegFile extends Module { ...@@ -99,7 +112,12 @@ class RegFile extends Module {
io.host.resp.bits := rdata io.host.resp.bits := rdata
io.launch := reg(0)(0) io.launch := reg(0)(0)
io.length := reg(1)
io.inp_baddr := Cat(reg(3), reg(2)) for (i <- 0 until config.nVals) {
io.out_baddr := Cat(reg(5), reg(4)) io.vals(i) := reg(vo + i)
}
for (i <- 0 until config.nPtrs) {
io.ptrs(i) := Cat(reg(po + 2*i + 1), reg(po + 2*i))
}
} }
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
ifeq (, $(shell which verilator))
$(error "No Verilator in $(PATH), consider doing apt-get install verilator")
endif
# Change VERILATOR_INC_DIR if Verilator is installed on a different location
ifeq (, $(VERILATOR_INC_DIR))
ifeq (, $(wildcard /usr/local/share/verilator/include/*))
ifeq (, $(wildcard /usr/share/verilator/include/*))
$(error "Verilator include directory is not set properly")
else
VERILATOR_INC_DIR := /usr/share/verilator/include
endif
else
VERILATOR_INC_DIR := /usr/local/share/verilator/include
endif
endif
TOP = TestAccel
BUILD_NAME = build
USE_TRACE = 0
LIBNAME = libhw
vta_dir = $(abspath ../../../../)
tvm_dir = $(abspath ../../../../../)
build_dir = $(abspath .)/$(BUILD_NAME)
verilator_opt = --cc
verilator_opt += +define+RANDOMIZE_GARBAGE_ASSIGN
verilator_opt += +define+RANDOMIZE_REG_INIT
verilator_opt += +define+RANDOMIZE_MEM_INIT
verilator_opt += --x-assign unique
verilator_opt += --output-split 20000
verilator_opt += --output-split-cfuncs 20000
verilator_opt += --top-module ${TOP}
verilator_opt += -Mdir ${build_dir}
cxx_flags = -O2 -Wall -fPIC -shared
cxx_flags += -fvisibility=hidden -std=c++11
cxx_flags += -DVL_TSIM_NAME=V$(TOP)
cxx_flags += -DVL_PRINTF=printf
cxx_flags += -DVL_USER_FINISH
cxx_flags += -DVM_COVERAGE=0
cxx_flags += -DVM_SC=0
cxx_flags += -Wno-sign-compare
cxx_flags += -include V$(TOP).h
cxx_flags += -I$(build_dir)
cxx_flags += -I$(VERILATOR_INC_DIR)
cxx_flags += -I$(VERILATOR_INC_DIR)/vltstd
cxx_flags += -I$(vta_dir)/include
cxx_flags += -I$(tvm_dir)/include
cxx_flags += -I$(tvm_dir)/3rdparty/dlpack/include
cxx_files = $(VERILATOR_INC_DIR)/verilated.cpp
cxx_files += $(VERILATOR_INC_DIR)/verilated_dpi.cpp
cxx_files += $(wildcard $(build_dir)/*.cpp)
cxx_files += $(vta_dir)/hardware/dpi/tsim_device.cc
v_files = $(wildcard $(abspath .)/src/*.v $(vta_dir)/hardware/chisel/src/main/resources/verilog/*.v)
ifneq ($(USE_TRACE), 0)
verilator_opt += --trace
cxx_flags += -DVM_TRACE=1
cxx_flags += -DTSIM_TRACE_FILE=$(build_dir)/$(TOP).vcd
cxx_files += $(VERILATOR_INC_DIR)/verilated_vcd_c.cpp
else
cxx_flags += -DVM_TRACE=0
endif
default: lib
lib: $(build_dir)/$(LIBNAME).so
$(build_dir)/$(LIBNAME).so: $(build_dir)/V$(TOP).cpp
g++ $(cxx_flags) $(cxx_files) -o $@
verilator: $(build_dir)/V$(TOP).cpp
$(build_dir)/V$(TOP).cpp: $(v_files) | $(build_dir)
verilator $(verilator_opt) $(v_files)
$(build_dir):
mkdir -p $@
clean:
-rm -rf $(build_dir)
...@@ -62,6 +62,11 @@ module Accel # ...@@ -62,6 +62,11 @@ module Accel #
logic launch; logic launch;
logic finish; logic finish;
logic event_counter_valid;
logic [HOST_DATA_BITS-1:0] event_counter_value;
logic [HOST_DATA_BITS-1:0] constant;
logic [HOST_DATA_BITS-1:0] length; logic [HOST_DATA_BITS-1:0] length;
logic [MEM_ADDR_BITS-1:0] inp_baddr; logic [MEM_ADDR_BITS-1:0] inp_baddr;
logic [MEM_ADDR_BITS-1:0] out_baddr; logic [MEM_ADDR_BITS-1:0] out_baddr;
...@@ -74,22 +79,27 @@ module Accel # ...@@ -74,22 +79,27 @@ module Accel #
) )
rf rf
( (
.clock (clock), .clock (clock),
.reset (reset), .reset (reset),
.host_req_valid (host_req_valid), .host_req_valid (host_req_valid),
.host_req_opcode (host_req_opcode), .host_req_opcode (host_req_opcode),
.host_req_addr (host_req_addr), .host_req_addr (host_req_addr),
.host_req_value (host_req_value), .host_req_value (host_req_value),
.host_req_deq (host_req_deq), .host_req_deq (host_req_deq),
.host_resp_valid (host_resp_valid), .host_resp_valid (host_resp_valid),
.host_resp_bits (host_resp_bits), .host_resp_bits (host_resp_bits),
.launch (launch), .launch (launch),
.finish (finish), .finish (finish),
.length (length),
.inp_baddr (inp_baddr), .event_counter_valid (event_counter_valid),
.out_baddr (out_baddr) .event_counter_value (event_counter_value),
.constant (constant),
.length (length),
.inp_baddr (inp_baddr),
.out_baddr (out_baddr)
); );
Compute # Compute #
...@@ -101,24 +111,29 @@ module Accel # ...@@ -101,24 +111,29 @@ module Accel #
) )
comp comp
( (
.clock (clock), .clock (clock),
.reset (reset), .reset (reset),
.mem_req_valid (mem_req_valid), .mem_req_valid (mem_req_valid),
.mem_req_opcode (mem_req_opcode), .mem_req_opcode (mem_req_opcode),
.mem_req_len (mem_req_len), .mem_req_len (mem_req_len),
.mem_req_addr (mem_req_addr), .mem_req_addr (mem_req_addr),
.mem_wr_valid (mem_wr_valid), .mem_wr_valid (mem_wr_valid),
.mem_wr_bits (mem_wr_bits), .mem_wr_bits (mem_wr_bits),
.mem_rd_valid (mem_rd_valid), .mem_rd_valid (mem_rd_valid),
.mem_rd_bits (mem_rd_bits), .mem_rd_bits (mem_rd_bits),
.mem_rd_ready (mem_rd_ready), .mem_rd_ready (mem_rd_ready),
.launch (launch), .launch (launch),
.finish (finish), .finish (finish),
.length (length),
.inp_baddr (inp_baddr), .event_counter_valid (event_counter_valid),
.out_baddr (out_baddr) .event_counter_value (event_counter_value),
.constant (constant),
.length (length),
.inp_baddr (inp_baddr),
.out_baddr (out_baddr)
); );
endmodule endmodule
...@@ -52,6 +52,11 @@ module Compute # ...@@ -52,6 +52,11 @@ module Compute #
input launch, input launch,
output finish, output finish,
output event_counter_valid,
output [HOST_DATA_BITS-1:0] event_counter_value,
input [HOST_DATA_BITS-1:0] constant,
input [HOST_DATA_BITS-1:0] length, input [HOST_DATA_BITS-1:0] length,
input [MEM_ADDR_BITS-1:0] inp_baddr, input [MEM_ADDR_BITS-1:0] inp_baddr,
input [MEM_ADDR_BITS-1:0] out_baddr input [MEM_ADDR_BITS-1:0] out_baddr
...@@ -84,7 +89,7 @@ module Compute # ...@@ -84,7 +89,7 @@ module Compute #
IDLE: begin IDLE: begin
if (launch) begin if (launch) begin
state_n = READ_REQ; state_n = READ_REQ;
end end
end end
READ_REQ: begin READ_REQ: begin
...@@ -94,9 +99,9 @@ module Compute # ...@@ -94,9 +99,9 @@ module Compute #
READ_DATA: begin READ_DATA: begin
if (mem_rd_valid) begin if (mem_rd_valid) begin
state_n = WRITE_REQ; state_n = WRITE_REQ;
end else begin end else begin
state_n = READ_DATA; state_n = READ_DATA;
end end
end end
WRITE_REQ: begin WRITE_REQ: begin
...@@ -106,9 +111,9 @@ module Compute # ...@@ -106,9 +111,9 @@ module Compute #
WRITE_DATA: begin WRITE_DATA: begin
if (cnt == (length - 1'b1)) begin if (cnt == (length - 1'b1)) begin
state_n = IDLE; state_n = IDLE;
end else begin end else begin
state_n = READ_REQ; state_n = READ_REQ;
end end
end end
default: begin default: begin
...@@ -116,6 +121,22 @@ module Compute # ...@@ -116,6 +121,22 @@ module Compute #
endcase endcase
end end
logic last;
assign last = (state_r == WRITE_DATA) & (cnt == (length - 1'b1));
// cycle counter
logic [HOST_DATA_BITS-1:0] cycle_counter;
always_ff @(posedge clock) begin
if (reset | state_r == IDLE) begin
cycle_counter <= '0;
end else begin
cycle_counter <= cycle_counter + 1'b1;
end
end
assign event_counter_valid = last;
assign event_counter_value = cycle_counter;
// calculate next address // calculate next address
always_ff @(posedge clock) begin always_ff @(posedge clock) begin
if (reset | state_r == IDLE) begin if (reset | state_r == IDLE) begin
...@@ -136,7 +157,7 @@ module Compute # ...@@ -136,7 +157,7 @@ module Compute #
// read // read
always_ff @(posedge clock) begin always_ff @(posedge clock) begin
if ((state_r == READ_DATA) & mem_rd_valid) begin if ((state_r == READ_DATA) & mem_rd_valid) begin
data <= mem_rd_bits + 1'b1; data <= mem_rd_bits + {32'd0, constant};
end end
end end
assign mem_rd_ready = state_r == READ_DATA; assign mem_rd_ready = state_r == READ_DATA;
...@@ -155,5 +176,5 @@ module Compute # ...@@ -155,5 +176,5 @@ module Compute #
end end
// done when read/write are equal to length // done when read/write are equal to length
assign finish = (state_r == WRITE_DATA) & (cnt == (length - 1'b1)); assign finish = last;
endmodule endmodule
...@@ -25,11 +25,13 @@ ...@@ -25,11 +25,13 @@
* Register description | addr * Register description | addr
* -------------------------|----- * -------------------------|-----
* Control status register | 0x00 * Control status register | 0x00
* Length value register | 0x04 * Cycle counter | 0x04
* Input pointer lsb | 0x08 * Constant value | 0x08
* Input pointer msb | 0x0c * Vector length | 0x0c
* Output pointer lsb | 0x10 * Input pointer lsb | 0x10
* Output pointer msb | 0x14 * Input pointer msb | 0x14
* Output pointer lsb | 0x18
* Output pointer msb | 0x1c
* ------------------------------- * -------------------------------
* ------------------------------ * ------------------------------
...@@ -58,11 +60,18 @@ module RegFile # ...@@ -58,11 +60,18 @@ module RegFile #
output launch, output launch,
input finish, input finish,
input event_counter_valid,
input [HOST_DATA_BITS-1:0] event_counter_value,
output [HOST_DATA_BITS-1:0] constant,
output [HOST_DATA_BITS-1:0] length, output [HOST_DATA_BITS-1:0] length,
output [MEM_ADDR_BITS-1:0] inp_baddr, output [MEM_ADDR_BITS-1:0] inp_baddr,
output [MEM_ADDR_BITS-1:0] out_baddr output [MEM_ADDR_BITS-1:0] out_baddr
); );
localparam NUM_REG = 8;
typedef enum logic {IDLE, READ} state_t; typedef enum logic {IDLE, READ} state_t;
state_t state_n, state_r; state_t state_n, state_r;
...@@ -80,7 +89,7 @@ module RegFile # ...@@ -80,7 +89,7 @@ module RegFile #
IDLE: begin IDLE: begin
if (host_req_valid & ~host_req_opcode) begin if (host_req_valid & ~host_req_opcode) begin
state_n = READ; state_n = READ;
end end
end end
READ: begin READ: begin
...@@ -91,28 +100,49 @@ module RegFile # ...@@ -91,28 +100,49 @@ module RegFile #
assign host_req_deq = (state_r == IDLE) ? host_req_valid : 1'b0; assign host_req_deq = (state_r == IDLE) ? host_req_valid : 1'b0;
logic [HOST_DATA_BITS-1:0] rf [5:0]; logic [HOST_DATA_BITS-1:0] rf [NUM_REG-1:0];
genvar i; genvar i;
for (i = 0; i < 6; i++) begin for (i = 0; i < NUM_REG; i++) begin
logic wen = (state_r == IDLE)? host_req_valid & host_req_opcode & i*4 == host_req_addr : 1'b0; logic wen = (state_r == IDLE)? host_req_valid & host_req_opcode & i*4 == host_req_addr : 1'b0;
if (i == 0) begin if (i == 0) begin
always_ff @(posedge clock) begin always_ff @(posedge clock) begin
if (reset) begin if (reset) begin
end else if (finish) begin rf[i] <= 'd0;
rf[i] <= 'd2; end else if (finish) begin
end else if (wen) begin rf[i] <= 'd2;
rf[i] <= host_req_value; end else if (wen) begin
end rf[i] <= host_req_value;
end
end end
end else if (i == 1) begin
always_ff @(posedge clock) begin
if (reset) begin
rf[i] <= 'd0;
end else if (event_counter_valid) begin
rf[i] <= event_counter_value;
end else if (wen) begin
rf[i] <= host_req_value;
end
end
end else begin end else begin
always_ff @(posedge clock) begin always_ff @(posedge clock) begin
if (reset) begin if (reset) begin
end else if (wen) begin rf[i] <= 'd0;
rf[i] <= host_req_value; end else if (wen) begin
end rf[i] <= host_req_value;
end
end end
end end
end end
logic [HOST_DATA_BITS-1:0] rdata; logic [HOST_DATA_BITS-1:0] rdata;
...@@ -132,6 +162,10 @@ module RegFile # ...@@ -132,6 +162,10 @@ module RegFile #
rdata <= rf[4]; rdata <= rf[4];
end else if (host_req_addr == 'h14) begin end else if (host_req_addr == 'h14) begin
rdata <= rf[5]; rdata <= rf[5];
end else if (host_req_addr == 'h18) begin
rdata <= rf[6];
end else if (host_req_addr == 'h1c) begin
rdata <= rf[7];
end else begin end else begin
rdata <= 'd0; rdata <= 'd0;
end end
...@@ -142,8 +176,9 @@ module RegFile # ...@@ -142,8 +176,9 @@ module RegFile #
assign host_resp_bits = rdata; assign host_resp_bits = rdata;
assign launch = rf[0][0]; assign launch = rf[0][0];
assign length = rf[1]; assign constant = rf[2];
assign inp_baddr = {rf[3], rf[2]}; assign length = rf[3];
assign out_baddr = {rf[5], rf[4]}; assign inp_baddr = {rf[5], rf[4]};
assign out_baddr = {rf[7], rf[6]};
endmodule endmodule
...@@ -17,31 +17,25 @@ ...@@ -17,31 +17,25 @@
import tvm import tvm
import ctypes import ctypes
import json
import os.path as osp import os.path as osp
from sys import platform from sys import platform
def driver(hw_lib, sw_lib): def driver(hw_backend):
"""Init hardware and software shared library for add-by-one accelerator """Init hardware and software shared library for accelerator
Parameters Parameters
------------ ------------
hw_lib : str hw_backend : str
Name of hardware shared library Hardware backend can be verilog or chisel
sw_lib : str
Name of software shared library
""" """
_ext = ".dylib" if platform == "darwin" else ".so"
_hw_libname = "libhw" + _ext
_sw_libname = "libsw" + _ext
_cur_path = osp.dirname(osp.abspath(osp.expanduser(__file__))) _cur_path = osp.dirname(osp.abspath(osp.expanduser(__file__)))
_root_path = osp.join(_cur_path, "..", "..") if hw_backend in ("verilog", "chisel"):
_cfg_file = osp.join(_root_path, "config", "config.json") _hw_lib = osp.join(_cur_path, "..", "..", "hardware", hw_backend, "build", _hw_libname)
_cfg = json.load(open(_cfg_file)) _sw_lib = osp.join(_cur_path, "..", "..", "build", _sw_libname)
if not hw_lib.endswith(("dylib", "so")):
hw_lib += ".dylib" if platform == "darwin" else ".so"
if not sw_lib.endswith(("dylib", "so")):
sw_lib += ".dylib" if platform == "darwin" else ".so"
_hw_lib = osp.join(_root_path, _cfg['BUILD_NAME'], hw_lib)
_sw_lib = osp.join(_root_path, _cfg['BUILD_NAME'], sw_lib)
def load_dll(dll): def load_dll(dll):
try: try:
...@@ -49,9 +43,9 @@ def driver(hw_lib, sw_lib): ...@@ -49,9 +43,9 @@ def driver(hw_lib, sw_lib):
except OSError: except OSError:
return [] return []
def run(a, b): def run(a, b, c):
load_dll(_sw_lib) load_dll(_sw_lib)
f = tvm.get_global_func("tvm.vta.driver") f = tvm.get_global_func("tvm.vta.driver")
m = tvm.module.load(_hw_lib, "vta-tsim") m = tvm.module.load(_hw_lib, "vta-tsim")
f(m, a, b) return f(m, a, b, c)
return run return run
...@@ -43,34 +43,40 @@ class Device { ...@@ -43,34 +43,40 @@ class Device {
module.operator->()); module.operator->());
} }
int Run(uint32_t length, void* inp, void* out) { uint32_t Run(uint32_t c, uint32_t length, void* inp, void* out) {
uint32_t wait_cycles = 100000000; uint32_t cycles;
this->Launch(wait_cycles, length, inp, out); this->Launch(c, length, inp, out);
this->WaitForCompletion(wait_cycles); cycles = this->WaitForCompletion();
dpi_->Finish(); dpi_->Finish();
return 0; return cycles;
} }
private: private:
void Launch(uint32_t wait_cycles, uint32_t length, void* inp, void* out) { void Launch(uint32_t c, uint32_t length, void* inp, void* out) {
dpi_->Launch(wait_cycles); dpi_->Launch(wait_cycles_);
// write registers // set counter to zero
dpi_->WriteReg(0x04, length); dpi_->WriteReg(0x04, 0);
dpi_->WriteReg(0x08, get_half_addr(inp, false)); dpi_->WriteReg(0x08, c);
dpi_->WriteReg(0x0c, get_half_addr(inp, true)); dpi_->WriteReg(0x0c, length);
dpi_->WriteReg(0x10, get_half_addr(out, false)); dpi_->WriteReg(0x10, get_half_addr(inp, false));
dpi_->WriteReg(0x14, get_half_addr(out, true)); dpi_->WriteReg(0x14, get_half_addr(inp, true));
dpi_->WriteReg(0x00, 0x1); // launch dpi_->WriteReg(0x18, get_half_addr(out, false));
dpi_->WriteReg(0x1c, get_half_addr(out, true));
// launch
dpi_->WriteReg(0x00, 0x1);
} }
void WaitForCompletion(uint32_t wait_cycles) { uint32_t WaitForCompletion() {
uint32_t i, val; uint32_t i, val;
for (i = 0; i < wait_cycles; i++) { for (i = 0; i < wait_cycles_; i++) {
val = dpi_->ReadReg(0x00); val = dpi_->ReadReg(0x00);
if (val == 2) break; // finish if (val == 2) break; // finish
} }
val = dpi_->ReadReg(0x04);
return val;
} }
uint32_t wait_cycles_{100000000};
DPIModuleNode* dpi_; DPIModuleNode* dpi_;
Module module_; Module module_;
}; };
...@@ -84,7 +90,8 @@ TVM_REGISTER_GLOBAL("tvm.vta.driver") ...@@ -84,7 +90,8 @@ TVM_REGISTER_GLOBAL("tvm.vta.driver")
DLTensor* A = args[1]; DLTensor* A = args[1];
DLTensor* B = args[2]; DLTensor* B = args[2];
Device dev_(dev_mod); Device dev_(dev_mod);
dev_.Run(A->shape[0], A->data, B->data); uint32_t cycles = dev_.Run(static_cast<int>(args[3]), A->shape[0], A->data, B->data);
*rv = static_cast<int>(cycles);
}); });
} // namespace driver } // namespace driver
......
...@@ -18,22 +18,21 @@ ...@@ -18,22 +18,21 @@
import tvm import tvm
import numpy as np import numpy as np
from tsim.driver import driver from accel.driver import driver
def test_tsim(i): def test_accel():
rmin = 1 # min vector size of 1
rmax = 64 rmax = 64
n = np.random.randint(rmin, rmax) n = np.random.randint(1, rmax)
c = np.random.randint(0, rmax)
ctx = tvm.cpu(0) ctx = tvm.cpu(0)
a = tvm.nd.array(np.random.randint(rmax, size=n).astype("uint64"), ctx) a = tvm.nd.array(np.random.randint(rmax, size=n).astype("uint64"), ctx)
b = tvm.nd.array(np.zeros(n).astype("uint64"), ctx) b = tvm.nd.array(np.zeros(n).astype("uint64"), ctx)
f = driver("libhw", "libsw") f = driver("chisel")
f(a, b) cycles = f(a, b, c)
emsg = "[FAIL] test number:{} n:{}".format(i, n) msg = "cycles:{0:4} n:{1:2} c:{2:2}".format(cycles, n, c)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1, err_msg=emsg) np.testing.assert_equal(b.asnumpy(), a.asnumpy() + c, err_msg = "[FAIL] " + msg)
print("[PASS] test number:{} n:{}".format(i, n)) print("[PASS] " + msg)
if __name__ == "__main__": if __name__ == "__main__":
times = 10 for i in range(10):
for i in range(times): test_accel()
test_tsim(i)
...@@ -15,10 +15,24 @@ ...@@ -15,10 +15,24 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
file(GLOB TSIM_SW_SRC src/driver.cc) import tvm
add_library(sw SHARED ${TSIM_SW_SRC}) import numpy as np
target_include_directories(sw PRIVATE ${VTA_DIR}/include)
if(APPLE) from accel.driver import driver
set_target_properties(sw PROPERTIES LINK_FLAGS "-undefined dynamic_lookup")
endif(APPLE) def test_accel():
rmax = 64
n = np.random.randint(1, rmax)
c = np.random.randint(0, rmax)
ctx = tvm.cpu(0)
a = tvm.nd.array(np.random.randint(rmax, size=n).astype("uint64"), ctx)
b = tvm.nd.array(np.zeros(n).astype("uint64"), ctx)
f = driver("verilog")
cycles = f(a, b, c)
msg = "cycles:{0:4} n:{1:2} c:{2:2}".format(cycles, n, c)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + c, err_msg = "[FAIL] " + msg)
print("[PASS] " + msg)
if __name__ == "__main__":
for i in range(10):
test_accel()
...@@ -112,7 +112,6 @@ module VTAHostDPI # ...@@ -112,7 +112,6 @@ module VTAHostDPI #
always_ff @(posedge clock) begin always_ff @(posedge clock) begin
if (__exit == 'd1) begin if (__exit == 'd1) begin
$display("[TSIM] Verilog $finish called at cycle:%016d", cycles);
$finish; $finish;
end end
end end
......
...@@ -75,7 +75,6 @@ void VTADPIInit(VTAContextHandle handle, ...@@ -75,7 +75,6 @@ void VTADPIInit(VTAContextHandle handle,
// VL_USER_FINISH needs to be defined when compiling Verilator code // VL_USER_FINISH needs to be defined when compiling Verilator code
void vl_finish(const char* filename, int linenum, const char* hier) { void vl_finish(const char* filename, int linenum, const char* hier) {
Verilated::gotFinish(true); Verilated::gotFinish(true);
VL_PRINTF("[TSIM] exiting simulation\n");
} }
int VTADPISim(uint64_t max_cycles) { int VTADPISim(uint64_t max_cycles) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment