Commit 47e50e1e by Benjamin Tu Committed by Thierry Moreau

[VTA][TSIM] Serial GEMM Application Added (#4082)

* app init push

* fix on readme

* change name, add bit serial explanantion

* rm serialLoadMM, change doc

* syntax change for readme

* add parallel test functionality

* fix readme

* add python doc

* syntax
parent aad48ff1
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
cmake_minimum_required(VERSION 3.2)
project(tsim C CXX)
set(TVM_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../../)
set(VTA_DIR ${TVM_DIR}/vta)
include_directories("${TVM_DIR}/include")
include_directories("${TVM_DIR}/3rdparty/dlpack/include")
include_directories("${TVM_DIR}/3rdparty/dmlc-core/include")
include_directories("${TVM_DIR}/vta/src/dpi")
set(CMAKE_C_FLAGS "-O2 -Wall -fPIC -fvisibility=hidden")
set(CMAKE_CXX_FLAGS "-O2 -Wall -fPIC -fvisibility=hidden -std=c++11")
if (CMAKE_CXX_COMPILER_ID MATCHES "GNU" AND
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0)
set(CMAKE_CXX_FLAGS "-faligned-new ${CMAKE_CXX_FLAGS}")
endif()
file(GLOB TSIM_SW_SRC src/driver.cc)
list(APPEND TSIM_SW_SRC ${VTA_DIR}/src/vmem/virtual_memory.cc)
list(APPEND TSIM_SW_SRC ${VTA_DIR}/src/dpi/module.cc)
add_library(sw SHARED ${TSIM_SW_SRC})
target_include_directories(sw PRIVATE ${VTA_DIR}/include ${VTA_DIR}/src)
if(APPLE)
set_target_properties(sw PROPERTIES LINK_FLAGS "-undefined dynamic_lookup")
endif(APPLE)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
export PYTHONPATH:=$(PWD)/python:$(PYTHONPATH)
BUILD_NAME = build
build_dir = $(abspath .)/$(BUILD_NAME)
default: chisel driver
python3 tests/python/chisel_accel.py serial
serial:
python3 tests/python/chisel_accel.py serial
parallel:
python3 tests/python/chisel_accel.py parallel
driver: | $(build_dir)
cd $(build_dir) && cmake .. && make
$(build_dir):
mkdir -p $@
chisel:
make -C hardware/chisel
clean:
-rm -rf $(build_dir)
make -C hardware/chisel clean
<!--- Licensed to the Apache Software Foundation (ASF) under one -->
<!--- or more contributor license agreements. See the NOTICE file -->
<!--- distributed with this work for additional information -->
<!--- regarding copyright ownership. The ASF licenses this file -->
<!--- to you under the Apache License, Version 2.0 (the -->
<!--- "License"); you may not use this file except in compliance -->
<!--- with the License. You may obtain a copy of the License at -->
<!--- http://www.apache.org/licenses/LICENSE-2.0 -->
<!--- Unless required by applicable law or agreed to in writing, -->
<!--- software distributed under the License is distributed on an -->
<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
<!--- KIND, either express or implied. See the License for the -->
<!--- specific language governing permissions and limitations -->
<!--- under the License. -->
VTA TSIM Application
======================
Prior to this application, please take a look at `<tvm-root>/vta/apps/tsim_example` for installation
This is an application that performs Bit Serial Multiplication for GEMM utilizing TSIM.
**Bit Serial Multiplication for GEMM:**
General Matrix Multiplications (GEMM), are mostly calculated by repeatly calculating the dot product for each pair of vectors.
The dot product is calculated by summing every product of the vector pair.
We approach this operation with slicing and shifting, like how basic multiplication works, each vector elements before we accumulate them.
We can sufficiently reduce the cycles required to perform a gemm given that the data bit width is small. This GEMM application uses TSIM for future accerlerator prototypes.
* Test Chisel3 backend with bit serial GEMM
* Go to `<tvm-root>/vta/apps/gemm`
* Run `make`
* If you have already compiled chisel backend (i.e. ran `make`)
* Bit Serial test with another input set, run `make serial`
* Bit parallel test with another input set, run `make parallel`
* Some steps for creating your own custom TSIM application
* Go to `<tvm-root>/vta/apps/gemm`
* Create custom circuit within `./hardware/chisel/src/scala.main/accel/Compute.scala`
* Map the according Registers in `./hardware/chisel/src/scala.main/accel/RegFile.scala`
* Create your test script
* Map the registers in `./src/driver.cc` and link it with both `RegFile.scala` and the test script
* Understanding of `<tvm-root>/vta/apps/tsim_example`, which performs add by one to a vector, is highly encouraged to create a more complex application
* Some pointers
* Chisel3 tests in `<tvm-root>/vta/apps/gemm/tests/python`
* Chisel3 accelerator backend `<tvm-root>/vta/apps/gemm/hardware/chisel`
* Software C++ driver (backend) that handles the accelerator `<tvm-root>/vta/apps/gemm/src/driver.cc`
* Software Python driver (frontend) that handles the accelerator `<tvm-root>/vta/apps/gemm/python/accel`
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
ifeq (, $(shell which verilator))
$(error "No Verilator in $(PATH), consider doing apt-get install verilator")
endif
# Change VERILATOR_INC_DIR if Verilator is installed on a different location
ifeq (, $(VERILATOR_INC_DIR))
ifeq (, $(wildcard /usr/local/share/verilator/include/*))
ifeq (, $(wildcard /usr/share/verilator/include/*))
$(error "Verilator include directory is not set properly")
else
VERILATOR_INC_DIR := /usr/share/verilator/include
endif
else
VERILATOR_INC_DIR := /usr/local/share/verilator/include
endif
endif
TOP = TestAccel
BUILD_NAME = build
USE_TRACE = 1
LIBNAME = libhw
vta_dir = $(abspath ../../../../)
tvm_dir = $(abspath ../../../../../)
build_dir = $(abspath .)/$(BUILD_NAME)
verilator_build_dir = $(build_dir)/verilator
chisel_build_dir = $(build_dir)/chisel
verilator_opt = --cc
verilator_opt += +define+RANDOMIZE_GARBAGE_ASSIGN
verilator_opt += +define+RANDOMIZE_REG_INIT
verilator_opt += +define+RANDOMIZE_MEM_INIT
verilator_opt += --x-assign unique
verilator_opt += --output-split 20000
verilator_opt += --output-split-cfuncs 20000
verilator_opt += --top-module ${TOP}
verilator_opt += -Mdir ${verilator_build_dir}
verilator_opt += -I$(chisel_build_dir)
cxx_flags = -O2 -Wall -fPIC -shared
cxx_flags += -fvisibility=hidden -std=c++11
cxx_flags += -DVL_TSIM_NAME=V$(TOP)
cxx_flags += -DVL_PRINTF=printf
cxx_flags += -DVL_USER_FINISH
cxx_flags += -DVM_COVERAGE=0
cxx_flags += -DVM_SC=0
cxx_flags += -Wno-sign-compare
cxx_flags += -include V$(TOP).h
cxx_flags += -I$(verilator_build_dir)
cxx_flags += -I$(VERILATOR_INC_DIR)
cxx_flags += -I$(VERILATOR_INC_DIR)/vltstd
cxx_flags += -I$(vta_dir)/include
cxx_flags += -I$(tvm_dir)/include
cxx_flags += -I$(tvm_dir)/3rdparty/dlpack/include
cxx_files = $(VERILATOR_INC_DIR)/verilated.cpp
cxx_files += $(VERILATOR_INC_DIR)/verilated_dpi.cpp
cxx_files += $(wildcard $(verilator_build_dir)/*.cpp)
cxx_files += $(vta_dir)/hardware/dpi/tsim_device.cc
ifneq ($(USE_TRACE), 0)
verilator_opt += --trace
cxx_flags += -DVM_TRACE=1
cxx_flags += -DTSIM_TRACE_FILE=$(verilator_build_dir)/$(TOP).vcd
cxx_files += $(VERILATOR_INC_DIR)/verilated_vcd_c.cpp
else
cxx_flags += -DVM_TRACE=0
endif
# The following is to be consistent with cmake
ifeq ($(shell uname), Darwin)
lib_path = $(build_dir)/$(LIBNAME).dylib
else
lib_path = $(build_dir)/$(LIBNAME).so
endif
default: lib
lib: $(lib_path)
$(lib_path): $(verilator_build_dir)/V$(TOP).cpp
g++ $(cxx_flags) $(cxx_files) -o $@
verilator: $(verilator_build_dir)/V$(TOP).cpp
$(verilator_build_dir)/V$(TOP).cpp: $(chisel_build_dir)/$(TOP).v
verilator $(verilator_opt) $<
verilog: $(chisel_build_dir)/$(TOP).v
$(chisel_build_dir)/$(TOP).v: install_vta_package
sbt 'test:runMain test.Elaborate --target-dir $(chisel_build_dir) --top-name $(TOP)'
install_vta_package:
cd $(vta_dir)/hardware/chisel && sbt publishLocal
clean:
-rm -rf $(build_dir) target project/target project/project
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
name := "accel"
version := "0.1.0-SNAPSHOT"
organization := "edu.washington.cs"
def scalacOptionsVersion(scalaVersion: String): Seq[String] = {
Seq() ++ {
// If we're building with Scala > 2.11, enable the compile option
// switch to support our anonymous Bundle definitions:
// https://github.com/scala/bug/issues/10047
CrossVersion.partialVersion(scalaVersion) match {
case Some((2, scalaMajor: Long)) if scalaMajor < 12 => Seq()
case _ => Seq(
"-Xsource:2.11",
"-language:reflectiveCalls",
"-language:implicitConversions",
"-deprecation",
"-Xlint",
"-Ywarn-unused",
)
}
}
}
def javacOptionsVersion(scalaVersion: String): Seq[String] = {
Seq() ++ {
// Scala 2.12 requires Java 8. We continue to generate
// Java 7 compatible code for Scala 2.11
// for compatibility with old clients.
CrossVersion.partialVersion(scalaVersion) match {
case Some((2, scalaMajor: Long)) if scalaMajor < 12 =>
Seq("-source", "1.7", "-target", "1.7")
case _ =>
Seq("-source", "1.8", "-target", "1.8")
}
}
}
scalaVersion := "2.11.12"
resolvers ++= Seq(
Resolver.sonatypeRepo("snapshots"),
Resolver.sonatypeRepo("releases"))
libraryDependencies ++= Seq(
"edu.berkeley.cs" %% "chisel3" % "3.1.7",
"edu.washington.cs" %% "vta" % "0.1.0-SNAPSHOT",
)
scalacOptions ++= scalacOptionsVersion(scalaVersion.value)
javacOptions ++= javacOptionsVersion(scalaVersion.value)
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
sbt.version = 1.1.1
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
logLevel := Level.Warn
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package accel
import chisel3._
import vta.dpi._
/** Add-by-one accelerator.
*
* ___________ ___________
* | | | |
* | HostDPI | <--> | RegFile | <->|
* |_________| |_________| |
* |
* ___________ ___________ |
* | | | | |
* | MemDPI | <--> | Compute | <->|
* |_________| |_________|
*
*/
case class AccelConfig() {
val nCtrl = 1
val nECnt = 1
val nVals = 4
val nPtrs = 3
val regBits = 32
val ptrBits = 2*regBits
}
class Accel extends Module {
val io = IO(new Bundle {
val host = new VTAHostDPIClient
val mem = new VTAMemDPIMaster
})
implicit val config = AccelConfig()
val rf = Module(new RegFile)
val ce = Module(new Compute)
rf.io.host <> io.host
io.mem <> ce.io.mem
ce.io.launch := rf.io.launch
rf.io.finish := ce.io.finish
rf.io.ecnt <> ce.io.ecnt
ce.io.vals <> rf.io.vals
ce.io.ptrs <> rf.io.ptrs
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package accel
import chisel3._
import chisel3.util._
import vta.dpi._
/** Compute
*
* Bit Slice GEMM:
*
* 1. Wait for launch to be asserted
* 2. Issue 2 read request for 8-byte value at inp1_baddr address and inp2_baddr address
* 3. Wait for the value
* 4. Increment read-address for next value
* 5. Wait for sliced accumulator
* 6. Check if counter (cnt) is equal to length process,
otherwise goto step 2
* 7. Check if reset slice accumulator
* 8. Wait for overall accumulator
* 8. Issue a write request for 8-byte value at out_baddr address
*/
class Compute(implicit config: AccelConfig) extends Module {
val io = IO(new Bundle {
val launch = Input(Bool())
val finish = Output(Bool())
val ecnt = Vec(config.nECnt, ValidIO(UInt(config.regBits.W)))
val vals = Input(Vec(config.nVals, UInt(config.regBits.W)))
val ptrs = Input(Vec(config.nPtrs, UInt(config.ptrBits.W)))
val mem = new VTAMemDPIMaster
})
val sIdle :: sReadAReq :: sReadAData :: sReadBReq :: sReadBData :: sWriteReq :: sWriteData :: Nil = Enum(7)
val state = RegInit(sIdle)
val shift = io.vals(0)
val length = io.vals(1)
val rstAccum = io.vals(2)
val startDot = io.vals(3)
val cycles = RegInit(0.U(config.regBits.W))
val reg1 = Reg(chiselTypeOf(io.mem.rd.bits))
val reg2 = Reg(chiselTypeOf(io.mem.rd.bits))
val cnt = Reg(UInt(config.regBits.W))
val raddr1 = Reg(UInt(config.ptrBits.W))
val raddr2 = Reg(UInt(config.ptrBits.W))
val waddr = Reg(UInt(config.ptrBits.W))
switch (state) {
is (sIdle) {
when (io.launch) {
state := sReadAReq
}
}
// Read
is (sReadAReq) {
state := sReadAData
}
is (sReadAData) {
when (io.mem.rd.valid) {
state := sReadBReq
}
}
is (sReadBReq) {
state := sReadBData
}
is (sReadBData) {
when (io.mem.rd.valid) {
state := sWriteReq
}
}
// Write
is (sWriteReq) {
state := sWriteData
}
is (sWriteData) {
when (cnt === (length - 1.U)) {
state := sIdle
} .otherwise {
state := sReadAReq
}
}
}
val last = state === sWriteData && cnt === (length - 1.U)
// cycle counter
when (state === sIdle) {
cycles := 0.U
} .otherwise {
cycles := cycles + 1.U
}
io.ecnt(0).valid := last
io.ecnt(0).bits := cycles
// calculate next address
when (state === sIdle) {
raddr1 := io.ptrs(0)
raddr2 := io.ptrs(1)
waddr := io.ptrs(2)
} .elsewhen (state === sWriteData) { // increment input array by 1-byte
raddr1 := raddr1 + 1.U
raddr2 := raddr2 + 1.U
waddr := waddr
}
// create request
io.mem.req.valid := state === sReadAReq | state === sReadBReq | state === sWriteReq
io.mem.req.opcode := state === sWriteReq
io.mem.req.len := 0.U // one-word-per-request
io.mem.req.addr := Mux(state === sReadAReq | state === sReadBReq, Mux(state === sReadAReq, raddr1, raddr2), waddr)
// read
when (state === sReadAData && io.mem.rd.valid) {
reg1 := io.mem.rd.bits(7, 0)
}
when (state === sReadBData && io.mem.rd.valid) {
reg2 := io.mem.rd.bits(7, 0)
}
io.mem.rd.ready := state === sReadAData | state === sReadBData
val sliceAccum = Module(new Accumulator(63))
val overallAccum = Module(new Accumulator(64))
sliceAccum.io.valid := state === sWriteReq // 2 inputs have been processed
sliceAccum.io.in := reg1 * reg2
sliceAccum.io.clear := startDot
overallAccum.io.clear := rstAccum
overallAccum.io.valid := last // last element has been processed
overallAccum.io.in := sliceAccum.io.sum << shift(7,0) // limit to 8 bits
// write
io.mem.wr.valid := overallAccum.io.ready
io.mem.wr.bits := overallAccum.io.sum
// count read/write
when (state === sIdle) {
cnt := 0.U
} .elsewhen (state === sWriteData) {
cnt := cnt + 1.U
}
io.finish := overallAccum.io.ready // data has been added
}
class Accumulator(dataBits: Int = 8) extends Module {
val io = IO(new Bundle {
val clear = Input(Bool())
val valid = Input(Bool())
val ready = Output(Bool())
val in = Input(UInt(dataBits.W))
val sum = Output(UInt((dataBits).W))
})
val reg = RegInit(0.U((dataBits).W))
val ready = RegNext(io.valid)
when (io.clear) {
reg := 0.U
} .elsewhen (io.valid) {
reg := reg + io.in
}
io.ready := ready
io.sum := reg
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package accel
import chisel3._
import chisel3.util._
import vta.dpi._
/** Register File.
*
* Six 32-bit register file.
*
* -------------------------------
* Register description | addr
* -------------------------|-----
* Control status register | 0x00
* Cycle counter | 0x04
* Shift value | 0x08
* Vector length | 0x0c
* Reset Accumulator | 0x10
* Reset Dot Module | 0x14
* Input1 pointer lsb | 0x18
* Input1 pointer msb | 0x1c
* Input2 pointer lsb | 0x20
* Input2 pointer msb | 0x24
* Output pointer lsb | 0x28
* Output pointer msb | 0x2c
* -------------------------------
* ------------------------------
* Control status register | bit
* ------------------------------
* Launch | 0
* Finish | 1
* ------------------------------
*/
class RegFile(implicit config: AccelConfig) extends Module {
val io = IO(new Bundle {
val launch = Output(Bool())
val finish = Input(Bool())
val ecnt = Vec(config.nECnt, Flipped(ValidIO(UInt(config.regBits.W))))
val vals = Output(Vec(config.nVals, UInt(config.regBits.W)))
val ptrs = Output(Vec(config.nPtrs, UInt(config.ptrBits.W)))
val host = new VTAHostDPIClient
})
val sIdle :: sRead :: Nil = Enum(2)
val state = RegInit(sIdle)
switch (state) {
is (sIdle) {
when (io.host.req.valid && !io.host.req.opcode) {
state := sRead
}
}
is (sRead) {
state := sIdle
}
}
io.host.req.deq := state === sIdle & io.host.req.valid
val nTotal = config.nCtrl + config.nECnt + config.nVals + (2*config.nPtrs)
val reg = Seq.fill(nTotal)(RegInit(0.U.asTypeOf(chiselTypeOf(io.host.req.value))))
val addr = Seq.tabulate(nTotal)(_ * 4)
val reg_map = (addr zip reg) map { case (a, r) => a.U -> r }
val eo = config.nCtrl
val vo = eo + config.nECnt
val po = vo + config.nVals
when (io.finish) {
reg(0) := "b_10".U
} .elsewhen (state === sIdle && io.host.req.valid &&
io.host.req.opcode && addr(0).U === io.host.req.addr) {
reg(0) := io.host.req.value
}
for (i <- 0 until config.nECnt) {
when (io.ecnt(i).valid) {
reg(eo + i) := io.ecnt(i).bits
} .elsewhen (state === sIdle && io.host.req.valid &&
io.host.req.opcode && addr(eo + i).U === io.host.req.addr) {
reg(eo + i) := io.host.req.value
}
}
for (i <- 0 until (config.nVals + (2*config.nPtrs))) {
when (state === sIdle && io.host.req.valid &&
io.host.req.opcode && addr(vo + i).U === io.host.req.addr) {
reg(vo + i) := io.host.req.value
}
}
val rdata = RegInit(0.U.asTypeOf(chiselTypeOf(io.host.req.value)))
when (state === sIdle && io.host.req.valid && !io.host.req.opcode) {
rdata := MuxLookup(io.host.req.addr, 0.U, reg_map)
}
io.host.resp.valid := state === sRead
io.host.resp.bits := rdata
io.launch := reg(0)(0)
for (i <- 0 until config.nVals) {
io.vals(i) := reg(vo + i)
}
for (i <- 0 until config.nPtrs) {
io.ptrs(i) := Cat(reg(po + 2*i + 1), reg(po + 2*i))
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package test
import chisel3._
import chisel3.experimental.MultiIOModule
import vta.dpi._
import accel._
/** VTA simulation shell.
*
* Instantiate Host and Memory DPI modules.
*
*/
class VTASimShell extends MultiIOModule {
val host = IO(new VTAHostDPIMaster)
val mem = IO(new VTAMemDPIClient)
val sim_clock = IO(Input(Clock()))
val sim_wait = IO(Output(Bool()))
val mod_sim = Module(new VTASimDPI)
val mod_host = Module(new VTAHostDPI)
val mod_mem = Module(new VTAMemDPI)
mod_mem.io.clock := clock
mod_mem.io.reset := reset
mod_mem.io.dpi <> mem
mod_host.io.clock := clock
mod_host.io.reset := reset
host <> mod_host.io.dpi
mod_sim.io.clock := sim_clock
mod_sim.io.reset := reset
sim_wait := mod_sim.io.dpi_wait
}
/** Test accelerator.
*
* Instantiate and connect the simulation-shell and the accelerator.
*
*/
class TestAccel extends MultiIOModule {
val sim_clock = IO(Input(Clock()))
val sim_wait = IO(Output(Bool()))
val sim_shell = Module(new VTASimShell)
val vta_accel = Module(new Accel)
sim_shell.sim_clock := sim_clock
sim_wait := sim_shell.sim_wait
sim_shell.mem <> vta_accel.io.mem
vta_accel.io.host <> sim_shell.host
}
/** Generate TestAccel as top module */
object Elaborate extends App {
chisel3.Driver.execute(args, () => new TestAccel)
}
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
import ctypes
import os.path as osp
from sys import platform
def get_ext():
"""Return shared library extension"""
return ".dylib" if platform == "darwin" else ".so"
def load_dll(dll):
"""Load shared library
Parameters
------------
dll : str
Path for shared library
Returns
------------
The shared library
"""
try:
return [ctypes.CDLL(dll, ctypes.RTLD_GLOBAL)]
except OSError:
return []
def load_sw():
"""Load all software shared libraries"""
cur_path = osp.dirname(osp.abspath(osp.expanduser(__file__)))
sw_libname = "libsw" + get_ext()
sw_lib = osp.join(cur_path, "..", "build", sw_libname)
load_dll(sw_lib)
def init(hw_backend):
"""Init hardware and software shared library for accelerator
Parameters
------------
hw_backend : str
Hardware backend can be verilog or chisel
"""
cur_path = osp.dirname(osp.abspath(osp.expanduser(__file__)))
hw_libname = "libhw" + get_ext()
if hw_backend in ("verilog", "chisel"):
hw_lib = osp.join(cur_path, "..", "hardware", hw_backend, "build", hw_libname)
load_sw()
m = tvm.module.load(hw_lib, "vta-tsim")
f = tvm.get_global_func("tvm.vta.tsim.init")
f(m)
def load_module():
"""Return driver function"""
load_sw()
return tvm.get_global_func("tvm.vta.driver")
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>
#include <vta/dpi/module.h>
#include "vmem/virtual_memory.h"
namespace vta {
namespace driver {
using vta::dpi::DPIModuleNode;
using tvm::runtime::Module;
class DPILoader {
public:
~DPILoader() {
dpi_->SimResume();
dpi_->SimFinish();
}
void Init(Module module) {
mod_ = module;
dpi_ = this->Get();
dpi_->SimLaunch();
dpi_->SimWait();
}
DPIModuleNode* Get() {
return static_cast<DPIModuleNode*>(mod_.operator->());
}
static DPILoader* Global() {
static DPILoader inst;
return &inst;
}
// TVM module
Module mod_;
// DPI Module
DPIModuleNode* dpi_{nullptr};
};
class Device {
public:
Device() {
loader_ = DPILoader::Global();
}
uint32_t Run(DLTensor* inp1, DLTensor* inp2, uint32_t shiftVal, DLTensor* out, uint32_t reset) {
uint32_t cycles;
uint32_t length = inp1->shape[0];
size_t size1 = (inp1->dtype.bits >> 3) * length;
size_t size2 = (inp2->dtype.bits >> 3) * length;
size_t size3 = (64 >> 3);
inp1_ = this->MemAlloc(size1);
inp2_ = this->MemAlloc(size2);
out_ = this->MemAlloc(size3);
this->MemCopyFromHost(inp1_, inp1->data, size1);
this->MemCopyFromHost(inp2_, inp2->data, size2);
this->Init();
this->Launch(length, shiftVal, reset);
cycles = this->WaitForCompletion();
this->MemCopyToHost(out->data, out_, size3);
this->MemFree(inp1_);
this->MemFree(inp2_);
this->MemFree(out_);
return cycles;
}
private:
void Init() {
dpi_ = loader_->Get();
dpi_->SimResume();
}
void* MemAlloc(size_t size) {
void * addr = vta::vmem::VirtualMemoryManager::Global()->Alloc(size);
return reinterpret_cast<void*>(vta::vmem::VirtualMemoryManager::Global()->GetPhyAddr(addr));
}
void MemFree(void* buf) {
void * addr = vta::vmem::VirtualMemoryManager::Global()->GetAddr(reinterpret_cast<uint64_t>(buf));
vta::vmem::VirtualMemoryManager::Global()->Free(addr);
}
vta_phy_addr_t MemGetPhyAddr(void* buf) {
return reinterpret_cast<uint64_t>(reinterpret_cast<uint64_t*>(buf));
}
void MemCopyFromHost(void* dst, const void* src, size_t size) {
vta::vmem::VirtualMemoryManager::Global()->MemCopyFromHost(dst, src, size);
}
void MemCopyToHost(void* dst, const void* src, size_t size) {
vta::vmem::VirtualMemoryManager::Global()->MemCopyToHost(dst, src, size);
}
void Launch(uint32_t length, uint32_t shiftVal, uint32_t reset) {
dpi_->WriteReg(0x08, shiftVal);
dpi_->WriteReg(0x0c, length); // vector length
dpi_->WriteReg(0x18, this->MemGetPhyAddr(inp1_));
dpi_->WriteReg(0x20, this->MemGetPhyAddr(inp2_));
dpi_->WriteReg(0x28, this->MemGetPhyAddr(out_));
dpi_->WriteReg(0x00, 0x1); // launch
dpi_->WriteReg(0x00, 0x0); // launch
if (reset == 1) {
dpi_->WriteReg(0x10, 0x1); // reset accum
dpi_->WriteReg(0x10, 0x0); // stop reset accum
}
dpi_->WriteReg(0x14, 0x1); // reset dot
dpi_->WriteReg(0x14, 0x0); // stop reset dot
}
uint32_t WaitForCompletion() {
uint32_t i, val;
for (i = 0; i < wait_cycles_; i++) {
val = dpi_->ReadReg(0x00);
if (val == 2) break; // finish
}
val = dpi_->ReadReg(0x04);
dpi_->SimWait();
return val;
}
// wait cycles
uint32_t wait_cycles_{100000000};
// DPI loader
DPILoader* loader_{nullptr};
// DPI Module
DPIModuleNode* dpi_{nullptr};
// input vm ptr
void* inp1_{nullptr};
void* inp2_{nullptr};
// output vm ptr
void* out_{nullptr};
};
using tvm::runtime::TVMRetValue;
using tvm::runtime::TVMArgs;
TVM_REGISTER_GLOBAL("tvm.vta.tsim.init")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Module m = args[0];
DPILoader::Global()->Init(m);
});
TVM_REGISTER_GLOBAL("tvm.vta.driver")
.set_body([](TVMArgs args, TVMRetValue* rv) {
DLTensor* A = args[0];
DLTensor* B = args[1];
DLTensor* C = args[3];
Device dev_;
uint32_t cycles = dev_.Run(A, B, static_cast<int>(args[2]), C, static_cast<int>(args[4]));
*rv = static_cast<int>(cycles);
});
} // namespace driver
} // namespace vta
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
import numpy as np
import tsim
import sys
""" Vector Bit Slice and Pack Function
Parameters
----------
A : Vector to be sliced and packed
slice_width : slice width
Returnsi
---------
C: 2d matrix where each cloumn (because of bit packing) represents each bit slice of A
"""
def slice(A, slice_width):
assert np.log2(slice_width) % 1 == 0, "only power of 2 is supported"
dtype = type(A[0])
row = 0
# currently only supports uint
if dtype is np.uint8: row = 8 // slice_width
elif dtype is np.uint16: row = 16 // slice_width
elif dtype is np.uint32: row = 32 // slice_width
elif dtype is np.uint64: row = 64 // slice_width
else: raise ValueError("datatype " + str(dtype) + "currently not supported")
if (row >= 8):
dtype = 'uint' + str(row)
else:
dtype = 'uint8'
C = np.zeros((row, len(A))).astype(dtype) # sliced and transform
# create mask
slice_mask = 2**(slice_width)-1
# slice and pack
for x in range(len(A)):
for y in range(row):
C[y][x] = (np.uint64(A[x]) >> np.uint64(slice_width * y)) & np.uint64(slice_mask)
return C
""" Matrix Multiplication Function
Parameters
----------
A : Matrix A
B: Matrix B
w_width : weight slice width
a_width : activation slice width
Returns
---------
C: result of A * B
"""
# A is a n*m matrix, B is a m*p matrix(not transposed yet)
def matrix_multiply(A, B, w_width, a_width):
assert A.shape[1] == B.shape[0], "can't perform multiplication"
BT = B.transpose()
cycles = 0
C = np.zeros((A.shape[0], B.shape[1])).astype('uint64')
for i in range(A.shape[0]):
for j in range(B.shape[1]):
# C[i, j] = A[i].dot(BT[j])
A_sliced = slice(A[i], w_width)
B_sliced = slice(BT[j], a_width)
C[i, j] = compute(A_sliced, B_sliced, w_width, a_width)
test = test_accel(A_sliced, B_sliced, w_width, a_width)
cycles += test[1]
np.testing.assert_equal(C[i,j], A[i].astype('uint64').dot(BT[j]))
print("PASS SW serial & parallel")
np.testing.assert_equal(test[0], C[i, j])
print("PASS SW & HW bit serial")
np.testing.assert_equal(test[0], A[i].astype('uint64').dot(BT[j]))
print("PASS SW bit parallel & HW bit parallel")
print("result: ")
print(C)
print("ALL TESTS PASSED, cycles: " + str(cycles))
return C
""" Software Verification Function"""
# takes 2 matrix input (sliced and packed)
def compute(A, B, w_width, a_width):
assert A.shape[1] == B.shape[1], "sliced shape not match"
# reset hardware accumulator
accum = 0
for x in range(A.shape[0]):
for y in range(B.shape[0]):
# hardware implementation
accum += np.uint64(A[x]).dot(np.uint64(B[y])) << np.uint64(x*w_width + y*a_width)
# get value from accumulator
return accum
"""Testing Function for Dot Product"""
def test_accel(A, B, w_width, a_width):
assert A.shape[1] == B.shape[1], "sliced shape not match"
dtype = A.dtype
ctx = tvm.cpu(0)
f = tsim.load_module()
a_arr = []
b_arr = []
for i in range(A.shape[0]):
list_a = np.zeros(A.shape[1]).astype(dtype)
for j in range(A.shape[1]):
list_a[j] = A[i][j]
a_arr.append(tvm.nd.array(list_a.astype(dtype), ctx))
for i in range(B.shape[0]):
list_b = np.zeros(B.shape[1]).astype(dtype)
for j in range(B.shape[1]):
list_b[j] = B[i][j]
b_arr.append(tvm.nd.array(list_b.astype(dtype), ctx))
cycles = 0
accum = tvm.nd.array(np.array([0]).astype("uint64"), ctx)
for i in range(len(a_arr)):
for j in range(len(b_arr)):
shift = np.uint8(i*w_width + j*a_width)
if i == 0 and j == 0:
cycles += f(a_arr[i], b_arr[j], shift, accum, np.uint32(1)) # reset accumulator
else:
cycles += f(a_arr[i], b_arr[j], shift, accum, np.uint32(0)) # no reset
return (accum.asnumpy()[0], cycles)
""" Matrix Generator
Parameters
----------
dtype : String, datatype generated (supports only uint)
w_width : weight bit slices(needs to be less than actual bit width)
a_width : activation bit slices(needs to be less than actual bit width)
"""
def top_test(dtype, w_width, a_width):
rmax = np.random.randint(256)
# random matrix generation (dimension up to 8)
rrow = np.random.randint(7) + 1
rclmn = np.random.randint(7) + 1
rrow2 = np.random.randint(7) + 1
A = np.random.randint(rmax, size=(rrow,rclmn)).astype(dtype)
B = np.random.randint(rmax, size=(rclmn,rrow2)).astype(dtype)
print("A: ")
print(A)
print("\n")
print("B: ")
print(B)
print("\n")
matrix_multiply(A, B, w_width, a_width)
if __name__ == "__main__":
tsim.init("chisel")
for i in range(1):
# reg1 and reg2 bits in Compute.scala must be modified for slices greater than 8 bits
if sys.argv[1] == 'serial':
# generates a random uint8 GEMM with 2-bit(8/4) weight and 4-bit(8/2) activation
top_test("uint8",4, 2)
elif sys.argv[1] == 'parallel':
# generates a random uint8 GEMM with 8-bit weight and 8-bit activation (bit parallel)
top_test('uint8', 1, 1)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment