Commit a6d04b8d by Luis Vega Committed by Thierry Moreau

[RFC] [VTA] [TSIM] Enabling Cycle-Accurate Hardware Simulation for VTA #3009 (#3010)

* merge files

* move verilator to the right place

* change name to tsim

* add default rule to be build and run

* add README for tsim

* Update README.md

* add some structural feedback

* change name of VTASim to VTADPISim

* more renaming

* update comment

* add license

* fix indentation

* add switch for vta-tsim

* add more licenses

* update readme

* address some of the new feedback

* add some feedback from cpplint

* add one more whitespace

* pass pointer so linter is happy

* pass pointer so linter is happy

* README moved to vta documentation

* create types for dpi functions, so they can be handle easily

* fix pointer style

* add feedback from docs

* parametrize width data and pointers

* fix comments

* fix comment

* add comment to class

* add missing parameters

* move README back to tsim example

* add feedback

* add more comments and remove un-necessary argument in finish

* update comments

* fix cpplint

* fix doc
parent 981db150
...@@ -131,3 +131,6 @@ set(USE_SORT ON) ...@@ -131,3 +131,6 @@ set(USE_SORT ON)
# Build ANTLR parser for Relay text format # Build ANTLR parser for Relay text format
set(USE_ANTLR OFF) set(USE_ANTLR OFF)
# Build TSIM for VTA
set(USE_VTA_TSIM OFF)
...@@ -60,6 +60,13 @@ elseif(PYTHON) ...@@ -60,6 +60,13 @@ elseif(PYTHON)
find_library(__cma_lib NAMES cma PATH /usr/lib) find_library(__cma_lib NAMES cma PATH /usr/lib)
target_link_libraries(vta ${__cma_lib}) target_link_libraries(vta ${__cma_lib})
endif() endif()
if(NOT USE_VTA_TSIM STREQUAL "OFF")
include_directories("vta/include")
file(GLOB RUNTIME_DPI_SRCS vta/src/dpi/module.cc)
list(APPEND RUNTIME_SRCS ${RUNTIME_DPI_SRCS})
endif()
else() else()
message(STATUS "Cannot found python in env, VTA build is skipped..") message(STATUS "Cannot found python in env, VTA build is skipped..")
endif() endif()
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
cmake_minimum_required(VERSION 3.2)
project(tsim C CXX)
set(TVM_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../../)
set(VTA_DIR ${TVM_DIR}/vta)
include_directories("${TVM_DIR}/include")
include_directories("${TVM_DIR}/3rdparty/dlpack/include")
include_directories("${TVM_DIR}/3rdparty/dmlc-core/include")
include_directories("${TVM_DIR}/vta/src/dpi")
set(CMAKE_C_FLAGS "-O2 -Wall -fPIC -fvisibility=hidden")
set(CMAKE_CXX_FLAGS "-O2 -Wall -fPIC -fvisibility=hidden -std=c++11")
if (CMAKE_CXX_COMPILER_ID MATCHES "GNU" AND
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0)
set(CMAKE_CXX_FLAGS "-faligned-new ${CMAKE_CXX_FLAGS}")
endif()
# Module rules
include(cmake/modules/tsim.cmake)
include(cmake/modules/driver.cmake)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
export PYTHONPATH:=$(PWD)/python:$(PYTHONPATH)
BUILD_DIR = $(shell python python/tsim/config.py --get-build-name)
TVM_DIR = $(abspath ../../../)
TSIM_TARGET = verilog
TSIM_TOP_NAME = TestAccel
TSIM_BUILD_NAME = build
# optional
TSIM_TRACE_NAME = trace.vcd
default: cmake run
.PHONY: cmake
cmake: | $(BUILD_DIR)
cd $(BUILD_DIR) && cmake .. && make
$(BUILD_DIR):
mkdir -p $@
run:
python3 tests/python/test_tsim.py | grep PASS
clean:
-rm -rf $(BUILD_DIR)
<!--- Licensed to the Apache Software Foundation (ASF) under one -->
<!--- or more contributor license agreements. See the NOTICE file -->
<!--- distributed with this work for additional information -->
<!--- regarding copyright ownership. The ASF licenses this file -->
<!--- to you under the Apache License, Version 2.0 (the -->
<!--- "License"); you may not use this file except in compliance -->
<!--- with the License. You may obtain a copy of the License at -->
<!--- http://www.apache.org/licenses/LICENSE-2.0 -->
<!--- Unless required by applicable law or agreed to in writing, -->
<!--- software distributed under the License is distributed on an -->
<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
<!--- KIND, either express or implied. See the License for the -->
<!--- specific language governing permissions and limitations -->
<!--- under the License. -->
VTA TSIM Installation
======================
*TSIM* is a cycle-accurate hardware simulation environment that can be invoked and managed directly from TVM. It aims to enable cycle accurate simulation of deep learning accelerators including VTA.
This simulation environment can be used in both OSX and Linux.
There are two dependencies required to make *TSIM* works: [Verilator](https://www.veripool.org/wiki/verilator) and [sbt](https://www.scala-sbt.org/) for accelerators designed in [Chisel3](https://github.com/freechipsproject/chisel3).
## OSX Dependencies
Install `sbt` and `verilator` using [Homebrew](https://brew.sh/).
```bash
brew install verilator sbt
```
## Linux Dependencies
Add `sbt` to package manager (Ubuntu).
```bash
echo "deb https://dl.bintray.com/sbt/debian /" | sudo tee -a /etc/apt/sources.list.d/sbt.list
sudo apt-key adv --keyserver hkp://keyserver.ubuntu.com:80 --recv 2EE0EA64E40A89B84B2DF73499E82A75642AC823
sudo apt-get update
```
Install `sbt` and `verilator`.
```bash
sudo apt install verilator sbt
```
## Setup in TVM
1. Install `verilator` and `sbt` as described above
2. Enable VTA TSIM by turning on the switch `USE_VTA_TSIM` in config.cmake
3. Build tvm
## How to run VTA TSIM examples
There are two sample VTA accelerators (add-by-one) designed in Chisel3 and Verilog to show how *TSIM* works.
These examples are located at `<tvm-root>/vta/apps/tsim_example`.
* Instructions
* Open `<tvm-root>/vta/apps/tsim_example/python/tsim/config.json`
* Change `TARGET` from `verilog` to `chisel`, depending on what language backend you would like to test
* Go to `tvm/vta/apps/tsim`
* Run `make`
* Some pointers
* Build cmake script for driver `<tvm-root>/vta/apps/tsim_example/cmake/modules/driver.cmake`
* Build cmake script for tsim `<tvm-root>/vta/apps/tsim_example/cmake/modules/tsim.cmake`
* Software driver that handles the VTA accelerator `<tvm-root>/vta/apps/tsim_example/src/driver.cc`
* VTA add-by-one accelerator (Verilog) `<tvm-root>/vta/apps/tsim_example/hardware/verilog`
* VTA add-by-one accelerator (Chisel) `<tvm-root>/vta/apps/tsim_example/hardware/chisel`
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
file(GLOB TSIM_SW_SRC src/driver.cc)
add_library(driver SHARED ${TSIM_SW_SRC})
target_include_directories(driver PRIVATE ${VTA_DIR}/include)
if(APPLE)
set_target_properties(driver PROPERTIES LINK_FLAGS "-undefined dynamic_lookup")
endif(APPLE)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
if(MSVC)
message(STATUS "TSIM build is skipped in Windows..")
else()
find_program(PYTHON NAMES python python3 python3.6)
find_program(VERILATOR NAMES verilator)
if (VERILATOR AND PYTHON)
if (TSIM_TOP_NAME STREQUAL "")
message(FATAL_ERROR "TSIM_TOP_NAME should be defined")
endif()
if (TSIM_BUILD_NAME STREQUAL "")
message(FATAL_ERROR "TSIM_BUILD_NAME should be defined")
endif()
set(TSIM_CONFIG ${PYTHON} ${CMAKE_CURRENT_SOURCE_DIR}/python/tsim/config.py)
execute_process(COMMAND ${TSIM_CONFIG} --get-target OUTPUT_VARIABLE __TSIM_TARGET)
execute_process(COMMAND ${TSIM_CONFIG} --get-top-name OUTPUT_VARIABLE __TSIM_TOP_NAME)
execute_process(COMMAND ${TSIM_CONFIG} --get-build-name OUTPUT_VARIABLE __TSIM_BUILD_NAME)
execute_process(COMMAND ${TSIM_CONFIG} --get-use-trace OUTPUT_VARIABLE __TSIM_USE_TRACE)
execute_process(COMMAND ${TSIM_CONFIG} --get-trace-name OUTPUT_VARIABLE __TSIM_TRACE_NAME)
string(STRIP ${__TSIM_TARGET} TSIM_TARGET)
string(STRIP ${__TSIM_TOP_NAME} TSIM_TOP_NAME)
string(STRIP ${__TSIM_BUILD_NAME} TSIM_BUILD_NAME)
string(STRIP ${__TSIM_USE_TRACE} TSIM_USE_TRACE)
string(STRIP ${__TSIM_TRACE_NAME} TSIM_TRACE_NAME)
set(TSIM_BUILD_DIR ${CMAKE_CURRENT_SOURCE_DIR}/${TSIM_BUILD_NAME})
if (TSIM_TARGET STREQUAL "chisel")
find_program(SBT NAMES sbt)
if (SBT)
# Install Chisel VTA package for DPI modules
set(VTA_CHISEL_DIR ${VTA_DIR}/hardware/chisel)
execute_process(WORKING_DIRECTORY ${VTA_CHISEL_DIR}
COMMAND ${SBT} publishLocal RESULT_VARIABLE RETCODE)
if (NOT RETCODE STREQUAL "0")
message(FATAL_ERROR "[TSIM] sbt failed to install VTA scala package")
endif()
# Chisel - Scala to Verilog compilation
set(TSIM_CHISEL_DIR ${CMAKE_CURRENT_SOURCE_DIR}/hardware/chisel)
set(CHISEL_TARGET_DIR ${TSIM_BUILD_DIR}/chisel)
set(CHISEL_OPT "test:runMain test.Elaborate --target-dir ${CHISEL_TARGET_DIR} --top-name ${TSIM_TOP_NAME}")
execute_process(WORKING_DIRECTORY ${TSIM_CHISEL_DIR} COMMAND ${SBT} ${CHISEL_OPT} RESULT_VARIABLE RETCODE)
if (NOT RETCODE STREQUAL "0")
message(FATAL_ERROR "[TSIM] sbt failed to compile from Chisel to Verilog.")
endif()
file(GLOB VERILATOR_RTL_SRC ${CHISEL_TARGET_DIR}/*.v)
else()
message(FATAL_ERROR "[TSIM] sbt should be installed for Chisel")
endif() # sbt
elseif (TSIM_TARGET STREQUAL "verilog")
set(VTA_VERILOG_DIR ${VTA_DIR}/hardware/chisel/src/main/resources/verilog)
set(TSIM_VERILOG_DIR ${CMAKE_CURRENT_SOURCE_DIR}/hardware/verilog)
file(GLOB VERILATOR_RTL_SRC ${VTA_VERILOG_DIR}/*.v ${TSIM_VERILOG_DIR}/*.v)
else()
message(STATUS "[TSIM] target language can be only verilog or chisel...")
endif() # TSIM_TARGET
if (TSIM_TARGET STREQUAL "chisel" OR TSIM_TARGET STREQUAL "verilog")
# Check if tracing can be enabled
if (NOT TSIM_USE_TRACE STREQUAL "OFF")
message(STATUS "[TSIM] Verilog enable tracing")
else()
message(STATUS "[TSIM] Verilator disable tracing")
endif()
# Verilator - Verilog to C++ compilation
set(VERILATOR_TARGET_DIR ${TSIM_BUILD_DIR}/verilator)
set(VERILATOR_OPT +define+RANDOMIZE_GARBAGE_ASSIGN +define+RANDOMIZE_REG_INIT)
list(APPEND VERILATOR_OPT +define+RANDOMIZE_MEM_INIT --x-assign unique)
list(APPEND VERILATOR_OPT --output-split 20000 --output-split-cfuncs 20000)
list(APPEND VERILATOR_OPT --top-module ${TSIM_TOP_NAME} -Mdir ${VERILATOR_TARGET_DIR})
list(APPEND VERILATOR_OPT --cc ${VERILATOR_RTL_SRC})
if (NOT TSIM_USE_TRACE STREQUAL "OFF")
list(APPEND VERILATOR_OPT --trace)
endif()
execute_process(COMMAND ${VERILATOR} ${VERILATOR_OPT} RESULT_VARIABLE RETCODE)
if (NOT RETCODE STREQUAL "0")
message(FATAL_ERROR "[TSIM] Verilator failed to compile Verilog to C++...")
endif()
# Build shared library (.so)
set(VTA_HW_DPI_DIR ${VTA_DIR}/hardware/dpi)
set(VERILATOR_INC_DIR /usr/local/share/verilator/include)
set(VERILATOR_LIB_SRC ${VERILATOR_INC_DIR}/verilated.cpp ${VERILATOR_INC_DIR}/verilated_dpi.cpp)
if (NOT TSIM_USE_TRACE STREQUAL "OFF")
list(APPEND VERILATOR_LIB_SRC ${VERILATOR_INC_DIR}/verilated_vcd_c.cpp)
endif()
file(GLOB VERILATOR_GEN_SRC ${VERILATOR_TARGET_DIR}/*.cpp)
file(GLOB VERILATOR_SRC ${VTA_HW_DPI_DIR}/tsim_device.cc)
add_library(tsim SHARED ${VERILATOR_LIB_SRC} ${VERILATOR_GEN_SRC} ${VERILATOR_SRC})
set(VERILATOR_DEF VL_TSIM_NAME=V${TSIM_TOP_NAME} VL_PRINTF=printf VM_COVERAGE=0 VM_SC=0)
if (NOT TSIM_USE_TRACE STREQUAL "OFF")
list(APPEND VERILATOR_DEF VM_TRACE=1 TSIM_TRACE_FILE=${TSIM_BUILD_DIR}/${TSIM_TRACE_NAME}.vcd)
else()
list(APPEND VERILATOR_DEF VM_TRACE=0)
endif()
target_compile_definitions(tsim PRIVATE ${VERILATOR_DEF})
target_compile_options(tsim PRIVATE -Wno-sign-compare -include V${TSIM_TOP_NAME}.h)
target_include_directories(tsim PRIVATE ${VERILATOR_TARGET_DIR} ${VERILATOR_INC_DIR} ${VERILATOR_INC_DIR}/vltstd ${VTA_DIR}/include)
if(APPLE)
set_target_properties(tsim PROPERTIES LINK_FLAGS "-undefined dynamic_lookup")
endif(APPLE)
endif() # TSIM_TARGET STREQUAL "chisel" OR TSIM_TARGET STREQUAL "verilog"
else()
message(STATUS "[TSIM] could not find Python or Verilator, build is skipped...")
endif() # VERILATOR
endif() # MSVC
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
clean:
-rm -rf target project/target project/project
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
name := "accel"
version := "0.1.0-SNAPSHOT"
organization := "edu.washington.cs"
def scalacOptionsVersion(scalaVersion: String): Seq[String] = {
Seq() ++ {
// If we're building with Scala > 2.11, enable the compile option
// switch to support our anonymous Bundle definitions:
// https://github.com/scala/bug/issues/10047
CrossVersion.partialVersion(scalaVersion) match {
case Some((2, scalaMajor: Long)) if scalaMajor < 12 => Seq()
case _ => Seq(
"-Xsource:2.11",
"-language:reflectiveCalls",
"-language:implicitConversions",
"-deprecation",
"-Xlint",
"-Ywarn-unused",
)
}
}
}
def javacOptionsVersion(scalaVersion: String): Seq[String] = {
Seq() ++ {
// Scala 2.12 requires Java 8. We continue to generate
// Java 7 compatible code for Scala 2.11
// for compatibility with old clients.
CrossVersion.partialVersion(scalaVersion) match {
case Some((2, scalaMajor: Long)) if scalaMajor < 12 =>
Seq("-source", "1.7", "-target", "1.7")
case _ =>
Seq("-source", "1.8", "-target", "1.8")
}
}
}
scalaVersion := "2.11.12"
resolvers ++= Seq(
Resolver.sonatypeRepo("snapshots"),
Resolver.sonatypeRepo("releases"))
libraryDependencies ++= Seq(
"edu.berkeley.cs" %% "chisel3" % "3.1.7",
"edu.washington.cs" %% "vta" % "0.1.0-SNAPSHOT",
)
scalacOptions ++= scalacOptionsVersion(scalaVersion.value)
javacOptions ++= javacOptionsVersion(scalaVersion.value)
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
sbt.version = 1.1.1
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
logLevel := Level.Warn
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package accel
import chisel3._
import vta.dpi._
/** Add-by-one accelerator.
*
* ___________ ___________
* | | | |
* | HostDPI | <--> | RegFile | <->|
* |_________| |_________| |
* |
* ___________ ___________ |
* | | | | |
* | MemDPI | <--> | Compute | <->|
* |_________| |_________|
*
*/
class Accel extends Module {
val io = IO(new Bundle {
val host = new VTAHostDPIClient
val mem = new VTAMemDPIMaster
})
val rf = Module(new RegFile)
val ce = Module(new Compute)
rf.io.host <> io.host
io.mem <> ce.io.mem
ce.io.launch := rf.io.launch
rf.io.finish := ce.io.finish
ce.io.length := rf.io.length
ce.io.inp_baddr := rf.io.inp_baddr
ce.io.out_baddr := rf.io.out_baddr
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package accel
import chisel3._
import chisel3.util._
import vta.dpi._
/** Compute
*
* Add-by-one procedure:
*
* 1. Wait for launch to be asserted
* 2. Issue a read request for 8-byte value at inp_baddr address
* 3. Wait for the value
* 4. Issue a write request for 8-byte value at out_baddr address
* 5. Increment read-address and write-address for next value
* 6. Check if counter (cnt) is equal to length to assert finish,
* otherwise go to step 2.
*/
class Compute extends Module {
val io = IO(new Bundle {
val launch = Input(Bool())
val finish = Output(Bool())
val length = Input(UInt(32.W))
val inp_baddr = Input(UInt(64.W))
val out_baddr = Input(UInt(64.W))
val mem = new VTAMemDPIMaster
})
val sIdle :: sReadReq :: sReadData :: sWriteReq :: sWriteData :: Nil = Enum(5)
val state = RegInit(sIdle)
val reg = Reg(chiselTypeOf(io.mem.rd.bits))
val cnt = Reg(chiselTypeOf(io.length))
val raddr = Reg(chiselTypeOf(io.inp_baddr))
val waddr = Reg(chiselTypeOf(io.out_baddr))
switch (state) {
is (sIdle) {
when (io.launch) {
state := sReadReq
}
}
is (sReadReq) {
state := sReadData
}
is (sReadData) {
when (io.mem.rd.valid) {
state := sWriteReq
}
}
is (sWriteReq) {
state := sWriteData
}
is (sWriteData) {
when (cnt === (io.length - 1.U)) {
state := sIdle
} .otherwise {
state := sReadReq
}
}
}
// calculate next address
when (state === sIdle) {
raddr := io.inp_baddr
waddr := io.out_baddr
} .elsewhen (state === sWriteData) { // increment by 8-bytes
raddr := raddr + 8.U
waddr := waddr + 8.U
}
// create request
io.mem.req.valid := state === sReadReq | state === sWriteReq
io.mem.req.opcode := state === sWriteReq
io.mem.req.len := 0.U // one-word-per-request
io.mem.req.addr := Mux(state === sReadReq, raddr, waddr)
// read
when (state === sReadData && io.mem.rd.valid) {
reg := io.mem.rd.bits + 1.U
}
io.mem.rd.ready := state === sReadData
// write
io.mem.wr.valid := state === sWriteData
io.mem.wr.bits := reg
// count read/write
when (state === sIdle) {
cnt := 0.U
} .elsewhen (state === sWriteData) {
cnt := cnt + 1.U
}
// done when read/write are equal to length
io.finish := state === sWriteData && cnt === (io.length - 1.U)
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package accel
import chisel3._
import chisel3.util._
import vta.dpi._
/** Register File.
*
* Six 32-bit register file.
*
* -------------------------------
* Register description | addr
* -------------------------|-----
* Control status register | 0x00
* Length value register | 0x04
* Input pointer lsb | 0x08
* Input pointer msb | 0x0c
* Output pointer lsb | 0x10
* Output pointer msb | 0x14
* -------------------------------
* ------------------------------
* Control status register | bit
* ------------------------------
* Launch | 0
* Finish | 1
* ------------------------------
*/
class RegFile extends Module {
val io = IO(new Bundle {
val launch = Output(Bool())
val finish = Input(Bool())
val length = Output(UInt(32.W))
val inp_baddr = Output(UInt(64.W))
val out_baddr = Output(UInt(64.W))
val host = new VTAHostDPIClient
})
val sIdle :: sRead :: Nil = Enum(2)
val state = RegInit(sIdle)
switch (state) {
is (sIdle) {
when (io.host.req.valid && !io.host.req.opcode) {
state := sRead
}
}
is (sRead) {
state := sIdle
}
}
io.host.req.deq := state === sIdle & io.host.req.valid
val reg = Seq.fill(6)(RegInit(0.U.asTypeOf(chiselTypeOf(io.host.req.value))))
val addr = Seq.tabulate(6)(_ * 4)
val reg_map = (addr zip reg) map { case (a, r) => a.U -> r }
(reg zip addr).foreach { case(r, a) =>
if (a == 0) { // control status register
when (io.finish) {
r := "b_10".U
} .elsewhen (state === sIdle && io.host.req.valid &&
io.host.req.opcode && a.U === io.host.req.addr) {
r := io.host.req.value
}
} else {
when (state === sIdle && io.host.req.valid &&
io.host.req.opcode && a.U === io.host.req.addr) {
r := io.host.req.value
}
}
}
val rdata = RegInit(0.U.asTypeOf(chiselTypeOf(io.host.req.value)))
when (state === sIdle && io.host.req.valid && !io.host.req.opcode) {
rdata := MuxLookup(io.host.req.addr, 0.U, reg_map)
}
io.host.resp.valid := state === sRead
io.host.resp.bits := rdata
io.launch := reg(0)(0)
io.length := reg(1)
io.inp_baddr := Cat(reg(3), reg(2))
io.out_baddr := Cat(reg(5), reg(4))
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package test
import chisel3._
import chisel3.experimental.{RawModule, withClockAndReset}
import vta.dpi._
import accel._
/** VTA simulation shell.
*
* Instantiate Host and Memory DPI modules.
*
*/
class VTASimShell extends RawModule {
val io = IO(new Bundle {
val clock = Input(Clock())
val reset = Input(Bool())
val host = new VTAHostDPIMaster
val mem = new VTAMemDPIClient
})
val host = Module(new VTAHostDPI)
val mem = Module(new VTAMemDPI)
mem.io.reset := io.reset
mem.io.clock := io.clock
host.io.reset := io.reset
host.io.clock := io.clock
io.mem <> mem.io.dpi
io.host <> host.io.dpi
}
/** Test accelerator.
*
* Instantiate and connect the simulation-shell and the accelerator.
*
*/
class TestAccel extends RawModule {
val clock = IO(Input(Clock()))
val reset = IO(Input(Bool()))
val sim_shell = Module(new VTASimShell)
val vta_accel = withClockAndReset(clock, reset) { Module(new Accel) }
sim_shell.io.clock := clock
sim_shell.io.reset := reset
vta_accel.io.host <> sim_shell.io.host
sim_shell.io.mem <> vta_accel.io.mem
}
/** Generate TestAccel as top module */
object Elaborate extends App {
chisel3.Driver.execute(args, () => new TestAccel)
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/** Add-by-one accelerator.
*
* ___________ ___________
* | | | |
* | HostDPI | <--> | RegFile | <->|
* |_________| |_________| |
* |
* ___________ ___________ |
* | | | | |
* | MemDPI | <--> | Compute | <->|
* |_________| |_________|
*
*/
module Accel #
( parameter HOST_ADDR_BITS = 8,
parameter HOST_DATA_BITS = 32,
parameter MEM_LEN_BITS = 8,
parameter MEM_ADDR_BITS = 64,
parameter MEM_DATA_BITS = 64
)
(
input clock,
input reset,
input host_req_valid,
input host_req_opcode,
input [HOST_ADDR_BITS-1:0] host_req_addr,
input [HOST_DATA_BITS-1:0] host_req_value,
output host_req_deq,
output host_resp_valid,
output [HOST_DATA_BITS-1:0] host_resp_bits,
output mem_req_valid,
output mem_req_opcode,
output [MEM_LEN_BITS-1:0] mem_req_len,
output [MEM_ADDR_BITS-1:0] mem_req_addr,
output mem_wr_valid,
output [MEM_DATA_BITS-1:0] mem_wr_bits,
input mem_rd_valid,
input [MEM_DATA_BITS-1:0] mem_rd_bits,
output mem_rd_ready
);
logic launch;
logic finish;
logic [HOST_DATA_BITS-1:0] length;
logic [MEM_ADDR_BITS-1:0] inp_baddr;
logic [MEM_ADDR_BITS-1:0] out_baddr;
RegFile #
(
.MEM_ADDR_BITS(MEM_ADDR_BITS),
.HOST_ADDR_BITS(HOST_ADDR_BITS),
.HOST_DATA_BITS(HOST_DATA_BITS)
)
rf
(
.clock (clock),
.reset (reset),
.host_req_valid (host_req_valid),
.host_req_opcode (host_req_opcode),
.host_req_addr (host_req_addr),
.host_req_value (host_req_value),
.host_req_deq (host_req_deq),
.host_resp_valid (host_resp_valid),
.host_resp_bits (host_resp_bits),
.launch (launch),
.finish (finish),
.length (length),
.inp_baddr (inp_baddr),
.out_baddr (out_baddr)
);
Compute #
(
.MEM_LEN_BITS(MEM_LEN_BITS),
.MEM_ADDR_BITS(MEM_ADDR_BITS),
.MEM_DATA_BITS(MEM_DATA_BITS),
.HOST_DATA_BITS(HOST_DATA_BITS)
)
comp
(
.clock (clock),
.reset (reset),
.mem_req_valid (mem_req_valid),
.mem_req_opcode (mem_req_opcode),
.mem_req_len (mem_req_len),
.mem_req_addr (mem_req_addr),
.mem_wr_valid (mem_wr_valid),
.mem_wr_bits (mem_wr_bits),
.mem_rd_valid (mem_rd_valid),
.mem_rd_bits (mem_rd_bits),
.mem_rd_ready (mem_rd_ready),
.launch (launch),
.finish (finish),
.length (length),
.inp_baddr (inp_baddr),
.out_baddr (out_baddr)
);
endmodule
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/** Compute
*
* Add-by-one procedure:
*
* 1. Wait for launch to be asserted
* 2. Issue a read request for 8-byte value at inp_baddr address
* 3. Wait for the value
* 4. Issue a write request for 8-byte value at out_baddr address
* 5. Increment read-address and write-address for next value
* 6. Check if counter (cnt) is equal to length to assert finish,
* otherwise go to step 2.
*/
module Compute #
(
parameter MEM_LEN_BITS = 8,
parameter MEM_ADDR_BITS = 64,
parameter MEM_DATA_BITS = 64,
parameter HOST_DATA_BITS = 32
)
(
input clock,
input reset,
output mem_req_valid,
output mem_req_opcode,
output [MEM_LEN_BITS-1:0] mem_req_len,
output [MEM_ADDR_BITS-1:0] mem_req_addr,
output mem_wr_valid,
output [MEM_DATA_BITS-1:0] mem_wr_bits,
input mem_rd_valid,
input [MEM_DATA_BITS-1:0] mem_rd_bits,
output mem_rd_ready,
input launch,
output finish,
input [HOST_DATA_BITS-1:0] length,
input [MEM_ADDR_BITS-1:0] inp_baddr,
input [MEM_ADDR_BITS-1:0] out_baddr
);
typedef enum logic [2:0] {IDLE,
READ_REQ,
READ_DATA,
WRITE_REQ,
WRITE_DATA} state_t;
state_t state_n, state_r;
logic [31:0] cnt;
logic [MEM_DATA_BITS-1:0] data;
logic [MEM_ADDR_BITS-1:0] raddr;
logic [MEM_ADDR_BITS-1:0] waddr;
always_ff @(posedge clock) begin
if (reset) begin
state_r <= IDLE;
end else begin
state_r <= state_n;
end
end
always_comb begin
state_n = IDLE;
case (state_r)
IDLE: begin
if (launch) begin
state_n = READ_REQ;
end
end
READ_REQ: begin
state_n = READ_DATA;
end
READ_DATA: begin
if (mem_rd_valid) begin
state_n = WRITE_REQ;
end else begin
state_n = READ_DATA;
end
end
WRITE_REQ: begin
state_n = WRITE_DATA;
end
WRITE_DATA: begin
if (cnt == (length - 1'b1)) begin
state_n = IDLE;
end else begin
state_n = READ_REQ;
end
end
default: begin
end
endcase
end
// calculate next address
always_ff @(posedge clock) begin
if (reset | state_r == IDLE) begin
raddr <= inp_baddr;
waddr <= out_baddr;
end else if (state_r == WRITE_DATA) begin
raddr <= raddr + 'd8;
waddr <= waddr + 'd8;
end
end
// create request
assign mem_req_valid = (state_r == READ_REQ) | (state_r == WRITE_REQ);
assign mem_req_opcode = state_r == WRITE_REQ;
assign mem_req_len = 'd0; // one-word-per-request
assign mem_req_addr = (state_r == READ_REQ)? raddr : waddr;
// read
always_ff @(posedge clock) begin
if ((state_r == READ_DATA) & mem_rd_valid) begin
data <= mem_rd_bits + 1'b1;
end
end
assign mem_rd_ready = state_r == READ_DATA;
// write
assign mem_wr_valid = state_r == WRITE_DATA;
assign mem_wr_bits = data;
// count read/write
always_ff @(posedge clock) begin
if (reset | state_r == IDLE) begin
cnt <= 'd0;
end else if (state_r == WRITE_DATA) begin
cnt <= cnt + 1'b1;
end
end
// done when read/write are equal to length
assign finish = (state_r == WRITE_DATA) & (cnt == (length - 1'b1));
endmodule
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/** Register File.
*
* Six 32-bit register file.
*
* -------------------------------
* Register description | addr
* -------------------------|-----
* Control status register | 0x00
* Length value register | 0x04
* Input pointer lsb | 0x08
* Input pointer msb | 0x0c
* Output pointer lsb | 0x10
* Output pointer msb | 0x14
* -------------------------------
* ------------------------------
* Control status register | bit
* ------------------------------
* Launch | 0
* Finish | 1
* ------------------------------
*/
module RegFile #
(parameter MEM_ADDR_BITS = 64,
parameter HOST_ADDR_BITS = 8,
parameter HOST_DATA_BITS = 32
)
(
input clock,
input reset,
input host_req_valid,
input host_req_opcode,
input [HOST_ADDR_BITS-1:0] host_req_addr,
input [HOST_DATA_BITS-1:0] host_req_value,
output host_req_deq,
output host_resp_valid,
output [HOST_DATA_BITS-1:0] host_resp_bits,
output launch,
input finish,
output [HOST_DATA_BITS-1:0] length,
output [MEM_ADDR_BITS-1:0] inp_baddr,
output [MEM_ADDR_BITS-1:0] out_baddr
);
typedef enum logic {IDLE, READ} state_t;
state_t state_n, state_r;
always_ff @(posedge clock) begin
if (reset) begin
state_r <= IDLE;
end else begin
state_r <= state_n;
end
end
always_comb begin
state_n = IDLE;
case (state_r)
IDLE: begin
if (host_req_valid & ~host_req_opcode) begin
state_n = READ;
end
end
READ: begin
state_n = IDLE;
end
endcase
end
assign host_req_deq = (state_r == IDLE) ? host_req_valid : 1'b0;
logic [HOST_DATA_BITS-1:0] rf [5:0];
genvar i;
for (i = 0; i < 6; i++) begin
logic wen = (state_r == IDLE)? host_req_valid & host_req_opcode & i*4 == host_req_addr : 1'b0;
if (i == 0) begin
always_ff @(posedge clock) begin
if (reset) begin
end else if (finish) begin
rf[i] <= 'd2;
end else if (wen) begin
rf[i] <= host_req_value;
end
end
end else begin
always_ff @(posedge clock) begin
if (reset) begin
end else if (wen) begin
rf[i] <= host_req_value;
end
end
end
end
logic [HOST_DATA_BITS-1:0] rdata;
always_ff @(posedge clock) begin
if (reset) begin
rdata <= 'd0;
end else if ((state_r == IDLE) & host_req_valid & ~host_req_opcode) begin
if (host_req_addr == 'h00) begin
rdata <= rf[0];
end else if (host_req_addr == 'h04) begin
rdata <= rf[1];
end else if (host_req_addr == 'h08) begin
rdata <= rf[2];
end else if (host_req_addr == 'h0c) begin
rdata <= rf[3];
end else if (host_req_addr == 'h10) begin
rdata <= rf[4];
end else if (host_req_addr == 'h14) begin
rdata <= rf[5];
end else begin
rdata <= 'd0;
end
end
end
assign host_resp_valid = (state_r == READ);
assign host_resp_bits = rdata;
assign launch = rf[0][0];
assign length = rf[1];
assign inp_baddr = {rf[3], rf[2]};
assign out_baddr = {rf[5], rf[4]};
endmodule
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/** Test accelerator.
*
* Instantiate host/memory DPI modules and connect them to the accelerator.
*
*/
module TestAccel
(
input clock,
input reset
);
localparam HOST_ADDR_BITS = 8;
localparam HOST_DATA_BITS = 32;
logic host_req_valid;
logic host_req_opcode;
logic [HOST_ADDR_BITS-1:0] host_req_addr;
logic [HOST_DATA_BITS-1:0] host_req_value;
logic host_req_deq;
logic host_resp_valid;
logic [HOST_DATA_BITS-1:0] host_resp_bits;
localparam MEM_LEN_BITS = 8;
localparam MEM_ADDR_BITS = 64;
localparam MEM_DATA_BITS = 64;
logic mem_req_valid;
logic mem_req_opcode;
logic [MEM_LEN_BITS-1:0] mem_req_len;
logic [MEM_ADDR_BITS-1:0] mem_req_addr;
logic mem_wr_valid;
logic [MEM_DATA_BITS-1:0] mem_wr_bits;
logic mem_rd_valid;
logic [MEM_DATA_BITS-1:0] mem_rd_bits;
logic mem_rd_ready;
VTAHostDPI host
(
.clock (clock),
.reset (reset),
.dpi_req_valid (host_req_valid),
.dpi_req_opcode (host_req_opcode),
.dpi_req_addr (host_req_addr),
.dpi_req_value (host_req_value),
.dpi_req_deq (host_req_deq),
.dpi_resp_valid (host_resp_valid),
.dpi_resp_bits (host_resp_bits)
);
VTAMemDPI mem
(
.clock (clock),
.reset (reset),
.dpi_req_valid (mem_req_valid),
.dpi_req_opcode (mem_req_opcode),
.dpi_req_len (mem_req_len),
.dpi_req_addr (mem_req_addr),
.dpi_wr_valid (mem_wr_valid),
.dpi_wr_bits (mem_wr_bits),
.dpi_rd_valid (mem_rd_valid),
.dpi_rd_bits (mem_rd_bits),
.dpi_rd_ready (mem_rd_ready)
);
Accel #
(
.HOST_ADDR_BITS(HOST_ADDR_BITS),
.HOST_DATA_BITS(HOST_DATA_BITS),
.MEM_LEN_BITS(MEM_LEN_BITS),
.MEM_ADDR_BITS(MEM_ADDR_BITS),
.MEM_DATA_BITS(MEM_DATA_BITS)
)
accel
(
.clock (clock),
.reset (reset),
.host_req_valid (host_req_valid),
.host_req_opcode (host_req_opcode),
.host_req_addr (host_req_addr),
.host_req_value (host_req_value),
.host_req_deq (host_req_deq),
.host_resp_valid (host_resp_valid),
.host_resp_bits (host_resp_bits),
.mem_req_valid (mem_req_valid),
.mem_req_opcode (mem_req_opcode),
.mem_req_len (mem_req_len),
.mem_req_addr (mem_req_addr),
.mem_wr_valid (mem_wr_valid),
.mem_wr_bits (mem_wr_bits),
.mem_rd_valid (mem_rd_valid),
.mem_rd_bits (mem_rd_bits),
.mem_rd_ready (mem_rd_ready)
);
endmodule
{
"TARGET" : "verilog",
"TOP_NAME" : "TestAccel",
"BUILD_NAME" : "build",
"USE_TRACE" : "off",
"TRACE_NAME" : "trace"
}
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os.path as osp
import sys
import json
import argparse
cur = osp.abspath(osp.dirname(__file__))
cfg = json.load(open(osp.join(cur, 'config.json')))
def main():
"""Main function"""
parser = argparse.ArgumentParser()
parser.add_argument("--get-target", action="store_true",
help="Get target language, i.e. verilog or chisel")
parser.add_argument("--get-top-name", action="store_true",
help="Get hardware design top name")
parser.add_argument("--get-build-name", action="store_true",
help="Get build folder name")
parser.add_argument("--get-use-trace", action="store_true",
help="Get use trace")
parser.add_argument("--get-trace-name", action="store_true",
help="Get trace filename")
args = parser.parse_args()
if len(sys.argv) == 1:
parser.print_help()
return
if args.get_target:
print(cfg['TARGET'])
if args.get_top_name:
print(cfg['TOP_NAME'])
if args.get_build_name:
print(cfg['BUILD_NAME'])
if args.get_use_trace:
print(cfg['USE_TRACE'])
if args.get_trace_name:
print(cfg['TRACE_NAME'])
if __name__ == "__main__":
main()
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
import ctypes
import json
import os.path as osp
from sys import platform
def get_build_path():
curr_path = osp.dirname(osp.abspath(osp.expanduser(__file__)))
cfg = json.load(open(osp.join(curr_path, 'config.json')))
return osp.join(curr_path, "..", "..", cfg['BUILD_NAME'])
def get_lib_ext():
if platform == "darwin":
ext = ".dylib"
else:
ext = ".so"
return ext
def get_lib_path(name):
build_path = get_build_path()
ext = get_lib_ext()
libname = name + ext
return osp.join(build_path, libname)
def _load_driver_lib():
lib = get_lib_path("libdriver")
try:
return [ctypes.CDLL(lib, ctypes.RTLD_GLOBAL)]
except OSError:
return []
def load_driver():
return tvm.get_global_func("tvm.vta.driver")
def load_tsim():
lib = get_lib_path("libtsim")
return tvm.module.load(lib, "vta-tsim")
LIBS = _load_driver_lib()
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>
#include <vta/dpi/module.h>
namespace vta {
namespace driver {
uint32_t get_half_addr(void *p, bool upper) {
if (upper) {
return ((uint64_t) ((uint64_t*) p)) >> 32;
} else {
return ((uint64_t) ((uint64_t*) p));
}
}
using vta::dpi::DPIModuleNode;
using tvm::runtime::Module;
class TestDriver {
public:
TestDriver(Module module)
: module_(module) {
dpi_ = static_cast<DPIModuleNode*>(
module.operator->());
}
int Run(uint32_t length, void* inp, void* out) {
uint32_t wait_cycles = 100000000;
this->Launch(wait_cycles, length, inp, out);
this->WaitForCompletion(wait_cycles);
dpi_->Finish();
return 0;
}
private:
void Launch(uint32_t wait_cycles, uint32_t length, void* inp, void* out) {
dpi_->Launch(wait_cycles);
// write registers
dpi_->WriteReg(0x04, length);
dpi_->WriteReg(0x08, get_half_addr(inp, false));
dpi_->WriteReg(0x0c, get_half_addr(inp, true));
dpi_->WriteReg(0x10, get_half_addr(out, false));
dpi_->WriteReg(0x14, get_half_addr(out, true));
dpi_->WriteReg(0x00, 0x1); // launch
}
void WaitForCompletion(uint32_t wait_cycles) {
uint32_t i, val;
for (i = 0; i < wait_cycles; i++) {
val = dpi_->ReadReg(0x00);
if (val == 2) break; // finish
}
}
private:
DPIModuleNode* dpi_;
Module module_;
};
using tvm::runtime::TVMRetValue;
using tvm::runtime::TVMArgs;
TVM_REGISTER_GLOBAL("tvm.vta.driver")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Module dev_mod = args[0];
DLTensor* A = args[1];
DLTensor* B = args[2];
TestDriver dev_(dev_mod);
dev_.Run(A->shape[0], A->data, B->data);
});
} // namespace driver
} // namespace vta
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
import numpy as np
from tsim.load import load_driver, load_tsim
def test_tsim(i):
rmin = 1 # min vector size of 1
rmax = 64
n = np.random.randint(rmin, rmax)
ctx = tvm.cpu(0)
a = tvm.nd.array(np.random.randint(rmax, size=n).astype("uint64"), ctx)
b = tvm.nd.array(np.zeros(n).astype("uint64"), ctx)
tsim = load_tsim()
f = load_driver()
f(tsim, a, b)
emsg = "[FAIL] test number:{} n:{}".format(i, n)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1, err_msg=emsg)
print("[PASS] test number:{} n:{}".format(i, n))
if __name__ == "__main__":
for i in range(10):
test_tsim(i)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
clean:
-rm -rf target project/target project/project
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
name := "vta"
version := "0.1.0-SNAPSHOT"
organization := "edu.washington.cs"
def scalacOptionsVersion(scalaVersion: String): Seq[String] = {
Seq() ++ {
// If we're building with Scala > 2.11, enable the compile option
// switch to support our anonymous Bundle definitions:
// https://github.com/scala/bug/issues/10047
CrossVersion.partialVersion(scalaVersion) match {
case Some((2, scalaMajor: Long)) if scalaMajor < 12 => Seq()
case _ => Seq(
"-Xsource:2.11",
"-language:reflectiveCalls",
"-language:implicitConversions",
"-deprecation",
"-Xlint",
"-Ywarn-unused",
)
}
}
}
def javacOptionsVersion(scalaVersion: String): Seq[String] = {
Seq() ++ {
// Scala 2.12 requires Java 8. We continue to generate
// Java 7 compatible code for Scala 2.11
// for compatibility with old clients.
CrossVersion.partialVersion(scalaVersion) match {
case Some((2, scalaMajor: Long)) if scalaMajor < 12 =>
Seq("-source", "1.7", "-target", "1.7")
case _ =>
Seq("-source", "1.8", "-target", "1.8")
}
}
}
scalaVersion := "2.11.12"
resolvers ++= Seq(
Resolver.sonatypeRepo("snapshots"),
Resolver.sonatypeRepo("releases"))
libraryDependencies ++= Seq(
"edu.berkeley.cs" %% "chisel3" % "3.1.7",
)
scalacOptions ++= scalacOptionsVersion(scalaVersion.value)
javacOptions ++= javacOptionsVersion(scalaVersion.value)
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
sbt.version = 1.1.1
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
logLevel := Level.Warn
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module VTAHostDPI #
( parameter ADDR_BITS = 8,
parameter DATA_BITS = 32
)
(
input clock,
input reset,
output logic dpi_req_valid,
output logic dpi_req_opcode,
output logic [ADDR_BITS-1:0] dpi_req_addr,
output logic [DATA_BITS-1:0] dpi_req_value,
input dpi_req_deq,
input dpi_resp_valid,
input [DATA_BITS-1:0] dpi_resp_bits
);
import "DPI-C" function void VTAHostDPI
(
output byte unsigned exit,
output byte unsigned req_valid,
output byte unsigned req_opcode,
output byte unsigned req_addr,
output int unsigned req_value,
input byte unsigned req_deq,
input byte unsigned resp_valid,
input int unsigned resp_value
);
typedef logic dpi1_t;
typedef logic [7:0] dpi8_t;
typedef logic [31:0] dpi32_t;
dpi1_t __reset;
dpi8_t __exit;
dpi8_t __req_valid;
dpi8_t __req_opcode;
dpi8_t __req_addr;
dpi32_t __req_value;
dpi8_t __req_deq;
dpi8_t __resp_valid;
dpi32_t __resp_bits;
// reset
always_ff @(posedge clock) begin
__reset <= reset;
end
// delaying outputs by one-cycle
// since verilator does not support delays
always_ff @(posedge clock) begin
dpi_req_valid <= dpi1_t ' (__req_valid);
dpi_req_opcode <= dpi1_t ' (__req_opcode);
dpi_req_addr <= __req_addr;
dpi_req_value <= __req_value;
end
assign __req_deq = dpi8_t ' (dpi_req_deq);
assign __resp_valid = dpi8_t ' (dpi_resp_valid);
assign __resp_bits = dpi_resp_bits;
// evaluate DPI function
always_ff @(posedge clock) begin
if (reset | __reset) begin
__exit = 0;
__req_valid = 0;
__req_opcode = 0;
__req_addr = 0;
__req_value = 0;
end
else begin
VTAHostDPI(
__exit,
__req_valid,
__req_opcode,
__req_addr,
__req_value,
__req_deq,
__resp_valid,
__resp_bits);
end
end
logic [63:0] cycles;
always_ff @(posedge clock) begin
if (reset | __reset) begin
cycles <= 'd0;
end
else begin
cycles <= cycles + 1'b1;
end
end
always_ff @(posedge clock) begin
if (__exit == 'd1) begin
$display("[DONE] at cycle:%016d", cycles);
$finish;
end
end
endmodule
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module VTAMemDPI #
( parameter LEN_BITS = 8,
parameter ADDR_BITS = 64,
parameter DATA_BITS = 64
)
(
input clock,
input reset,
input dpi_req_valid,
input dpi_req_opcode,
input [LEN_BITS-1:0] dpi_req_len,
input [ADDR_BITS-1:0] dpi_req_addr,
input dpi_wr_valid,
input [DATA_BITS-1:0] dpi_wr_bits,
output logic dpi_rd_valid,
output logic [DATA_BITS-1:0] dpi_rd_bits,
input dpi_rd_ready
);
import "DPI-C" function void VTAMemDPI
(
input byte unsigned req_valid,
input byte unsigned req_opcode,
input byte unsigned req_len,
input longint unsigned req_addr,
input byte unsigned wr_valid,
input longint unsigned wr_value,
output byte unsigned rd_valid,
output longint unsigned rd_value,
input byte unsigned rd_ready
);
typedef logic dpi1_t;
typedef logic [7:0] dpi8_t;
typedef logic [31:0] dpi32_t;
typedef logic [63:0] dpi64_t;
dpi1_t __reset;
dpi8_t __req_valid;
dpi8_t __req_opcode;
dpi8_t __req_len;
dpi64_t __req_addr;
dpi8_t __wr_valid;
dpi64_t __wr_value;
dpi8_t __rd_valid;
dpi64_t __rd_value;
dpi8_t __rd_ready;
always_ff @(posedge clock) begin
__reset <= reset;
end
// delaying outputs by one-cycle
// since verilator does not support delays
always_ff @(posedge clock) begin
dpi_rd_valid <= dpi1_t ' (__rd_valid);
dpi_rd_bits <= __rd_value;
end
assign __req_valid = dpi8_t ' (dpi_req_valid);
assign __req_opcode = dpi8_t ' (dpi_req_opcode);
assign __req_len = dpi_req_len;
assign __req_addr = dpi_req_addr;
assign __wr_valid = dpi8_t ' (dpi_wr_valid);
assign __wr_value = dpi_wr_bits;
assign __rd_ready = dpi8_t ' (dpi_rd_ready);
// evaluate DPI function
always_ff @(posedge clock) begin
if (reset | __reset) begin
__rd_valid = 0;
__rd_value = 0;
end
else begin
VTAMemDPI(
__req_valid,
__req_opcode,
__req_len,
__req_addr,
__wr_valid,
__wr_value,
__rd_valid,
__rd_value,
__rd_ready);
end
end
endmodule
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package vta.dpi
import chisel3._
import chisel3.util._
/** Host DPI parameters */
trait VTAHostDPIParams {
val dpiAddrBits = 8
val dpiDataBits = 32
}
/** Host master interface.
*
* This interface is tipically used by the Host
*/
class VTAHostDPIMaster extends Bundle with VTAHostDPIParams {
val req = new Bundle {
val valid = Output(Bool())
val opcode = Output(Bool())
val addr = Output(UInt(dpiAddrBits.W))
val value = Output(UInt(dpiDataBits.W))
val deq = Input(Bool())
}
val resp = Flipped(ValidIO(UInt(dpiDataBits.W)))
}
/** Host client interface.
*
* This interface is tipically used by the Accelerator
*/
class VTAHostDPIClient extends Bundle with VTAHostDPIParams {
val req = new Bundle {
val valid = Input(Bool())
val opcode = Input(Bool())
val addr = Input(UInt(dpiAddrBits.W))
val value = Input(UInt(dpiDataBits.W))
val deq = Output(Bool())
}
val resp = ValidIO(UInt(dpiDataBits.W))
}
/** Host DPI module.
*
* Wrapper for Host Verilog DPI module.
*/
class VTAHostDPI extends BlackBox with HasBlackBoxResource {
val io = IO(new Bundle {
val clock = Input(Clock())
val reset = Input(Bool())
val dpi = new VTAHostDPIMaster
})
setResource("/verilog/VTAHostDPI.v")
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package vta.dpi
import chisel3._
import chisel3.util._
/** Memory DPI parameters */
trait VTAMemDPIParams {
val dpiLenBits = 8
val dpiAddrBits = 64
val dpiDataBits = 64
}
/** Memory master interface.
*
* This interface is tipically used by the Accelerator
*/
class VTAMemDPIMaster extends Bundle with VTAMemDPIParams {
val req = new Bundle {
val valid = Output(Bool())
val opcode = Output(Bool())
val len = Output(UInt(dpiLenBits.W))
val addr = Output(UInt(dpiAddrBits.W))
}
val wr = ValidIO(UInt(dpiDataBits.W))
val rd = Flipped(Decoupled(UInt(dpiDataBits.W)))
}
/** Memory client interface.
*
* This interface is tipically used by the Host
*/
class VTAMemDPIClient extends Bundle with VTAMemDPIParams {
val req = new Bundle {
val valid = Input(Bool())
val opcode = Input(Bool())
val len = Input(UInt(dpiLenBits.W))
val addr = Input(UInt(dpiAddrBits.W))
}
val wr = Flipped(ValidIO(UInt(dpiDataBits.W)))
val rd = Decoupled(UInt(dpiDataBits.W))
}
/** Memory DPI module.
*
* Wrapper for Memory Verilog DPI module.
*/
class VTAMemDPI extends BlackBox with HasBlackBoxResource {
val io = IO(new Bundle {
val clock = Input(Clock())
val reset = Input(Bool())
val dpi = new VTAMemDPIClient
})
setResource("/verilog/VTAMemDPI.v")
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <vta/dpi/tsim.h>
#if VM_TRACE
#include <verilated_vcd_c.h>
#endif
#if VM_TRACE
#define STRINGIZE(x) #x
#define STRINGIZE_VALUE_OF(x) STRINGIZE(x)
#endif
static VTAContextHandle _ctx = nullptr;
static VTAMemDPIFunc _mem_dpi = nullptr;
static VTAHostDPIFunc _host_dpi = nullptr;
void VTAHostDPI(dpi8_t* exit,
dpi8_t* req_valid,
dpi8_t* req_opcode,
dpi8_t* req_addr,
dpi32_t* req_value,
dpi8_t req_deq,
dpi8_t resp_valid,
dpi32_t resp_value) {
assert(_host_dpi != nullptr);
(*_host_dpi)(_ctx, exit, req_valid, req_opcode,
req_addr, req_value, req_deq,
resp_valid, resp_value);
}
void VTAMemDPI(dpi8_t req_valid,
dpi8_t req_opcode,
dpi8_t req_len,
dpi64_t req_addr,
dpi8_t wr_valid,
dpi64_t wr_value,
dpi8_t* rd_valid,
dpi64_t* rd_value,
dpi8_t rd_ready) {
assert(_mem_dpi != nullptr);
(*_mem_dpi)(_ctx, req_valid, req_opcode, req_len,
req_addr, wr_valid, wr_value,
rd_valid, rd_value, rd_ready);
}
void VTADPIInit(VTAContextHandle handle,
VTAHostDPIFunc host_dpi,
VTAMemDPIFunc mem_dpi) {
_ctx = handle;
_host_dpi = host_dpi;
_mem_dpi = mem_dpi;
}
int VTADPISim(uint64_t max_cycles) {
uint64_t trace_count = 0;
#if VM_TRACE
uint64_t start = 0;
#endif
VL_TSIM_NAME* top = new VL_TSIM_NAME;
#if VM_TRACE
Verilated::traceEverOn(true);
VerilatedVcdC* tfp = new VerilatedVcdC;
top->trace(tfp, 99);
tfp->open(STRINGIZE_VALUE_OF(TSIM_TRACE_FILE));
#endif
// reset
for (int i = 0; i < 10; i++) {
top->reset = 1;
top->clock = 0;
top->eval();
#if VM_TRACE
if (trace_count >= start)
tfp->dump(static_cast<vluint64_t>(trace_count * 2));
#endif
top->clock = 1;
top->eval();
#if VM_TRACE
if (trace_count >= start)
tfp->dump(static_cast<vluint64_t>(trace_count * 2 + 1));
#endif
trace_count++;
}
top->reset = 0;
// start simulation
while (!Verilated::gotFinish() && trace_count < max_cycles) {
top->clock = 0;
top->eval();
#if VM_TRACE
if (trace_count >= start)
tfp->dump(static_cast<vluint64_t>(trace_count * 2));
#endif
top->clock = 1;
top->eval();
#if VM_TRACE
if (trace_count >= start)
tfp->dump(static_cast<vluint64_t>(trace_count * 2 + 1));
#endif
trace_count++;
}
#if VM_TRACE
tfp->close();
#endif
delete top;
return 0;
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef VTA_DPI_MODULE_H_
#define VTA_DPI_MODULE_H_
#include <tvm/runtime/module.h>
#include <mutex>
#include <queue>
#include <condition_variable>
#include <string>
namespace vta {
namespace dpi {
/*!
* \brief DPI driver module for managing the accelerator
*/
class DPIModuleNode : public tvm::runtime::ModuleNode {
public:
/*!
* \brief Launch accelerator until it finishes or reach max_cycles
* \param max_cycles The maximum of cycles to wait
*/
virtual void Launch(uint64_t max_cycles) = 0;
/*!
* \brief Write an accelerator register
* \param addr The register address
* \param value The register value
*/
virtual void WriteReg(int addr, uint32_t value) = 0;
/*!
* \brief Read an accelerator register
* \param addr The register address
*/
virtual uint32_t ReadReg(int addr) = 0;
/*! \brief Kill or Exit() the accelerator */
virtual void Finish() = 0;
static tvm::runtime::Module Load(std::string dll_name);
};
} // namespace dpi
} // namespace vta
#endif // VTA_DPI_MODULE_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef VTA_DPI_TSIM_H_
#define VTA_DPI_TSIM_H_
#include <tvm/runtime/c_runtime_api.h>
#include <stdint.h>
#ifdef __cplusplus
extern "C" {
#endif
typedef unsigned char dpi8_t;
typedef unsigned int dpi32_t;
typedef unsigned long long dpi64_t; // NOLINT(*)
/*! \brief the context handle */
typedef void* VTAContextHandle;
/*!
* \brief Host DPI callback function that is invoked in VTAHostDPI.v every clock cycle
* \param exit Host kill simulation
* \param req_valid Host has a valid request for read or write a register in Accel
* \param req_opcode Host request type, opcode=0 for read and opcode=1 for write
* \param req_addr Host request register address
* \param req_value Host request value to be written to a register
* \param req_deq Accel is ready to dequeue Host request
* \param resp_valid Accel has a valid response for Host
* \param resp_value Accel response value for Host
* \return 0 if success,
*/
typedef void (*VTAHostDPIFunc)(
VTAContextHandle self,
dpi8_t* exit,
dpi8_t* req_valid,
dpi8_t* req_opcode,
dpi8_t* req_addr,
dpi32_t* req_value,
dpi8_t req_deq,
dpi8_t resp_valid,
dpi32_t resp_value);
/*!
* \brief Memory DPI callback function that is invoked in VTAMemDPI.v every clock cycle
* \param req_valid Accel has a valid request for Host
* \param req_opcode Accel request type, opcode=0 (read) and opcode=1 (write)
* \param req_len Accel request length of size 8-byte and starts at 0
* \param req_addr Accel request base address
* \param wr_valid Accel has a valid value for Host
* \param wr_value Accel has a value to be written Host
* \param rd_valid Host has a valid value for Accel
* \param rd_value Host has a value to be read by Accel
*/
typedef void (*VTAMemDPIFunc)(
VTAContextHandle self,
dpi8_t req_valid,
dpi8_t req_opcode,
dpi8_t req_len,
dpi64_t req_addr,
dpi8_t wr_valid,
dpi64_t wr_value,
dpi8_t* rd_valid,
dpi64_t* rd_value,
dpi8_t rd_ready);
/*! \brief The type of VTADPIInit function pointer */
typedef void (*VTADPIInitFunc)(VTAContextHandle handle,
VTAHostDPIFunc host_dpi,
VTAMemDPIFunc mem_dpi);
/*! \brief The type of VTADPISim function pointer */
typedef int (*VTADPISimFunc)(uint64_t max_cycles);
/*!
* \brief Set Host and Memory DPI functions
* \param handle DPI Context handle
* \param host_dpi Host DPI function
* \param mem_dpi Memory DPI function
*/
TVM_DLL void VTADPIInit(VTAContextHandle handle,
VTAHostDPIFunc host_dpi,
VTAMemDPIFunc mem_dpi);
/*!
* \brief Instantiate VTA design and generate clock/reset
* \param max_cycles The maximum number of simulation cycles
*/
TVM_DLL int VTADPISim(uint64_t max_cycles);
#ifdef __cplusplus
}
#endif
#endif // VTA_DPI_TSIM_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/runtime/module.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <vta/dpi/module.h>
#include <vta/dpi/tsim.h>
#if defined(_WIN32)
#include <windows.h>
#else
#include <dlfcn.h>
#endif
#include <mutex>
#include <queue>
#include <thread>
#include <condition_variable>
namespace vta {
namespace dpi {
using namespace tvm::runtime;
typedef void* DeviceHandle;
struct HostRequest {
uint8_t opcode;
uint8_t addr;
uint32_t value;
};
struct HostResponse {
uint32_t value;
};
struct MemResponse {
uint8_t valid;
uint64_t value;
};
template <typename T>
class ThreadSafeQueue {
public:
void Push(const T item) {
std::lock_guard<std::mutex> lock(mutex_);
queue_.push(std::move(item));
cond_.notify_one();
}
void WaitPop(T* item) {
std::unique_lock<std::mutex> lock(mutex_);
cond_.wait(lock, [this]{return !queue_.empty();});
*item = std::move(queue_.front());
queue_.pop();
}
bool TryPop(T* item, bool pop) {
std::lock_guard<std::mutex> lock(mutex_);
if (queue_.empty()) return false;
*item = std::move(queue_.front());
if (pop) queue_.pop();
return true;
}
private:
mutable std::mutex mutex_;
std::queue<T> queue_;
std::condition_variable cond_;
};
class HostDevice {
public:
void PushRequest(uint8_t opcode, uint8_t addr, uint32_t value);
bool TryPopRequest(HostRequest* r, bool pop);
void PushResponse(uint32_t value);
void WaitPopResponse(HostResponse* r);
void Exit();
uint8_t GetExitStatus();
private:
uint8_t exit_{0};
mutable std::mutex mutex_;
ThreadSafeQueue<HostRequest> req_;
ThreadSafeQueue<HostResponse> resp_;
};
class MemDevice {
public:
void SetRequest(uint8_t opcode, uint64_t addr, uint32_t len);
MemResponse ReadData(uint8_t ready);
void WriteData(uint64_t value);
private:
uint64_t* raddr_{0};
uint64_t* waddr_{0};
uint32_t rlen_{0};
uint32_t wlen_{0};
std::mutex mutex_;
};
void HostDevice::PushRequest(uint8_t opcode, uint8_t addr, uint32_t value) {
HostRequest r;
r.opcode = opcode;
r.addr = addr;
r.value = value;
req_.Push(r);
}
bool HostDevice::TryPopRequest(HostRequest* r, bool pop) {
r->opcode = 0xad;
r->addr = 0xad;
r->value = 0xbad;
return req_.TryPop(r, pop);
}
void HostDevice::PushResponse(uint32_t value) {
HostResponse r;
r.value = value;
resp_.Push(r);
}
void HostDevice::WaitPopResponse(HostResponse* r) {
resp_.WaitPop(r);
}
void HostDevice::Exit() {
std::unique_lock<std::mutex> lock(mutex_);
exit_ = 1;
}
uint8_t HostDevice::GetExitStatus() {
std::unique_lock<std::mutex> lock(mutex_);
return exit_;
}
void MemDevice::SetRequest(uint8_t opcode, uint64_t addr, uint32_t len) {
std::lock_guard<std::mutex> lock(mutex_);
if (opcode == 1) {
wlen_ = len + 1;
waddr_ = reinterpret_cast<uint64_t*>(addr);
} else {
rlen_ = len + 1;
raddr_ = reinterpret_cast<uint64_t*>(addr);
}
}
MemResponse MemDevice::ReadData(uint8_t ready) {
std::lock_guard<std::mutex> lock(mutex_);
MemResponse r;
r.valid = rlen_ > 0;
r.value = rlen_ > 0 ? *raddr_ : 0xdeadbeefdeadbeef;
if (ready == 1 && rlen_ > 0) {
raddr_++;
rlen_ -= 1;
}
return r;
}
void MemDevice::WriteData(uint64_t value) {
std::lock_guard<std::mutex> lock(mutex_);
if (wlen_ > 0) {
*waddr_ = value;
waddr_++;
wlen_ -= 1;
}
}
class DPIModule final : public DPIModuleNode {
public:
~DPIModule() {
if (lib_handle_) Unload();
}
const char* type_key() const final {
return "vta-tsim";
}
PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final {
if (name == "WriteReg") {
return TypedPackedFunc<void(int, int)>(
[this](int addr, int value){
this->WriteReg(addr, value);
});
} else {
LOG(FATAL) << "Member " << name << "does not exists";
return nullptr;
}
}
void Init(const std::string& name) {
Load(name);
VTADPIInitFunc finit = reinterpret_cast<VTADPIInitFunc>(
GetSymbol("VTADPIInit"));
CHECK(finit != nullptr);
finit(this, VTAHostDPI, VTAMemDPI);
fvsim_ = reinterpret_cast<VTADPISimFunc>(GetSymbol("VTADPISim"));
CHECK(fvsim_ != nullptr);
}
void Launch(uint64_t max_cycles) {
auto frun = [this, max_cycles]() {
(*fvsim_)(max_cycles);
};
vsim_thread_ = std::thread(frun);
}
void WriteReg(int addr, uint32_t value) {
host_device_.PushRequest(1, addr, value);
}
uint32_t ReadReg(int addr) {
uint32_t value;
HostResponse* r = new HostResponse;
host_device_.PushRequest(0, addr, 0);
host_device_.WaitPopResponse(r);
value = r->value;
delete r;
return value;
}
void Finish() {
host_device_.Exit();
vsim_thread_.join();
}
protected:
VTADPISimFunc fvsim_;
HostDevice host_device_;
MemDevice mem_device_;
std::thread vsim_thread_;
void HostDPI(dpi8_t* exit,
dpi8_t* req_valid,
dpi8_t* req_opcode,
dpi8_t* req_addr,
dpi32_t* req_value,
dpi8_t req_deq,
dpi8_t resp_valid,
dpi32_t resp_value) {
HostRequest* r = new HostRequest;
*exit = host_device_.GetExitStatus();
*req_valid = host_device_.TryPopRequest(r, req_deq);
*req_opcode = r->opcode;
*req_addr = r->addr;
*req_value = r->value;
if (resp_valid) {
host_device_.PushResponse(resp_value);
}
delete r;
}
void MemDPI(
dpi8_t req_valid,
dpi8_t req_opcode,
dpi8_t req_len,
dpi64_t req_addr,
dpi8_t wr_valid,
dpi64_t wr_value,
dpi8_t* rd_valid,
dpi64_t* rd_value,
dpi8_t rd_ready) {
MemResponse r = mem_device_.ReadData(rd_ready);
*rd_valid = r.valid;
*rd_value = r.value;
if (wr_valid) {
mem_device_.WriteData(wr_value);
}
if (req_valid) {
mem_device_.SetRequest(req_opcode, req_addr, req_len);
}
}
static void VTAHostDPI(
VTAContextHandle self,
dpi8_t* exit,
dpi8_t* req_valid,
dpi8_t* req_opcode,
dpi8_t* req_addr,
dpi32_t* req_value,
dpi8_t req_deq,
dpi8_t resp_valid,
dpi32_t resp_value) {
static_cast<DPIModule*>(self)->HostDPI(
exit, req_valid, req_opcode, req_addr,
req_value, req_deq, resp_valid, resp_value);
}
static void VTAMemDPI(
VTAContextHandle self,
dpi8_t req_valid,
dpi8_t req_opcode,
dpi8_t req_len,
dpi64_t req_addr,
dpi8_t wr_valid,
dpi64_t wr_value,
dpi8_t* rd_valid,
dpi64_t* rd_value,
dpi8_t rd_ready) {
static_cast<DPIModule*>(self)->MemDPI(
req_valid, req_opcode, req_len,
req_addr, wr_valid, wr_value,
rd_valid, rd_value, rd_ready);
}
private:
// Platform dependent handling.
#if defined(_WIN32)
// library handle
HMODULE lib_handle_{nullptr};
// Load the library
void Load(const std::string& name) {
// use wstring version that is needed by LLVM.
std::wstring wname(name.begin(), name.end());
lib_handle_ = LoadLibraryW(wname.c_str());
CHECK(lib_handle_ != nullptr)
<< "Failed to load dynamic shared library " << name;
}
void* GetSymbol(const char* name) {
return reinterpret_cast<void*>(
GetProcAddress(lib_handle_, (LPCSTR)name)); // NOLINT(*)
}
void Unload() {
FreeLibrary(lib_handle_);
}
#else
// Library handle
void* lib_handle_{nullptr};
// load the library
void Load(const std::string& name) {
lib_handle_ = dlopen(name.c_str(), RTLD_LAZY | RTLD_LOCAL);
CHECK(lib_handle_ != nullptr)
<< "Failed to load dynamic shared library " << name
<< " " << dlerror();
}
void* GetSymbol(const char* name) {
return dlsym(lib_handle_, name);
}
void Unload() {
dlclose(lib_handle_);
}
#endif
};
Module DPIModuleNode::Load(std::string dll_name) {
std::shared_ptr<DPIModule> n =
std::make_shared<DPIModule>();
n->Init(dll_name);
return Module(n);
}
TVM_REGISTER_GLOBAL("module.loadfile_vta-tsim")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = DPIModuleNode::Load(args[0]);
});
} // namespace dpi
} // namespace vta
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment