Commit 32f74f31 by Luis Vega Committed by Tianqi Chen

[VTA] [Hardware] Chisel implementation (#3258)

parent 1f62d956
......@@ -132,9 +132,6 @@ set(USE_SORT ON)
# Build ANTLR parser for Relay text format
set(USE_ANTLR OFF)
# Build TSIM for VTA
set(USE_VTA_TSIM OFF)
# Whether use Relay debug mode
set(USE_RELAY_DEBUG OFF)
......@@ -29,8 +29,7 @@ elseif(PYTHON)
--use-cfg=${CMAKE_CURRENT_BINARY_DIR}/vta_config.json)
endif()
execute_process(COMMAND ${VTA_CONFIG} --target OUTPUT_VARIABLE __vta_target)
string(STRIP ${__vta_target} VTA_TARGET)
execute_process(COMMAND ${VTA_CONFIG} --target OUTPUT_VARIABLE VTA_TARGET OUTPUT_STRIP_TRAILING_WHITESPACE)
message(STATUS "Build VTA runtime with target: " ${VTA_TARGET})
......@@ -44,6 +43,13 @@ elseif(PYTHON)
add_library(vta SHARED ${VTA_RUNTIME_SRCS})
if(${VTA_TARGET} STREQUAL "tsim")
target_compile_definitions(vta PUBLIC USE_TSIM)
include_directories("vta/include")
file(GLOB RUNTIME_DPI_SRCS vta/src/dpi/module.cc)
list(APPEND RUNTIME_SRCS ${RUNTIME_DPI_SRCS})
endif()
target_include_directories(vta PUBLIC vta/include)
foreach(__def ${VTA_DEFINITIONS})
......@@ -61,12 +67,6 @@ elseif(PYTHON)
target_link_libraries(vta ${__cma_lib})
endif()
if(NOT USE_VTA_TSIM STREQUAL "OFF")
include_directories("vta/include")
file(GLOB RUNTIME_DPI_SRCS vta/src/dpi/module.cc)
list(APPEND RUNTIME_SRCS ${RUNTIME_DPI_SRCS})
endif()
else()
message(STATUS "Cannot found python in env, VTA build is skipped..")
endif()
......@@ -49,7 +49,7 @@ sudo apt install verilator sbt
## Setup in TVM
1. Install `verilator` and `sbt` as described above
2. Enable VTA TSIM by turning on the switch `USE_VTA_TSIM` in config.cmake
2. Set the VTA TARGET to `tsim` on `<tvm-root>/vta/config/vta_config.json`
3. Build tvm
## How to run VTA TSIM examples
......
......@@ -124,7 +124,7 @@ else()
file(GLOB VERILATOR_SRC ${VTA_HW_DPI_DIR}/tsim_device.cc)
add_library(hw SHARED ${VERILATOR_LIB_SRC} ${VERILATOR_GEN_SRC} ${VERILATOR_SRC})
set(VERILATOR_DEF VL_TSIM_NAME=V${TSIM_TOP_NAME} VL_PRINTF=printf VM_COVERAGE=0 VM_SC=0)
set(VERILATOR_DEF VL_USER_FINISH VL_TSIM_NAME=V${TSIM_TOP_NAME} VL_PRINTF=printf VM_COVERAGE=0 VM_SC=0)
if (NOT TSIM_USE_TRACE STREQUAL "OFF")
list(APPEND VERILATOR_DEF VM_TRACE=1 TSIM_TRACE_FILE=${TSIM_BUILD_DIR}/${TSIM_TRACE_NAME}.vcd)
else()
......
......@@ -15,5 +15,81 @@
# specific language governing permissions and limitations
# under the License.
CONFIG = DefaultF1Config
TOP = VTA
TOP_TEST = Test
BUILD_NAME = build
USE_TRACE = 0
VTA_LIBNAME = libvta_hw
config_test = $(TOP_TEST)$(CONFIG)
vta_dir = $(abspath ../../)
tvm_dir = $(abspath ../../../)
verilator_inc_dir = /usr/local/share/verilator/include
verilator_build_dir = $(vta_dir)/$(BUILD_NAME)/verilator
chisel_build_dir = $(vta_dir)/$(BUILD_NAME)/chisel
verilator_opt = --cc
verilator_opt += +define+RANDOMIZE_GARBAGE_ASSIGN
verilator_opt += +define+RANDOMIZE_REG_INIT
verilator_opt += +define+RANDOMIZE_MEM_INIT
verilator_opt += --x-assign unique
verilator_opt += --output-split 20000
verilator_opt += --output-split-cfuncs 20000
verilator_opt += --top-module ${TOP_TEST}
verilator_opt += -Mdir ${verilator_build_dir}
verilator_opt += -I$(chisel_build_dir)
cxx_flags = -O2 -Wall -fPIC -shared
cxx_flags += -fvisibility=hidden -std=c++11
cxx_flags += -DVL_TSIM_NAME=V$(TOP_TEST)
cxx_flags += -DVL_PRINTF=printf
cxx_flags += -DVL_USER_FINISH
cxx_flags += -DVM_COVERAGE=0
cxx_flags += -DVM_SC=0
cxx_flags += -Wno-sign-compare
cxx_flags += -include V$(TOP_TEST).h
cxx_flags += -I$(verilator_build_dir)
cxx_flags += -I$(verilator_inc_dir)
cxx_flags += -I$(verilator_inc_dir)/vltstd
cxx_flags += -I$(vta_dir)/include
cxx_flags += -I$(tvm_dir)/include
cxx_flags += -I$(tvm_dir)/3rdparty/dlpack/include
cxx_files = $(verilator_inc_dir)/verilated.cpp
cxx_files += $(verilator_inc_dir)/verilated_dpi.cpp
cxx_files += $(wildcard $(verilator_build_dir)/*.cpp)
cxx_files += $(vta_dir)/hardware/dpi/tsim_device.cc
ifneq ($(USE_TRACE), 0)
verilator_opt += --trace
cxx_flags += -DVM_TRACE=1
cxx_flags += -DTSIM_TRACE_FILE=$(verilator_build_dir)/$(TOP_TEST).vcd
cxx_files += $(verilator_inc_dir)/verilated_vcd_c.cpp
else
cxx_flags += -DVM_TRACE=0
endif
default: lib
lib: $(vta_dir)/$(BUILD_NAME)/$(VTA_LIBNAME).so
$(vta_dir)/$(BUILD_NAME)/$(VTA_LIBNAME).so: $(verilator_build_dir)/V$(TOP_TEST).cpp
g++ $(cxx_flags) $(cxx_files) -o $@
verilator: $(verilator_build_dir)/V$(TOP_TEST).cpp
$(verilator_build_dir)/V$(TOP_TEST).cpp: $(chisel_build_dir)/$(TOP_TEST).$(CONFIG).v
verilator $(verilator_opt) $<
verilog: $(chisel_build_dir)/$(TOP).$(CONFIG).v
$(chisel_build_dir)/$(TOP).$(CONFIG).v:
sbt 'runMain vta.$(CONFIG) --target-dir $(chisel_build_dir) --top-name $(TOP).$(CONFIG)'
verilog_test: $(chisel_build_dir)/$(TOP_TEST).$(CONFIG).v
$(chisel_build_dir)/$(TOP_TEST).$(CONFIG).v:
sbt 'runMain vta.$(config_test) --target-dir $(chisel_build_dir) --top-name $(TOP_TEST).$(CONFIG)'
clean:
-rm -rf target project/target project/project
cleanall:
-rm -rf $(vta_dir)/$(BUILD_NAME)
......@@ -112,7 +112,7 @@ module VTAHostDPI #
always_ff @(posedge clock) begin
if (__exit == 'd1) begin
$display("[DONE] at cycle:%016d", cycles);
$display("[TSIM] Verilog $finish called at cycle:%016d", cycles);
$finish;
end
end
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package vta.core
import chisel3._
import chisel3.util._
import vta.util.config._
import vta.shell._
/** Compute.
*
* The compute unit is in charge of the following:
* - Loading micro-ops from memory (loadUop module)
* - Loading biases (acc) from memory (tensorAcc module)
* - Compute ALU instructions (tensorAlu module)
* - Compute GEMM instructions (tensorGemm module)
*/
class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module {
val mp = p(ShellKey).memParams
val io = IO(new Bundle {
val i_post = Vec(2, Input(Bool()))
val o_post = Vec(2, Output(Bool()))
val inst = Flipped(Decoupled(UInt(INST_BITS.W)))
val uop_baddr = Input(UInt(mp.addrBits.W))
val acc_baddr = Input(UInt(mp.addrBits.W))
val vme_rd = Vec(2, new VMEReadMaster)
val inp = new TensorMaster(tensorType = "inp")
val wgt = new TensorMaster(tensorType = "wgt")
val out = new TensorMaster(tensorType = "out")
val finish = Output(Bool())
})
val sIdle :: sSync :: sExe :: Nil = Enum(3)
val state = RegInit(sIdle)
val s = Seq.tabulate(2)(_ => Module(new Semaphore(counterBits = 8, counterInitValue = 0)))
val loadUop = Module(new LoadUop)
val tensorAcc = Module(new TensorLoad(tensorType = "acc"))
val tensorGemm = Module(new TensorGemm)
val tensorAlu = Module(new TensorAlu)
val inst_q = Module(new Queue(UInt(INST_BITS.W), p(CoreKey).instQueueEntries))
// decode
val dec = Module(new ComputeDecode)
dec.io.inst := inst_q.io.deq.bits
val inst_type = Cat(dec.io.isFinish,
dec.io.isAlu,
dec.io.isGemm,
dec.io.isLoadAcc,
dec.io.isLoadUop).asUInt
val sprev = inst_q.io.deq.valid & Mux(dec.io.pop_prev, s(0).io.sready, true.B)
val snext = inst_q.io.deq.valid & Mux(dec.io.pop_next, s(1).io.sready, true.B)
val start = snext & sprev
val done =
MuxLookup(inst_type,
false.B, // default
Array(
"h_01".U -> loadUop.io.done,
"h_02".U -> tensorAcc.io.done,
"h_04".U -> tensorGemm.io.done,
"h_08".U -> tensorAlu.io.done,
"h_10".U -> true.B // Finish
)
)
// control
switch (state) {
is (sIdle) {
when (start) {
when (dec.io.isSync) {
state := sSync
} .elsewhen (inst_type.orR) {
state := sExe
}
}
}
is (sSync) {
state := sIdle
}
is (sExe) {
when (done) {
state := sIdle
}
}
}
// instructions
inst_q.io.enq <> io.inst
inst_q.io.deq.ready := (state === sExe & done) | (state === sSync)
// uop
loadUop.io.start := state === sIdle & start & dec.io.isLoadUop
loadUop.io.inst := inst_q.io.deq.bits
loadUop.io.baddr := io.uop_baddr
io.vme_rd(0) <> loadUop.io.vme_rd
loadUop.io.uop.idx <> Mux(dec.io.isGemm, tensorGemm.io.uop.idx, tensorAlu.io.uop.idx)
// acc
tensorAcc.io.start := state === sIdle & start & dec.io.isLoadAcc
tensorAcc.io.inst := inst_q.io.deq.bits
tensorAcc.io.baddr := io.acc_baddr
tensorAcc.io.tensor.rd.idx <> Mux(dec.io.isGemm, tensorGemm.io.acc.rd.idx, tensorAlu.io.acc.rd.idx)
tensorAcc.io.tensor.wr <> Mux(dec.io.isGemm, tensorGemm.io.acc.wr, tensorAlu.io.acc.wr)
io.vme_rd(1) <> tensorAcc.io.vme_rd
// gemm
tensorGemm.io.start := state === sIdle & start & dec.io.isGemm
tensorGemm.io.inst := inst_q.io.deq.bits
tensorGemm.io.uop.data.valid := loadUop.io.uop.data.valid & dec.io.isGemm
tensorGemm.io.uop.data.bits <> loadUop.io.uop.data.bits
tensorGemm.io.inp <> io.inp
tensorGemm.io.wgt <> io.wgt
tensorGemm.io.acc.rd.data.valid := tensorAcc.io.tensor.rd.data.valid & dec.io.isGemm
tensorGemm.io.acc.rd.data.bits <> tensorAcc.io.tensor.rd.data.bits
tensorGemm.io.out.rd.data.valid := io.out.rd.data.valid & dec.io.isGemm
tensorGemm.io.out.rd.data.bits <> io.out.rd.data.bits
// alu
tensorAlu.io.start := state === sIdle & start & dec.io.isAlu
tensorAlu.io.inst := inst_q.io.deq.bits
tensorAlu.io.uop.data.valid := loadUop.io.uop.data.valid & dec.io.isAlu
tensorAlu.io.uop.data.bits <> loadUop.io.uop.data.bits
tensorAlu.io.acc.rd.data.valid := tensorAcc.io.tensor.rd.data.valid & dec.io.isAlu
tensorAlu.io.acc.rd.data.bits <> tensorAcc.io.tensor.rd.data.bits
tensorAlu.io.out.rd.data.valid := io.out.rd.data.valid & dec.io.isAlu
tensorAlu.io.out.rd.data.bits <> io.out.rd.data.bits
// out
io.out.rd.idx <> Mux(dec.io.isGemm, tensorGemm.io.out.rd.idx, tensorAlu.io.out.rd.idx)
io.out.wr <> Mux(dec.io.isGemm, tensorGemm.io.out.wr, tensorAlu.io.out.wr)
// semaphore
s(0).io.spost := io.i_post(0)
s(1).io.spost := io.i_post(1)
s(0).io.swait := dec.io.pop_prev & (state === sIdle & start)
s(1).io.swait := dec.io.pop_next & (state === sIdle & start)
io.o_post(0) := dec.io.push_prev & ((state === sExe & done) | (state === sSync))
io.o_post(1) := dec.io.push_next & ((state === sExe & done) | (state === sSync))
// finish
io.finish := state === sExe & done & dec.io.isFinish
// debug
if (debug) {
// start
when (state === sIdle && start) {
when (dec.io.isSync) {
printf("[Compute] start sync\n")
} .elsewhen (dec.io.isLoadUop) {
printf("[Compute] start load uop\n")
} .elsewhen (dec.io.isLoadAcc) {
printf("[Compute] start load acc\n")
} .elsewhen (dec.io.isGemm) {
printf("[Compute] start gemm\n")
} .elsewhen (dec.io.isAlu) {
printf("[Compute] start alu\n")
} .elsewhen (dec.io.isFinish) {
printf("[Compute] start finish\n")
}
}
// done
when (state === sSync) {
printf("[Compute] done sync\n")
}
when (state === sExe) {
when (done) {
when (dec.io.isLoadUop) {
printf("[Compute] done load uop\n")
} .elsewhen (dec.io.isLoadAcc) {
printf("[Compute] done load acc\n")
} .elsewhen (dec.io.isGemm) {
printf("[Compute] done gemm\n")
} .elsewhen (dec.io.isAlu) {
printf("[Compute] done alu\n")
} .elsewhen (dec.io.isFinish) {
printf("[Compute] done finish\n")
}
}
}
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package vta.core
import vta.util.config._
/** CoreConfig.
*
* This is one supported configuration for VTA. This file will
* be eventually filled out with class configurations that can be
* mixed/matched with Shell configurations for different backends.
*/
class CoreConfig extends Config((site, here, up) => {
case CoreKey => CoreParams(
batch = 1,
blockOut = 16,
blockIn = 16,
inpBits = 8,
wgtBits = 8,
uopBits = 32,
accBits = 32,
outBits = 8,
uopMemDepth = 2048,
inpMemDepth = 2048,
wgtMemDepth = 1024,
accMemDepth = 2048,
outMemDepth = 2048,
instQueueEntries = 512)
})
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package vta.core
import chisel3._
import vta.util.config._
import vta.shell._
/** Core parameters */
case class CoreParams (
batch: Int = 1,
blockOut: Int = 16,
blockIn: Int = 16,
inpBits: Int = 8,
wgtBits: Int = 8,
uopBits: Int = 32,
accBits: Int = 32,
outBits: Int = 8,
uopMemDepth: Int = 512,
inpMemDepth: Int = 512,
wgtMemDepth: Int = 512,
accMemDepth: Int = 512,
outMemDepth: Int = 512,
instQueueEntries: Int = 32
)
case object CoreKey extends Field[CoreParams]
/** Core.
*
* The core defines the current VTA architecture by connecting memory and
* compute modules together such as load/store and compute. Most of the
* connections in the core are bulk (<>), and we should try to keep it this
* way, because it is easier to understand what is going on.
*
* Also, the core must be instantiated by a shell using the
* VTA Control Register (VCR) and the VTA Memory Engine (VME) interfaces.
* More info about these interfaces and modules can be found in the shell
* directory.
*/
class Core(implicit p: Parameters) extends Module {
val io = IO(new Bundle {
val vcr = new VCRClient
val vme = new VMEMaster
})
val fetch = Module(new Fetch)
val load = Module(new Load)
val compute = Module(new Compute)
val store = Module(new Store)
// Read(rd) and write(wr) from/to memory (i.e. DRAM)
io.vme.rd(0) <> fetch.io.vme_rd
io.vme.rd(1) <> compute.io.vme_rd(0)
io.vme.rd(2) <> load.io.vme_rd(0)
io.vme.rd(3) <> load.io.vme_rd(1)
io.vme.rd(4) <> compute.io.vme_rd(1)
io.vme.wr(0) <> store.io.vme_wr
// Fetch instructions (tasks) from memory (DRAM) into queues (SRAMs)
fetch.io.launch := io.vcr.launch
fetch.io.ins_baddr := io.vcr.ptrs(0)
fetch.io.ins_count := io.vcr.vals(0)
// Load inputs and weights from memory (DRAM) into scratchpads (SRAMs)
load.io.i_post := compute.io.o_post(0)
load.io.inst <> fetch.io.inst.ld
load.io.inp_baddr := io.vcr.ptrs(2)
load.io.wgt_baddr := io.vcr.ptrs(3)
// The compute module performs the following:
// - Load micro-ops (uops) and accumulations (acc)
// - Compute dense and ALU instructions (tasks)
compute.io.i_post(0) := load.io.o_post
compute.io.i_post(1) := store.io.o_post
compute.io.inst <> fetch.io.inst.co
compute.io.uop_baddr := io.vcr.ptrs(1)
compute.io.acc_baddr := io.vcr.ptrs(4)
compute.io.inp <> load.io.inp
compute.io.wgt <> load.io.wgt
// The store module performs the following:
// - Writes results from compute into scratchpads (SRAMs)
// - Store results from scratchpads (SRAMs) to memory (DRAM)
store.io.i_post := compute.io.o_post(1)
store.io.inst <> fetch.io.inst.st
store.io.out_baddr := io.vcr.ptrs(5)
store.io.out <> compute.io.out
// Finish instruction is executed and asserts the VCR finish flag
val finish = RegNext(compute.io.finish)
io.vcr.finish := finish
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package vta.core
import chisel3._
import chisel3.util._
import ISA._
/** MemDecode.
*
* Decode memory instructions with a Bundle. This is similar to an union,
* therefore order matters when declaring fields. These are the instructions
* decoded with this bundle:
* - LUOP
* - LWGT
* - LINP
* - LACC
* - SOUT
*/
class MemDecode extends Bundle {
val xpad_1 = UInt(M_PAD_BITS.W)
val xpad_0 = UInt(M_PAD_BITS.W)
val ypad_1 = UInt(M_PAD_BITS.W)
val ypad_0 = UInt(M_PAD_BITS.W)
val xstride = UInt(M_STRIDE_BITS.W)
val xsize = UInt(M_SIZE_BITS.W)
val ysize = UInt(M_SIZE_BITS.W)
val empty_0 = UInt(7.W) // derive this
val dram_offset = UInt(M_DRAM_OFFSET_BITS.W)
val sram_offset = UInt(M_SRAM_OFFSET_BITS.W)
val id = UInt(M_ID_BITS.W)
val push_next = Bool()
val push_prev = Bool()
val pop_next = Bool()
val pop_prev = Bool()
val op = UInt(OP_BITS.W)
}
/** GemmDecode.
*
* Decode GEMM instruction with a Bundle. This is similar to an union,
* therefore order matters when declaring fields.
*/
class GemmDecode extends Bundle {
val wgt_1 = UInt(C_WIDX_BITS.W)
val wgt_0 = UInt(C_WIDX_BITS.W)
val inp_1 = UInt(C_IIDX_BITS.W)
val inp_0 = UInt(C_IIDX_BITS.W)
val acc_1 = UInt(C_AIDX_BITS.W)
val acc_0 = UInt(C_AIDX_BITS.W)
val empty_0 = Bool()
val lp_1 = UInt(C_ITER_BITS.W)
val lp_0 = UInt(C_ITER_BITS.W)
val uop_end = UInt(C_UOP_END_BITS.W)
val uop_begin = UInt(C_UOP_BGN_BITS.W)
val reset = Bool()
val push_next = Bool()
val push_prev = Bool()
val pop_next = Bool()
val pop_prev = Bool()
val op = UInt(OP_BITS.W)
}
/** AluDecode.
*
* Decode ALU instructions with a Bundle. This is similar to an union,
* therefore order matters when declaring fields. These are the instructions
* decoded with this bundle:
* - VMIN
* - VMAX
* - VADD
* - VSHX
*/
class AluDecode extends Bundle {
val empty_1 = Bool()
val alu_imm = UInt(C_ALU_IMM_BITS.W)
val alu_use_imm = Bool()
val alu_op = UInt(C_ALU_DEC_BITS.W)
val src_1 = UInt(C_IIDX_BITS.W)
val src_0 = UInt(C_IIDX_BITS.W)
val dst_1 = UInt(C_AIDX_BITS.W)
val dst_0 = UInt(C_AIDX_BITS.W)
val empty_0 = Bool()
val lp_1 = UInt(C_ITER_BITS.W)
val lp_0 = UInt(C_ITER_BITS.W)
val uop_end = UInt(C_UOP_END_BITS.W)
val uop_begin = UInt(C_UOP_BGN_BITS.W)
val reset = Bool()
val push_next = Bool()
val push_prev = Bool()
val pop_next = Bool()
val pop_prev = Bool()
val op = UInt(OP_BITS.W)
}
/** UopDecode.
*
* Decode micro-ops (uops).
*/
class UopDecode extends Bundle {
val u2 = UInt(10.W)
val u1 = UInt(11.W)
val u0 = UInt(11.W)
}
/** FetchDecode.
*
* Partial decoding for dispatching instructions to Load, Compute, and Store.
*/
class FetchDecode extends Module {
val io = IO(new Bundle {
val inst = Input(UInt(INST_BITS.W))
val isLoad = Output(Bool())
val isCompute = Output(Bool())
val isStore = Output(Bool())
})
val csignals =
ListLookup(io.inst,
List(N, OP_X),
Array(
LUOP -> List(Y, OP_G),
LWGT -> List(Y, OP_L),
LINP -> List(Y, OP_L),
LACC -> List(Y, OP_G),
SOUT -> List(Y, OP_S),
GEMM -> List(Y, OP_G),
FNSH -> List(Y, OP_G),
VMIN -> List(Y, OP_G),
VMAX -> List(Y, OP_G),
VADD -> List(Y, OP_G),
VSHX -> List(Y, OP_G)
)
)
val (cs_val_inst: Bool) :: cs_op_type :: Nil = csignals
io.isLoad := cs_val_inst & cs_op_type === OP_L
io.isCompute := cs_val_inst & cs_op_type === OP_G
io.isStore := cs_val_inst & cs_op_type === OP_S
}
/** LoadDecode.
*
* Decode dependencies, type and sync for Load module.
*/
class LoadDecode extends Module {
val io = IO(new Bundle {
val inst = Input(UInt(INST_BITS.W))
val push_next = Output(Bool())
val pop_next = Output(Bool())
val isInput = Output(Bool())
val isWeight = Output(Bool())
val isSync = Output(Bool())
})
val dec = io.inst.asTypeOf(new MemDecode)
io.push_next := dec.push_next
io.pop_next := dec.pop_next
io.isInput := io.inst === LINP & dec.xsize =/= 0.U
io.isWeight := io.inst === LWGT & dec.xsize =/= 0.U
io.isSync := (io.inst === LINP | io.inst === LWGT) & dec.xsize === 0.U
}
/** ComputeDecode.
*
* Decode dependencies, type and sync for Compute module.
*/
class ComputeDecode extends Module {
val io = IO(new Bundle {
val inst = Input(UInt(INST_BITS.W))
val push_next = Output(Bool())
val push_prev = Output(Bool())
val pop_next = Output(Bool())
val pop_prev = Output(Bool())
val isLoadAcc = Output(Bool())
val isLoadUop = Output(Bool())
val isSync = Output(Bool())
val isAlu = Output(Bool())
val isGemm = Output(Bool())
val isFinish = Output(Bool())
})
val dec = io.inst.asTypeOf(new MemDecode)
io.push_next := dec.push_next
io.push_prev := dec.push_prev
io.pop_next := dec.pop_next
io.pop_prev := dec.pop_prev
io.isLoadAcc := io.inst === LACC & dec.xsize =/= 0.U
io.isLoadUop := io.inst === LUOP & dec.xsize =/= 0.U
io.isSync := (io.inst === LACC | io.inst === LUOP) & dec.xsize === 0.U
io.isAlu := io.inst === VMIN | io.inst === VMAX | io.inst === VADD | io.inst === VSHX
io.isGemm := io.inst === GEMM
io.isFinish := io.inst === FNSH
}
/** StoreDecode.
*
* Decode dependencies, type and sync for Store module.
*/
class StoreDecode extends Module {
val io = IO(new Bundle {
val inst = Input(UInt(INST_BITS.W))
val push_prev = Output(Bool())
val pop_prev = Output(Bool())
val isStore = Output(Bool())
val isSync = Output(Bool())
})
val dec = io.inst.asTypeOf(new MemDecode)
io.push_prev := dec.push_prev
io.pop_prev := dec.pop_prev
io.isStore := io.inst === SOUT & dec.xsize =/= 0.U
io.isSync := io.inst === SOUT & dec.xsize === 0.U
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package vta.core
import chisel3._
import chisel3.util._
import vta.util.config._
import vta.shell._
/** Fetch.
*
* The fetch unit reads instructions (tasks) from memory (i.e. DRAM), using the
* VTA Memory Engine (VME), and push them into an instruction queue called
* inst_q. Once the instruction queue is full, instructions are dispatched to
* the Load, Compute and Store module queues based on the instruction opcode.
* After draining the queue, the fetch unit checks if there are more instructions
* via the ins_count register which is written by the host.
*
* Additionally, instructions are read into two chunks (see sReadLSB and sReadMSB)
* because we are using a DRAM payload of 8-bytes or half of a VTA instruction.
* This should be configurable for larger payloads, i.e. 64-bytes, which can load
* more than one instruction at the time. Finally, the instruction queue is
* sized (entries_q), depending on the maximum burst allowed in the memory.
*/
class Fetch(debug: Boolean = false)(implicit p: Parameters) extends Module {
val vp = p(ShellKey).vcrParams
val mp = p(ShellKey).memParams
val io = IO(new Bundle {
val launch = Input(Bool())
val ins_baddr = Input(UInt(mp.addrBits.W))
val ins_count = Input(UInt(vp.regBits.W))
val vme_rd = new VMEReadMaster
val inst = new Bundle {
val ld = Decoupled(UInt(INST_BITS.W))
val co = Decoupled(UInt(INST_BITS.W))
val st = Decoupled(UInt(INST_BITS.W))
}
})
val entries_q = 1 << (mp.lenBits - 1) // one-instr-every-two-vme-word
val inst_q = Module(new Queue(UInt(INST_BITS.W), entries_q))
val dec = Module(new FetchDecode)
val s1_launch = RegNext(io.launch)
val pulse = io.launch & ~s1_launch
val raddr = Reg(chiselTypeOf(io.vme_rd.cmd.bits.addr))
val rlen = Reg(chiselTypeOf(io.vme_rd.cmd.bits.len))
val ilen = Reg(chiselTypeOf(io.vme_rd.cmd.bits.len))
val xrem = Reg(chiselTypeOf(io.ins_count))
val xsize = (io.ins_count << 1.U) - 1.U
val xmax = (1 << mp.lenBits).U
val xmax_bytes = ((1 << mp.lenBits)*mp.dataBits/8).U
val sIdle :: sReadCmd :: sReadLSB :: sReadMSB :: sDrain :: Nil = Enum(5)
val state = RegInit(sIdle)
// control
switch (state) {
is (sIdle) {
when (pulse) {
state := sReadCmd
when (xsize < xmax) {
rlen := xsize
ilen := xsize >> 1.U
xrem := 0.U
} .otherwise {
rlen := xmax - 1.U
ilen := (xmax >> 1.U) - 1.U
xrem := xsize - xmax
}
}
}
is (sReadCmd) {
when (io.vme_rd.cmd.ready) {
state := sReadLSB
}
}
is (sReadLSB) {
when (io.vme_rd.data.valid) {
state := sReadMSB
}
}
is (sReadMSB) {
when (io.vme_rd.data.valid) {
when (inst_q.io.count === ilen) {
state := sDrain
} .otherwise {
state := sReadLSB
}
}
}
is (sDrain) {
when (inst_q.io.count === 0.U) {
when (xrem === 0.U) {
state := sIdle
} .elsewhen (xrem < xmax) {
state := sReadCmd
rlen := xrem
ilen := xrem >> 1.U
xrem := 0.U
} .otherwise {
state := sReadCmd
rlen := xmax - 1.U
ilen := (xmax >> 1.U) - 1.U
xrem := xrem - xmax
}
}
}
}
// read instructions from dram
when (state === sIdle) {
raddr := io.ins_baddr
} .elsewhen (state === sDrain && inst_q.io.count === 0.U && xrem =/= 0.U) {
raddr := raddr + xmax_bytes
}
io.vme_rd.cmd.valid := state === sReadCmd
io.vme_rd.cmd.bits.addr := raddr
io.vme_rd.cmd.bits.len := rlen
io.vme_rd.data.ready := inst_q.io.enq.ready
val lsb = Reg(chiselTypeOf(io.vme_rd.data.bits))
val msb = io.vme_rd.data.bits
val inst = Cat(msb, lsb)
when (state === sReadLSB) { lsb := io.vme_rd.data.bits }
inst_q.io.enq.valid := io.vme_rd.data.valid & state === sReadMSB
inst_q.io.enq.bits := inst
// decode
dec.io.inst := inst_q.io.deq.bits
// instruction queues
io.inst.ld.valid := dec.io.isLoad & inst_q.io.deq.valid & state === sDrain
io.inst.co.valid := dec.io.isCompute & inst_q.io.deq.valid & state === sDrain
io.inst.st.valid := dec.io.isStore & inst_q.io.deq.valid & state === sDrain
io.inst.ld.bits := inst_q.io.deq.bits
io.inst.co.bits := inst_q.io.deq.bits
io.inst.st.bits := inst_q.io.deq.bits
// check if selected queue is ready
val deq_sel = Cat(dec.io.isCompute, dec.io.isStore, dec.io.isLoad).asUInt
val deq_ready =
MuxLookup(deq_sel,
false.B, // default
Array(
"h_01".U -> io.inst.ld.ready,
"h_02".U -> io.inst.st.ready,
"h_04".U -> io.inst.co.ready
)
)
// dequeue instruction
inst_q.io.deq.ready := deq_ready & inst_q.io.deq.valid & state === sDrain
// debug
if (debug) {
when (state === sIdle && pulse) {
printf("[Fetch] Launch\n")
}
// instruction
when (inst_q.io.deq.fire()) {
when (dec.io.isLoad) {
printf("[Fetch] [instruction decode] [L] %x\n", inst_q.io.deq.bits)
}
when (dec.io.isCompute) {
printf("[Fetch] [instruction decode] [C] %x\n", inst_q.io.deq.bits)
}
when (dec.io.isStore) {
printf("[Fetch] [instruction decode] [S] %x\n", inst_q.io.deq.bits)
}
}
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package vta.core
import chisel3._
import chisel3.util._
/** ISAConstants.
*
* These constants are used for decoding (parsing) fields on instructions.
*/
trait ISAConstants
{
val INST_BITS = 128
val OP_BITS = 3
val M_DEP_BITS = 4
val M_ID_BITS = 2
val M_SRAM_OFFSET_BITS = 16
val M_DRAM_OFFSET_BITS = 32
val M_SIZE_BITS = 16
val M_STRIDE_BITS = 16
val M_PAD_BITS = 4
val C_UOP_BGN_BITS = 13
val C_UOP_END_BITS = 14
val C_ITER_BITS = 14
val C_AIDX_BITS = 11
val C_IIDX_BITS = 11
val C_WIDX_BITS = 10
val C_ALU_DEC_BITS = 2 // FIXME: there should be a SHL and SHR instruction
val C_ALU_OP_BITS = 3
val C_ALU_IMM_BITS = 16
val Y = true.B
val N = false.B
val OP_L = 0.asUInt(OP_BITS.W)
val OP_S = 1.asUInt(OP_BITS.W)
val OP_G = 2.asUInt(OP_BITS.W)
val OP_F = 3.asUInt(OP_BITS.W)
val OP_A = 4.asUInt(OP_BITS.W)
val OP_X = 5.asUInt(OP_BITS.W)
val ALU_OP_NUM = 5
val ALU_OP = Enum(ALU_OP_NUM)
val M_ID_U = 0.asUInt(M_ID_BITS.W)
val M_ID_W = 1.asUInt(M_ID_BITS.W)
val M_ID_I = 2.asUInt(M_ID_BITS.W)
val M_ID_A = 3.asUInt(M_ID_BITS.W)
}
/** ISA.
*
* This is the VTA ISA, here we specify the cares and dont-cares that makes
* decoding easier. Since instructions are quite long 128-bit, we could generate
* these based on ISAConstants.
*
* FIXME: VSHX should be replaced by VSHR and VSHL once we modify the compiler
* TODO: Add VXOR to clear accumulator
*/
object ISA {
def LUOP = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????0_0????000")
def LWGT = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????0_1????000")
def LINP = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????1_0????000")
def LACC = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????1_1????000")
def SOUT = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????001")
def GEMM = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????010")
def VMIN = BitPat("b_????????_????????_??00????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100")
def VMAX = BitPat("b_????????_????????_??01????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100")
def VADD = BitPat("b_????????_????????_??10????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100")
def VSHX = BitPat("b_????????_????????_??11????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100")
def FNSH = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????011")
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package vta.core
import chisel3._
import chisel3.util._
import vta.util.config._
import vta.shell._
/** Load.
*
* Load inputs and weights from memory (DRAM) into scratchpads (SRAMs).
* This module instantiate the TensorLoad unit which is in charge of
* loading 1D and 2D tensors to scratchpads, so it can be used by
* other modules such as Compute.
*/
class Load(debug: Boolean = false)(implicit p: Parameters) extends Module {
val mp = p(ShellKey).memParams
val io = IO(new Bundle {
val i_post = Input(Bool())
val o_post = Output(Bool())
val inst = Flipped(Decoupled(UInt(INST_BITS.W)))
val inp_baddr = Input(UInt(mp.addrBits.W))
val wgt_baddr = Input(UInt(mp.addrBits.W))
val vme_rd = Vec(2, new VMEReadMaster)
val inp = new TensorClient(tensorType = "inp")
val wgt = new TensorClient(tensorType = "wgt")
})
val sIdle :: sSync :: sExe :: Nil = Enum(3)
val state = RegInit(sIdle)
val s = Module(new Semaphore(counterBits = 8, counterInitValue = 0))
val inst_q = Module(new Queue(UInt(INST_BITS.W), p(CoreKey).instQueueEntries))
val dec = Module(new LoadDecode)
dec.io.inst := inst_q.io.deq.bits
val tensorType = Seq("inp", "wgt")
val tensorDec = Seq(dec.io.isInput, dec.io.isWeight)
val tensorLoad = Seq.tabulate(2)(i => Module(new TensorLoad(tensorType = tensorType(i))))
val start = inst_q.io.deq.valid & Mux(dec.io.pop_next, s.io.sready, true.B)
val done = Mux(dec.io.isInput, tensorLoad(0).io.done, tensorLoad(1).io.done)
// control
switch (state) {
is (sIdle) {
when (start) {
when (dec.io.isSync) {
state := sSync
} .elsewhen (dec.io.isInput || dec.io.isWeight) {
state := sExe
}
}
}
is (sSync) {
state := sIdle
}
is (sExe) {
when (done) {
state := sIdle
}
}
}
// instructions
inst_q.io.enq <> io.inst
inst_q.io.deq.ready := (state === sExe & done) | (state === sSync)
// load tensor
// [0] input (inp)
// [1] weight (wgt)
val ptr = Seq(io.inp_baddr, io.wgt_baddr)
val tsor = Seq(io.inp, io.wgt)
for (i <- 0 until 2) {
tensorLoad(i).io.start := state === sIdle & start & tensorDec(i)
tensorLoad(i).io.inst := inst_q.io.deq.bits
tensorLoad(i).io.baddr := ptr(i)
tensorLoad(i).io.tensor <> tsor(i)
io.vme_rd(i) <> tensorLoad(i).io.vme_rd
}
// semaphore
s.io.spost := io.i_post
s.io.swait := dec.io.pop_next & (state === sIdle & start)
io.o_post := dec.io.push_next & ((state === sExe & done) | (state === sSync))
// debug
if (debug) {
// start
when (state === sIdle && start) {
when (dec.io.isSync) {
printf("[Load] start sync\n")
} .elsewhen (dec.io.isInput) {
printf("[Load] start input\n")
} .elsewhen (dec.io.isWeight) {
printf("[Load] start weight\n")
}
}
// done
when (state === sSync) {
printf("[Load] done sync\n")
}
when (state === sExe) {
when (done) {
when (dec.io.isInput) {
printf("[Load] done input\n")
} .elsewhen (dec.io.isWeight) {
printf("[Load] done weight\n")
}
}
}
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package vta.core
import chisel3._
import chisel3.util._
import vta.util.config._
import vta.shell._
/** UopMaster.
*
* Uop interface used by a master module, i.e. TensorAlu or TensorGemm,
* to request a micro-op (uop) from the uop-scratchpad. The index (idx) is
* used as an address to find the uop in the uop-scratchpad.
*/
class UopMaster(implicit p: Parameters) extends Bundle {
val addrBits = log2Ceil(p(CoreKey).uopMemDepth)
val idx = ValidIO(UInt(addrBits.W))
val data = Flipped(ValidIO(new UopDecode))
override def cloneType = new UopMaster().asInstanceOf[this.type]
}
/** UopClient.
*
* Uop interface used by a client module, i.e. LoadUop, to receive
* a request from a master module, i.e. TensorAlu or TensorGemm.
* The index (idx) is used as an address to find the uop in the uop-scratchpad.
*/
class UopClient(implicit p: Parameters) extends Bundle {
val addrBits = log2Ceil(p(CoreKey).uopMemDepth)
val idx = Flipped(ValidIO(UInt(addrBits.W)))
val data = ValidIO(new UopDecode)
override def cloneType = new UopClient().asInstanceOf[this.type]
}
/** LoadUop.
*
* Load micro-ops (uops) from memory, i.e. DRAM, and store them in the
* uop-scratchpad. Currently, micro-ops are 32-bit wide and loaded in
* group of 2 given the fact that the DRAM payload is 8-bytes. This module
* should be modified later on to support different DRAM sizes efficiently.
*/
class LoadUop(debug: Boolean = false)(implicit p: Parameters) extends Module {
val mp = p(ShellKey).memParams
val io = IO(new Bundle {
val start = Input(Bool())
val done = Output(Bool())
val inst = Input(UInt(INST_BITS.W))
val baddr = Input(UInt(mp.addrBits.W))
val vme_rd = new VMEReadMaster
val uop = new UopClient
})
val numUop = 2 // store two uops per sram word
val uopBits = p(CoreKey).uopBits
val uopDepth = p(CoreKey).uopMemDepth / numUop
val dec = io.inst.asTypeOf(new MemDecode)
val raddr = Reg(chiselTypeOf(io.vme_rd.cmd.bits.addr))
val xcnt = Reg(chiselTypeOf(io.vme_rd.cmd.bits.len))
val xlen = Reg(chiselTypeOf(io.vme_rd.cmd.bits.len))
val xrem = Reg(chiselTypeOf(dec.xsize))
val xsize = dec.xsize(0) + (dec.xsize >> log2Ceil(numUop)) - 1.U
val xmax = (1 << mp.lenBits).U
val xmax_bytes = ((1 << mp.lenBits)*mp.dataBits/8).U
val offsetIsEven = (dec.sram_offset % 2.U) === 0.U
val sizeIsEven = (dec.xsize % 2.U) === 0.U
val sIdle :: sReadCmd :: sReadData :: Nil = Enum(3)
val state = RegInit(sIdle)
// control
switch (state) {
is (sIdle) {
when (io.start) {
state := sReadCmd
when (xsize < xmax) {
xlen := xsize
xrem := 0.U
} .otherwise {
xlen := xmax - 1.U
xrem := xsize - xmax
}
}
}
is (sReadCmd) {
when (io.vme_rd.cmd.ready) {
state := sReadData
}
}
is (sReadData) {
when (io.vme_rd.data.valid) {
when(xcnt === xlen) {
when (xrem === 0.U) {
state := sIdle
} .elsewhen (xrem < xmax) {
state := sReadCmd
xlen := xrem
xrem := 0.U
} .otherwise {
state := sReadCmd
xlen := xmax - 1.U
xrem := xrem - xmax
}
}
}
}
}
// read-from-dram
when (state === sIdle) {
when (offsetIsEven) {
raddr := io.baddr + dec.dram_offset
} .otherwise {
raddr := io.baddr + dec.dram_offset - 4.U
}
} .elsewhen (state === sReadData && xcnt === xlen && xrem =/= 0.U) {
raddr := raddr + xmax_bytes
}
io.vme_rd.cmd.valid := state === sReadCmd
io.vme_rd.cmd.bits.addr := raddr
io.vme_rd.cmd.bits.len := xlen
io.vme_rd.data.ready := state === sReadData
when (state =/= sReadData) {
xcnt := 0.U
} .elsewhen (io.vme_rd.data.fire()) {
xcnt := xcnt + 1.U
}
val waddr = Reg(UInt(log2Ceil(uopDepth).W))
when (state === sIdle) {
waddr := dec.sram_offset >> log2Ceil(numUop)
} .elsewhen (io.vme_rd.data.fire()) {
waddr := waddr + 1.U
}
val wdata = Wire(Vec(numUop, UInt(uopBits.W)))
val mem = SyncReadMem(uopDepth, chiselTypeOf(wdata))
val wmask = Reg(Vec(numUop, Bool()))
when (offsetIsEven) {
when (sizeIsEven) {
wmask := "b_11".U.asTypeOf(wmask)
} .elsewhen (io.vme_rd.cmd.fire()) {
when (dec.xsize === 1.U) {
wmask := "b_01".U.asTypeOf(wmask)
} .otherwise {
wmask := "b_11".U.asTypeOf(wmask)
}
} .elsewhen (io.vme_rd.data.fire()) {
when (xcnt === xlen - 1.U) {
wmask := "b_01".U.asTypeOf(wmask)
} .otherwise {
wmask := "b_11".U.asTypeOf(wmask)
}
}
} .otherwise {
when (io.vme_rd.cmd.fire()) {
wmask := "b_10".U.asTypeOf(wmask)
} .elsewhen (io.vme_rd.data.fire()) {
when (sizeIsEven && xcnt === xlen - 1.U) {
wmask := "b_01".U.asTypeOf(wmask)
} .otherwise {
wmask := "b_11".U.asTypeOf(wmask)
}
}
}
wdata := io.vme_rd.data.bits.asTypeOf(wdata)
when (io.vme_rd.data.fire()) {
mem.write(waddr, wdata, wmask)
}
// read-from-sram
io.uop.data.valid := RegNext(io.uop.idx.valid)
val sIdx = io.uop.idx.bits % numUop.U
val rIdx = io.uop.idx.bits >> log2Ceil(numUop)
val memRead = mem.read(rIdx, io.uop.idx.valid)
val sWord = memRead.asUInt.asTypeOf(wdata)
val sUop = sWord(sIdx).asTypeOf(io.uop.data.bits)
io.uop.data.bits <> sUop
// done
io.done := state === sReadData & io.vme_rd.data.valid & xcnt === xlen & xrem === 0.U
// debug
if (debug) {
when (io.vme_rd.cmd.fire()) {
printf("[LoadUop] cmd addr:%x len:%x rem:%x\n", raddr, xlen, xrem)
}
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package vta.core
import chisel3._
import chisel3.util._
/** Semaphore.
*
* This semaphore is used instead of push/pop fifo, used in the initial
* version of VTA. This semaphore is incremented (spost) or decremented (swait)
* depending on the push and pop fields on instructions to prevent RAW and WAR
* hazards.
*/
class Semaphore(counterBits: Int = 1, counterInitValue: Int = 1) extends Module {
val io = IO(new Bundle {
val spost = Input(Bool())
val swait = Input(Bool())
val sready = Output(Bool())
})
val cnt = RegInit(counterInitValue.U(counterBits.W))
when (io.spost && !io.swait && cnt =/= ((1 << counterBits) - 1).asUInt) { cnt := cnt + 1.U }
when (!io.spost && io.swait && cnt =/= 0.U) { cnt := cnt - 1.U }
io.sready := cnt =/= 0.U
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package vta.core
import chisel3._
import chisel3.util._
import vta.util.config._
import vta.shell._
/** Store.
*
* Store results back to memory (DRAM) from scratchpads (SRAMs).
* This module instantiate the TensorStore unit which is in charge
* of storing 1D and 2D tensors to main memory.
*/
class Store(debug: Boolean = false)(implicit p: Parameters) extends Module {
val mp = p(ShellKey).memParams
val io = IO(new Bundle {
val i_post = Input(Bool())
val o_post = Output(Bool())
val inst = Flipped(Decoupled(UInt(INST_BITS.W)))
val out_baddr = Input(UInt(mp.addrBits.W))
val vme_wr = new VMEWriteMaster
val out = new TensorClient(tensorType = "out")
})
val sIdle :: sSync :: sExe :: Nil = Enum(3)
val state = RegInit(sIdle)
val s = Module(new Semaphore(counterBits = 8, counterInitValue = 0))
val inst_q = Module(new Queue(UInt(INST_BITS.W), p(CoreKey).instQueueEntries))
val dec = Module(new StoreDecode)
dec.io.inst := inst_q.io.deq.bits
val tensorStore = Module(new TensorStore(tensorType = "out"))
val start = inst_q.io.deq.valid & Mux(dec.io.pop_prev, s.io.sready, true.B)
val done = tensorStore.io.done
// control
switch (state) {
is (sIdle) {
when (start) {
when (dec.io.isSync) {
state := sSync
} .elsewhen (dec.io.isStore) {
state := sExe
}
}
}
is (sSync) {
state := sIdle
}
is (sExe) {
when (done) {
state := sIdle
}
}
}
// instructions
inst_q.io.enq <> io.inst
inst_q.io.deq.ready := (state === sExe & done) | (state === sSync)
// store
tensorStore.io.start := state === sIdle & start & dec.io.isStore
tensorStore.io.inst := inst_q.io.deq.bits
tensorStore.io.baddr := io.out_baddr
io.vme_wr <> tensorStore.io.vme_wr
tensorStore.io.tensor <> io.out
// semaphore
s.io.spost := io.i_post
s.io.swait := dec.io.pop_prev & (state === sIdle & start)
io.o_post := dec.io.push_prev & ((state === sExe & done) | (state === sSync))
// debug
if (debug) {
// start
when (state === sIdle && start) {
when (dec.io.isSync) {
printf("[Store] start sync\n")
} .elsewhen (dec.io.isStore) {
printf("[Store] start\n")
}
}
// done
when (state === sSync) {
printf("[Store] done sync\n")
}
when (state === sExe) {
when (done) {
printf("[Store] done\n")
}
}
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package vta.core
import chisel3._
import chisel3.util._
import vta.util.config._
/** ALU datapath */
class Alu(implicit p: Parameters) extends Module {
val aluBits = p(CoreKey).accBits
val io = IO(new Bundle {
val opcode = Input(UInt(C_ALU_OP_BITS.W))
val a = Input(SInt(aluBits.W))
val b = Input(SInt(aluBits.W))
val y = Output(SInt(aluBits.W))
})
// FIXME: the following three will change once we support properly SHR and SHL
val ub = io.b.asUInt
val width = log2Ceil(aluBits)
val m = ~ub(width - 1, 0) + 1.U
val n = ub(width - 1, 0)
val fop = Seq(Mux(io.a < io.b, io.a, io.b),
Mux(io.a < io.b, io.b, io.a),
io.a + io.b,
io.a >> n,
io.a << m)
val opmux = Seq.tabulate(ALU_OP_NUM)(i => ALU_OP(i) -> fop(i))
io.y := MuxLookup(io.opcode, io.a, opmux)
}
/** Pipelined ALU */
class AluReg(implicit p: Parameters) extends Module {
val io = IO(new Bundle {
val opcode = Input(UInt(C_ALU_OP_BITS.W))
val a = Flipped(ValidIO(UInt(p(CoreKey).accBits.W)))
val b = Flipped(ValidIO(UInt(p(CoreKey).accBits.W)))
val y = ValidIO(UInt(p(CoreKey).accBits.W))
})
val alu = Module(new Alu)
val rA = RegEnable(io.a.bits, io.a.valid)
val rB = RegEnable(io.b.bits, io.b.valid)
val valid = RegNext(io.b.valid)
alu.io.opcode := io.opcode
// register input
alu.io.a := rA.asSInt
alu.io.b := rB.asSInt
// output
io.y.valid := valid
io.y.bits := alu.io.y.asUInt
}
/** Vector of pipeline ALUs */
class AluVector(implicit p: Parameters) extends Module {
val io = IO(new Bundle {
val opcode = Input(UInt(C_ALU_OP_BITS.W))
val acc_a = new TensorMasterData(tensorType = "acc")
val acc_b = new TensorMasterData(tensorType = "acc")
val acc_y = new TensorClientData(tensorType = "acc")
val out = new TensorClientData(tensorType = "out")
})
val blockOut = p(CoreKey).blockOut
val f = Seq.fill(blockOut)(Module(new AluReg))
val valid = Wire(Vec(blockOut, Bool()))
for (i <- 0 until blockOut) {
f(i).io.opcode := io.opcode
f(i).io.a.valid := io.acc_a.data.valid
f(i).io.a.bits := io.acc_a.data.bits(0)(i)
f(i).io.b.valid := io.acc_b.data.valid
f(i).io.b.bits := io.acc_b.data.bits(0)(i)
valid(i) := f(i).io.y.valid
io.acc_y.data.bits(0)(i) := f(i).io.y.bits
io.out.data.bits(0)(i) := f(i).io.y.bits
}
io.acc_y.data.valid := valid.asUInt.andR
io.out.data.valid := valid.asUInt.andR
}
/** TensorAlu.
*
* This unit instantiate the ALU vector unit (AluVector) and go over the
* micro-ops (uops) which are used to read the source operands (vectors)
* from the acc-scratchpad and then they are written back the same
* acc-scratchpad.
*/
class TensorAlu(debug: Boolean = false)(implicit p: Parameters) extends Module {
val io = IO(new Bundle {
val start = Input(Bool())
val done = Output(Bool())
val inst = Input(UInt(INST_BITS.W))
val uop = new UopMaster
val acc = new TensorMaster(tensorType = "acc")
val out = new TensorMaster(tensorType = "out")
})
val sIdle :: sReadUop :: sComputeIdx :: sReadTensorA :: sReadTensorB :: sExe :: Nil = Enum(6)
val state = RegInit(sIdle)
val alu = Module(new AluVector)
val dec = io.inst.asTypeOf(new AluDecode)
val uop_idx = Reg(chiselTypeOf(dec.uop_end))
val uop_end = dec.uop_end
val uop_dst = Reg(chiselTypeOf(dec.uop_end))
val uop_src = Reg(chiselTypeOf(dec.uop_end))
val cnt_o = Reg(chiselTypeOf(dec.lp_0))
val dst_o = Reg(chiselTypeOf(dec.uop_end))
val src_o = Reg(chiselTypeOf(dec.uop_end))
val cnt_i = Reg(chiselTypeOf(dec.lp_1))
val dst_i = Reg(chiselTypeOf(dec.uop_end))
val src_i = Reg(chiselTypeOf(dec.uop_end))
val done =
state === sExe &
alu.io.out.data.valid &
(cnt_o === dec.lp_0 - 1.U) &
(cnt_i === dec.lp_1 - 1.U) &
(uop_idx === uop_end - 1.U)
switch (state) {
is (sIdle) {
when (io.start) {
state := sReadUop
}
}
is (sReadUop) {
state := sComputeIdx
}
is (sComputeIdx) {
state := sReadTensorA
}
is (sReadTensorA) {
state := sReadTensorB
}
is (sReadTensorB) {
state := sExe
}
is (sExe) {
when (alu.io.out.data.valid) {
when ((cnt_o === dec.lp_0 - 1.U) &&
(cnt_i === dec.lp_1 - 1.U) &&
(uop_idx === uop_end - 1.U)) {
state := sIdle
} .otherwise {
state := sReadUop
}
}
}
}
when (state === sIdle ||
(state === sExe &&
alu.io.out.data.valid &&
uop_idx === uop_end - 1.U)) {
uop_idx := dec.uop_begin
} .elsewhen (state === sExe && alu.io.out.data.valid) {
uop_idx := uop_idx + 1.U
}
when (state === sIdle) {
cnt_o := 0.U
dst_o := 0.U
src_o := 0.U
} .elsewhen (state === sExe &&
alu.io.out.data.valid &&
uop_idx === uop_end - 1.U &&
cnt_i === dec.lp_1 - 1.U) {
cnt_o := cnt_o + 1.U
dst_o := dst_o + dec.dst_0
src_o := src_o + dec.src_0
}
when (state === sIdle) {
cnt_i := 0.U
dst_i := 0.U
src_i := 0.U
} .elsewhen (state === sReadUop && cnt_i === dec.lp_1) {
cnt_i := 0.U
dst_i := dst_o
src_i := src_o
} .elsewhen (state === sExe &&
alu.io.out.data.valid &&
uop_idx === uop_end - 1.U) {
cnt_i := cnt_i + 1.U
dst_i := dst_i + dec.dst_1
src_i := src_i + dec.src_1
}
when (state === sComputeIdx && io.uop.data.valid) {
uop_dst := io.uop.data.bits.u0 + dst_i
uop_src := io.uop.data.bits.u1 + src_i
}
// uop
io.uop.idx.valid := state === sReadUop
io.uop.idx.bits := uop_idx
// acc_i
io.acc.rd.idx.valid := state === sReadTensorA | (state === sReadTensorB & ~dec.alu_use_imm)
io.acc.rd.idx.bits := Mux(state === sReadTensorA, uop_dst, uop_src)
// imm
val tensorImm = Wire(new TensorClientData(tensorType = "acc"))
tensorImm.data.valid := state === sReadTensorB
tensorImm.data.bits.foreach { b => b.foreach { c => c := dec.alu_imm } }
// alu
val isSHR = dec.alu_op === ALU_OP(3)
val neg_shift = isSHR & dec.alu_imm(C_ALU_IMM_BITS-1)
val fixme_alu_op = Cat(neg_shift, Mux(neg_shift, 0.U, dec.alu_op))
alu.io.opcode := fixme_alu_op
alu.io.acc_a.data.valid := io.acc.rd.data.valid & state === sReadTensorB
alu.io.acc_a.data.bits <> io.acc.rd.data.bits
alu.io.acc_b.data.valid := Mux(dec.alu_use_imm, tensorImm.data.valid, io.acc.rd.data.valid & state === sExe)
alu.io.acc_b.data.bits <> Mux(dec.alu_use_imm, tensorImm.data.bits, io.acc.rd.data.bits)
// acc_o
io.acc.wr.valid := alu.io.acc_y.data.valid
io.acc.wr.bits.idx := uop_dst
io.acc.wr.bits.data <> alu.io.acc_y.data.bits
// out
io.out.wr.valid := alu.io.out.data.valid
io.out.wr.bits.idx := uop_dst
io.out.wr.bits.data <> alu.io.out.data.bits
io.out.tieoffRead() // write-only
io.done := done
if (debug) {
when (state === sReadUop) {
printf("[TensorAlu] [uop] idx:%x\n", uop_idx)
}
when (state === sReadTensorA) {
printf("[TensorAlu] [uop] dst:%x src:%x\n", uop_dst, uop_src)
}
when (state === sIdle && io.start) {
printf(p"[TensorAlu] decode:$dec\n")
}
alu.io.acc_a.data.bits.foreach { tensor =>
tensor.zipWithIndex.foreach { case(elem, i) =>
when (alu.io.acc_a.data.valid) {
printf("[TensorAlu] [a] i:%x val:%x\n", i.U, elem)
}
}
}
alu.io.acc_b.data.bits.foreach { tensor =>
tensor.zipWithIndex.foreach { case(elem, i) =>
when (alu.io.acc_b.data.valid) {
printf("[TensorAlu] [b] i:%x val:%x\n", i.U, elem)
}
}
}
alu.io.acc_y.data.bits.foreach { tensor =>
tensor.zipWithIndex.foreach { case(elem, i) =>
when (alu.io.acc_y.data.valid) {
printf("[TensorAlu] [y] i:%x val:%x\n", i.U, elem)
}
}
}
alu.io.out.data.bits.foreach { tensor =>
tensor.zipWithIndex.foreach { case(elem, i) =>
when (alu.io.out.data.valid) {
printf("[TensorAlu] [out] i:%x val:%x\n", i.U, elem)
}
}
}
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package vta.core
import chisel3._
import chisel3.util._
import chisel3.experimental._
import vta.util.config._
import scala.math.pow
/** Pipelined multiply and accumulate */
class MAC(dataBits: Int = 8, cBits: Int = 16, outBits: Int = 17) extends Module {
require (cBits >= dataBits * 2)
require (outBits >= dataBits * 2)
val io = IO(new Bundle {
val a = Input(SInt(dataBits.W))
val b = Input(SInt(dataBits.W))
val c = Input(SInt(cBits.W))
val y = Output(SInt(outBits.W))
})
val mult = Wire(SInt(cBits.W))
val add = Wire(SInt(outBits.W))
val rA = RegNext(io.a)
val rB = RegNext(io.b)
val rC = RegNext(io.c)
mult := rA * rB
add := rC + mult
io.y := add
}
/** Pipelined adder */
class Adder(dataBits: Int = 8, outBits: Int = 17) extends Module {
require (outBits >= dataBits)
val io = IO(new Bundle {
val a = Input(SInt(dataBits.W))
val b = Input(SInt(dataBits.W))
val y = Output(SInt(outBits.W))
})
val add = Wire(SInt(outBits.W))
val rA = RegNext(io.a)
val rB = RegNext(io.b)
add := rA + rB
io.y := add
}
/** Pipelined DotProduct based on MAC and Adder */
class DotProduct(dataBits: Int = 8, size: Int = 16) extends Module {
val errMsg = s"\n\n[VTA] [DotProduct] size must be greater than 4 and a power of 2\n\n"
require(size >= 4 && isPow2(size), errMsg)
val b = dataBits * 2
val outBits = b + log2Ceil(size) + 1
val io = IO(new Bundle {
val a = Input(Vec(size, SInt(dataBits.W)))
val b = Input(Vec(size, SInt(dataBits.W)))
val y = Output(SInt(outBits.W))
})
val p = log2Ceil(size/2)
val s = Seq.tabulate(log2Ceil(size))(i => pow(2, p - i).toInt)
val da = Seq.tabulate(s(0))(i => RegNext(io.a(s(0) + i)))
val db = Seq.tabulate(s(0))(i => RegNext(io.b(s(0) + i)))
val m = Seq.tabulate(2)(i =>
Seq.fill(s(0))(Module(new MAC(dataBits = dataBits, cBits = b + i, outBits = b + i + 1)))
)
val a = Seq.tabulate(p)(i =>
Seq.fill(s(i + 1))(Module(new Adder(dataBits = b + i + 2, outBits = b + i + 3)))
)
for (i <- 0 until log2Ceil(size)) {
for (j <- 0 until s(i)) {
if (i == 0) {
m(i)(j).io.a := io.a(j)
m(i)(j).io.b := io.b(j)
m(i)(j).io.c := 0.S
m(i + 1)(j).io.a := da(j)
m(i + 1)(j).io.b := db(j)
m(i + 1)(j).io.c := m(i)(j).io.y
} else if (i == 1) {
a(i - 1)(j).io.a := m(i)(2*j).io.y
a(i - 1)(j).io.b := m(i)(2*j + 1).io.y
} else {
a(i - 1)(j).io.a := a(i - 2)(2*j).io.y
a(i - 1)(j).io.b := a(i - 2)(2*j + 1).io.y
}
}
}
io.y := a(p-1)(0).io.y
}
/** Perform matric-vector-multiplication based on DotProduct */
class MatrixVectorCore(implicit p: Parameters) extends Module {
val accBits = p(CoreKey).accBits
val size = p(CoreKey).blockOut
val dataBits = p(CoreKey).inpBits
val io = IO(new Bundle{
val reset = Input(Bool()) // FIXME: reset should be replaced by a load-acc instr
val inp = new TensorMasterData(tensorType = "inp")
val wgt = new TensorMasterData(tensorType = "wgt")
val acc_i = new TensorMasterData(tensorType = "acc")
val acc_o = new TensorClientData(tensorType = "acc")
val out = new TensorClientData(tensorType = "out")
})
val dot = Seq.fill(size)(Module(new DotProduct(dataBits, size)))
val acc = Seq.fill(size)(Module(new Pipe(UInt(accBits.W), latency = log2Ceil(size) + 1)))
val add = Seq.fill(size)(Wire(SInt(accBits.W)))
val vld = Wire(Vec(size, Bool()))
for (i <- 0 until size) {
acc(i).io.enq.valid := io.inp.data.valid & io.wgt.data.valid & io.acc_i.data.valid & ~io.reset
acc(i).io.enq.bits := io.acc_i.data.bits(0)(i)
for (j <- 0 until size) {
dot(i).io.a(j) := io.inp.data.bits(0)(j).asSInt
dot(i).io.b(j) := io.wgt.data.bits(i)(j).asSInt
}
add(i) := acc(i).io.deq.bits.asSInt + dot(i).io.y
io.acc_o.data.bits(0)(i) := Mux(io.reset, 0.U, add(i).asUInt)
io.out.data.bits(0)(i) := add(i).asUInt
vld(i) := acc(i).io.deq.valid
}
io.acc_o.data.valid := vld.asUInt.andR | io.reset
io.out.data.valid := vld.asUInt.andR
}
/** TensorGemm.
*
* This unit instantiate the MatrixVectorCore and go over the
* micro-ops (uops) which are used to read inputs, weights and biases,
* and writes results back to the acc and out scratchpads.
*
* Also, the TensorGemm uses the reset field in the Gemm instruction to
* clear or zero-out the acc-scratchpad locations based on the micro-ops.
*/
class TensorGemm(debug: Boolean = false)(implicit p: Parameters) extends Module {
val io = IO(new Bundle {
val start = Input(Bool())
val done = Output(Bool())
val inst = Input(UInt(INST_BITS.W))
val uop = new UopMaster
val inp = new TensorMaster(tensorType = "inp")
val wgt = new TensorMaster(tensorType = "wgt")
val acc = new TensorMaster(tensorType = "acc")
val out = new TensorMaster(tensorType = "out")
})
val sIdle :: sReadUop :: sComputeIdx :: sReadTensor :: sExe :: sWait :: Nil = Enum(6)
val state = RegInit(sIdle)
val mvc = Module(new MatrixVectorCore)
val dec = io.inst.asTypeOf(new GemmDecode)
val uop_idx = Reg(chiselTypeOf(dec.uop_end))
val uop_end = dec.uop_end
val uop_acc = Reg(chiselTypeOf(dec.uop_end))
val uop_inp = Reg(chiselTypeOf(dec.uop_end))
val uop_wgt = Reg(chiselTypeOf(dec.uop_end))
val cnt_o = Reg(chiselTypeOf(dec.lp_0))
val acc_o = Reg(chiselTypeOf(dec.uop_end))
val inp_o = Reg(chiselTypeOf(dec.uop_end))
val wgt_o = Reg(chiselTypeOf(dec.uop_end))
val cnt_i = Reg(chiselTypeOf(dec.lp_1))
val acc_i = Reg(chiselTypeOf(dec.uop_end))
val inp_i = Reg(chiselTypeOf(dec.uop_end))
val wgt_i = Reg(chiselTypeOf(dec.uop_end))
val pBits = log2Ceil(p(CoreKey).blockOut) + 1
val inflight = Reg(UInt(pBits.W))
val wrpipe = Module(new Pipe(chiselTypeOf(dec.uop_end), latency = pBits))
val done = inflight === 0.U &
((state === sExe &
cnt_o === dec.lp_0 - 1.U &
cnt_i === dec.lp_1 - 1.U &
uop_idx === uop_end - 1.U &
inflight === 0.U) |
state === sWait)
switch (state) {
is (sIdle) {
when (io.start) {
state := sReadUop
}
}
is (sReadUop) {
state := sComputeIdx
}
is (sComputeIdx) {
state := sReadTensor
}
is (sReadTensor) {
state := sExe
}
is (sExe) {
when ((cnt_o === dec.lp_0 - 1.U) &&
(cnt_i === dec.lp_1 - 1.U) &&
(uop_idx === uop_end - 1.U)) {
when (inflight =/= 0.U) {
state := sWait
} .otherwise {
state := sIdle
}
} .otherwise {
state := sReadUop
}
}
is (sWait) {
when (inflight === 0.U) {
state := sIdle
}
}
}
when (state === sIdle) {
inflight := 0.U
} .elsewhen (!dec.reset) {
when (state === sExe && inflight =/= ((1 << pBits) - 1).asUInt) { // overflow check
inflight := inflight + 1.U
} .elsewhen (mvc.io.acc_o.data.valid && inflight =/= 0.U) { // underflow check
inflight := inflight - 1.U
}
}
when (state === sIdle ||
(state === sExe &&
uop_idx === uop_end - 1.U)) {
uop_idx := dec.uop_begin
} .elsewhen (state === sExe) {
uop_idx := uop_idx + 1.U
}
when (state === sIdle) {
cnt_o := 0.U
acc_o := 0.U
inp_o := 0.U
wgt_o := 0.U
} .elsewhen (state === sExe &&
uop_idx === uop_end - 1.U &&
cnt_i === dec.lp_1 - 1.U) {
cnt_o := cnt_o + 1.U
acc_o := acc_o + dec.acc_0
inp_o := inp_o + dec.inp_0
wgt_o := wgt_o + dec.wgt_0
}
when (state === sIdle) {
cnt_i := 0.U
acc_i := 0.U
inp_i := 0.U
wgt_i := 0.U
} .elsewhen (state === sReadUop && cnt_i === dec.lp_1) {
cnt_i := 0.U
acc_i := acc_o
inp_i := inp_o
wgt_i := wgt_o
} .elsewhen (state === sExe &&
uop_idx === uop_end - 1.U) {
cnt_i := cnt_i + 1.U
acc_i := acc_i + dec.acc_1
inp_i := inp_i + dec.inp_1
wgt_i := wgt_i + dec.wgt_1
}
when (state === sComputeIdx && io.uop.data.valid) {
uop_acc := io.uop.data.bits.u0 + acc_i
uop_inp := io.uop.data.bits.u1 + inp_i
uop_wgt := io.uop.data.bits.u2 + wgt_i
}
wrpipe.io.enq.valid := state === sExe & ~dec.reset
wrpipe.io.enq.bits := uop_acc
// uop
io.uop.idx.valid := state === sReadUop
io.uop.idx.bits := uop_idx
// inp
io.inp.rd.idx.valid := state === sReadTensor
io.inp.rd.idx.bits := uop_inp
io.inp.tieoffWrite() // read-only
// wgt
io.wgt.rd.idx.valid := state === sReadTensor
io.wgt.rd.idx.bits := uop_wgt
io.wgt.tieoffWrite() // read-only
// acc_i
io.acc.rd.idx.valid := state === sReadTensor
io.acc.rd.idx.bits := uop_acc
// mvc
mvc.io.reset := dec.reset & state === sExe
mvc.io.inp.data <> io.inp.rd.data
mvc.io.wgt.data <> io.wgt.rd.data
mvc.io.acc_i.data <> io.acc.rd.data
// acc_o
io.acc.wr.valid := mvc.io.acc_o.data.valid & Mux(dec.reset, true.B, wrpipe.io.deq.valid)
io.acc.wr.bits.idx := Mux(dec.reset, uop_acc, wrpipe.io.deq.bits)
io.acc.wr.bits.data <> mvc.io.acc_o.data.bits
// out
io.out.wr.valid := mvc.io.out.data.valid & wrpipe.io.deq.valid
io.out.wr.bits.idx := wrpipe.io.deq.bits
io.out.wr.bits.data <> mvc.io.out.data.bits
io.out.tieoffRead() // write-only
io.done := done
if (debug) {
when (state === sReadUop && ~dec.reset) {
printf("[TensorGemm] [uop] idx:%x\n", uop_idx)
}
when (state === sReadTensor && ~dec.reset) {
printf("[TensorGemm] [uop] acc:%x inp:%x wgt:%x\n", uop_acc, uop_inp, uop_wgt)
}
io.inp.rd.data.bits.zipWithIndex.foreach { case(r, i) =>
when (io.inp.rd.data.valid && ~dec.reset) {
printf("[TensorGemm] [inp] i:%x val:%x\n", i.U, r.asUInt)
}
}
io.wgt.rd.data.bits.zipWithIndex.foreach { case(r, i) =>
when (io.wgt.rd.data.valid && ~dec.reset) {
printf("[TensorGemm] [wgt] i:%x val:%x\n", i.U, r.asUInt)
}
}
io.acc.rd.data.bits.foreach { tensor =>
tensor.zipWithIndex.foreach { case(elem, i) =>
when (io.acc.rd.data.valid && ~dec.reset) {
printf("[TensorGemm] [acc_i] i:%x val:%x\n", i.U, elem)
}
}
}
mvc.io.acc_o.data.bits.foreach { tensor =>
tensor.zipWithIndex.foreach { case(elem, i) =>
when (mvc.io.acc_o.data.valid && ~dec.reset) {
printf("[TensorGemm] [acc_o] i:%x val:%x\n", i.U, elem)
}
}
}
mvc.io.out.data.bits.foreach { tensor =>
tensor.zipWithIndex.foreach { case(elem, i) =>
when (mvc.io.out.data.valid && ~dec.reset) {
printf("[TensorGemm] [out] i:%x val:%x\n", i.U, elem)
}
}
}
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package vta.core
import chisel3._
import chisel3.util._
import vta.util.config._
import vta.shell._
/** TensorStore.
*
* Load 1D and 2D tensors from main memory (DRAM) to input/weight
* scratchpads (SRAM). Also, there is support for zero padding, while
* doing the load. Zero-padding works on the y and x axis, and it is
* managed by TensorPadCtrl. The TensorDataCtrl is in charge of
* handling the way tensors are stored on the scratchpads.
*/
class TensorLoad(tensorType: String = "none", debug: Boolean = false)
(implicit p: Parameters) extends Module {
val tp = new TensorParams(tensorType)
val mp = p(ShellKey).memParams
val io = IO(new Bundle {
val start = Input(Bool())
val done = Output(Bool())
val inst = Input(UInt(INST_BITS.W))
val baddr = Input(UInt(mp.addrBits.W))
val vme_rd = new VMEReadMaster
val tensor = new TensorClient(tensorType)
})
val sizeFactor = tp.tensorLength * tp.numMemBlock
val strideFactor = tp.tensorLength * tp.tensorWidth
val dec = io.inst.asTypeOf(new MemDecode)
val dataCtrl = Module(new TensorDataCtrl(sizeFactor, strideFactor))
val dataCtrlDone = RegInit(false.B)
val yPadCtrl0 = Module(new TensorPadCtrl(padType = "YPad0", sizeFactor))
val yPadCtrl1 = Module(new TensorPadCtrl(padType = "YPad1", sizeFactor))
val xPadCtrl0 = Module(new TensorPadCtrl(padType = "XPad0", sizeFactor))
val xPadCtrl1 = Module(new TensorPadCtrl(padType = "XPad1", sizeFactor))
val tag = Reg(UInt(8.W))
val set = Reg(UInt(8.W))
val sIdle :: sYPad0 :: sXPad0 :: sReadCmd :: sReadData :: sXPad1 :: sYPad1 :: Nil = Enum(7)
val state = RegInit(sIdle)
// control
switch (state) {
is (sIdle) {
when (io.start) {
when (dec.ypad_0 =/= 0.U) {
state := sYPad0
} .elsewhen (dec.xpad_0 =/= 0.U) {
state := sXPad0
} .otherwise {
state := sReadCmd
}
}
}
is (sYPad0) {
when (yPadCtrl0.io.done) {
when (dec.xpad_0 =/= 0.U) {
state := sXPad0
} .otherwise {
state := sReadCmd
}
}
}
is (sXPad0) {
when (xPadCtrl0.io.done) {
state := sReadCmd
}
}
is (sReadCmd) {
when (io.vme_rd.cmd.ready) {
state := sReadData
}
}
is (sReadData) {
when (io.vme_rd.data.valid) {
when (dataCtrl.io.done) {
when (dec.xpad_1 =/= 0.U) {
state := sXPad1
} .elsewhen (dec.ypad_1 =/= 0.U) {
state := sYPad1
} .otherwise {
state := sIdle
}
} .elsewhen (dataCtrl.io.stride || dataCtrl.io.split) {
when (dec.xpad_1 =/= 0.U) {
state := sXPad1
} .elsewhen (dec.xpad_0 =/= 0.U) {
state := sXPad0
} .otherwise {
state := sReadCmd
}
}
}
}
is (sXPad1) {
when (xPadCtrl1.io.done) {
when (dataCtrlDone) {
when (dec.ypad_1 =/= 0.U) {
state := sYPad1
} .otherwise {
state := sIdle
}
} .otherwise {
when (dec.xpad_0 =/= 0.U) {
state := sXPad0
} .otherwise {
state := sReadCmd
}
}
}
}
is (sYPad1) {
when (yPadCtrl1.io.done && dataCtrlDone) {
state := sIdle
}
}
}
// data controller
dataCtrl.io.start := state === sIdle & io.start
dataCtrl.io.inst := io.inst
dataCtrl.io.baddr := io.baddr
dataCtrl.io.xinit := io.vme_rd.cmd.fire()
dataCtrl.io.xupdate := io.vme_rd.data.fire()
dataCtrl.io.yupdate := io.vme_rd.data.fire()
when (state === sIdle) {
dataCtrlDone := false.B
} .elsewhen (io.vme_rd.data.fire() && dataCtrl.io.done) {
dataCtrlDone := true.B
}
// pad
yPadCtrl0.io.start := dec.ypad_0 =/= 0.U & state === sIdle & io.start
yPadCtrl1.io.start := dec.ypad_1 =/= 0.U &
((io.vme_rd.data.fire() & dataCtrl.io.done & dec.xpad_1 === 0.U) |
(state === sXPad1 & xPadCtrl1.io.done & dataCtrlDone))
xPadCtrl0.io.start := dec.xpad_0 =/= 0.U &
((state === sIdle & io.start) |
(state === sYPad0 & yPadCtrl0.io.done) |
(io.vme_rd.data.fire() & ~dataCtrlDone & (dataCtrl.io.stride | dataCtrl.io.split) & dec.xpad_1 === 0.U) |
(state === sXPad1 & xPadCtrl1.io.done & ~dataCtrlDone))
xPadCtrl1.io.start := dec.xpad_1 =/= 0.U & io.vme_rd.data.fire() &
((dataCtrl.io.done) |
(~dataCtrl.io.done & (dataCtrl.io.stride | dataCtrl.io.split) & dec.xpad_1 =/= 0.U))
yPadCtrl0.io.inst := io.inst
yPadCtrl1.io.inst := io.inst
xPadCtrl0.io.inst := io.inst
xPadCtrl1.io.inst := io.inst
// read-from-dram
io.vme_rd.cmd.valid := state === sReadCmd
io.vme_rd.cmd.bits.addr := dataCtrl.io.addr
io.vme_rd.cmd.bits.len := dataCtrl.io.len
io.vme_rd.data.ready := state === sReadData
// write-to-sram
val isZeroPad = state === sYPad0 |
state === sXPad0 |
state === sXPad1 |
state === sYPad1
when (state === sIdle || state === sReadCmd || tag === (tp.numMemBlock - 1).U) {
tag := 0.U
} .elsewhen (io.vme_rd.data.fire() || isZeroPad) {
tag := tag + 1.U
}
when (state === sIdle || state === sReadCmd || (set === (tp.tensorLength - 1).U && tag === (tp.numMemBlock - 1).U)) {
set := 0.U
} .elsewhen ((io.vme_rd.data.fire() || isZeroPad) && tag === (tp.numMemBlock - 1).U) {
set := set + 1.U
}
val waddr_cur = Reg(UInt(tp.memAddrBits.W))
val waddr_nxt = Reg(UInt(tp.memAddrBits.W))
when (state === sIdle) {
waddr_cur := dec.sram_offset
waddr_nxt := dec.sram_offset
} .elsewhen ((io.vme_rd.data.fire() || isZeroPad) && set === (tp.tensorLength - 1).U && tag === (tp.numMemBlock - 1).U) {
waddr_cur := waddr_cur + 1.U
} .elsewhen (dataCtrl.io.stride) {
waddr_cur := waddr_nxt + dec.xsize
waddr_nxt := waddr_nxt + dec.xsize
}
val tensorFile = Seq.fill(tp.tensorLength) { SyncReadMem(tp.memDepth, Vec(tp.numMemBlock, UInt(tp.memBlockBits.W))) }
val wmask = Seq.fill(tp.tensorLength) { Wire(Vec(tp.numMemBlock, Bool())) }
val wdata = Seq.fill(tp.tensorLength) { Wire(Vec(tp.numMemBlock, UInt(tp.memBlockBits.W))) }
val no_mask = Wire(Vec(tp.numMemBlock, Bool()))
no_mask.foreach { m => m := true.B }
for (i <- 0 until tp.tensorLength) {
for (j <- 0 until tp.numMemBlock) {
wmask(i)(j) := tag === j.U
wdata(i)(j) := Mux(isZeroPad, 0.U, io.vme_rd.data.bits)
}
val tdata = io.tensor.wr.bits.data(i).asUInt.asTypeOf(wdata(i))
val muxWen = Mux(state === sIdle, io.tensor.wr.valid, (io.vme_rd.data.fire() | isZeroPad) & set === i.U)
val muxWaddr = Mux(state === sIdle, io.tensor.wr.bits.idx, waddr_cur)
val muxWdata = Mux(state === sIdle, tdata, wdata(i))
val muxWmask = Mux(state === sIdle, no_mask, wmask(i))
when (muxWen) {
tensorFile(i).write(muxWaddr, muxWdata, muxWmask)
}
}
// read-from-sram
val rvalid = RegNext(io.tensor.rd.idx.valid)
io.tensor.rd.data.valid := rvalid
val rdata = tensorFile.map(_.read(io.tensor.rd.idx.bits, io.tensor.rd.idx.valid))
rdata.zipWithIndex.foreach { case(r, i) =>
io.tensor.rd.data.bits(i) := r.asUInt.asTypeOf(io.tensor.rd.data.bits(i))
}
// done
val done_no_pad = io.vme_rd.data.fire() & dataCtrl.io.done & dec.xpad_1 === 0.U & dec.ypad_1 === 0.U
val done_x_pad = state === sXPad1 & xPadCtrl1.io.done & dataCtrlDone & dec.ypad_1 === 0.U
val done_y_pad = state === sYPad1 & dataCtrlDone & yPadCtrl1.io.done
io.done := done_no_pad | done_x_pad | done_y_pad
// debug
if (debug) {
if (tensorType == "inp") {
when (io.vme_rd.cmd.fire()) {
printf("[TensorLoad] [inp] cmd addr:%x len:%x\n", dataCtrl.io.addr, dataCtrl.io.len)
}
when (state === sYPad0) {
printf("[TensorLoad] [inp] sYPad0\n")
}
when (state === sYPad1) {
printf("[TensorLoad] [inp] sYPad1\n")
}
when (state === sXPad0) {
printf("[TensorLoad] [inp] sXPad0\n")
}
when (state === sXPad1) {
printf("[TensorLoad] [inp] sXPad1\n")
}
} else if (tensorType == "wgt") {
when (io.vme_rd.cmd.fire()) {
printf("[TensorLoad] [wgt] cmd addr:%x len:%x\n", dataCtrl.io.addr, dataCtrl.io.len)
}
} else if (tensorType == "acc") {
when (io.vme_rd.cmd.fire()) {
printf("[TensorLoad] [acc] cmd addr:%x len:%x\n", dataCtrl.io.addr, dataCtrl.io.len)
}
}
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package vta.core
import chisel3._
import chisel3.util._
import vta.util.config._
import vta.shell._
/** TensorStore.
*
* Store 1D and 2D tensors from out-scratchpad (SRAM) to main memory (DRAM).
*/
class TensorStore(tensorType: String = "true", debug: Boolean = false)
(implicit p: Parameters) extends Module {
val tp = new TensorParams(tensorType)
val mp = p(ShellKey).memParams
val io = IO(new Bundle {
val start = Input(Bool())
val done = Output(Bool())
val inst = Input(UInt(INST_BITS.W))
val baddr = Input(UInt(mp.addrBits.W))
val vme_wr = new VMEWriteMaster
val tensor = new TensorClient(tensorType)
})
val tensorLength = tp.tensorLength
val tensorWidth = tp.tensorWidth
val tensorElemBits = tp.tensorElemBits
val memBlockBits = tp.memBlockBits
val memDepth = tp.memDepth
val numMemBlock = tp.numMemBlock
val dec = io.inst.asTypeOf(new MemDecode)
val waddr_cur = Reg(chiselTypeOf(io.vme_wr.cmd.bits.addr))
val waddr_nxt = Reg(chiselTypeOf(io.vme_wr.cmd.bits.addr))
val xcnt = Reg(chiselTypeOf(io.vme_wr.cmd.bits.len))
val xlen = Reg(chiselTypeOf(io.vme_wr.cmd.bits.len))
val xrem = Reg(chiselTypeOf(dec.xsize))
val xsize = (dec.xsize << log2Ceil(tensorLength*numMemBlock)) - 1.U
val xmax = (1 << mp.lenBits).U
val xmax_bytes = ((1 << mp.lenBits)*mp.dataBits/8).U
val ycnt = Reg(chiselTypeOf(dec.ysize))
val ysize = dec.ysize
val tag = Reg(UInt(8.W))
val set = Reg(UInt(8.W))
val sIdle :: sWriteCmd :: sWriteData :: sReadMem :: sWriteAck :: Nil = Enum(5)
val state = RegInit(sIdle)
// control
switch (state) {
is (sIdle) {
when (io.start) {
state := sWriteCmd
when (xsize < xmax) {
xlen := xsize
xrem := 0.U
} .otherwise {
xlen := xmax - 1.U
xrem := xsize - xmax
}
}
}
is (sWriteCmd) {
when (io.vme_wr.cmd.ready) {
state := sWriteData
}
}
is (sWriteData) {
when (io.vme_wr.data.ready) {
when (xcnt === xlen) {
state := sWriteAck
} .elsewhen (tag === (numMemBlock - 1).U) {
state := sReadMem
}
}
}
is (sReadMem) {
state := sWriteData
}
is (sWriteAck) {
when (io.vme_wr.ack) {
when (xrem === 0.U) {
when (ycnt === ysize - 1.U) {
state := sIdle
} .otherwise {
state := sWriteCmd
when (xsize < xmax) {
xlen := xsize
xrem := 0.U
} .otherwise {
xlen := xmax - 1.U
xrem := xsize - xmax
}
}
} .elsewhen (xrem < xmax) {
state := sWriteCmd
xlen := xrem
xrem := 0.U
} .otherwise {
state := sWriteCmd
xlen := xmax - 1.U
xrem := xrem - xmax
}
}
}
}
// write-to-sram
val tensorFile = Seq.fill(tensorLength) { SyncReadMem(memDepth, Vec(numMemBlock, UInt(memBlockBits.W))) }
val wdata_t = Wire(Vec(numMemBlock, UInt(memBlockBits.W)))
val no_mask = Wire(Vec(numMemBlock, Bool()))
wdata_t := DontCare
no_mask.foreach { m => m := true.B }
for (i <- 0 until tensorLength) {
val inWrData = io.tensor.wr.bits.data(i).asUInt.asTypeOf(wdata_t)
when (io.tensor.wr.valid) {
tensorFile(i).write(io.tensor.wr.bits.idx, inWrData, no_mask)
}
}
// read-from-sram
val stride = state === sWriteAck &
io.vme_wr.ack &
xcnt === xlen + 1.U &
xrem === 0.U &
ycnt =/= ysize - 1.U
when (state === sIdle) {
ycnt := 0.U
} .elsewhen (stride) {
ycnt := ycnt + 1.U
}
when (state === sWriteCmd || tag === (numMemBlock - 1).U) {
tag := 0.U
} .elsewhen (io.vme_wr.data.fire()) {
tag := tag + 1.U
}
when (state === sWriteCmd || (set === (tensorLength - 1).U && tag === (numMemBlock - 1).U)) {
set := 0.U
} .elsewhen (io.vme_wr.data.fire() && tag === (numMemBlock - 1).U) {
set := set + 1.U
}
val raddr_cur = Reg(UInt(tp.memAddrBits.W))
val raddr_nxt = Reg(UInt(tp.memAddrBits.W))
when (state === sIdle) {
raddr_cur := dec.sram_offset
raddr_nxt := dec.sram_offset
} .elsewhen (io.vme_wr.data.fire() && set === (tensorLength - 1).U && tag === (numMemBlock - 1).U) {
raddr_cur := raddr_cur + 1.U
} .elsewhen (stride) {
raddr_cur := raddr_nxt + dec.xsize
raddr_nxt := raddr_nxt + dec.xsize
}
val tread = Seq.tabulate(tensorLength) { i => i.U ->
tensorFile(i).read(raddr_cur, state === sWriteCmd | state === sReadMem) }
val mdata = MuxLookup(set, 0.U.asTypeOf(chiselTypeOf(wdata_t)), tread)
// write-to-dram
when (state === sIdle) {
waddr_cur := io.baddr + dec.dram_offset
waddr_nxt := io.baddr + dec.dram_offset
} .elsewhen (state === sWriteAck && io.vme_wr.ack && xrem =/= 0.U) {
waddr_cur := waddr_cur + xmax_bytes
} .elsewhen (stride) {
waddr_cur := waddr_nxt + (dec.xstride << log2Ceil(tensorLength*tensorWidth))
waddr_nxt := waddr_nxt + (dec.xstride << log2Ceil(tensorLength*tensorWidth))
}
io.vme_wr.cmd.valid := state === sWriteCmd
io.vme_wr.cmd.bits.addr := waddr_cur
io.vme_wr.cmd.bits.len := xlen
io.vme_wr.data.valid := state === sWriteData
io.vme_wr.data.bits := mdata(tag)
when (state === sWriteCmd) {
xcnt := 0.U
} .elsewhen (io.vme_wr.data.fire()) {
xcnt := xcnt + 1.U
}
// disable external read-from-sram requests
io.tensor.tieoffRead()
// done
io.done := state === sWriteAck & io.vme_wr.ack & xrem === 0.U & ycnt === ysize - 1.U
// debug
if (debug) {
when (io.vme_wr.cmd.fire()) {
printf("[TensorStore] ysize:%x ycnt:%x raddr:%x waddr:%x len:%x rem:%x\n", ysize, ycnt, raddr_cur, waddr_cur, xlen, xrem)
}
when (io.vme_wr.data.fire()) {
printf("[TensorStore] data:%x\n", io.vme_wr.data.bits)
}
when (io.vme_wr.ack) {
printf("[TensorStore] ack\n")
}
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package vta.core
import chisel3._
import chisel3.util._
import vta.util.config._
import vta.shell._
/** TensorParams.
*
* This Bundle derives parameters for each tensorType, including inputs (inp),
* weights (wgt), biases (acc), and outputs (out). This is used to avoid
* doing the same boring calculations over and over again.
*/
class TensorParams(tensorType: String = "none")(implicit p: Parameters) extends Bundle {
val errorMsg = s"\n\n[VTA] [TensorParams] only inp, wgt, acc, and out supported\n\n"
require (tensorType == "inp" || tensorType == "wgt"
|| tensorType == "acc" || tensorType == "out", errorMsg)
val (tensorLength, tensorWidth, tensorElemBits) =
if (tensorType == "inp")
(p(CoreKey).batch, p(CoreKey).blockIn, p(CoreKey).inpBits)
else if (tensorType == "wgt")
(p(CoreKey).blockOut, p(CoreKey).blockIn, p(CoreKey).wgtBits)
else if (tensorType == "acc")
(p(CoreKey).batch, p(CoreKey).blockOut, p(CoreKey).accBits)
else
(p(CoreKey).batch, p(CoreKey).blockOut, p(CoreKey).outBits)
val memBlockBits = p(ShellKey).memParams.dataBits
val numMemBlock = (tensorWidth * tensorElemBits) / memBlockBits
val memDepth =
if (tensorType == "inp")
p(CoreKey).inpMemDepth
else if (tensorType == "wgt")
p(CoreKey).wgtMemDepth
else if (tensorType == "acc")
p(CoreKey).accMemDepth
else
p(CoreKey).outMemDepth
val memAddrBits = log2Ceil(memDepth)
}
/** TensorMaster.
*
* This interface issue read and write tensor-requests to scratchpads. For example,
* The TensorGemm unit uses this interface for managing the inputs (inp), weights (wgt),
* biases (acc), and outputs (out).
*
*/
class TensorMaster(tensorType: String = "none")
(implicit p: Parameters) extends TensorParams(tensorType) {
val rd = new Bundle {
val idx = ValidIO(UInt(memAddrBits.W))
val data = Flipped(ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))))
}
val wr = ValidIO(new Bundle {
val idx = UInt(memAddrBits.W)
val data = Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))
})
def tieoffRead() {
rd.idx.valid := false.B
rd.idx.bits := 0.U
}
def tieoffWrite() {
wr.valid := false.B
wr.bits.idx := 0.U
wr.bits.data.foreach { b => b.foreach { c => c := 0.U } }
}
override def cloneType =
new TensorMaster(tensorType).asInstanceOf[this.type]
}
/** TensorClient.
*
* This interface receives read and write tensor-requests to scratchpads. For example,
* The TensorLoad unit uses this interface for receiving read and write requests from
* the TensorGemm unit.
*/
class TensorClient(tensorType: String = "none")
(implicit p: Parameters) extends TensorParams(tensorType) {
val rd = new Bundle {
val idx = Flipped(ValidIO(UInt(memAddrBits.W)))
val data = ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W))))
}
val wr = Flipped(ValidIO(new Bundle {
val idx = UInt(memAddrBits.W)
val data = Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))
}))
def tieoffRead() {
rd.data.valid := false.B
rd.data.bits.foreach { b => b.foreach { c => c := 0.U } }
}
override def cloneType =
new TensorClient(tensorType).asInstanceOf[this.type]
}
/** TensorMasterData.
*
* This interface is only used for datapath only purposes and the direction convention
* is based on the TensorMaster interface, which means this is an input. This interface
* is used on datapath only module such MatrixVectorCore or AluVector.
*/
class TensorMasterData(tensorType: String = "none")
(implicit p: Parameters) extends TensorParams(tensorType) {
val data = Flipped(ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))))
override def cloneType =
new TensorMasterData(tensorType).asInstanceOf[this.type]
}
/** TensorClientData.
*
* This interface is only used for datapath only purposes and the direction convention
* is based on the TensorClient interface, which means this is an output. This interface
* is used on datapath only module such MatrixVectorCore or AluVector.
*/
class TensorClientData(tensorType: String = "none")
(implicit p: Parameters) extends TensorParams(tensorType) {
val data = ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W))))
override def cloneType =
new TensorClientData(tensorType).asInstanceOf[this.type]
}
/** TensorPadCtrl. Zero-padding controller for TensorLoad. */
class TensorPadCtrl(padType: String = "none", sizeFactor: Int = 1) extends Module {
val errorMsg = s"\n\n\n[VTA-ERROR] only YPad0, YPad1, XPad0, or XPad1 supported\n\n\n"
require (padType == "YPad0" || padType == "YPad1"
|| padType == "XPad0" || padType == "XPad1", errorMsg)
val io = IO(new Bundle {
val start = Input(Bool())
val done = Output(Bool())
val inst = Input(UInt(INST_BITS.W))
})
val dec = io.inst.asTypeOf(new MemDecode)
val xmax = Reg(chiselTypeOf(dec.xsize))
val ymax = Reg(chiselTypeOf(dec.ypad_0))
val xcnt = Reg(chiselTypeOf(dec.xsize))
val ycnt = Reg(chiselTypeOf(dec.ypad_0))
val xval =
if (padType == "YPad0" || padType == "YPad1")
((dec.xpad_0 + dec.xsize + dec.xpad_1) << log2Ceil(sizeFactor)) - 1.U
else if (padType == "XPad0")
(dec.xpad_0 << log2Ceil(sizeFactor)) - 1.U
else
(dec.xpad_1 << log2Ceil(sizeFactor)) - 1.U
val yval =
if (padType == "YPad0")
Mux(dec.ypad_0 =/= 0.U, dec.ypad_0 - 1.U, 0.U)
else if (padType == "YPad1")
Mux(dec.ypad_1 =/= 0.U, dec.ypad_1 - 1.U, 0.U)
else
0.U
val sIdle :: sActive :: Nil = Enum(2)
val state = RegInit(sIdle)
switch (state) {
is (sIdle) {
when (io.start) {
state := sActive
}
}
is (sActive) {
when (ycnt === ymax && xcnt === xmax) {
state := sIdle
}
}
}
when (state === sIdle) {
xmax := xval
ymax := yval
}
when (state === sIdle || xcnt === xmax) {
xcnt := 0.U
} .elsewhen (state === sActive) {
xcnt := xcnt + 1.U
}
when (state === sIdle || ymax === 0.U) {
ycnt := 0.U
} .elsewhen (state === sActive && xcnt === xmax) {
ycnt := ycnt + 1.U
}
io.done := state === sActive & ycnt === ymax & xcnt === xmax
}
/** TensorDataCtrl. Data controller for TensorLoad. */
class TensorDataCtrl(sizeFactor: Int = 1, strideFactor: Int = 1)(implicit p: Parameters) extends Module {
val mp = p(ShellKey).memParams
val io = IO(new Bundle {
val start = Input(Bool())
val done = Output(Bool())
val inst = Input(UInt(INST_BITS.W))
val baddr = Input(UInt(mp.addrBits.W))
val xinit = Input(Bool())
val xupdate = Input(Bool())
val yupdate = Input(Bool())
val stride = Output(Bool())
val split = Output(Bool())
val commit = Output(Bool())
val addr = Output(UInt(mp.addrBits.W))
val len = Output(UInt(mp.lenBits.W))
})
val dec = io.inst.asTypeOf(new MemDecode)
val caddr = Reg(UInt(mp.addrBits.W))
val baddr = Reg(UInt(mp.addrBits.W))
val len = Reg(UInt(mp.lenBits.W))
val xmax_bytes = ((1 << mp.lenBits)*mp.dataBits/8).U
val xcnt = Reg(UInt(mp.lenBits.W))
val xrem = Reg(chiselTypeOf(dec.xsize))
val xsize = (dec.xsize << log2Ceil(sizeFactor)) - 1.U
val xmax = (1 << mp.lenBits).U
val ycnt = Reg(chiselTypeOf(dec.ysize))
val stride = xcnt === len &
xrem === 0.U &
ycnt =/= dec.ysize - 1.U
val split = xcnt === len & xrem =/= 0.U
when (io.start || (io.xupdate && stride)) {
when (xsize < xmax) {
len := xsize
xrem := 0.U
} .otherwise {
len := xmax - 1.U
xrem := xsize - xmax
}
} .elsewhen (io.xupdate && split) {
when (xrem < xmax) {
len := xrem
xrem := 0.U
} .otherwise {
len := xmax - 1.U
xrem := xrem - xmax
}
}
when (io.xinit) {
xcnt := 0.U
} .elsewhen (io.xupdate) {
xcnt := xcnt + 1.U
}
when (io.start) {
ycnt := 0.U
} .elsewhen (io.yupdate && stride) {
ycnt := ycnt + 1.U
}
when (io.start) {
caddr := io.baddr + dec.dram_offset
baddr := io.baddr + dec.dram_offset
} .elsewhen (io.yupdate) {
when (split) {
caddr := caddr + xmax_bytes
} .elsewhen (stride) {
caddr := baddr + (dec.xstride << log2Ceil(strideFactor))
baddr := baddr + (dec.xstride << log2Ceil(strideFactor))
}
}
io.stride := stride
io.split := split
io.commit := xcnt === len
io.addr := caddr
io.len := len
io.done := xcnt === len &
xrem === 0.U &
ycnt === dec.ysize - 1.U
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package vta
/** This trick makes ISAConstants globally available */
package object core extends vta.core.ISAConstants
......@@ -21,6 +21,9 @@ package vta.dpi
import chisel3._
import chisel3.util._
import vta.util.config._
import vta.interface.axi._
import vta.shell._
/** Host DPI parameters */
trait VTAHostDPIParams {
......@@ -70,3 +73,83 @@ class VTAHostDPI extends BlackBox with HasBlackBoxResource {
})
setResource("/verilog/VTAHostDPI.v")
}
/** Host DPI to AXI Converter.
*
* Convert Host DPI to AXI for VTAShell
*/
class VTAHostDPIToAXI(debug: Boolean = false)(implicit p: Parameters) extends Module {
val io = IO(new Bundle {
val dpi = new VTAHostDPIClient
val axi = new AXILiteMaster(p(ShellKey).hostParams)
})
val addr = RegInit(0.U.asTypeOf(chiselTypeOf(io.dpi.req.addr)))
val data = RegInit(0.U.asTypeOf(chiselTypeOf(io.dpi.req.value)))
val sIdle :: sReadAddress :: sReadData :: sWriteAddress :: sWriteData :: sWriteResponse :: Nil = Enum(6)
val state = RegInit(sIdle)
switch (state) {
is (sIdle) {
when (io.dpi.req.valid) {
when (io.dpi.req.opcode) {
state := sWriteAddress
} .otherwise {
state := sReadAddress
}
}
}
is (sReadAddress) {
when (io.axi.ar.ready) {
state := sReadData
}
}
is (sReadData) {
when (io.axi.r.valid) {
state := sIdle
}
}
is (sWriteAddress) {
when (io.axi.aw.ready) {
state := sWriteData
}
}
is (sWriteData) {
when (io.axi.w.ready) {
state := sWriteResponse
}
}
is (sWriteResponse) {
when (io.axi.b.valid) {
state := sIdle
}
}
}
when (state === sIdle && io.dpi.req.valid) {
addr := io.dpi.req.addr
data := io.dpi.req.value
}
io.axi.aw.valid := state === sWriteAddress
io.axi.aw.bits.addr := addr
io.axi.w.valid := state === sWriteData
io.axi.w.bits.data := data
io.axi.w.bits.strb := "h_f".U
io.axi.b.ready := state === sWriteResponse
io.axi.ar.valid := state === sReadAddress
io.axi.ar.bits.addr := addr
io.axi.r.ready := state === sReadData
io.dpi.req.deq := (state === sReadAddress & io.axi.ar.ready) | (state === sWriteAddress & io.axi.aw.ready)
io.dpi.resp.valid := io.axi.r.valid
io.dpi.resp.bits := io.axi.r.bits.data
if (debug) {
when (state === sWriteAddress && io.axi.aw.ready) { printf("[VTAHostDPIToAXI] [AW] addr:%x\n", addr) }
when (state === sReadAddress && io.axi.ar.ready) { printf("[VTAHostDPIToAXI] [AR] addr:%x\n", addr) }
when (io.axi.r.fire()) { printf("[VTAHostDPIToAXI] [R] value:%x\n", io.axi.r.bits.data) }
when (io.axi.w.fire()) { printf("[VTAHostDPIToAXI] [W] value:%x\n", io.axi.w.bits.data) }
}
}
......@@ -21,6 +21,9 @@ package vta.dpi
import chisel3._
import chisel3.util._
import vta.util.config._
import vta.interface.axi._
import vta.shell._
/** Memory DPI parameters */
trait VTAMemDPIParams {
......@@ -71,3 +74,98 @@ class VTAMemDPI extends BlackBox with HasBlackBoxResource {
})
setResource("/verilog/VTAMemDPI.v")
}
class VTAMemDPIToAXI(debug: Boolean = false)(implicit p: Parameters) extends Module {
val io = IO(new Bundle {
val dpi = new VTAMemDPIMaster
val axi = new AXIClient(p(ShellKey).memParams)
})
val opcode = RegInit(false.B)
val len = RegInit(0.U.asTypeOf(chiselTypeOf(io.dpi.req.len)))
val addr = RegInit(0.U.asTypeOf(chiselTypeOf(io.dpi.req.addr)))
val sIdle :: sReadAddress :: sReadData :: sWriteAddress :: sWriteData :: sWriteResponse :: Nil = Enum(6)
val state = RegInit(sIdle)
switch (state) {
is (sIdle) {
when (io.axi.ar.valid) {
state := sReadAddress
} .elsewhen (io.axi.aw.valid) {
state := sWriteAddress
}
}
is (sReadAddress) {
when (io.axi.ar.valid) {
state := sReadData
}
}
is (sReadData) {
when (io.axi.r.ready && io.dpi.rd.valid && len === 0.U) {
state := sIdle
}
}
is (sWriteAddress) {
when (io.axi.aw.valid) {
state := sWriteData
}
}
is (sWriteData) {
when (io.axi.w.valid && io.axi.w.bits.last) {
state := sWriteResponse
}
}
is (sWriteResponse) {
when (io.axi.b.ready) {
state := sIdle
}
}
}
when (state === sIdle) {
when (io.axi.ar.valid) {
opcode := false.B
len := io.axi.ar.bits.len
addr := io.axi.ar.bits.addr
} .elsewhen (io.axi.aw.valid) {
opcode := true.B
len := io.axi.aw.bits.len
addr := io.axi.aw.bits.addr
}
} .elsewhen (state === sReadData) {
when (io.axi.r.ready && io.dpi.rd.valid && len =/= 0.U) {
len := len - 1.U
}
}
io.dpi.req.valid := (state === sReadAddress & io.axi.ar.valid) | (state === sWriteAddress & io.axi.aw.valid)
io.dpi.req.opcode := opcode
io.dpi.req.len := len
io.dpi.req.addr := addr
io.axi.ar.ready := state === sReadAddress
io.axi.aw.ready := state === sWriteAddress
io.axi.r.valid := state === sReadData & io.dpi.rd.valid
io.axi.r.bits.data := io.dpi.rd.bits
io.axi.r.bits.last := len === 0.U
io.axi.r.bits.resp := 0.U
io.axi.r.bits.user := 0.U
io.axi.r.bits.id := 0.U
io.dpi.rd.ready := state === sReadData & io.axi.r.ready
io.dpi.wr.valid := state === sWriteData & io.axi.w.valid
io.dpi.wr.bits := io.axi.w.bits.data
io.axi.w.ready := state === sWriteData
io.axi.b.valid := state === sWriteResponse
io.axi.b.bits.resp := 0.U
io.axi.b.bits.user := 0.U
io.axi.b.bits.id := 0.U
if (debug) {
when (state === sReadAddress && io.axi.ar.valid) { printf("[VTAMemDPIToAXI] [AR] addr:%x len:%x\n", addr, len) }
when (state === sWriteAddress && io.axi.aw.valid) { printf("[VTAMemDPIToAXI] [AW] addr:%x len:%x\n", addr, len) }
when (io.axi.r.fire()) { printf("[VTAMemDPIToAXI] [R] last:%x data:%x\n", io.axi.r.bits.last, io.axi.r.bits.data) }
when (io.axi.w.fire()) { printf("[VTAMemDPIToAXI] [W] last:%x data:%x\n", io.axi.w.bits.last, io.axi.w.bits.data) }
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package vta.interface.axi
import chisel3._
import chisel3.util._
import vta.util.genericbundle._
case class AXIParams(
addrBits: Int = 32,
dataBits: Int = 64
)
{
require (addrBits > 0)
require (dataBits >= 8 && dataBits % 2 == 0)
val idBits = 1
val userBits = 1
val strbBits = dataBits/8
val lenBits = 8
val sizeBits = 3
val burstBits = 2
val lockBits = 2
val cacheBits = 4
val protBits = 3
val qosBits = 4
val regionBits = 4
val respBits = 2
val sizeConst = log2Ceil(dataBits/8)
val idConst = 0
val userConst = 0
val burstConst = 1
val lockConst = 0
val cacheConst = 3
val protConst = 0
val qosConst = 0
val regionConst = 0
}
abstract class AXIBase(params: AXIParams)
extends GenericParameterizedBundle(params)
// AXILite
class AXILiteAddress(params: AXIParams) extends AXIBase(params) {
val addr = UInt(params.addrBits.W)
}
class AXILiteWriteData(params: AXIParams) extends AXIBase(params) {
val data = UInt(params.dataBits.W)
val strb = UInt(params.strbBits.W)
}
class AXILiteWriteResponse(params: AXIParams) extends AXIBase(params) {
val resp = UInt(params.respBits.W)
}
class AXILiteReadData(params: AXIParams) extends AXIBase(params) {
val data = UInt(params.dataBits.W)
val resp = UInt(params.respBits.W)
}
class AXILiteMaster(params: AXIParams) extends AXIBase(params) {
val aw = Decoupled(new AXILiteAddress(params))
val w = Decoupled(new AXILiteWriteData(params))
val b = Flipped(Decoupled(new AXILiteWriteResponse(params)))
val ar = Decoupled(new AXILiteAddress(params))
val r = Flipped(Decoupled(new AXILiteReadData(params)))
def tieoff() {
aw.valid := false.B
aw.bits.addr := 0.U
w.valid := false.B
w.bits.data := 0.U
w.bits.strb := 0.U
b.ready := false.B
ar.valid := false.B
ar.bits.addr := 0.U
r.ready := false.B
}
}
class AXILiteClient(params: AXIParams) extends AXIBase(params) {
val aw = Flipped(Decoupled(new AXILiteAddress(params)))
val w = Flipped(Decoupled(new AXILiteWriteData(params)))
val b = Decoupled(new AXILiteWriteResponse(params))
val ar = Flipped(Decoupled(new AXILiteAddress(params)))
val r = Decoupled(new AXILiteReadData(params))
def tieoff() {
aw.ready := false.B
w.ready := false.B
b.valid := false.B
b.bits.resp := 0.U
ar.ready := false.B
r.valid := false.B
r.bits.resp := 0.U
r.bits.data := 0.U
}
}
// AXI extends AXILite
class AXIAddress(params: AXIParams) extends AXILiteAddress(params) {
val id = UInt(params.idBits.W)
val user = UInt(params.userBits.W)
val len = UInt(params.lenBits.W)
val size = UInt(params.sizeBits.W)
val burst = UInt(params.burstBits.W)
val lock = UInt(params.lockBits.W)
val cache = UInt(params.cacheBits.W)
val prot = UInt(params.protBits.W)
val qos = UInt(params.qosBits.W)
val region = UInt(params.regionBits.W)
}
class AXIWriteData(params: AXIParams) extends AXILiteWriteData(params) {
val last = Bool()
val id = UInt(params.idBits.W)
val user = UInt(params.userBits.W)
}
class AXIWriteResponse(params: AXIParams) extends AXILiteWriteResponse(params) {
val id = UInt(params.idBits.W)
val user = UInt(params.userBits.W)
}
class AXIReadData(params: AXIParams) extends AXILiteReadData(params) {
val last = Bool()
val id = UInt(params.idBits.W)
val user = UInt(params.userBits.W)
}
class AXIMaster(params: AXIParams) extends AXIBase(params) {
val aw = Decoupled(new AXIAddress(params))
val w = Decoupled(new AXIWriteData(params))
val b = Flipped(Decoupled(new AXIWriteResponse(params)))
val ar = Decoupled(new AXIAddress(params))
val r = Flipped(Decoupled(new AXIReadData(params)))
def tieoff() {
aw.valid := false.B
aw.bits.addr := 0.U
aw.bits.id := 0.U
aw.bits.user := 0.U
aw.bits.len := 0.U
aw.bits.size := 0.U
aw.bits.burst := 0.U
aw.bits.lock := 0.U
aw.bits.cache := 0.U
aw.bits.prot := 0.U
aw.bits.qos := 0.U
aw.bits.region := 0.U
w.valid := false.B
w.bits.data := 0.U
w.bits.strb := 0.U
w.bits.last := false.B
w.bits.id := 0.U
w.bits.user := 0.U
b.ready := false.B
ar.valid := false.B
ar.bits.addr := 0.U
ar.bits.id := 0.U
ar.bits.user := 0.U
ar.bits.len := 0.U
ar.bits.size := 0.U
ar.bits.burst := 0.U
ar.bits.lock := 0.U
ar.bits.cache := 0.U
ar.bits.prot := 0.U
ar.bits.qos := 0.U
ar.bits.region := 0.U
r.ready := false.B
}
def setConst() {
aw.bits.user := params.userConst.U
aw.bits.burst := params.burstConst.U
aw.bits.lock := params.lockConst.U
aw.bits.cache := params.cacheConst.U
aw.bits.prot := params.protConst.U
aw.bits.qos := params.qosConst.U
aw.bits.region := params.regionConst.U
aw.bits.size := params.sizeConst.U
aw.bits.id := params.idConst.U
w.bits.id := params.idConst.U
w.bits.user := params.userConst.U
w.bits.strb := Fill(params.strbBits, true.B)
ar.bits.user := params.userConst.U
ar.bits.burst := params.burstConst.U
ar.bits.lock := params.lockConst.U
ar.bits.cache := params.cacheConst.U
ar.bits.prot := params.protConst.U
ar.bits.qos := params.qosConst.U
ar.bits.region := params.regionConst.U
ar.bits.size := params.sizeConst.U
ar.bits.id := params.idConst.U
}
}
class AXIClient(params: AXIParams) extends AXIBase(params) {
val aw = Flipped(Decoupled(new AXIAddress(params)))
val w = Flipped(Decoupled(new AXIWriteData(params)))
val b = Decoupled(new AXIWriteResponse(params))
val ar = Flipped(Decoupled(new AXIAddress(params)))
val r = Decoupled(new AXIReadData(params))
def tieoff() {
aw.ready := false.B
w.ready := false.B
b.valid := false.B
b.bits.resp := 0.U
b.bits.user := 0.U
b.bits.id := 0.U
ar.ready := false.B
r.valid := false.B
r.bits.resp := 0.U
r.bits.data := 0.U
r.bits.user := 0.U
r.bits.last := false.B
r.bits.id := 0.U
}
}
// XilinxAXILiteClient and XilinxAXIMaster bundles are needed
// for wrapper purposes, because the package RTL tool in Xilinx Vivado
// only allows certain name formats
class XilinxAXILiteClient(params: AXIParams) extends AXIBase(params) {
val AWVALID = Input(Bool())
val AWREADY = Output(Bool())
val AWADDR = Input(UInt(params.addrBits.W))
val WVALID = Input(Bool())
val WREADY = Output(Bool())
val WDATA = Input(UInt(params.dataBits.W))
val WSTRB = Input(UInt(params.strbBits.W))
val BVALID = Output(Bool())
val BREADY = Input(Bool())
val BRESP = Output(UInt(params.respBits.W))
val ARVALID = Input(Bool())
val ARREADY = Output(Bool())
val ARADDR = Input(UInt(params.addrBits.W))
val RVALID = Output(Bool())
val RREADY = Input(Bool())
val RDATA = Output(UInt(params.dataBits.W))
val RRESP = Output(UInt(params.respBits.W))
}
class XilinxAXIMaster(params: AXIParams) extends AXIBase(params) {
val AWVALID = Output(Bool())
val AWREADY = Input(Bool())
val AWADDR = Output(UInt(params.addrBits.W))
val AWID = Output(UInt(params.idBits.W))
val AWUSER = Output(UInt(params.userBits.W))
val AWLEN = Output(UInt(params.lenBits.W))
val AWSIZE = Output(UInt(params.sizeBits.W))
val AWBURST = Output(UInt(params.burstBits.W))
val AWLOCK = Output(UInt(params.lockBits.W))
val AWCACHE = Output(UInt(params.cacheBits.W))
val AWPROT = Output(UInt(params.protBits.W))
val AWQOS = Output(UInt(params.qosBits.W))
val AWREGION = Output(UInt(params.regionBits.W))
val WVALID = Output(Bool())
val WREADY = Input(Bool())
val WDATA = Output(UInt(params.dataBits.W))
val WSTRB = Output(UInt(params.strbBits.W))
val WLAST = Output(Bool())
val WID = Output(UInt(params.idBits.W))
val WUSER = Output(UInt(params.userBits.W))
val BVALID = Input(Bool())
val BREADY = Output(Bool())
val BRESP = Input(UInt(params.respBits.W))
val BID = Input(UInt(params.idBits.W))
val BUSER = Input(UInt(params.userBits.W))
val ARVALID = Output(Bool())
val ARREADY = Input(Bool())
val ARADDR = Output(UInt(params.addrBits.W))
val ARID = Output(UInt(params.idBits.W))
val ARUSER = Output(UInt(params.userBits.W))
val ARLEN = Output(UInt(params.lenBits.W))
val ARSIZE = Output(UInt(params.sizeBits.W))
val ARBURST = Output(UInt(params.burstBits.W))
val ARLOCK = Output(UInt(params.lockBits.W))
val ARCACHE = Output(UInt(params.cacheBits.W))
val ARPROT = Output(UInt(params.protBits.W))
val ARQOS = Output(UInt(params.qosBits.W))
val ARREGION = Output(UInt(params.regionBits.W))
val RVALID = Input(Bool())
val RREADY = Output(Bool())
val RDATA = Input(UInt(params.dataBits.W))
val RRESP = Input(UInt(params.respBits.W))
val RLAST = Input(Bool())
val RID = Input(UInt(params.idBits.W))
val RUSER = Input(UInt(params.userBits.W))
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package vta.shell
import chisel3._
import chisel3.util._
import vta.util.config._
import vta.interface.axi._
/** PynqConfig. Shell configuration for Pynq */
class PynqConfig extends Config((site, here, up) => {
case ShellKey => ShellParams(
hostParams = AXIParams(
addrBits = 16,
dataBits = 32),
memParams = AXIParams(
addrBits = 32,
dataBits = 64),
vcrParams = VCRParams(),
vmeParams = VMEParams())
})
/** F1Config. Shell configuration for F1 */
class F1Config extends Config((site, here, up) => {
case ShellKey => ShellParams(
hostParams = AXIParams(
addrBits = 16,
dataBits = 32),
memParams = AXIParams(
addrBits = 64,
dataBits = 64),
vcrParams = VCRParams(),
vmeParams = VMEParams())
})
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package vta.shell
import chisel3._
import vta.util.config._
import vta.interface.axi._
import vta.shell._
import vta.dpi._
/** VTAHost.
*
* This module translate the DPI protocol into AXI. This is a simulation only
* module and used to test host-to-VTA communication. This module should be updated
* for testing hosts using a different bus protocol, other than AXI.
*/
class VTAHost(implicit p: Parameters) extends Module {
val io = IO(new Bundle {
val axi = new AXILiteMaster(p(ShellKey).hostParams)
})
val host_dpi = Module(new VTAHostDPI)
val host_axi = Module(new VTAHostDPIToAXI)
host_dpi.io.reset := reset
host_dpi.io.clock := clock
host_axi.io.dpi <> host_dpi.io.dpi
io.axi <> host_axi.io.axi
}
/** VTAMem.
*
* This module translate the DPI protocol into AXI. This is a simulation only
* module and used to test VTA-to-memory communication. This module should be updated
* for testing memories using a different bus protocol, other than AXI.
*/
class VTAMem(implicit p: Parameters) extends Module {
val io = IO(new Bundle {
val axi = new AXIClient(p(ShellKey).memParams)
})
val mem_dpi = Module(new VTAMemDPI)
val mem_axi = Module(new VTAMemDPIToAXI)
mem_dpi.io.reset := reset
mem_dpi.io.clock := clock
mem_dpi.io.dpi <> mem_axi.io.dpi
mem_axi.io.axi <> io.axi
}
/** SimShell.
*
* The simulation shell instantiate a host and memory simulation modules and it is
* intended to be connected to the VTAShell.
*/
class SimShell(implicit p: Parameters) extends Module {
val io = IO(new Bundle {
val mem = new AXIClient(p(ShellKey).memParams)
val host = new AXILiteMaster(p(ShellKey).hostParams)
})
val host = Module(new VTAHost)
val mem = Module(new VTAMem)
io.mem <> mem.io.axi
io.host <> host.io.axi
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package vta.shell
import chisel3._
import chisel3.util._
import vta.util.config._
import vta.util.genericbundle._
import scala.collection.mutable.ListBuffer
import scala.collection.mutable.LinkedHashMap
import vta.interface.axi._
/** VCR parameters.
*
* These parameters are used on VCR interfaces and modules.
*/
case class VCRParams()
{
val nValsReg: Int = 1
val nPtrsReg: Int = 6
val regBits: Int = 32
val nCtrlReg: Int = 4
val ctrlBaseAddr: Int = 0
require (nValsReg > 0)
require (nPtrsReg > 0)
}
/** VCRBase. Parametrize base class. */
abstract class VCRBase(implicit p: Parameters)
extends GenericParameterizedBundle(p)
/** VCRMaster.
*
* This is the master interface used by VCR in the VTAShell to control
* the Core unit.
*/
class VCRMaster(implicit p: Parameters) extends VCRBase {
val vp = p(ShellKey).vcrParams
val mp = p(ShellKey).memParams
val launch = Output(Bool())
val finish = Input(Bool())
val irq = Output(Bool())
val ptrs = Output(Vec(vp.nPtrsReg, UInt(mp.addrBits.W)))
val vals = Output(Vec(vp.nValsReg, UInt(vp.regBits.W)))
}
/** VCRClient.
*
* This is the client interface used by the Core module to communicate
* to the VCR in the VTAShell.
*/
class VCRClient(implicit p: Parameters) extends VCRBase {
val vp = p(ShellKey).vcrParams
val mp = p(ShellKey).memParams
val launch = Input(Bool())
val finish = Output(Bool())
val irq = Input(Bool())
val ptrs = Input(Vec(vp.nPtrsReg, UInt(mp.addrBits.W)))
val vals = Input(Vec(vp.nValsReg, UInt(vp.regBits.W)))
}
/** VTA Control Registers (VCR).
*
* This unit provides control registers (32 and 64 bits) to be used by a control'
* unit, typically a host processor. These registers are read-only by the core
* at the moment but this will likely change once we add support to general purpose
* registers that could be used as event counters by the Core unit.
*/
class VCR(implicit p: Parameters) extends Module {
val io = IO(new Bundle{
val host = new AXILiteClient(p(ShellKey).hostParams)
val vcr = new VCRMaster
})
val vp = p(ShellKey).vcrParams
val mp = p(ShellKey).memParams
val hp = p(ShellKey).hostParams
// Write control (AW, W, B)
val waddr = RegInit("h_ffff".U(hp.addrBits.W)) // init with invalid address
val wdata = io.host.w.bits.data
val wstrb = io.host.w.bits.strb
val wmask = Cat(Fill(8, wstrb(3)), Fill(8, wstrb(2)), Fill(8, wstrb(1)), Fill(8, wstrb(0)))
val sWriteAddress :: sWriteData :: sWriteResponse :: Nil = Enum(3)
val wstate = RegInit(sWriteAddress)
switch (wstate) {
is (sWriteAddress) {
when (io.host.aw.valid) {
wstate := sWriteData
}
}
is (sWriteData) {
when (io.host.w.valid) {
wstate := sWriteResponse
}
}
is (sWriteResponse) {
when (io.host.b.ready) {
wstate := sWriteAddress
}
}
}
when (io.host.aw.fire()) { waddr := io.host.aw.bits.addr }
io.host.aw.ready := wstate === sWriteAddress
io.host.w.ready := wstate === sWriteData
io.host.b.valid := wstate === sWriteResponse
io.host.b.bits.resp := "h_0".U
// read control (AR, R)
val sReadAddress :: sReadData :: Nil = Enum(2)
val rstate = RegInit(sReadAddress)
switch (rstate) {
is (sReadAddress) {
when (io.host.ar.valid) {
rstate := sReadData
}
}
is (sReadData) {
when (io.host.r.ready) {
rstate := sReadAddress
}
}
}
io.host.ar.ready := rstate === sReadAddress
io.host.r.valid := rstate === sReadData
val nPtrsReg = vp.nPtrsReg
val nValsReg = vp.nValsReg
val regBits = vp.regBits
val ptrsBits = mp.addrBits
val nCtrlReg = vp.nCtrlReg
val rStride = regBits/8
val pStride = ptrsBits/8
val ctrlBaseAddr = vp.ctrlBaseAddr
val valsBaseAddr = ctrlBaseAddr + nCtrlReg*rStride
val ptrsBaseAddr = valsBaseAddr + nValsReg*rStride
val ctrlAddr = Seq.tabulate(nCtrlReg)(i => i*rStride + ctrlBaseAddr)
val valsAddr = Seq.tabulate(nValsReg)(i => i*rStride + valsBaseAddr)
val ptrsAddr = new ListBuffer[Int]()
for (i <- 0 until nPtrsReg) {
ptrsAddr += i*pStride + ptrsBaseAddr
if (ptrsBits == 64) {
ptrsAddr += i*pStride + rStride + ptrsBaseAddr
}
}
// AP register
val c0 = RegInit(VecInit(Seq.fill(regBits)(false.B)))
// ap start
when (io.host.w.fire() && waddr === ctrlAddr(0).asUInt && wstrb(0) && wdata(0)) {
c0(0) := true.B
} .elsewhen (io.vcr.finish) {
c0(0) := false.B
}
// ap done = finish
when (io.vcr.finish) {
c0(1) := true.B
} .elsewhen (io.host.ar.fire() && io.host.ar.bits.addr === ctrlAddr(0).asUInt) {
c0(1) := false.B
}
val c1 = 0.U
val c2 = 0.U
val c3 = 0.U
val ctrlRegList = List(c0, c1, c2, c3)
io.vcr.launch := c0(0)
// interrupts not supported atm
io.vcr.irq := false.B
// Write pointer and value registers
val pvAddr = valsAddr ++ ptrsAddr
val pvNumReg = if (ptrsBits == 64) nValsReg + nPtrsReg*2 else nValsReg + nPtrsReg
val pvReg = RegInit(VecInit(Seq.fill(pvNumReg)(0.U(regBits.W))))
val pvRegList = new ListBuffer[UInt]()
for (i <- 0 until pvNumReg) {
when (io.host.w.fire() && (waddr === pvAddr(i).U)) {
pvReg(i) := (wdata & wmask) | (pvReg(i) & ~wmask)
}
pvRegList += pvReg(i)
}
for (i <- 0 until nValsReg) {
io.vcr.vals(i) := pvReg(i)
}
for (i <- 0 until nPtrsReg) {
if (ptrsBits == 64) {
io.vcr.ptrs(i) := Cat(pvReg(nValsReg + i*2 + 1), pvReg(nValsReg + i*2))
} else {
io.vcr.ptrs(i) := pvReg(nValsReg + i)
}
}
// Read pointer and value registers
val mapAddr = ctrlAddr ++ valsAddr ++ ptrsAddr
val mapRegList = ctrlRegList ++ pvRegList
val rdata = RegInit(0.U(regBits.W))
val rmap = LinkedHashMap[Int,UInt]()
val totalReg = mapRegList.length
for (i <- 0 until totalReg) { rmap += mapAddr(i) -> mapRegList(i).asUInt }
val decodeAddr = rmap map { case (k, _) => k -> (io.host.ar.bits.addr === k.asUInt) }
when (io.host.ar.fire()) {
rdata := Mux1H(for ((k, v) <- rmap) yield decodeAddr(k) -> v)
}
io.host.r.bits.resp := 0.U
io.host.r.bits.data := rdata
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package vta.shell
import chisel3._
import chisel3.util._
import vta.util.config._
import vta.util.genericbundle._
import vta.interface.axi._
/** VME parameters.
*
* These parameters are used on VME interfaces and modules.
*/
case class VMEParams() {
val nReadClients: Int = 5
val nWriteClients: Int = 1
require (nReadClients > 0, s"\n\n[VTA] [VMEParams] nReadClients must be larger than 0\n\n")
require (nWriteClients == 1, s"\n\n[VTA] [VMEParams] nWriteClients must be 1, only one-write-client support atm\n\n")
}
/** VMEBase. Parametrize base class. */
abstract class VMEBase(implicit p: Parameters)
extends GenericParameterizedBundle(p)
/** VMECmd.
*
* This interface is used for creating write and read requests to memory.
*/
class VMECmd(implicit p: Parameters) extends VMEBase {
val addrBits = p(ShellKey).memParams.addrBits
val lenBits = p(ShellKey).memParams.lenBits
val addr = UInt(addrBits.W)
val len = UInt(lenBits.W)
}
/** VMEReadMaster.
*
* This interface is used by modules inside the core to generate read requests
* and receive responses from VME.
*/
class VMEReadMaster(implicit p: Parameters) extends Bundle {
val dataBits = p(ShellKey).memParams.dataBits
val cmd = Decoupled(new VMECmd)
val data = Flipped(Decoupled(UInt(dataBits.W)))
override def cloneType =
new VMEReadMaster().asInstanceOf[this.type]
}
/** VMEReadClient.
*
* This interface is used by the VME to receive read requests and generate
* responses to modules inside the core.
*/
class VMEReadClient(implicit p: Parameters) extends Bundle {
val dataBits = p(ShellKey).memParams.dataBits
val cmd = Flipped(Decoupled(new VMECmd))
val data = Decoupled(UInt(dataBits.W))
override def cloneType =
new VMEReadClient().asInstanceOf[this.type]
}
/** VMEWriteMaster.
*
* This interface is used by modules inside the core to generate write requests
* to the VME.
*/
class VMEWriteMaster(implicit p: Parameters) extends Bundle {
val dataBits = p(ShellKey).memParams.dataBits
val cmd = Decoupled(new VMECmd)
val data = Decoupled(UInt(dataBits.W))
val ack = Input(Bool())
override def cloneType =
new VMEWriteMaster().asInstanceOf[this.type]
}
/** VMEWriteClient.
*
* This interface is used by the VME to handle write requests from modules inside
* the core.
*/
class VMEWriteClient(implicit p: Parameters) extends Bundle {
val dataBits = p(ShellKey).memParams.dataBits
val cmd = Flipped(Decoupled(new VMECmd))
val data = Flipped(Decoupled(UInt(dataBits.W)))
val ack = Output(Bool())
override def cloneType =
new VMEWriteClient().asInstanceOf[this.type]
}
/** VMEMaster.
*
* Pack nRd number of VMEReadMaster interfaces and nWr number of VMEWriteMaster
* interfaces.
*/
class VMEMaster(implicit p: Parameters) extends Bundle {
val nRd = p(ShellKey).vmeParams.nReadClients
val nWr = p(ShellKey).vmeParams.nWriteClients
val rd = Vec(nRd, new VMEReadMaster)
val wr = Vec(nWr, new VMEWriteMaster)
}
/** VMEClient.
*
* Pack nRd number of VMEReadClient interfaces and nWr number of VMEWriteClient
* interfaces.
*/
class VMEClient(implicit p: Parameters) extends Bundle {
val nRd = p(ShellKey).vmeParams.nReadClients
val nWr = p(ShellKey).vmeParams.nWriteClients
val rd = Vec(nRd, new VMEReadClient)
val wr = Vec(nWr, new VMEWriteClient)
}
/** VTA Memory Engine (VME).
*
* This unit multiplexes the memory controller interface for the Core. Currently,
* it supports single-writer and multiple-reader mode and it is also based on AXI.
*/
class VME(implicit p: Parameters) extends Module {
val io = IO(new Bundle {
val mem = new AXIMaster(p(ShellKey).memParams)
val vme = new VMEClient
})
val nReadClients = p(ShellKey).vmeParams.nReadClients
val rd_arb = Module(new Arbiter(new VMECmd, nReadClients))
val rd_arb_chosen = RegEnable(rd_arb.io.chosen, rd_arb.io.out.fire())
for (i <- 0 until nReadClients) { rd_arb.io.in(i) <> io.vme.rd(i).cmd }
val sReadIdle :: sReadAddr :: sReadData :: Nil = Enum(3)
val rstate = RegInit(sReadIdle)
switch (rstate) {
is (sReadIdle) {
when (rd_arb.io.out.valid) {
rstate := sReadAddr
}
}
is (sReadAddr) {
when (io.mem.ar.ready) {
rstate := sReadData
}
}
is (sReadData) {
when (io.mem.r.fire() && io.mem.r.bits.last) {
rstate := sReadIdle
}
}
}
val sWriteIdle :: sWriteAddr :: sWriteData :: sWriteResp :: Nil = Enum(4)
val wstate = RegInit(sWriteIdle)
val addrBits = p(ShellKey).memParams.addrBits
val lenBits = p(ShellKey).memParams.lenBits
val wr_cnt = RegInit(0.U(lenBits.W))
when (wstate === sWriteIdle) {
wr_cnt := 0.U
} .elsewhen (io.mem.w.fire()) {
wr_cnt := wr_cnt + 1.U
}
switch (wstate) {
is (sWriteIdle) {
when (io.vme.wr(0).cmd.valid) {
wstate := sWriteAddr
}
}
is (sWriteAddr) {
when (io.mem.aw.ready) {
wstate := sWriteData
}
}
is (sWriteData) {
when (io.mem.w.ready && wr_cnt === io.vme.wr(0).cmd.bits.len) {
wstate := sWriteResp
}
}
is (sWriteResp) {
when (io.mem.b.valid) {
wstate := sWriteIdle
}
}
}
// registers storing read/write cmds
val rd_len = RegInit(0.U(lenBits.W))
val wr_len = RegInit(0.U(lenBits.W))
val rd_addr = RegInit(0.U(addrBits.W))
val wr_addr = RegInit(0.U(addrBits.W))
when (rd_arb.io.out.fire()) {
rd_len := rd_arb.io.out.bits.len
rd_addr := rd_arb.io.out.bits.addr
}
when (io.vme.wr(0).cmd.fire()) {
wr_len := io.vme.wr(0).cmd.bits.len
wr_addr := io.vme.wr(0).cmd.bits.addr
}
// rd arb
rd_arb.io.out.ready := rstate === sReadIdle
// vme
for (i <- 0 until nReadClients) {
io.vme.rd(i).data.valid := rd_arb_chosen === i.asUInt & io.mem.r.valid
io.vme.rd(i).data.bits := io.mem.r.bits.data
}
io.vme.wr(0).cmd.ready := wstate === sWriteIdle
io.vme.wr(0).ack := io.mem.b.fire()
io.vme.wr(0).data.ready := wstate === sWriteData & io.mem.w.ready
// mem
io.mem.aw.valid := wstate === sWriteAddr
io.mem.aw.bits.addr := wr_addr
io.mem.aw.bits.len := wr_len
io.mem.w.valid := wstate === sWriteData & io.vme.wr(0).data.valid
io.mem.w.bits.data := io.vme.wr(0).data.bits
io.mem.w.bits.last := wr_cnt === io.vme.wr(0).cmd.bits.len
io.mem.b.ready := wstate === sWriteResp
io.mem.ar.valid := rstate === sReadAddr
io.mem.ar.bits.addr := rd_addr
io.mem.ar.bits.len := rd_len
io.mem.r.ready := rstate === sReadData & io.vme.rd(rd_arb_chosen).data.ready
// AXI constants - statically defined
io.mem.setConst()
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package vta.shell
import chisel3._
import vta.util.config._
import vta.interface.axi._
import vta.core._
/** Shell parameters. */
case class ShellParams(
hostParams: AXIParams,
memParams: AXIParams,
vcrParams: VCRParams,
vmeParams: VMEParams
)
case object ShellKey extends Field[ShellParams]
/** VTAShell.
*
* The VTAShell is based on a VME, VCR and core. This creates a complete VTA
* system that can be used for simulation or real hardware.
*/
class VTAShell(implicit p: Parameters) extends Module {
val io = IO(new Bundle{
val host = new AXILiteClient(p(ShellKey).hostParams)
val mem = new AXIMaster(p(ShellKey).memParams)
})
val vcr = Module(new VCR)
val vme = Module(new VME)
val core = Module(new Core)
core.io.vcr <> vcr.io.vcr
vme.io.vme <> core.io.vme
vcr.io.host <> io.host
io.mem <> vme.io.mem
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package vta.shell
import chisel3._
import chisel3.experimental.{RawModule, withClockAndReset}
import vta.util.config._
import vta.interface.axi._
/** XilinxShell.
*
* This is a wrapper shell mostly used to match Xilinx convention naming,
* therefore we can pack VTA as an IP for IPI based flows.
*/
class XilinxShell(implicit p: Parameters) extends RawModule {
val hp = p(ShellKey).hostParams
val mp = p(ShellKey).memParams
val ap_clk = IO(Input(Clock()))
val ap_rst_n = IO(Input(Bool()))
val m_axi_gmem = IO(new XilinxAXIMaster(mp))
val s_axi_control = IO(new XilinxAXILiteClient(hp))
val shell = withClockAndReset (clock = ap_clk, reset = ~ap_rst_n) { Module(new VTAShell) }
// memory
m_axi_gmem.AWVALID := shell.io.mem.aw.valid
shell.io.mem.aw.ready := m_axi_gmem.AWREADY
m_axi_gmem.AWADDR := shell.io.mem.aw.bits.addr
m_axi_gmem.AWID := shell.io.mem.aw.bits.id
m_axi_gmem.AWUSER := shell.io.mem.aw.bits.user
m_axi_gmem.AWLEN := shell.io.mem.aw.bits.len
m_axi_gmem.AWSIZE := shell.io.mem.aw.bits.size
m_axi_gmem.AWBURST := shell.io.mem.aw.bits.burst
m_axi_gmem.AWLOCK := shell.io.mem.aw.bits.lock
m_axi_gmem.AWCACHE := shell.io.mem.aw.bits.cache
m_axi_gmem.AWPROT := shell.io.mem.aw.bits.prot
m_axi_gmem.AWQOS := shell.io.mem.aw.bits.qos
m_axi_gmem.AWREGION := shell.io.mem.aw.bits.region
m_axi_gmem.WVALID := shell.io.mem.w.valid
shell.io.mem.w.ready := m_axi_gmem.WREADY
m_axi_gmem.WDATA := shell.io.mem.w.bits.data
m_axi_gmem.WSTRB := shell.io.mem.w.bits.strb
m_axi_gmem.WLAST := shell.io.mem.w.bits.last
m_axi_gmem.WID := shell.io.mem.w.bits.id
m_axi_gmem.WUSER := shell.io.mem.w.bits.user
shell.io.mem.b.valid := m_axi_gmem.BVALID
m_axi_gmem.BREADY := shell.io.mem.b.valid
shell.io.mem.b.bits.resp := m_axi_gmem.BRESP
shell.io.mem.b.bits.id := m_axi_gmem.BID
shell.io.mem.b.bits.user := m_axi_gmem.BUSER
m_axi_gmem.ARVALID := shell.io.mem.ar.valid
shell.io.mem.ar.ready := m_axi_gmem.ARREADY
m_axi_gmem.ARADDR := shell.io.mem.ar.bits.addr
m_axi_gmem.ARID := shell.io.mem.ar.bits.id
m_axi_gmem.ARUSER := shell.io.mem.ar.bits.user
m_axi_gmem.ARLEN := shell.io.mem.ar.bits.len
m_axi_gmem.ARSIZE := shell.io.mem.ar.bits.size
m_axi_gmem.ARBURST := shell.io.mem.ar.bits.burst
m_axi_gmem.ARLOCK := shell.io.mem.ar.bits.lock
m_axi_gmem.ARCACHE := shell.io.mem.ar.bits.cache
m_axi_gmem.ARPROT := shell.io.mem.ar.bits.prot
m_axi_gmem.ARQOS := shell.io.mem.ar.bits.qos
m_axi_gmem.ARREGION := shell.io.mem.ar.bits.region
shell.io.mem.r.valid := m_axi_gmem.RVALID
m_axi_gmem.RREADY := shell.io.mem.r.ready
shell.io.mem.r.bits.data := m_axi_gmem.RDATA
shell.io.mem.r.bits.resp := m_axi_gmem.RRESP
shell.io.mem.r.bits.last := m_axi_gmem.RLAST
shell.io.mem.r.bits.id := m_axi_gmem.RID
shell.io.mem.r.bits.user := m_axi_gmem.RUSER
// host
shell.io.host.aw.valid := s_axi_control.AWVALID
s_axi_control.AWREADY := shell.io.host.aw.ready
shell.io.host.aw.bits.addr := s_axi_control.AWADDR
shell.io.host.w.valid := s_axi_control.WVALID
s_axi_control.WREADY := shell.io.host.w.ready
shell.io.host.w.bits.data := s_axi_control.WDATA
shell.io.host.w.bits.strb := s_axi_control.WSTRB
s_axi_control.BVALID := shell.io.host.b.valid
shell.io.host.b.ready := s_axi_control.BREADY
s_axi_control.BRESP := shell.io.host.b.bits.resp
shell.io.host.ar.valid := s_axi_control.ARVALID
s_axi_control.ARREADY := shell.io.host.ar.ready
shell.io.host.ar.bits.addr := s_axi_control.ARADDR
s_axi_control.RVALID := shell.io.host.r.valid
shell.io.host.r.ready := s_axi_control.RREADY
s_axi_control.RDATA := shell.io.host.r.bits.data
s_axi_control.RRESP := shell.io.host.r.bits.resp
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package vta.test
import chisel3._
import vta.util.config._
import vta.shell._
/** Test. This generates a testbench file for simulation */
class Test(implicit p: Parameters) extends Module {
val io = IO(new Bundle {})
val sim_shell = Module(new SimShell)
val vta_shell = Module(new VTAShell)
vta_shell.io.host <> sim_shell.io.host
sim_shell.io.mem <> vta_shell.io.mem
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package vta.util.config
// taken from https://github.com/vta.roject/rocket-chip
abstract class Field[T] private (val default: Option[T])
{
def this() = this(None)
def this(default: T) = this(Some(default))
}
abstract class View {
final def apply[T](pname: Field[T]): T = apply(pname, this)
final def apply[T](pname: Field[T], site: View): T = {
val out = find(pname, site)
require (out.isDefined, s"Key ${pname} is not defined in Parameters")
out.get
}
final def lift[T](pname: Field[T]): Option[T] = lift(pname, this)
final def lift[T](pname: Field[T], site: View): Option[T] = find(pname, site).map(_.asInstanceOf[T])
protected[config] def find[T](pname: Field[T], site: View): Option[T]
}
abstract class Parameters extends View {
final def ++ (x: Parameters): Parameters =
new ChainParameters(this, x)
final def alter(f: (View, View, View) => PartialFunction[Any,Any]): Parameters =
Parameters(f) ++ this
final def alterPartial(f: PartialFunction[Any,Any]): Parameters =
Parameters((_,_,_) => f) ++ this
final def alterMap(m: Map[Any,Any]): Parameters =
new MapParameters(m) ++ this
protected[config] def chain[T](site: View, tail: View, pname: Field[T]): Option[T]
protected[config] def find[T](pname: Field[T], site: View) = chain(site, new TerminalView, pname)
}
object Parameters {
def empty: Parameters = new EmptyParameters
def apply(f: (View, View, View) => PartialFunction[Any,Any]): Parameters = new PartialParameters(f)
}
class Config(p: Parameters) extends Parameters {
def this(f: (View, View, View) => PartialFunction[Any,Any]) = this(Parameters(f))
protected[config] def chain[T](site: View, tail: View, pname: Field[T]) = p.chain(site, tail, pname)
override def toString = this.getClass.getSimpleName
def toInstance = this
}
// Internal implementation:
private class TerminalView extends View {
def find[T](pname: Field[T], site: View): Option[T] = pname.default
}
private class ChainView(head: Parameters, tail: View) extends View {
def find[T](pname: Field[T], site: View) = head.chain(site, tail, pname)
}
private class ChainParameters(x: Parameters, y: Parameters) extends Parameters {
def chain[T](site: View, tail: View, pname: Field[T]) = x.chain(site, new ChainView(y, tail), pname)
}
private class EmptyParameters extends Parameters {
def chain[T](site: View, tail: View, pname: Field[T]) = tail.find(pname, site)
}
private class PartialParameters(f: (View, View, View) => PartialFunction[Any,Any]) extends Parameters {
protected[config] def chain[T](site: View, tail: View, pname: Field[T]) = {
val g = f(site, this, tail)
if (g.isDefinedAt(pname)) Some(g.apply(pname).asInstanceOf[T]) else tail.find(pname, site)
}
}
private class MapParameters(map: Map[Any, Any]) extends Parameters {
protected[config] def chain[T](site: View, tail: View, pname: Field[T]) = {
val g = map.get(pname)
if (g.isDefined) Some(g.get.asInstanceOf[T]) else tail.find(pname, site)
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package vta.util.genericbundle
// taken from https://github.com/vta.roject/rocket-chip
import chisel3._
abstract class GenericParameterizedBundle[+T <: Object](val params: T) extends Bundle
{
override def cloneType = {
try {
this.getClass.getConstructors.head.newInstance(params).asInstanceOf[this.type]
} catch {
case e: java.lang.IllegalArgumentException =>
throw new Exception("Unable to use GenericParameterizedBundle.cloneType on " +
this.getClass + ", probably because " + this.getClass +
"() takes more than one argument. Consider overriding " +
"cloneType() on " + this.getClass, e)
}
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package vta
import chisel3._
import vta.util.config._
import vta.shell._
import vta.core._
import vta.test._
/** VTA.
*
* This file contains all the configurations supported by VTA.
* These configurations are built in a mix/match form based on core
* and shell configurations.
*/
class DefaultPynqConfig extends Config(new CoreConfig ++ new PynqConfig)
class DefaultF1Config extends Config(new CoreConfig ++ new F1Config)
object DefaultPynqConfig extends App {
implicit val p: Parameters = new DefaultPynqConfig
chisel3.Driver.execute(args, () => new XilinxShell)
}
object DefaultF1Config extends App {
implicit val p: Parameters = new DefaultF1Config
chisel3.Driver.execute(args, () => new XilinxShell)
}
object TestDefaultF1Config extends App {
implicit val p: Parameters = new DefaultF1Config
chisel3.Driver.execute(args, () => new Test)
}
......@@ -70,8 +70,18 @@ void VTADPIInit(VTAContextHandle handle,
_mem_dpi = mem_dpi;
}
// Override Verilator finish definition
// VL_USER_FINISH needs to be defined when compiling Verilator code
void vl_finish(const char* filename, int linenum, const char* hier) {
Verilated::gotFinish(true);
VL_PRINTF("[TSIM] exiting simulation\n");
}
int VTADPISim(uint64_t max_cycles) {
uint64_t trace_count = 0;
Verilated::flushCall();
Verilated::gotFinish(false);
#if VM_TRACE
uint64_t start = 0;
......
......@@ -53,7 +53,11 @@ extern "C" {
typedef void * VTADeviceHandle;
/*! \brief physical address */
#ifdef USE_TSIM
typedef uint64_t vta_phy_addr_t;
#else
typedef uint32_t vta_phy_addr_t;
#endif
/*!
* \brief Allocate a device resource handle
......@@ -76,10 +80,22 @@ void VTADeviceFree(VTADeviceHandle handle);
*
* \return 0 if running is successful, 1 if timeout.
*/
#ifdef USE_TSIM
int VTADeviceRun(VTADeviceHandle device,
vta_phy_addr_t insn_phy_addr,
vta_phy_addr_t uop_phy_addr,
vta_phy_addr_t inp_phy_addr,
vta_phy_addr_t wgt_phy_addr,
vta_phy_addr_t acc_phy_addr,
vta_phy_addr_t out_phy_addr,
uint32_t insn_count,
uint32_t wait_cycles);
#else
int VTADeviceRun(VTADeviceHandle device,
vta_phy_addr_t insn_phy_addr,
uint32_t insn_count,
uint32_t wait_cycles);
#endif
/*!
* \brief Allocates physically contiguous region in memory (limited by MAX_XFER).
......
......@@ -239,7 +239,7 @@ class Environment(object):
"""The target host"""
if self.TARGET == "pynq":
return "llvm -target=armv7-none-linux-gnueabihf"
if self.TARGET == "sim":
if self.TARGET == "sim" or self.TARGET == "tsim":
return "llvm"
raise ValueError("Unknown target %s" % self.TARGET)
......
......@@ -17,6 +17,8 @@
"""Utilities to start simulator."""
import ctypes
import json
import sys
import os
import tvm
from ..libinfo import find_libvta
......@@ -55,5 +57,22 @@ def stats():
x = tvm.get_global_func("vta.simulator.profiler_status")()
return json.loads(x)
def tsim_init(hw_lib):
"""Init hardware shared library for TSIM
Parameters
------------
hw_lib : str
Name of hardware shared library
"""
cur_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
vta_build_path = os.path.join(cur_path, "..", "..", "..", "build")
if not hw_lib.endswith(("dylib", "so")):
hw_lib += ".dylib" if sys.platform == "darwin" else ".so"
lib = os.path.join(vta_build_path, hw_lib)
f = tvm.get_global_func("tvm.vta.tsim.init")
m = tvm.module.load(lib, "vta-tsim")
f(m)
LIBS = _load_lib()
......@@ -31,7 +31,7 @@ def run(run_func):
"""
env = get_env()
if env.TARGET == "sim":
if env.TARGET in ["sim", "tsim"]:
# Talk to local RPC if necessary to debug RPC server.
# Compile vta on your host with make at the root.
......@@ -48,7 +48,8 @@ def run(run_func):
# Make sure simulation library exists
# If this fails, build vta on host (make)
# with TARGET="sim" in the json.config file.
assert simulator.enabled()
if env.TARGET == "sim":
assert simulator.enabled()
run_func(env, rpc.LocalSession())
elif env.TARGET == "pynq":
......
......@@ -56,7 +56,7 @@ struct DataBuffer {
return data_;
}
/*! \return Physical address of the data. */
uint32_t phy_addr() const {
vta_phy_addr_t phy_addr() const {
return phy_addr_;
}
/*!
......@@ -113,7 +113,7 @@ struct DataBuffer {
/*! \brief The internal data. */
void* data_;
/*! \brief The physical address of the buffer, excluding header. */
uint32_t phy_addr_;
vta_phy_addr_t phy_addr_;
};
/*!
......@@ -302,7 +302,7 @@ class BaseQueue {
return dram_buffer_;
}
/*! \return Physical address of DRAM. */
uint32_t dram_phy_addr() const {
vta_phy_addr_t dram_phy_addr() const {
return dram_phy_addr_;
}
/*! \return Whether there is pending information. */
......@@ -367,7 +367,7 @@ class BaseQueue {
// The buffer in DRAM
char* dram_buffer_{nullptr};
// Physics address of the buffer
uint32_t dram_phy_addr_;
vta_phy_addr_t dram_phy_addr_;
};
/*!
......@@ -424,7 +424,11 @@ class UopQueue : public BaseQueue {
CHECK((dram_end_ - dram_begin_) == (sram_end_ - sram_begin_));
insn->memory_type = VTA_MEM_ID_UOP;
insn->sram_base = sram_begin_;
#ifdef USE_TSIM
insn->dram_base = (uint32_t) dram_phy_addr_ + dram_begin_*kElemBytes;
#else
insn->dram_base = dram_phy_addr_ / kElemBytes + dram_begin_;
#endif
insn->y_size = 1;
insn->x_size = (dram_end_ - dram_begin_);
insn->x_stride = (dram_end_ - dram_begin_);
......@@ -958,7 +962,11 @@ class CommandQueue {
insn->memory_type = dst_memory_type;
insn->sram_base = dst_sram_index;
DataBuffer* src = DataBuffer::FromHandle(src_dram_addr);
#ifdef USE_TSIM
insn->dram_base = (uint32_t) src->phy_addr() + src_elem_offset*GetElemBytes(dst_memory_type);
#else
insn->dram_base = src->phy_addr() / GetElemBytes(dst_memory_type) + src_elem_offset;
#endif
insn->y_size = y_size;
insn->x_size = x_size;
insn->x_stride = x_stride;
......@@ -981,7 +989,11 @@ class CommandQueue {
insn->memory_type = src_memory_type;
insn->sram_base = src_sram_index;
DataBuffer* dst = DataBuffer::FromHandle(dst_dram_addr);
#ifdef USE_TSIM
insn->dram_base = (uint32_t) dst->phy_addr() + dst_elem_offset*GetElemBytes(src_memory_type);
#else
insn->dram_base = dst->phy_addr() / GetElemBytes(src_memory_type) + dst_elem_offset;
#endif
insn->y_size = y_size;
insn->x_size = x_size;
insn->x_stride = x_stride;
......@@ -1046,11 +1058,24 @@ class CommandQueue {
// Make sure that we don't exceed contiguous physical memory limits
CHECK(insn_queue_.count() * sizeof(VTAGenericInsn) < VTA_MAX_XFER);
#ifdef USE_TSIM
int timeout = VTADeviceRun(
device_,
insn_queue_.dram_phy_addr(),
uop_queue_.dram_phy_addr(),
inp_phy_addr_,
wgt_phy_addr_,
acc_phy_addr_,
out_phy_addr_,
insn_queue_.count(),
wait_cycles);
#else
int timeout = VTADeviceRun(
device_,
insn_queue_.dram_phy_addr(),
insn_queue_.count(),
wait_cycles);
#endif
CHECK_EQ(timeout, 0);
// Reset buffers
uop_queue_.Reset();
......@@ -1125,6 +1150,18 @@ class CommandQueue {
ThreadLocal().reset();
}
#ifdef USE_TSIM
void SetBufPhyAddr(uint32_t type, vta_phy_addr_t addr) {
switch (type) {
case VTA_MEM_ID_INP: inp_phy_addr_ = addr;
case VTA_MEM_ID_WGT: wgt_phy_addr_ = addr;
case VTA_MEM_ID_ACC: acc_phy_addr_ = addr;
case VTA_MEM_ID_OUT: out_phy_addr_ = addr;
default: break;
}
}
#endif
private:
// Push GEMM uop to the command buffer
void PushGEMMOp(UopKernel* kernel) {
......@@ -1229,6 +1266,16 @@ class CommandQueue {
InsnQueue<VTA_MAX_XFER, true, true> insn_queue_;
// Device handle
VTADeviceHandle device_{nullptr};
#ifdef USE_TSIM
// Input phy addr
vta_phy_addr_t inp_phy_addr_{0};
// Weight phy addr
vta_phy_addr_t wgt_phy_addr_{0};
// Accumulator phy addr
vta_phy_addr_t acc_phy_addr_{0};
// Output phy addr
vta_phy_addr_t out_phy_addr_{0};
#endif
};
} // namespace vta
......@@ -1317,6 +1364,10 @@ void VTALoadBuffer2D(VTACommandHandle cmd,
uint32_t y_pad_after,
uint32_t dst_sram_index,
uint32_t dst_memory_type) {
#ifdef USE_TSIM
vta::DataBuffer* src = vta::DataBuffer::FromHandle(src_dram_addr);
static_cast<vta::CommandQueue*>(cmd)->SetBufPhyAddr(dst_memory_type, src->phy_addr());
#endif
static_cast<vta::CommandQueue*>(cmd)->
LoadBuffer2D(src_dram_addr, src_elem_offset,
x_size, y_size, x_stride,
......@@ -1333,6 +1384,10 @@ void VTAStoreBuffer2D(VTACommandHandle cmd,
uint32_t x_size,
uint32_t y_size,
uint32_t x_stride) {
#ifdef USE_TSIM
vta::DataBuffer* dst = vta::DataBuffer::FromHandle(dst_dram_addr);
static_cast<vta::CommandQueue*>(cmd)->SetBufPhyAddr(src_memory_type, dst->phy_addr());
#endif
static_cast<vta::CommandQueue*>(cmd)->
StoreBuffer2D(src_sram_index, src_memory_type,
dst_dram_addr, dst_elem_offset,
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <vta/driver.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>
#include <vta/dpi/module.h>
namespace vta {
namespace tsim {
using vta::dpi::DPIModuleNode;
using tvm::runtime::Module;
class DPILoader {
public:
void Init(Module module) {
mod_ = module;
}
DPIModuleNode* Get() {
return static_cast<DPIModuleNode*>(mod_.operator->());
}
static DPILoader* Global() {
static DPILoader inst;
return &inst;
}
Module mod_;
};
class Device {
public:
Device() {
dpi_ = DPILoader::Global();
}
int Run(vta_phy_addr_t insn_phy_addr,
vta_phy_addr_t uop_phy_addr,
vta_phy_addr_t inp_phy_addr,
vta_phy_addr_t wgt_phy_addr,
vta_phy_addr_t acc_phy_addr,
vta_phy_addr_t out_phy_addr,
uint32_t insn_count,
uint32_t wait_cycles) {
this->Init();
this->Launch(insn_phy_addr,
uop_phy_addr,
inp_phy_addr,
wgt_phy_addr,
acc_phy_addr,
out_phy_addr,
insn_count,
wait_cycles);
this->WaitForCompletion(wait_cycles);
dev_->Finish();
return 0;
}
private:
void Init() {
dev_ = dpi_->Get();
}
void Launch(vta_phy_addr_t insn_phy_addr,
vta_phy_addr_t uop_phy_addr,
vta_phy_addr_t inp_phy_addr,
vta_phy_addr_t wgt_phy_addr,
vta_phy_addr_t acc_phy_addr,
vta_phy_addr_t out_phy_addr,
uint32_t insn_count,
uint32_t wait_cycles) {
// launch simulation thread
dev_->Launch(wait_cycles);
dev_->WriteReg(0x10, insn_count);
dev_->WriteReg(0x14, insn_phy_addr);
dev_->WriteReg(0x18, insn_phy_addr >> 32);
dev_->WriteReg(0x1c, 0);
dev_->WriteReg(0x20, uop_phy_addr >> 32);
dev_->WriteReg(0x24, 0);
dev_->WriteReg(0x28, inp_phy_addr >> 32);
dev_->WriteReg(0x2c, 0);
dev_->WriteReg(0x30, wgt_phy_addr >> 32);
dev_->WriteReg(0x34, 0);
dev_->WriteReg(0x38, acc_phy_addr >> 32);
dev_->WriteReg(0x3c, 0);
dev_->WriteReg(0x40, out_phy_addr >> 32);
// start
dev_->WriteReg(0x00, 0x1);
}
void WaitForCompletion(uint32_t wait_cycles) {
uint32_t i, val;
for (i = 0; i < wait_cycles; i++) {
val = dev_->ReadReg(0x00);
val &= 0x2;
if (val == 0x2) break; // finish
}
}
DPILoader* dpi_;
DPIModuleNode* dev_;
};
using tvm::runtime::TVMRetValue;
using tvm::runtime::TVMArgs;
TVM_REGISTER_GLOBAL("tvm.vta.tsim.init")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Module m = args[0];
DPILoader::Global()->Init(m);
});
} // namespace tsim
} // namespace vta
void* VTAMemAlloc(size_t size, int cached) {
void *p = malloc(size);
return p;
}
void VTAMemFree(void* buf) {
free(buf);
}
vta_phy_addr_t VTAMemGetPhyAddr(void* buf) {
return reinterpret_cast<uint64_t>(reinterpret_cast<uint64_t*>(buf));
}
void VTAFlushCache(vta_phy_addr_t buf, int size) {
}
void VTAInvalidateCache(vta_phy_addr_t buf, int size) {
}
VTADeviceHandle VTADeviceAlloc() {
return new vta::tsim::Device();
}
void VTADeviceFree(VTADeviceHandle handle) {
delete static_cast<vta::tsim::Device*>(handle);
}
int VTADeviceRun(VTADeviceHandle handle,
vta_phy_addr_t insn_phy_addr,
vta_phy_addr_t uop_phy_addr,
vta_phy_addr_t inp_phy_addr,
vta_phy_addr_t wgt_phy_addr,
vta_phy_addr_t acc_phy_addr,
vta_phy_addr_t out_phy_addr,
uint32_t insn_count,
uint32_t wait_cycles) {
return static_cast<vta::tsim::Device*>(handle)->Run(
insn_phy_addr,
uop_phy_addr,
inp_phy_addr,
wgt_phy_addr,
acc_phy_addr,
out_phy_addr,
insn_count,
wait_cycles);
}
......@@ -68,6 +68,10 @@ def test_save_load_out():
y_np = x_np.astype(y.dtype)
x_nd = tvm.nd.array(x_np, ctx)
y_nd = tvm.nd.empty(y_np.shape, ctx=ctx, dtype=y_np.dtype)
if env.TARGET == "tsim":
simulator.tsim_init("libvta_hw")
f(x_nd, y_nd)
np.testing.assert_equal(y_np, y_nd.asnumpy())
......@@ -126,6 +130,10 @@ def test_padded_load():
:] = x_np
x_nd = tvm.nd.array(x_np, ctx)
y_nd = tvm.nd.empty(y_np.shape, ctx=ctx, dtype=y_np.dtype)
if env.TARGET == "tsim":
simulator.tsim_init("libvta_hw")
f(x_nd, y_nd)
np.testing.assert_equal(y_np, y_nd.asnumpy())
......@@ -197,6 +205,9 @@ def test_gemm():
y_np = np.right_shift(y_np, 8)
y_np = np.clip(y_np, 0, (1<<(env.INP_WIDTH-1))-1).astype(y.dtype)
if env.TARGET == "tsim":
simulator.tsim_init("libvta_hw")
if env.TARGET == "sim":
simulator.clear_stats()
f(x_nd, w_nd, y_nd)
......@@ -351,6 +362,10 @@ def test_alu():
a_nd = tvm.nd.array(a_np, ctx)
res_nd = tvm.nd.array(
np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx)
if env.TARGET == "tsim":
simulator.tsim_init("libvta_hw")
if use_imm:
f(a_nd, res_nd)
else:
......@@ -420,6 +435,10 @@ def test_relu():
a_nd = tvm.nd.array(a_np, ctx)
res_nd = tvm.nd.array(
np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx)
if env.TARGET == "tsim":
simulator.tsim_init("libvta_hw")
f(a_nd, res_nd)
np.testing.assert_equal(res_np, res_nd.asnumpy())
......@@ -479,6 +498,10 @@ def test_shift_and_scale():
a_nd = tvm.nd.array(a_np, ctx)
res_nd = tvm.nd.array(
np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx)
if env.TARGET == "tsim":
simulator.tsim_init("libvta_hw")
f(a_nd, res_nd)
np.testing.assert_equal(res_np, res_nd.asnumpy())
......@@ -503,11 +526,12 @@ if __name__ == "__main__":
print("Load/store test")
test_save_load_out()
print("Padded load test")
#test_padded_load()
test_padded_load()
print("GEMM test")
test_gemm()
test_alu()
print("ALU test")
test_alu()
print("Relu test")
test_relu()
print("Shift and scale")
test_shift_and_scale()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment