Commit 5fe61fd1 by Luis Vega Committed by Thierry Moreau

[VTA][Chisel] add scalafmt and format existing scala codebase (#3880)

* [VTA][Chisel] add scalafmt and format existing scala codebase

* change column width to 100

* add scalafmt conf file as a valid file type

* add asf header to scalafmt conf file and rerun formatter
parent f4a28c4b
...@@ -87,6 +87,7 @@ ALLOW_FILE_NAME = { ...@@ -87,6 +87,7 @@ ALLOW_FILE_NAME = {
".clang-format", ".clang-format",
".gitmodules", ".gitmodules",
"CODEOWNERS", "CODEOWNERS",
".scalafmt.conf",
} }
# List of specific files allowed in relpath to <proj_root> # List of specific files allowed in relpath to <proj_root>
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
maxColumn = 100
rewrite.rules = [SortModifiers, SortImports]
...@@ -91,7 +91,10 @@ else ...@@ -91,7 +91,10 @@ else
lib_path = $(build_dir)/$(LIBNAME).so lib_path = $(build_dir)/$(LIBNAME).so
endif endif
default: lib default: lint lib
lint:
sbt scalafmt
lib: $(lib_path) lib: $(lib_path)
$(lib_path): $(verilator_build_dir)/V$(TOP).cpp $(lib_path): $(verilator_build_dir)/V$(TOP).cpp
......
...@@ -18,3 +18,4 @@ ...@@ -18,3 +18,4 @@
*/ */
logLevel := Level.Warn logLevel := Level.Warn
addSbtPlugin("com.geirsson" % "sbt-scalafmt" % "1.5.1")
...@@ -41,7 +41,7 @@ case class AccelConfig() { ...@@ -41,7 +41,7 @@ case class AccelConfig() {
val nVals = 2 val nVals = 2
val nPtrs = 2 val nPtrs = 2
val regBits = 32 val regBits = 32
val ptrBits = 2*regBits val ptrBits = 2 * regBits
} }
class Accel extends Module { class Accel extends Module {
......
...@@ -54,27 +54,27 @@ class Compute(implicit config: AccelConfig) extends Module { ...@@ -54,27 +54,27 @@ class Compute(implicit config: AccelConfig) extends Module {
val raddr = Reg(UInt(config.ptrBits.W)) val raddr = Reg(UInt(config.ptrBits.W))
val waddr = Reg(UInt(config.ptrBits.W)) val waddr = Reg(UInt(config.ptrBits.W))
switch (state) { switch(state) {
is (sIdle) { is(sIdle) {
when (io.launch) { when(io.launch) {
state := sReadReq state := sReadReq
} }
} }
is (sReadReq) { is(sReadReq) {
state := sReadData state := sReadData
} }
is (sReadData) { is(sReadData) {
when (io.mem.rd.valid) { when(io.mem.rd.valid) {
state := sWriteReq state := sWriteReq
} }
} }
is (sWriteReq) { is(sWriteReq) {
state := sWriteData state := sWriteData
} }
is (sWriteData) { is(sWriteData) {
when (cnt === (length - 1.U)) { when(cnt === (length - 1.U)) {
state := sIdle state := sIdle
} .otherwise { }.otherwise {
state := sReadReq state := sReadReq
} }
} }
...@@ -83,9 +83,9 @@ class Compute(implicit config: AccelConfig) extends Module { ...@@ -83,9 +83,9 @@ class Compute(implicit config: AccelConfig) extends Module {
val last = state === sWriteData && cnt === (length - 1.U) val last = state === sWriteData && cnt === (length - 1.U)
// cycle counter // cycle counter
when (state === sIdle) { when(state === sIdle) {
cycles := 0.U cycles := 0.U
} .otherwise { }.otherwise {
cycles := cycles + 1.U cycles := cycles + 1.U
} }
...@@ -93,10 +93,10 @@ class Compute(implicit config: AccelConfig) extends Module { ...@@ -93,10 +93,10 @@ class Compute(implicit config: AccelConfig) extends Module {
io.ecnt(0).bits := cycles io.ecnt(0).bits := cycles
// calculate next address // calculate next address
when (state === sIdle) { when(state === sIdle) {
raddr := io.ptrs(0) raddr := io.ptrs(0)
waddr := io.ptrs(1) waddr := io.ptrs(1)
} .elsewhen (state === sWriteData) { // increment by 8-bytes }.elsewhen(state === sWriteData) { // increment by 8-bytes
raddr := raddr + 8.U raddr := raddr + 8.U
waddr := waddr + 8.U waddr := waddr + 8.U
} }
...@@ -108,7 +108,7 @@ class Compute(implicit config: AccelConfig) extends Module { ...@@ -108,7 +108,7 @@ class Compute(implicit config: AccelConfig) extends Module {
io.mem.req.addr := Mux(state === sReadReq, raddr, waddr) io.mem.req.addr := Mux(state === sReadReq, raddr, waddr)
// read // read
when (state === sReadData && io.mem.rd.valid) { when(state === sReadData && io.mem.rd.valid) {
reg := io.mem.rd.bits + const reg := io.mem.rd.bits + const
} }
io.mem.rd.ready := state === sReadData io.mem.rd.ready := state === sReadData
...@@ -118,9 +118,9 @@ class Compute(implicit config: AccelConfig) extends Module { ...@@ -118,9 +118,9 @@ class Compute(implicit config: AccelConfig) extends Module {
io.mem.wr.bits := reg io.mem.wr.bits := reg
// count read/write // count read/write
when (state === sIdle) { when(state === sIdle) {
cnt := 0.U cnt := 0.U
} .elsewhen (state === sWriteData) { }.elsewhen(state === sWriteData) {
cnt := cnt + 1.U cnt := cnt + 1.U
} }
......
...@@ -59,52 +59,54 @@ class RegFile(implicit config: AccelConfig) extends Module { ...@@ -59,52 +59,54 @@ class RegFile(implicit config: AccelConfig) extends Module {
val sIdle :: sRead :: Nil = Enum(2) val sIdle :: sRead :: Nil = Enum(2)
val state = RegInit(sIdle) val state = RegInit(sIdle)
switch (state) { switch(state) {
is (sIdle) { is(sIdle) {
when (io.host.req.valid && !io.host.req.opcode) { when(io.host.req.valid && !io.host.req.opcode) {
state := sRead state := sRead
} }
} }
is (sRead) { is(sRead) {
state := sIdle state := sIdle
} }
} }
io.host.req.deq := state === sIdle & io.host.req.valid io.host.req.deq := state === sIdle & io.host.req.valid
val nTotal = config.nCtrl + config.nECnt + config.nVals + (2*config.nPtrs) val nTotal = config.nCtrl + config.nECnt + config.nVals + (2 * config.nPtrs)
val reg = Seq.fill(nTotal)(RegInit(0.U.asTypeOf(chiselTypeOf(io.host.req.value)))) val reg =
Seq.fill(nTotal)(RegInit(0.U.asTypeOf(chiselTypeOf(io.host.req.value))))
val addr = Seq.tabulate(nTotal)(_ * 4) val addr = Seq.tabulate(nTotal)(_ * 4)
val reg_map = (addr zip reg) map { case (a, r) => a.U -> r } val reg_map = (addr zip reg) map { case (a, r) => a.U -> r }
val eo = config.nCtrl val eo = config.nCtrl
val vo = eo + config.nECnt val vo = eo + config.nECnt
val po = vo + config.nVals val po = vo + config.nVals
when (io.finish) { when(io.finish) {
reg(0) := "b_10".U reg(0) := "b_10".U
} .elsewhen (state === sIdle && io.host.req.valid && }.elsewhen(state === sIdle && io.host.req.valid &&
io.host.req.opcode && addr(0).U === io.host.req.addr) { io.host.req.opcode && addr(0).U === io.host.req.addr) {
reg(0) := io.host.req.value reg(0) := io.host.req.value
} }
for (i <- 0 until config.nECnt) { for (i <- 0 until config.nECnt) {
when (io.ecnt(i).valid) { when(io.ecnt(i).valid) {
reg(eo + i) := io.ecnt(i).bits reg(eo + i) := io.ecnt(i).bits
} .elsewhen (state === sIdle && io.host.req.valid && }.elsewhen(state === sIdle && io.host.req.valid &&
io.host.req.opcode && addr(eo + i).U === io.host.req.addr) { io.host.req.opcode && addr(eo + i).U === io.host.req.addr) {
reg(eo + i) := io.host.req.value reg(eo + i) := io.host.req.value
} }
} }
for (i <- 0 until (config.nVals + (2*config.nPtrs))) { for (i <- 0 until (config.nVals + (2 * config.nPtrs))) {
when (state === sIdle && io.host.req.valid && when(
state === sIdle && io.host.req.valid &&
io.host.req.opcode && addr(vo + i).U === io.host.req.addr) { io.host.req.opcode && addr(vo + i).U === io.host.req.addr) {
reg(vo + i) := io.host.req.value reg(vo + i) := io.host.req.value
} }
} }
val rdata = RegInit(0.U.asTypeOf(chiselTypeOf(io.host.req.value))) val rdata = RegInit(0.U.asTypeOf(chiselTypeOf(io.host.req.value)))
when (state === sIdle && io.host.req.valid && !io.host.req.opcode) { when(state === sIdle && io.host.req.valid && !io.host.req.opcode) {
rdata := MuxLookup(io.host.req.addr, 0.U, reg_map) rdata := MuxLookup(io.host.req.addr, 0.U, reg_map)
} }
...@@ -118,6 +120,6 @@ class RegFile(implicit config: AccelConfig) extends Module { ...@@ -118,6 +120,6 @@ class RegFile(implicit config: AccelConfig) extends Module {
} }
for (i <- 0 until config.nPtrs) { for (i <- 0 until config.nPtrs) {
io.ptrs(i) := Cat(reg(po + 2*i + 1), reg(po + 2*i)) io.ptrs(i) := Cat(reg(po + (2 * i) + 1), reg(po + (2 * i)))
} }
} }
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
maxColumn = 100
rewrite.rules = [SortModifiers, SortImports]
...@@ -102,7 +102,10 @@ else ...@@ -102,7 +102,10 @@ else
lib_path = $(vta_dir)/$(BUILD_NAME)/$(VTA_LIBNAME).so lib_path = $(vta_dir)/$(BUILD_NAME)/$(VTA_LIBNAME).so
endif endif
default: lib default: lint lib
lint:
sbt scalafmt
lib: $(lib_path) lib: $(lib_path)
......
...@@ -18,3 +18,4 @@ ...@@ -18,3 +18,4 @@
*/ */
logLevel := Level.Warn logLevel := Level.Warn
addSbtPlugin("com.geirsson" % "sbt-scalafmt" % "1.5.1")
...@@ -49,7 +49,8 @@ class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module { ...@@ -49,7 +49,8 @@ class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module {
val sIdle :: sSync :: sExe :: Nil = Enum(3) val sIdle :: sSync :: sExe :: Nil = Enum(3)
val state = RegInit(sIdle) val state = RegInit(sIdle)
val s = Seq.tabulate(2)(_ => Module(new Semaphore(counterBits = 8, counterInitValue = 0))) val s = Seq.tabulate(2)(_ =>
Module(new Semaphore(counterBits = 8, counterInitValue = 0)))
val loadUop = Module(new LoadUop) val loadUop = Module(new LoadUop)
val tensorAcc = Module(new TensorLoad(tensorType = "acc")) val tensorAcc = Module(new TensorLoad(tensorType = "acc"))
...@@ -62,7 +63,8 @@ class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module { ...@@ -62,7 +63,8 @@ class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module {
val dec = Module(new ComputeDecode) val dec = Module(new ComputeDecode)
dec.io.inst := inst_q.io.deq.bits dec.io.inst := inst_q.io.deq.bits
val inst_type = Cat(dec.io.isFinish, val inst_type =
Cat(dec.io.isFinish,
dec.io.isAlu, dec.io.isAlu,
dec.io.isGemm, dec.io.isGemm,
dec.io.isLoadAcc, dec.io.isLoadAcc,
...@@ -72,7 +74,8 @@ class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module { ...@@ -72,7 +74,8 @@ class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module {
val snext = inst_q.io.deq.valid & Mux(dec.io.pop_next, s(1).io.sready, true.B) val snext = inst_q.io.deq.valid & Mux(dec.io.pop_next, s(1).io.sready, true.B)
val start = snext & sprev val start = snext & sprev
val done = val done =
MuxLookup(inst_type, MuxLookup(
inst_type,
false.B, // default false.B, // default
Array( Array(
"h_01".U -> loadUop.io.done, "h_01".U -> loadUop.io.done,
...@@ -84,21 +87,21 @@ class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module { ...@@ -84,21 +87,21 @@ class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module {
) )
// control // control
switch (state) { switch(state) {
is (sIdle) { is(sIdle) {
when (start) { when(start) {
when (dec.io.isSync) { when(dec.io.isSync) {
state := sSync state := sSync
} .elsewhen (inst_type.orR) { }.elsewhen(inst_type.orR) {
state := sExe state := sExe
} }
} }
} }
is (sSync) { is(sSync) {
state := sIdle state := sIdle
} }
is (sExe) { is(sExe) {
when (done) { when(done) {
state := sIdle state := sIdle
} }
} }
...@@ -113,14 +116,20 @@ class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module { ...@@ -113,14 +116,20 @@ class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module {
loadUop.io.inst := inst_q.io.deq.bits loadUop.io.inst := inst_q.io.deq.bits
loadUop.io.baddr := io.uop_baddr loadUop.io.baddr := io.uop_baddr
io.vme_rd(0) <> loadUop.io.vme_rd io.vme_rd(0) <> loadUop.io.vme_rd
loadUop.io.uop.idx <> Mux(dec.io.isGemm, tensorGemm.io.uop.idx, tensorAlu.io.uop.idx) loadUop.io.uop.idx <> Mux(dec.io.isGemm,
tensorGemm.io.uop.idx,
tensorAlu.io.uop.idx)
// acc // acc
tensorAcc.io.start := state === sIdle & start & dec.io.isLoadAcc tensorAcc.io.start := state === sIdle & start & dec.io.isLoadAcc
tensorAcc.io.inst := inst_q.io.deq.bits tensorAcc.io.inst := inst_q.io.deq.bits
tensorAcc.io.baddr := io.acc_baddr tensorAcc.io.baddr := io.acc_baddr
tensorAcc.io.tensor.rd.idx <> Mux(dec.io.isGemm, tensorGemm.io.acc.rd.idx, tensorAlu.io.acc.rd.idx) tensorAcc.io.tensor.rd.idx <> Mux(dec.io.isGemm,
tensorAcc.io.tensor.wr <> Mux(dec.io.isGemm, tensorGemm.io.acc.wr, tensorAlu.io.acc.wr) tensorGemm.io.acc.rd.idx,
tensorAlu.io.acc.rd.idx)
tensorAcc.io.tensor.wr <> Mux(dec.io.isGemm,
tensorGemm.io.acc.wr,
tensorAlu.io.acc.wr)
io.vme_rd(1) <> tensorAcc.io.vme_rd io.vme_rd(1) <> tensorAcc.io.vme_rd
// gemm // gemm
...@@ -146,7 +155,9 @@ class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module { ...@@ -146,7 +155,9 @@ class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module {
tensorAlu.io.out.rd.data.bits <> io.out.rd.data.bits tensorAlu.io.out.rd.data.bits <> io.out.rd.data.bits
// out // out
io.out.rd.idx <> Mux(dec.io.isGemm, tensorGemm.io.out.rd.idx, tensorAlu.io.out.rd.idx) io.out.rd.idx <> Mux(dec.io.isGemm,
tensorGemm.io.out.rd.idx,
tensorAlu.io.out.rd.idx)
io.out.wr <> Mux(dec.io.isGemm, tensorGemm.io.out.wr, tensorAlu.io.out.wr) io.out.wr <> Mux(dec.io.isGemm, tensorGemm.io.out.wr, tensorAlu.io.out.wr)
// semaphore // semaphore
...@@ -163,36 +174,43 @@ class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module { ...@@ -163,36 +174,43 @@ class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module {
// debug // debug
if (debug) { if (debug) {
// start // start
when (state === sIdle && start) { when(state === sIdle && start) {
when (dec.io.isSync) { when(dec.io.isSync) {
printf("[Compute] start sync\n") printf("[Compute] start sync\n")
} .elsewhen (dec.io.isLoadUop) { }.elsewhen(dec.io.isLoadUop) {
printf("[Compute] start load uop\n") printf("[Compute] start load uop\n")
} .elsewhen (dec.io.isLoadAcc) { }
.elsewhen(dec.io.isLoadAcc) {
printf("[Compute] start load acc\n") printf("[Compute] start load acc\n")
} .elsewhen (dec.io.isGemm) { }
.elsewhen(dec.io.isGemm) {
printf("[Compute] start gemm\n") printf("[Compute] start gemm\n")
} .elsewhen (dec.io.isAlu) { }
.elsewhen(dec.io.isAlu) {
printf("[Compute] start alu\n") printf("[Compute] start alu\n")
} .elsewhen (dec.io.isFinish) { }
.elsewhen(dec.io.isFinish) {
printf("[Compute] start finish\n") printf("[Compute] start finish\n")
} }
} }
// done // done
when (state === sSync) { when(state === sSync) {
printf("[Compute] done sync\n") printf("[Compute] done sync\n")
} }
when (state === sExe) { when(state === sExe) {
when (done) { when(done) {
when (dec.io.isLoadUop) { when(dec.io.isLoadUop) {
printf("[Compute] done load uop\n") printf("[Compute] done load uop\n")
} .elsewhen (dec.io.isLoadAcc) { }.elsewhen(dec.io.isLoadAcc) {
printf("[Compute] done load acc\n") printf("[Compute] done load acc\n")
} .elsewhen (dec.io.isGemm) { }
.elsewhen(dec.io.isGemm) {
printf("[Compute] done gemm\n") printf("[Compute] done gemm\n")
} .elsewhen (dec.io.isAlu) { }
.elsewhen(dec.io.isAlu) {
printf("[Compute] done alu\n") printf("[Compute] done alu\n")
} .elsewhen (dec.io.isFinish) { }
.elsewhen(dec.io.isFinish) {
printf("[Compute] done finish\n") printf("[Compute] done finish\n")
} }
} }
......
...@@ -27,8 +27,10 @@ import vta.util.config._ ...@@ -27,8 +27,10 @@ import vta.util.config._
* be eventually filled out with class configurations that can be * be eventually filled out with class configurations that can be
* mixed/matched with Shell configurations for different backends. * mixed/matched with Shell configurations for different backends.
*/ */
class CoreConfig extends Config((site, here, up) => { class CoreConfig
case CoreKey => CoreParams( extends Config((site, here, up) => {
case CoreKey =>
CoreParams(
batch = 1, batch = 1,
blockOut = 16, blockOut = 16,
blockIn = 16, blockIn = 16,
...@@ -42,5 +44,6 @@ class CoreConfig extends Config((site, here, up) => { ...@@ -42,5 +44,6 @@ class CoreConfig extends Config((site, here, up) => {
wgtMemDepth = 1024, wgtMemDepth = 1024,
accMemDepth = 2048, accMemDepth = 2048,
outMemDepth = 2048, outMemDepth = 2048,
instQueueEntries = 512) instQueueEntries = 512
}) )
})
...@@ -24,7 +24,7 @@ import vta.util.config._ ...@@ -24,7 +24,7 @@ import vta.util.config._
import vta.shell._ import vta.shell._
/** Core parameters */ /** Core parameters */
case class CoreParams ( case class CoreParams(
batch: Int = 1, batch: Int = 1,
blockOut: Int = 16, blockOut: Int = 16,
blockIn: Int = 16, blockIn: Int = 16,
...@@ -39,9 +39,9 @@ case class CoreParams ( ...@@ -39,9 +39,9 @@ case class CoreParams (
accMemDepth: Int = 512, accMemDepth: Int = 512,
outMemDepth: Int = 512, outMemDepth: Int = 512,
instQueueEntries: Int = 32 instQueueEntries: Int = 32
) ) {
{ require(uopBits % 8 == 0,
require (uopBits % 8 == 0, s"\n\n[VTA] [CoreParams] uopBits must be byte aligned\n\n") s"\n\n[VTA] [CoreParams] uopBits must be byte aligned\n\n")
} }
case object CoreKey extends Field[CoreParams] case object CoreKey extends Field[CoreParams]
......
...@@ -133,7 +133,8 @@ class FetchDecode extends Module { ...@@ -133,7 +133,8 @@ class FetchDecode extends Module {
val isStore = Output(Bool()) val isStore = Output(Bool())
}) })
val csignals = val csignals =
ListLookup(io.inst, ListLookup(
io.inst,
List(N, OP_X), List(N, OP_X),
Array( Array(
LUOP -> List(Y, OP_G), LUOP -> List(Y, OP_G),
......
...@@ -38,17 +38,18 @@ import vta.shell._ ...@@ -38,17 +38,18 @@ import vta.shell._
* If one would like to add an event counter, then the value of nECnt must be * If one would like to add an event counter, then the value of nECnt must be
* changed in VCRParams together with the corresponding counting logic here. * changed in VCRParams together with the corresponding counting logic here.
*/ */
class EventCounters(debug: Boolean = false)(implicit p: Parameters) extends Module { class EventCounters(debug: Boolean = false)(implicit p: Parameters)
extends Module {
val vp = p(ShellKey).vcrParams val vp = p(ShellKey).vcrParams
val io = IO(new Bundle{ val io = IO(new Bundle {
val launch = Input(Bool()) val launch = Input(Bool())
val finish = Input(Bool()) val finish = Input(Bool())
val ecnt = Vec(vp.nECnt, ValidIO(UInt(vp.regBits.W))) val ecnt = Vec(vp.nECnt, ValidIO(UInt(vp.regBits.W)))
}) })
val cycle_cnt = RegInit(0.U(vp.regBits.W)) val cycle_cnt = RegInit(0.U(vp.regBits.W))
when (io.launch && !io.finish) { when(io.launch && !io.finish) {
cycle_cnt := cycle_cnt + 1.U cycle_cnt := cycle_cnt + 1.U
} .otherwise { }.otherwise {
cycle_cnt := 0.U cycle_cnt := 0.U
} }
io.ecnt(0).valid := io.finish io.ecnt(0).valid := io.finish
......
...@@ -67,56 +67,57 @@ class Fetch(debug: Boolean = false)(implicit p: Parameters) extends Module { ...@@ -67,56 +67,57 @@ class Fetch(debug: Boolean = false)(implicit p: Parameters) extends Module {
val xrem = Reg(chiselTypeOf(io.ins_count)) val xrem = Reg(chiselTypeOf(io.ins_count))
val xsize = (io.ins_count << 1.U) - 1.U val xsize = (io.ins_count << 1.U) - 1.U
val xmax = (1 << mp.lenBits).U val xmax = (1 << mp.lenBits).U
val xmax_bytes = ((1 << mp.lenBits)*mp.dataBits/8).U val xmax_bytes = ((1 << mp.lenBits) * mp.dataBits / 8).U
val sIdle :: sReadCmd :: sReadLSB :: sReadMSB :: sDrain :: Nil = Enum(5) val sIdle :: sReadCmd :: sReadLSB :: sReadMSB :: sDrain :: Nil = Enum(5)
val state = RegInit(sIdle) val state = RegInit(sIdle)
// control // control
switch (state) { switch(state) {
is (sIdle) { is(sIdle) {
when (pulse) { when(pulse) {
state := sReadCmd state := sReadCmd
when (xsize < xmax) { when(xsize < xmax) {
rlen := xsize rlen := xsize
ilen := xsize >> 1.U ilen := xsize >> 1.U
xrem := 0.U xrem := 0.U
} .otherwise { }.otherwise {
rlen := xmax - 1.U rlen := xmax - 1.U
ilen := (xmax >> 1.U) - 1.U ilen := (xmax >> 1.U) - 1.U
xrem := xsize - xmax xrem := xsize - xmax
} }
} }
} }
is (sReadCmd) { is(sReadCmd) {
when (io.vme_rd.cmd.ready) { when(io.vme_rd.cmd.ready) {
state := sReadLSB state := sReadLSB
} }
} }
is (sReadLSB) { is(sReadLSB) {
when (io.vme_rd.data.valid) { when(io.vme_rd.data.valid) {
state := sReadMSB state := sReadMSB
} }
} }
is (sReadMSB) { is(sReadMSB) {
when (io.vme_rd.data.valid) { when(io.vme_rd.data.valid) {
when (inst_q.io.count === ilen) { when(inst_q.io.count === ilen) {
state := sDrain state := sDrain
} .otherwise { }.otherwise {
state := sReadLSB state := sReadLSB
} }
} }
} }
is (sDrain) { is(sDrain) {
when (inst_q.io.count === 0.U) { when(inst_q.io.count === 0.U) {
when (xrem === 0.U) { when(xrem === 0.U) {
state := sIdle state := sIdle
} .elsewhen (xrem < xmax) { }.elsewhen(xrem < xmax) {
state := sReadCmd state := sReadCmd
rlen := xrem rlen := xrem
ilen := xrem >> 1.U ilen := xrem >> 1.U
xrem := 0.U xrem := 0.U
} .otherwise { }
.otherwise {
state := sReadCmd state := sReadCmd
rlen := xmax - 1.U rlen := xmax - 1.U
ilen := (xmax >> 1.U) - 1.U ilen := (xmax >> 1.U) - 1.U
...@@ -127,9 +128,9 @@ class Fetch(debug: Boolean = false)(implicit p: Parameters) extends Module { ...@@ -127,9 +128,9 @@ class Fetch(debug: Boolean = false)(implicit p: Parameters) extends Module {
} }
// read instructions from dram // read instructions from dram
when (state === sIdle) { when(state === sIdle) {
raddr := io.ins_baddr raddr := io.ins_baddr
} .elsewhen (state === sDrain && inst_q.io.count === 0.U && xrem =/= 0.U) { }.elsewhen(state === sDrain && inst_q.io.count === 0.U && xrem =/= 0.U) {
raddr := raddr + xmax_bytes raddr := raddr + xmax_bytes
} }
...@@ -143,7 +144,7 @@ class Fetch(debug: Boolean = false)(implicit p: Parameters) extends Module { ...@@ -143,7 +144,7 @@ class Fetch(debug: Boolean = false)(implicit p: Parameters) extends Module {
val msb = io.vme_rd.data.bits val msb = io.vme_rd.data.bits
val inst = Cat(msb, lsb) val inst = Cat(msb, lsb)
when (state === sReadLSB) { lsb := io.vme_rd.data.bits } when(state === sReadLSB) { lsb := io.vme_rd.data.bits }
inst_q.io.enq.valid := io.vme_rd.data.valid & state === sReadMSB inst_q.io.enq.valid := io.vme_rd.data.valid & state === sReadMSB
inst_q.io.enq.bits := inst inst_q.io.enq.bits := inst
...@@ -169,27 +170,25 @@ class Fetch(debug: Boolean = false)(implicit p: Parameters) extends Module { ...@@ -169,27 +170,25 @@ class Fetch(debug: Boolean = false)(implicit p: Parameters) extends Module {
"h_01".U -> io.inst.ld.ready, "h_01".U -> io.inst.ld.ready,
"h_02".U -> io.inst.st.ready, "h_02".U -> io.inst.st.ready,
"h_04".U -> io.inst.co.ready "h_04".U -> io.inst.co.ready
) ))
)
// dequeue instruction // dequeue instruction
inst_q.io.deq.ready := deq_ready & inst_q.io.deq.valid & state === sDrain inst_q.io.deq.ready := deq_ready & inst_q.io.deq.valid & state === sDrain
// debug // debug
if (debug) { if (debug) {
when (state === sIdle && pulse) { when(state === sIdle && pulse) {
printf("[Fetch] Launch\n") printf("[Fetch] Launch\n")
} }
// instruction // instruction
when (inst_q.io.deq.fire()) { when(inst_q.io.deq.fire()) {
when (dec.io.isLoad) { when(dec.io.isLoad) {
printf("[Fetch] [instruction decode] [L] %x\n", inst_q.io.deq.bits) printf("[Fetch] [instruction decode] [L] %x\n", inst_q.io.deq.bits)
} }
when (dec.io.isCompute) { when(dec.io.isCompute) {
printf("[Fetch] [instruction decode] [C] %x\n", inst_q.io.deq.bits) printf("[Fetch] [instruction decode] [C] %x\n", inst_q.io.deq.bits)
} }
when (dec.io.isStore) { when(dec.io.isStore) {
printf("[Fetch] [instruction decode] [S] %x\n", inst_q.io.deq.bits) printf("[Fetch] [instruction decode] [S] %x\n", inst_q.io.deq.bits)
} }
} }
......
...@@ -26,8 +26,7 @@ import chisel3.util._ ...@@ -26,8 +26,7 @@ import chisel3.util._
* *
* These constants are used for decoding (parsing) fields on instructions. * These constants are used for decoding (parsing) fields on instructions.
*/ */
trait ISAConstants trait ISAConstants {
{
val INST_BITS = 128 val INST_BITS = 128
val OP_BITS = 3 val OP_BITS = 3
...@@ -79,15 +78,37 @@ trait ISAConstants ...@@ -79,15 +78,37 @@ trait ISAConstants
* TODO: Add VXOR to clear accumulator * TODO: Add VXOR to clear accumulator
*/ */
object ISA { object ISA {
def LUOP = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????0_0????000") def LUOP =
def LWGT = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????0_1????000") BitPat(
def LINP = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????1_0????000") "b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????0_0????000")
def LACC = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????1_1????000") def LWGT =
def SOUT = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????001") BitPat(
def GEMM = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????010") "b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????0_1????000")
def VMIN = BitPat("b_????????_????????_??00????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100") def LINP =
def VMAX = BitPat("b_????????_????????_??01????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100") BitPat(
def VADD = BitPat("b_????????_????????_??10????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100") "b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????1_0????000")
def VSHX = BitPat("b_????????_????????_??11????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100") def LACC =
def FNSH = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????011") BitPat(
"b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????1_1????000")
def SOUT =
BitPat(
"b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????001")
def GEMM =
BitPat(
"b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????010")
def VMIN =
BitPat(
"b_????????_????????_??00????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100")
def VMAX =
BitPat(
"b_????????_????????_??01????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100")
def VADD =
BitPat(
"b_????????_????????_??10????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100")
def VSHX =
BitPat(
"b_????????_????????_??11????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100")
def FNSH =
BitPat(
"b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????011")
} }
...@@ -54,27 +54,28 @@ class Load(debug: Boolean = false)(implicit p: Parameters) extends Module { ...@@ -54,27 +54,28 @@ class Load(debug: Boolean = false)(implicit p: Parameters) extends Module {
val tensorType = Seq("inp", "wgt") val tensorType = Seq("inp", "wgt")
val tensorDec = Seq(dec.io.isInput, dec.io.isWeight) val tensorDec = Seq(dec.io.isInput, dec.io.isWeight)
val tensorLoad = Seq.tabulate(2)(i => Module(new TensorLoad(tensorType = tensorType(i)))) val tensorLoad =
Seq.tabulate(2)(i => Module(new TensorLoad(tensorType = tensorType(i))))
val start = inst_q.io.deq.valid & Mux(dec.io.pop_next, s.io.sready, true.B) val start = inst_q.io.deq.valid & Mux(dec.io.pop_next, s.io.sready, true.B)
val done = Mux(dec.io.isInput, tensorLoad(0).io.done, tensorLoad(1).io.done) val done = Mux(dec.io.isInput, tensorLoad(0).io.done, tensorLoad(1).io.done)
// control // control
switch (state) { switch(state) {
is (sIdle) { is(sIdle) {
when (start) { when(start) {
when (dec.io.isSync) { when(dec.io.isSync) {
state := sSync state := sSync
} .elsewhen (dec.io.isInput || dec.io.isWeight) { }.elsewhen(dec.io.isInput || dec.io.isWeight) {
state := sExe state := sExe
} }
} }
} }
is (sSync) { is(sSync) {
state := sIdle state := sIdle
} }
is (sExe) { is(sExe) {
when (done) { when(done) {
state := sIdle state := sIdle
} }
} }
...@@ -105,24 +106,25 @@ class Load(debug: Boolean = false)(implicit p: Parameters) extends Module { ...@@ -105,24 +106,25 @@ class Load(debug: Boolean = false)(implicit p: Parameters) extends Module {
// debug // debug
if (debug) { if (debug) {
// start // start
when (state === sIdle && start) { when(state === sIdle && start) {
when (dec.io.isSync) { when(dec.io.isSync) {
printf("[Load] start sync\n") printf("[Load] start sync\n")
} .elsewhen (dec.io.isInput) { }.elsewhen(dec.io.isInput) {
printf("[Load] start input\n") printf("[Load] start input\n")
} .elsewhen (dec.io.isWeight) { }
.elsewhen(dec.io.isWeight) {
printf("[Load] start weight\n") printf("[Load] start weight\n")
} }
} }
// done // done
when (state === sSync) { when(state === sSync) {
printf("[Load] done sync\n") printf("[Load] done sync\n")
} }
when (state === sExe) { when(state === sExe) {
when (done) { when(done) {
when (dec.io.isInput) { when(dec.io.isInput) {
printf("[Load] done input\n") printf("[Load] done input\n")
} .elsewhen (dec.io.isWeight) { }.elsewhen(dec.io.isWeight) {
printf("[Load] done weight\n") printf("[Load] done weight\n")
} }
} }
......
...@@ -79,7 +79,7 @@ class LoadUop(debug: Boolean = false)(implicit p: Parameters) extends Module { ...@@ -79,7 +79,7 @@ class LoadUop(debug: Boolean = false)(implicit p: Parameters) extends Module {
val xrem = Reg(chiselTypeOf(dec.xsize)) val xrem = Reg(chiselTypeOf(dec.xsize))
val xsize = (dec.xsize >> log2Ceil(numUop)) + dec.xsize(0) + (dec.sram_offset % 2.U) - 1.U val xsize = (dec.xsize >> log2Ceil(numUop)) + dec.xsize(0) + (dec.sram_offset % 2.U) - 1.U
val xmax = (1 << mp.lenBits).U val xmax = (1 << mp.lenBits).U
val xmax_bytes = ((1 << mp.lenBits)*mp.dataBits/8).U val xmax_bytes = ((1 << mp.lenBits) * mp.dataBits / 8).U
val offsetIsEven = (dec.sram_offset % 2.U) === 0.U val offsetIsEven = (dec.sram_offset % 2.U) === 0.U
val sizeIsEven = (dec.xsize % 2.U) === 0.U val sizeIsEven = (dec.xsize % 2.U) === 0.U
...@@ -88,34 +88,35 @@ class LoadUop(debug: Boolean = false)(implicit p: Parameters) extends Module { ...@@ -88,34 +88,35 @@ class LoadUop(debug: Boolean = false)(implicit p: Parameters) extends Module {
val state = RegInit(sIdle) val state = RegInit(sIdle)
// control // control
switch (state) { switch(state) {
is (sIdle) { is(sIdle) {
when (io.start) { when(io.start) {
state := sReadCmd state := sReadCmd
when (xsize < xmax) { when(xsize < xmax) {
xlen := xsize xlen := xsize
xrem := 0.U xrem := 0.U
} .otherwise { }.otherwise {
xlen := xmax - 1.U xlen := xmax - 1.U
xrem := xsize - xmax xrem := xsize - xmax
} }
} }
} }
is (sReadCmd) { is(sReadCmd) {
when (io.vme_rd.cmd.ready) { when(io.vme_rd.cmd.ready) {
state := sReadData state := sReadData
} }
} }
is (sReadData) { is(sReadData) {
when (io.vme_rd.data.valid) { when(io.vme_rd.data.valid) {
when(xcnt === xlen) { when(xcnt === xlen) {
when (xrem === 0.U) { when(xrem === 0.U) {
state := sIdle state := sIdle
} .elsewhen (xrem < xmax) { }.elsewhen(xrem < xmax) {
state := sReadCmd state := sReadCmd
xlen := xrem xlen := xrem
xrem := 0.U xrem := 0.U
} .otherwise { }
.otherwise {
state := sReadCmd state := sReadCmd
xlen := xmax - 1.U xlen := xmax - 1.U
xrem := xrem - xmax xrem := xrem - xmax
...@@ -127,13 +128,14 @@ class LoadUop(debug: Boolean = false)(implicit p: Parameters) extends Module { ...@@ -127,13 +128,14 @@ class LoadUop(debug: Boolean = false)(implicit p: Parameters) extends Module {
// read-from-dram // read-from-dram
val maskOffset = VecInit(Seq.fill(M_DRAM_OFFSET_BITS)(true.B)).asUInt val maskOffset = VecInit(Seq.fill(M_DRAM_OFFSET_BITS)(true.B)).asUInt
when (state === sIdle) { when(state === sIdle) {
when (offsetIsEven) { when(offsetIsEven) {
raddr := io.baddr | (maskOffset & (dec.dram_offset << log2Ceil(uopBytes))) raddr := io.baddr | (maskOffset & (dec.dram_offset << log2Ceil(uopBytes)))
} .otherwise { }.otherwise {
raddr := (io.baddr | (maskOffset & (dec.dram_offset << log2Ceil(uopBytes)))) - uopBytes.U raddr := (io.baddr | (maskOffset & (dec.dram_offset << log2Ceil(
uopBytes)))) - uopBytes.U
} }
} .elsewhen (state === sReadData && xcnt === xlen && xrem =/= 0.U) { }.elsewhen(state === sReadData && xcnt === xlen && xrem =/= 0.U) {
raddr := raddr + xmax_bytes raddr := raddr + xmax_bytes
} }
...@@ -143,16 +145,16 @@ class LoadUop(debug: Boolean = false)(implicit p: Parameters) extends Module { ...@@ -143,16 +145,16 @@ class LoadUop(debug: Boolean = false)(implicit p: Parameters) extends Module {
io.vme_rd.data.ready := state === sReadData io.vme_rd.data.ready := state === sReadData
when (state =/= sReadData) { when(state =/= sReadData) {
xcnt := 0.U xcnt := 0.U
} .elsewhen (io.vme_rd.data.fire()) { }.elsewhen(io.vme_rd.data.fire()) {
xcnt := xcnt + 1.U xcnt := xcnt + 1.U
} }
val waddr = Reg(UInt(log2Ceil(uopDepth).W)) val waddr = Reg(UInt(log2Ceil(uopDepth).W))
when (state === sIdle) { when(state === sIdle) {
waddr := dec.sram_offset >> log2Ceil(numUop) waddr := dec.sram_offset >> log2Ceil(numUop)
} .elsewhen (io.vme_rd.data.fire()) { }.elsewhen(io.vme_rd.data.fire()) {
waddr := waddr + 1.U waddr := waddr + 1.U
} }
...@@ -160,36 +162,37 @@ class LoadUop(debug: Boolean = false)(implicit p: Parameters) extends Module { ...@@ -160,36 +162,37 @@ class LoadUop(debug: Boolean = false)(implicit p: Parameters) extends Module {
val mem = SyncReadMem(uopDepth, chiselTypeOf(wdata)) val mem = SyncReadMem(uopDepth, chiselTypeOf(wdata))
val wmask = Reg(Vec(numUop, Bool())) val wmask = Reg(Vec(numUop, Bool()))
when (offsetIsEven) { when(offsetIsEven) {
when (sizeIsEven) { when(sizeIsEven) {
wmask := "b_11".U.asTypeOf(wmask) wmask := "b_11".U.asTypeOf(wmask)
} .elsewhen (io.vme_rd.cmd.fire()) { }.elsewhen(io.vme_rd.cmd.fire()) {
when (dec.xsize === 1.U) { when(dec.xsize === 1.U) {
wmask := "b_01".U.asTypeOf(wmask) wmask := "b_01".U.asTypeOf(wmask)
} .otherwise { }.otherwise {
wmask := "b_11".U.asTypeOf(wmask) wmask := "b_11".U.asTypeOf(wmask)
} }
} .elsewhen (io.vme_rd.data.fire()) { }
when (xcnt === xlen - 1.U) { .elsewhen(io.vme_rd.data.fire()) {
when(xcnt === xlen - 1.U) {
wmask := "b_01".U.asTypeOf(wmask) wmask := "b_01".U.asTypeOf(wmask)
} .otherwise { }.otherwise {
wmask := "b_11".U.asTypeOf(wmask) wmask := "b_11".U.asTypeOf(wmask)
} }
} }
} .otherwise { }.otherwise {
when (io.vme_rd.cmd.fire()) { when(io.vme_rd.cmd.fire()) {
wmask := "b_10".U.asTypeOf(wmask) wmask := "b_10".U.asTypeOf(wmask)
} .elsewhen (io.vme_rd.data.fire()) { }.elsewhen(io.vme_rd.data.fire()) {
when (sizeIsEven && xcnt === xlen - 1.U) { when(sizeIsEven && xcnt === xlen - 1.U) {
wmask := "b_01".U.asTypeOf(wmask) wmask := "b_01".U.asTypeOf(wmask)
} .otherwise { }.otherwise {
wmask := "b_11".U.asTypeOf(wmask) wmask := "b_11".U.asTypeOf(wmask)
} }
} }
} }
wdata := io.vme_rd.data.bits.asTypeOf(wdata) wdata := io.vme_rd.data.bits.asTypeOf(wdata)
when (io.vme_rd.data.fire()) { when(io.vme_rd.data.fire()) {
mem.write(waddr, wdata, wmask) mem.write(waddr, wdata, wmask)
} }
...@@ -209,7 +212,7 @@ class LoadUop(debug: Boolean = false)(implicit p: Parameters) extends Module { ...@@ -209,7 +212,7 @@ class LoadUop(debug: Boolean = false)(implicit p: Parameters) extends Module {
// debug // debug
if (debug) { if (debug) {
when (io.vme_rd.cmd.fire()) { when(io.vme_rd.cmd.fire()) {
printf("[LoadUop] cmd addr:%x len:%x rem:%x\n", raddr, xlen, xrem) printf("[LoadUop] cmd addr:%x len:%x rem:%x\n", raddr, xlen, xrem)
} }
} }
......
...@@ -29,14 +29,17 @@ import chisel3.util._ ...@@ -29,14 +29,17 @@ import chisel3.util._
* depending on the push and pop fields on instructions to prevent RAW and WAR * depending on the push and pop fields on instructions to prevent RAW and WAR
* hazards. * hazards.
*/ */
class Semaphore(counterBits: Int = 1, counterInitValue: Int = 1) extends Module { class Semaphore(counterBits: Int = 1, counterInitValue: Int = 1)
extends Module {
val io = IO(new Bundle { val io = IO(new Bundle {
val spost = Input(Bool()) val spost = Input(Bool())
val swait = Input(Bool()) val swait = Input(Bool())
val sready = Output(Bool()) val sready = Output(Bool())
}) })
val cnt = RegInit(counterInitValue.U(counterBits.W)) val cnt = RegInit(counterInitValue.U(counterBits.W))
when (io.spost && !io.swait && cnt =/= ((1 << counterBits) - 1).asUInt) { cnt := cnt + 1.U } when(io.spost && !io.swait && cnt =/= ((1 << counterBits) - 1).asUInt) {
when (!io.spost && io.swait && cnt =/= 0.U) { cnt := cnt - 1.U } cnt := cnt + 1.U
}
when(!io.spost && io.swait && cnt =/= 0.U) { cnt := cnt - 1.U }
io.sready := cnt =/= 0.U io.sready := cnt =/= 0.U
} }
...@@ -55,21 +55,21 @@ class Store(debug: Boolean = false)(implicit p: Parameters) extends Module { ...@@ -55,21 +55,21 @@ class Store(debug: Boolean = false)(implicit p: Parameters) extends Module {
val done = tensorStore.io.done val done = tensorStore.io.done
// control // control
switch (state) { switch(state) {
is (sIdle) { is(sIdle) {
when (start) { when(start) {
when (dec.io.isSync) { when(dec.io.isSync) {
state := sSync state := sSync
} .elsewhen (dec.io.isStore) { }.elsewhen(dec.io.isStore) {
state := sExe state := sExe
} }
} }
} }
is (sSync) { is(sSync) {
state := sIdle state := sIdle
} }
is (sExe) { is(sExe) {
when (done) { when(done) {
state := sIdle state := sIdle
} }
} }
...@@ -94,19 +94,19 @@ class Store(debug: Boolean = false)(implicit p: Parameters) extends Module { ...@@ -94,19 +94,19 @@ class Store(debug: Boolean = false)(implicit p: Parameters) extends Module {
// debug // debug
if (debug) { if (debug) {
// start // start
when (state === sIdle && start) { when(state === sIdle && start) {
when (dec.io.isSync) { when(dec.io.isSync) {
printf("[Store] start sync\n") printf("[Store] start sync\n")
} .elsewhen (dec.io.isStore) { }.elsewhen(dec.io.isStore) {
printf("[Store] start\n") printf("[Store] start\n")
} }
} }
// done // done
when (state === sSync) { when(state === sSync) {
printf("[Store] done sync\n") printf("[Store] done sync\n")
} }
when (state === sExe) { when(state === sExe) {
when (done) { when(done) {
printf("[Store] done\n") printf("[Store] done\n")
} }
} }
......
...@@ -116,7 +116,8 @@ class TensorAlu(debug: Boolean = false)(implicit p: Parameters) extends Module { ...@@ -116,7 +116,8 @@ class TensorAlu(debug: Boolean = false)(implicit p: Parameters) extends Module {
val acc = new TensorMaster(tensorType = "acc") val acc = new TensorMaster(tensorType = "acc")
val out = new TensorMaster(tensorType = "out") val out = new TensorMaster(tensorType = "out")
}) })
val sIdle :: sReadUop :: sComputeIdx :: sReadTensorA :: sReadTensorB :: sExe :: Nil = Enum(6) val sIdle :: sReadUop :: sComputeIdx :: sReadTensorA :: sReadTensorB :: sExe :: Nil =
Enum(6)
val state = RegInit(sIdle) val state = RegInit(sIdle)
val alu = Module(new AluVector) val alu = Module(new AluVector)
val dec = io.inst.asTypeOf(new AluDecode) val dec = io.inst.asTypeOf(new AluDecode)
...@@ -137,51 +138,54 @@ class TensorAlu(debug: Boolean = false)(implicit p: Parameters) extends Module { ...@@ -137,51 +138,54 @@ class TensorAlu(debug: Boolean = false)(implicit p: Parameters) extends Module {
(cnt_i === dec.lp_1 - 1.U) & (cnt_i === dec.lp_1 - 1.U) &
(uop_idx === uop_end - 1.U) (uop_idx === uop_end - 1.U)
switch (state) { switch(state) {
is (sIdle) { is(sIdle) {
when (io.start) { when(io.start) {
state := sReadUop state := sReadUop
} }
} }
is (sReadUop) { is(sReadUop) {
state := sComputeIdx state := sComputeIdx
} }
is (sComputeIdx) { is(sComputeIdx) {
state := sReadTensorA state := sReadTensorA
} }
is (sReadTensorA) { is(sReadTensorA) {
state := sReadTensorB state := sReadTensorB
} }
is (sReadTensorB) { is(sReadTensorB) {
state := sExe state := sExe
} }
is (sExe) { is(sExe) {
when (alu.io.out.data.valid) { when(alu.io.out.data.valid) {
when ((cnt_o === dec.lp_0 - 1.U) && when(
(cnt_o === dec.lp_0 - 1.U) &&
(cnt_i === dec.lp_1 - 1.U) && (cnt_i === dec.lp_1 - 1.U) &&
(uop_idx === uop_end - 1.U)) { (uop_idx === uop_end - 1.U)) {
state := sIdle state := sIdle
} .otherwise { }.otherwise {
state := sReadUop state := sReadUop
} }
} }
} }
} }
when (state === sIdle || when(
state === sIdle ||
(state === sExe && (state === sExe &&
alu.io.out.data.valid && alu.io.out.data.valid &&
uop_idx === uop_end - 1.U)) { uop_idx === uop_end - 1.U)) {
uop_idx := dec.uop_begin uop_idx := dec.uop_begin
} .elsewhen (state === sExe && alu.io.out.data.valid) { }.elsewhen(state === sExe && alu.io.out.data.valid) {
uop_idx := uop_idx + 1.U uop_idx := uop_idx + 1.U
} }
when (state === sIdle) { when(state === sIdle) {
cnt_o := 0.U cnt_o := 0.U
dst_o := 0.U dst_o := 0.U
src_o := 0.U src_o := 0.U
} .elsewhen (state === sExe && }.elsewhen(
state === sExe &&
alu.io.out.data.valid && alu.io.out.data.valid &&
uop_idx === uop_end - 1.U && uop_idx === uop_end - 1.U &&
cnt_i === dec.lp_1 - 1.U) { cnt_i === dec.lp_1 - 1.U) {
...@@ -190,15 +194,17 @@ class TensorAlu(debug: Boolean = false)(implicit p: Parameters) extends Module { ...@@ -190,15 +194,17 @@ class TensorAlu(debug: Boolean = false)(implicit p: Parameters) extends Module {
src_o := src_o + dec.src_0 src_o := src_o + dec.src_0
} }
when (state === sIdle) { when(state === sIdle) {
cnt_i := 0.U cnt_i := 0.U
dst_i := 0.U dst_i := 0.U
src_i := 0.U src_i := 0.U
} .elsewhen (state === sReadUop && cnt_i === dec.lp_1) { }.elsewhen(state === sReadUop && cnt_i === dec.lp_1) {
cnt_i := 0.U cnt_i := 0.U
dst_i := dst_o dst_i := dst_o
src_i := src_o src_i := src_o
} .elsewhen (state === sExe && }
.elsewhen(
state === sExe &&
alu.io.out.data.valid && alu.io.out.data.valid &&
uop_idx === uop_end - 1.U) { uop_idx === uop_end - 1.U) {
cnt_i := cnt_i + 1.U cnt_i := cnt_i + 1.U
...@@ -206,7 +212,7 @@ class TensorAlu(debug: Boolean = false)(implicit p: Parameters) extends Module { ...@@ -206,7 +212,7 @@ class TensorAlu(debug: Boolean = false)(implicit p: Parameters) extends Module {
src_i := src_i + dec.src_1 src_i := src_i + dec.src_1
} }
when (state === sComputeIdx && io.uop.data.valid) { when(state === sComputeIdx && io.uop.data.valid) {
uop_dst := io.uop.data.bits.u0 + dst_i uop_dst := io.uop.data.bits.u0 + dst_i
uop_src := io.uop.data.bits.u1 + src_i uop_src := io.uop.data.bits.u1 + src_i
} }
...@@ -222,17 +228,25 @@ class TensorAlu(debug: Boolean = false)(implicit p: Parameters) extends Module { ...@@ -222,17 +228,25 @@ class TensorAlu(debug: Boolean = false)(implicit p: Parameters) extends Module {
// imm // imm
val tensorImm = Wire(new TensorClientData(tensorType = "acc")) val tensorImm = Wire(new TensorClientData(tensorType = "acc"))
tensorImm.data.valid := state === sReadTensorB tensorImm.data.valid := state === sReadTensorB
tensorImm.data.bits.foreach { b => b.foreach { c => c := dec.alu_imm } } tensorImm.data.bits.foreach { b =>
b.foreach { c =>
c := dec.alu_imm
}
}
// alu // alu
val isSHR = dec.alu_op === ALU_OP(3) val isSHR = dec.alu_op === ALU_OP(3)
val neg_shift = isSHR & dec.alu_imm(C_ALU_IMM_BITS-1) val neg_shift = isSHR & dec.alu_imm(C_ALU_IMM_BITS - 1)
val fixme_alu_op = Cat(neg_shift, Mux(neg_shift, 0.U, dec.alu_op)) val fixme_alu_op = Cat(neg_shift, Mux(neg_shift, 0.U, dec.alu_op))
alu.io.opcode := fixme_alu_op alu.io.opcode := fixme_alu_op
alu.io.acc_a.data.valid := io.acc.rd.data.valid & state === sReadTensorB alu.io.acc_a.data.valid := io.acc.rd.data.valid & state === sReadTensorB
alu.io.acc_a.data.bits <> io.acc.rd.data.bits alu.io.acc_a.data.bits <> io.acc.rd.data.bits
alu.io.acc_b.data.valid := Mux(dec.alu_use_imm, tensorImm.data.valid, io.acc.rd.data.valid & state === sExe) alu.io.acc_b.data.valid := Mux(dec.alu_use_imm,
alu.io.acc_b.data.bits <> Mux(dec.alu_use_imm, tensorImm.data.bits, io.acc.rd.data.bits) tensorImm.data.valid,
io.acc.rd.data.valid & state === sExe)
alu.io.acc_b.data.bits <> Mux(dec.alu_use_imm,
tensorImm.data.bits,
io.acc.rd.data.bits)
// acc_o // acc_o
io.acc.wr.valid := alu.io.acc_y.data.valid io.acc.wr.valid := alu.io.acc_y.data.valid
...@@ -249,45 +263,49 @@ class TensorAlu(debug: Boolean = false)(implicit p: Parameters) extends Module { ...@@ -249,45 +263,49 @@ class TensorAlu(debug: Boolean = false)(implicit p: Parameters) extends Module {
if (debug) { if (debug) {
when (state === sReadUop) { when(state === sReadUop) {
printf("[TensorAlu] [uop] idx:%x\n", uop_idx) printf("[TensorAlu] [uop] idx:%x\n", uop_idx)
} }
when (state === sReadTensorA) { when(state === sReadTensorA) {
printf("[TensorAlu] [uop] dst:%x src:%x\n", uop_dst, uop_src) printf("[TensorAlu] [uop] dst:%x src:%x\n", uop_dst, uop_src)
} }
when (state === sIdle && io.start) { when(state === sIdle && io.start) {
printf(p"[TensorAlu] decode:$dec\n") printf(p"[TensorAlu] decode:$dec\n")
} }
alu.io.acc_a.data.bits.foreach { tensor => alu.io.acc_a.data.bits.foreach { tensor =>
tensor.zipWithIndex.foreach { case(elem, i) => tensor.zipWithIndex.foreach {
when (alu.io.acc_a.data.valid) { case (elem, i) =>
when(alu.io.acc_a.data.valid) {
printf("[TensorAlu] [a] i:%x val:%x\n", i.U, elem) printf("[TensorAlu] [a] i:%x val:%x\n", i.U, elem)
} }
} }
} }
alu.io.acc_b.data.bits.foreach { tensor => alu.io.acc_b.data.bits.foreach { tensor =>
tensor.zipWithIndex.foreach { case(elem, i) => tensor.zipWithIndex.foreach {
when (alu.io.acc_b.data.valid) { case (elem, i) =>
when(alu.io.acc_b.data.valid) {
printf("[TensorAlu] [b] i:%x val:%x\n", i.U, elem) printf("[TensorAlu] [b] i:%x val:%x\n", i.U, elem)
} }
} }
} }
alu.io.acc_y.data.bits.foreach { tensor => alu.io.acc_y.data.bits.foreach { tensor =>
tensor.zipWithIndex.foreach { case(elem, i) => tensor.zipWithIndex.foreach {
when (alu.io.acc_y.data.valid) { case (elem, i) =>
when(alu.io.acc_y.data.valid) {
printf("[TensorAlu] [y] i:%x val:%x\n", i.U, elem) printf("[TensorAlu] [y] i:%x val:%x\n", i.U, elem)
} }
} }
} }
alu.io.out.data.bits.foreach { tensor => alu.io.out.data.bits.foreach { tensor =>
tensor.zipWithIndex.foreach { case(elem, i) => tensor.zipWithIndex.foreach {
when (alu.io.out.data.valid) { case (elem, i) =>
when(alu.io.out.data.valid) {
printf("[TensorAlu] [out] i:%x val:%x\n", i.U, elem) printf("[TensorAlu] [out] i:%x val:%x\n", i.U, elem)
} }
} }
......
...@@ -62,8 +62,10 @@ class PipeAdder(aBits: Int = 8, bBits: Int = 8) extends Module { ...@@ -62,8 +62,10 @@ class PipeAdder(aBits: Int = 8, bBits: Int = 8) extends Module {
} }
/** Pipelined DotProduct based on MAC and PipeAdder */ /** Pipelined DotProduct based on MAC and PipeAdder */
class DotProduct(aBits: Int = 8, bBits: Int = 8, size: Int = 16) extends Module { class DotProduct(aBits: Int = 8, bBits: Int = 8, size: Int = 16)
val errorMsg = s"\n\n[VTA] [DotProduct] size must be greater than 4 and a power of 2\n\n" extends Module {
val errorMsg =
s"\n\n[VTA] [DotProduct] size must be greater than 4 and a power of 2\n\n"
require(size >= 2 && isPow2(size), errorMsg) require(size >= 2 && isPow2(size), errorMsg)
val b = aBits + bBits val b = aBits + bBits
val outBits = b + log2Ceil(size) + 1 val outBits = b + log2Ceil(size) + 1
...@@ -72,12 +74,15 @@ class DotProduct(aBits: Int = 8, bBits: Int = 8, size: Int = 16) extends Module ...@@ -72,12 +74,15 @@ class DotProduct(aBits: Int = 8, bBits: Int = 8, size: Int = 16) extends Module
val b = Input(Vec(size, SInt(bBits.W))) val b = Input(Vec(size, SInt(bBits.W)))
val y = Output(SInt(outBits.W)) val y = Output(SInt(outBits.W))
}) })
val s = Seq.tabulate(log2Ceil(size + 1))(i => pow(2, log2Ceil(size) - i).toInt) // # of total layers val s = Seq.tabulate(log2Ceil(size + 1))(i =>
pow(2, log2Ceil(size) - i).toInt) // # of total layers
val p = log2Ceil(size / 2) + 1 // # of adder layers val p = log2Ceil(size / 2) + 1 // # of adder layers
val m = Seq.fill(s(0))(Module(new MAC(aBits, bBits, cBits = 1))) // # of total vector pairs val m = Seq.fill(s(0))(Module(new MAC(aBits, bBits, cBits = 1))) // # of total vector pairs
val a = Seq.tabulate(p)(i => val a = Seq.tabulate(p)(
Seq.fill(s(i + 1))(Module(new PipeAdder(aBits = (b + i + 1), bBits = (b + i + 1)))) i =>
) // # adders within each layer Seq.fill(s(i + 1))(Module(new PipeAdder(
aBits = (b + i + 1),
bBits = (b + i + 1))))) // # adders within each layer
// Vector MACs // Vector MACs
for (i <- 0 until s(0)) { for (i <- 0 until s(0)) {
...@@ -111,7 +116,7 @@ class MatrixVectorMultiplication(implicit p: Parameters) extends Module { ...@@ -111,7 +116,7 @@ class MatrixVectorMultiplication(implicit p: Parameters) extends Module {
val inpBits = p(CoreKey).inpBits val inpBits = p(CoreKey).inpBits
val wgtBits = p(CoreKey).wgtBits val wgtBits = p(CoreKey).wgtBits
val outBits = p(CoreKey).outBits val outBits = p(CoreKey).outBits
val io = IO(new Bundle{ val io = IO(new Bundle {
val reset = Input(Bool()) // FIXME: reset should be replaced by a load-acc instr val reset = Input(Bool()) // FIXME: reset should be replaced by a load-acc instr
val inp = new TensorMasterData(tensorType = "inp") val inp = new TensorMasterData(tensorType = "inp")
val wgt = new TensorMasterData(tensorType = "wgt") val wgt = new TensorMasterData(tensorType = "wgt")
...@@ -119,8 +124,10 @@ class MatrixVectorMultiplication(implicit p: Parameters) extends Module { ...@@ -119,8 +124,10 @@ class MatrixVectorMultiplication(implicit p: Parameters) extends Module {
val acc_o = new TensorClientData(tensorType = "acc") val acc_o = new TensorClientData(tensorType = "acc")
val out = new TensorClientData(tensorType = "out") val out = new TensorClientData(tensorType = "out")
}) })
val dot = Seq.fill(size)(Module(new DotProduct(aBits = inpBits, bBits = wgtBits, size))) val dot = Seq.fill(size)(
val acc = Seq.fill(size)(Module(new Pipe(UInt(accBits.W), latency = log2Ceil(size) + 1))) Module(new DotProduct(aBits = inpBits, bBits = wgtBits, size)))
val acc = Seq.fill(size)(
Module(new Pipe(UInt(accBits.W), latency = log2Ceil(size) + 1)))
val add = Seq.fill(size)(Wire(SInt(accBits.W))) val add = Seq.fill(size)(Wire(SInt(accBits.W)))
val vld = Wire(Vec(size, Bool())) val vld = Wire(Vec(size, Bool()))
...@@ -149,7 +156,8 @@ class MatrixVectorMultiplication(implicit p: Parameters) extends Module { ...@@ -149,7 +156,8 @@ class MatrixVectorMultiplication(implicit p: Parameters) extends Module {
* Also, the TensorGemm uses the reset field in the Gemm instruction to * Also, the TensorGemm uses the reset field in the Gemm instruction to
* clear or zero-out the acc-scratchpad locations based on the micro-ops. * clear or zero-out the acc-scratchpad locations based on the micro-ops.
*/ */
class TensorGemm(debug: Boolean = false)(implicit p: Parameters) extends Module { class TensorGemm(debug: Boolean = false)(implicit p: Parameters)
extends Module {
val io = IO(new Bundle { val io = IO(new Bundle {
val start = Input(Bool()) val start = Input(Bool())
val done = Output(Bool()) val done = Output(Bool())
...@@ -160,7 +168,8 @@ class TensorGemm(debug: Boolean = false)(implicit p: Parameters) extends Module ...@@ -160,7 +168,8 @@ class TensorGemm(debug: Boolean = false)(implicit p: Parameters) extends Module
val acc = new TensorMaster(tensorType = "acc") val acc = new TensorMaster(tensorType = "acc")
val out = new TensorMaster(tensorType = "out") val out = new TensorMaster(tensorType = "out")
}) })
val sIdle :: sReadUop :: sComputeIdx :: sReadTensor :: sExe :: sWait :: Nil = Enum(6) val sIdle :: sReadUop :: sComputeIdx :: sReadTensor :: sExe :: sWait :: Nil =
Enum(6)
val state = RegInit(sIdle) val state = RegInit(sIdle)
val mvc = Module(new MatrixVectorMultiplication) val mvc = Module(new MatrixVectorMultiplication)
val dec = io.inst.asTypeOf(new GemmDecode) val dec = io.inst.asTypeOf(new GemmDecode)
...@@ -188,65 +197,68 @@ class TensorGemm(debug: Boolean = false)(implicit p: Parameters) extends Module ...@@ -188,65 +197,68 @@ class TensorGemm(debug: Boolean = false)(implicit p: Parameters) extends Module
inflight === 0.U) | inflight === 0.U) |
state === sWait) state === sWait)
switch (state) { switch(state) {
is (sIdle) { is(sIdle) {
when (io.start) { when(io.start) {
state := sReadUop state := sReadUop
} }
} }
is (sReadUop) { is(sReadUop) {
state := sComputeIdx state := sComputeIdx
} }
is (sComputeIdx) { is(sComputeIdx) {
state := sReadTensor state := sReadTensor
} }
is (sReadTensor) { is(sReadTensor) {
state := sExe state := sExe
} }
is (sExe) { is(sExe) {
when ((cnt_o === dec.lp_0 - 1.U) && when(
(cnt_o === dec.lp_0 - 1.U) &&
(cnt_i === dec.lp_1 - 1.U) && (cnt_i === dec.lp_1 - 1.U) &&
(uop_idx === uop_end - 1.U)) { (uop_idx === uop_end - 1.U)) {
when (inflight =/= 0.U) { when(inflight =/= 0.U) {
state := sWait state := sWait
} .otherwise { }.otherwise {
state := sIdle state := sIdle
} }
} .otherwise { }.otherwise {
state := sReadUop state := sReadUop
} }
} }
is (sWait) { is(sWait) {
when (inflight === 0.U) { when(inflight === 0.U) {
state := sIdle state := sIdle
} }
} }
} }
when (state === sIdle) { when(state === sIdle) {
inflight := 0.U inflight := 0.U
} .elsewhen (!dec.reset) { }.elsewhen(!dec.reset) {
when (state === sReadTensor) { // issue a tensor when(state === sReadTensor) { // issue a tensor
inflight := inflight + 1.U inflight := inflight + 1.U
} .elsewhen (mvc.io.acc_o.data.valid) { // commit a tensor }.elsewhen(mvc.io.acc_o.data.valid) { // commit a tensor
inflight := inflight - 1.U inflight := inflight - 1.U
} }
} }
when (state === sIdle || when(
state === sIdle ||
(state === sExe && (state === sExe &&
uop_idx === uop_end - 1.U)) { uop_idx === uop_end - 1.U)) {
uop_idx := dec.uop_begin uop_idx := dec.uop_begin
} .elsewhen (state === sExe) { }.elsewhen(state === sExe) {
uop_idx := uop_idx + 1.U uop_idx := uop_idx + 1.U
} }
when (state === sIdle) { when(state === sIdle) {
cnt_o := 0.U cnt_o := 0.U
acc_o := 0.U acc_o := 0.U
inp_o := 0.U inp_o := 0.U
wgt_o := 0.U wgt_o := 0.U
} .elsewhen (state === sExe && }.elsewhen(
state === sExe &&
uop_idx === uop_end - 1.U && uop_idx === uop_end - 1.U &&
cnt_i === dec.lp_1 - 1.U) { cnt_i === dec.lp_1 - 1.U) {
cnt_o := cnt_o + 1.U cnt_o := cnt_o + 1.U
...@@ -255,17 +267,18 @@ class TensorGemm(debug: Boolean = false)(implicit p: Parameters) extends Module ...@@ -255,17 +267,18 @@ class TensorGemm(debug: Boolean = false)(implicit p: Parameters) extends Module
wgt_o := wgt_o + dec.wgt_0 wgt_o := wgt_o + dec.wgt_0
} }
when (state === sIdle) { when(state === sIdle) {
cnt_i := 0.U cnt_i := 0.U
acc_i := 0.U acc_i := 0.U
inp_i := 0.U inp_i := 0.U
wgt_i := 0.U wgt_i := 0.U
} .elsewhen (state === sReadUop && cnt_i === dec.lp_1) { }.elsewhen(state === sReadUop && cnt_i === dec.lp_1) {
cnt_i := 0.U cnt_i := 0.U
acc_i := acc_o acc_i := acc_o
inp_i := inp_o inp_i := inp_o
wgt_i := wgt_o wgt_i := wgt_o
} .elsewhen (state === sExe && }
.elsewhen(state === sExe &&
uop_idx === uop_end - 1.U) { uop_idx === uop_end - 1.U) {
cnt_i := cnt_i + 1.U cnt_i := cnt_i + 1.U
acc_i := acc_i + dec.acc_1 acc_i := acc_i + dec.acc_1
...@@ -273,7 +286,7 @@ class TensorGemm(debug: Boolean = false)(implicit p: Parameters) extends Module ...@@ -273,7 +286,7 @@ class TensorGemm(debug: Boolean = false)(implicit p: Parameters) extends Module
wgt_i := wgt_i + dec.wgt_1 wgt_i := wgt_i + dec.wgt_1
} }
when (state === sComputeIdx && io.uop.data.valid) { when(state === sComputeIdx && io.uop.data.valid) {
uop_acc := io.uop.data.bits.u0 + acc_i uop_acc := io.uop.data.bits.u0 + acc_i
uop_inp := io.uop.data.bits.u1 + inp_i uop_inp := io.uop.data.bits.u1 + inp_i
uop_wgt := io.uop.data.bits.u2 + wgt_i uop_wgt := io.uop.data.bits.u2 + wgt_i
...@@ -307,7 +320,9 @@ class TensorGemm(debug: Boolean = false)(implicit p: Parameters) extends Module ...@@ -307,7 +320,9 @@ class TensorGemm(debug: Boolean = false)(implicit p: Parameters) extends Module
mvc.io.acc_i.data <> io.acc.rd.data mvc.io.acc_i.data <> io.acc.rd.data
// acc_o // acc_o
io.acc.wr.valid := mvc.io.acc_o.data.valid & Mux(dec.reset, true.B, wrpipe.io.deq.valid) io.acc.wr.valid := mvc.io.acc_o.data.valid & Mux(dec.reset,
true.B,
wrpipe.io.deq.valid)
io.acc.wr.bits.idx := Mux(dec.reset, uop_acc, wrpipe.io.deq.bits) io.acc.wr.bits.idx := Mux(dec.reset, uop_acc, wrpipe.io.deq.bits)
io.acc.wr.bits.data <> mvc.io.acc_o.data.bits io.acc.wr.bits.data <> mvc.io.acc_o.data.bits
...@@ -320,45 +335,53 @@ class TensorGemm(debug: Boolean = false)(implicit p: Parameters) extends Module ...@@ -320,45 +335,53 @@ class TensorGemm(debug: Boolean = false)(implicit p: Parameters) extends Module
io.done := done io.done := done
if (debug) { if (debug) {
when (state === sReadUop && ~dec.reset) { when(state === sReadUop && ~dec.reset) {
printf("[TensorGemm] [uop] idx:%x\n", uop_idx) printf("[TensorGemm] [uop] idx:%x\n", uop_idx)
} }
when (state === sReadTensor && ~dec.reset) { when(state === sReadTensor && ~dec.reset) {
printf("[TensorGemm] [uop] acc:%x inp:%x wgt:%x\n", uop_acc, uop_inp, uop_wgt) printf("[TensorGemm] [uop] acc:%x inp:%x wgt:%x\n",
uop_acc,
uop_inp,
uop_wgt)
} }
io.inp.rd.data.bits.zipWithIndex.foreach { case(r, i) => io.inp.rd.data.bits.zipWithIndex.foreach {
when (io.inp.rd.data.valid && ~dec.reset) { case (r, i) =>
when(io.inp.rd.data.valid && ~dec.reset) {
printf("[TensorGemm] [inp] i:%x val:%x\n", i.U, r.asUInt) printf("[TensorGemm] [inp] i:%x val:%x\n", i.U, r.asUInt)
} }
} }
io.wgt.rd.data.bits.zipWithIndex.foreach { case(r, i) => io.wgt.rd.data.bits.zipWithIndex.foreach {
when (io.wgt.rd.data.valid && ~dec.reset) { case (r, i) =>
when(io.wgt.rd.data.valid && ~dec.reset) {
printf("[TensorGemm] [wgt] i:%x val:%x\n", i.U, r.asUInt) printf("[TensorGemm] [wgt] i:%x val:%x\n", i.U, r.asUInt)
} }
} }
io.acc.rd.data.bits.foreach { tensor => io.acc.rd.data.bits.foreach { tensor =>
tensor.zipWithIndex.foreach { case(elem, i) => tensor.zipWithIndex.foreach {
when (io.acc.rd.data.valid && ~dec.reset) { case (elem, i) =>
when(io.acc.rd.data.valid && ~dec.reset) {
printf("[TensorGemm] [acc_i] i:%x val:%x\n", i.U, elem) printf("[TensorGemm] [acc_i] i:%x val:%x\n", i.U, elem)
} }
} }
} }
mvc.io.acc_o.data.bits.foreach { tensor => mvc.io.acc_o.data.bits.foreach { tensor =>
tensor.zipWithIndex.foreach { case(elem, i) => tensor.zipWithIndex.foreach {
when (mvc.io.acc_o.data.valid && ~dec.reset) { case (elem, i) =>
when(mvc.io.acc_o.data.valid && ~dec.reset) {
printf("[TensorGemm] [acc_o] i:%x val:%x\n", i.U, elem) printf("[TensorGemm] [acc_o] i:%x val:%x\n", i.U, elem)
} }
} }
} }
mvc.io.out.data.bits.foreach { tensor => mvc.io.out.data.bits.foreach { tensor =>
tensor.zipWithIndex.foreach { case(elem, i) => tensor.zipWithIndex.foreach {
when (mvc.io.out.data.valid && ~dec.reset) { case (elem, i) =>
when(mvc.io.out.data.valid && ~dec.reset) {
printf("[TensorGemm] [out] i:%x val:%x\n", i.U, elem) printf("[TensorGemm] [out] i:%x val:%x\n", i.U, elem)
} }
} }
......
...@@ -28,8 +28,9 @@ import vta.shell._ ...@@ -28,8 +28,9 @@ import vta.shell._
* *
* Store 1D and 2D tensors from out-scratchpad (SRAM) to main memory (DRAM). * Store 1D and 2D tensors from out-scratchpad (SRAM) to main memory (DRAM).
*/ */
class TensorStore(tensorType: String = "none", debug: Boolean = false) class TensorStore(tensorType: String = "none", debug: Boolean = false)(
(implicit p: Parameters) extends Module { implicit p: Parameters)
extends Module {
val tp = new TensorParams(tensorType) val tp = new TensorParams(tensorType)
val mp = p(ShellKey).memParams val mp = p(ShellKey).memParams
val io = IO(new Bundle { val io = IO(new Bundle {
...@@ -53,9 +54,9 @@ class TensorStore(tensorType: String = "none", debug: Boolean = false) ...@@ -53,9 +54,9 @@ class TensorStore(tensorType: String = "none", debug: Boolean = false)
val xcnt = Reg(chiselTypeOf(io.vme_wr.cmd.bits.len)) val xcnt = Reg(chiselTypeOf(io.vme_wr.cmd.bits.len))
val xlen = Reg(chiselTypeOf(io.vme_wr.cmd.bits.len)) val xlen = Reg(chiselTypeOf(io.vme_wr.cmd.bits.len))
val xrem = Reg(chiselTypeOf(dec.xsize)) val xrem = Reg(chiselTypeOf(dec.xsize))
val xsize = (dec.xsize << log2Ceil(tensorLength*numMemBlock)) - 1.U val xsize = (dec.xsize << log2Ceil(tensorLength * numMemBlock)) - 1.U
val xmax = (1 << mp.lenBits).U val xmax = (1 << mp.lenBits).U
val xmax_bytes = ((1 << mp.lenBits)*mp.dataBits/8).U val xmax_bytes = ((1 << mp.lenBits) * mp.dataBits / 8).U
val ycnt = Reg(chiselTypeOf(dec.ysize)) val ycnt = Reg(chiselTypeOf(dec.ysize))
val ysize = dec.ysize val ysize = dec.ysize
val tag = Reg(UInt(8.W)) val tag = Reg(UInt(8.W))
...@@ -65,56 +66,57 @@ class TensorStore(tensorType: String = "none", debug: Boolean = false) ...@@ -65,56 +66,57 @@ class TensorStore(tensorType: String = "none", debug: Boolean = false)
val state = RegInit(sIdle) val state = RegInit(sIdle)
// control // control
switch (state) { switch(state) {
is (sIdle) { is(sIdle) {
when (io.start) { when(io.start) {
state := sWriteCmd state := sWriteCmd
when (xsize < xmax) { when(xsize < xmax) {
xlen := xsize xlen := xsize
xrem := 0.U xrem := 0.U
} .otherwise { }.otherwise {
xlen := xmax - 1.U xlen := xmax - 1.U
xrem := xsize - xmax xrem := xsize - xmax
} }
} }
} }
is (sWriteCmd) { is(sWriteCmd) {
when (io.vme_wr.cmd.ready) { when(io.vme_wr.cmd.ready) {
state := sWriteData state := sWriteData
} }
} }
is (sWriteData) { is(sWriteData) {
when (io.vme_wr.data.ready) { when(io.vme_wr.data.ready) {
when (xcnt === xlen) { when(xcnt === xlen) {
state := sWriteAck state := sWriteAck
} .elsewhen (tag === (numMemBlock - 1).U) { }.elsewhen(tag === (numMemBlock - 1).U) {
state := sReadMem state := sReadMem
} }
} }
} }
is (sReadMem) { is(sReadMem) {
state := sWriteData state := sWriteData
} }
is (sWriteAck) { is(sWriteAck) {
when (io.vme_wr.ack) { when(io.vme_wr.ack) {
when (xrem === 0.U) { when(xrem === 0.U) {
when (ycnt === ysize - 1.U) { when(ycnt === ysize - 1.U) {
state := sIdle state := sIdle
} .otherwise { }.otherwise {
state := sWriteCmd state := sWriteCmd
when (xsize < xmax) { when(xsize < xmax) {
xlen := xsize xlen := xsize
xrem := 0.U xrem := 0.U
} .otherwise { }.otherwise {
xlen := xmax - 1.U xlen := xmax - 1.U
xrem := xsize - xmax xrem := xsize - xmax
} }
} }
} .elsewhen (xrem < xmax) { }.elsewhen(xrem < xmax) {
state := sWriteCmd state := sWriteCmd
xlen := xrem xlen := xrem
xrem := 0.U xrem := 0.U
} .otherwise { }
.otherwise {
state := sWriteCmd state := sWriteCmd
xlen := xmax - 1.U xlen := xmax - 1.U
xrem := xrem - xmax xrem := xrem - xmax
...@@ -124,16 +126,20 @@ class TensorStore(tensorType: String = "none", debug: Boolean = false) ...@@ -124,16 +126,20 @@ class TensorStore(tensorType: String = "none", debug: Boolean = false)
} }
// write-to-sram // write-to-sram
val tensorFile = Seq.fill(tensorLength) { SyncReadMem(memDepth, Vec(numMemBlock, UInt(memBlockBits.W))) } val tensorFile = Seq.fill(tensorLength) {
SyncReadMem(memDepth, Vec(numMemBlock, UInt(memBlockBits.W)))
}
val wdata_t = Wire(Vec(numMemBlock, UInt(memBlockBits.W))) val wdata_t = Wire(Vec(numMemBlock, UInt(memBlockBits.W)))
val no_mask = Wire(Vec(numMemBlock, Bool())) val no_mask = Wire(Vec(numMemBlock, Bool()))
wdata_t := DontCare wdata_t := DontCare
no_mask.foreach { m => m := true.B } no_mask.foreach { m =>
m := true.B
}
for (i <- 0 until tensorLength) { for (i <- 0 until tensorLength) {
val inWrData = io.tensor.wr.bits.data(i).asUInt.asTypeOf(wdata_t) val inWrData = io.tensor.wr.bits.data(i).asUInt.asTypeOf(wdata_t)
when (io.tensor.wr.valid) { when(io.tensor.wr.valid) {
tensorFile(i).write(io.tensor.wr.bits.idx, inWrData, no_mask) tensorFile(i).write(io.tensor.wr.bits.idx, inWrData, no_mask)
} }
} }
...@@ -145,51 +151,61 @@ class TensorStore(tensorType: String = "none", debug: Boolean = false) ...@@ -145,51 +151,61 @@ class TensorStore(tensorType: String = "none", debug: Boolean = false)
xrem === 0.U & xrem === 0.U &
ycnt =/= ysize - 1.U ycnt =/= ysize - 1.U
when (state === sIdle) { when(state === sIdle) {
ycnt := 0.U ycnt := 0.U
} .elsewhen (stride) { }.elsewhen(stride) {
ycnt := ycnt + 1.U ycnt := ycnt + 1.U
} }
when (state === sWriteCmd || tag === (numMemBlock - 1).U) { when(state === sWriteCmd || tag === (numMemBlock - 1).U) {
tag := 0.U tag := 0.U
} .elsewhen (io.vme_wr.data.fire()) { }.elsewhen(io.vme_wr.data.fire()) {
tag := tag + 1.U tag := tag + 1.U
} }
when (state === sWriteCmd || (set === (tensorLength - 1).U && tag === (numMemBlock - 1).U)) { when(
state === sWriteCmd || (set === (tensorLength - 1).U && tag === (numMemBlock - 1).U)) {
set := 0.U set := 0.U
} .elsewhen (io.vme_wr.data.fire() && tag === (numMemBlock - 1).U) { }.elsewhen(io.vme_wr.data.fire() && tag === (numMemBlock - 1).U) {
set := set + 1.U set := set + 1.U
} }
val raddr_cur = Reg(UInt(tp.memAddrBits.W)) val raddr_cur = Reg(UInt(tp.memAddrBits.W))
val raddr_nxt = Reg(UInt(tp.memAddrBits.W)) val raddr_nxt = Reg(UInt(tp.memAddrBits.W))
when (state === sIdle) { when(state === sIdle) {
raddr_cur := dec.sram_offset raddr_cur := dec.sram_offset
raddr_nxt := dec.sram_offset raddr_nxt := dec.sram_offset
} .elsewhen (io.vme_wr.data.fire() && set === (tensorLength - 1).U && tag === (numMemBlock - 1).U) { }.elsewhen(io.vme_wr.data
.fire() && set === (tensorLength - 1).U && tag === (numMemBlock - 1).U) {
raddr_cur := raddr_cur + 1.U raddr_cur := raddr_cur + 1.U
} .elsewhen (stride) { }
.elsewhen(stride) {
raddr_cur := raddr_nxt + dec.xsize raddr_cur := raddr_nxt + dec.xsize
raddr_nxt := raddr_nxt + dec.xsize raddr_nxt := raddr_nxt + dec.xsize
} }
val tread = Seq.tabulate(tensorLength) { i => i.U -> val tread = Seq.tabulate(tensorLength) { i =>
tensorFile(i).read(raddr_cur, state === sWriteCmd | state === sReadMem) } i.U ->
tensorFile(i).read(raddr_cur, state === sWriteCmd | state === sReadMem)
}
val mdata = MuxLookup(set, 0.U.asTypeOf(chiselTypeOf(wdata_t)), tread) val mdata = MuxLookup(set, 0.U.asTypeOf(chiselTypeOf(wdata_t)), tread)
// write-to-dram // write-to-dram
val maskOffset = VecInit(Seq.fill(M_DRAM_OFFSET_BITS)(true.B)).asUInt val maskOffset = VecInit(Seq.fill(M_DRAM_OFFSET_BITS)(true.B)).asUInt
val elemBytes = (p(CoreKey).batch * p(CoreKey).blockOut * p(CoreKey).outBits) / 8 val elemBytes = (p(CoreKey).batch * p(CoreKey).blockOut * p(CoreKey).outBits) / 8
when (state === sIdle) { when(state === sIdle) {
waddr_cur := io.baddr | (maskOffset & (dec.dram_offset << log2Ceil(elemBytes))) waddr_cur := io.baddr | (maskOffset & (dec.dram_offset << log2Ceil(
waddr_nxt := io.baddr | (maskOffset & (dec.dram_offset << log2Ceil(elemBytes))) elemBytes)))
} .elsewhen (state === sWriteAck && io.vme_wr.ack && xrem =/= 0.U) { waddr_nxt := io.baddr | (maskOffset & (dec.dram_offset << log2Ceil(
elemBytes)))
}.elsewhen(state === sWriteAck && io.vme_wr.ack && xrem =/= 0.U) {
waddr_cur := waddr_cur + xmax_bytes waddr_cur := waddr_cur + xmax_bytes
} .elsewhen (stride) { }
waddr_cur := waddr_nxt + (dec.xstride << log2Ceil(tensorLength*tensorWidth)) .elsewhen(stride) {
waddr_nxt := waddr_nxt + (dec.xstride << log2Ceil(tensorLength*tensorWidth)) waddr_cur := waddr_nxt + (dec.xstride << log2Ceil(
tensorLength * tensorWidth))
waddr_nxt := waddr_nxt + (dec.xstride << log2Ceil(
tensorLength * tensorWidth))
} }
io.vme_wr.cmd.valid := state === sWriteCmd io.vme_wr.cmd.valid := state === sWriteCmd
...@@ -199,9 +215,9 @@ class TensorStore(tensorType: String = "none", debug: Boolean = false) ...@@ -199,9 +215,9 @@ class TensorStore(tensorType: String = "none", debug: Boolean = false)
io.vme_wr.data.valid := state === sWriteData io.vme_wr.data.valid := state === sWriteData
io.vme_wr.data.bits := mdata(tag) io.vme_wr.data.bits := mdata(tag)
when (state === sWriteCmd) { when(state === sWriteCmd) {
xcnt := 0.U xcnt := 0.U
} .elsewhen (io.vme_wr.data.fire()) { }.elsewhen(io.vme_wr.data.fire()) {
xcnt := xcnt + 1.U xcnt := xcnt + 1.U
} }
...@@ -213,13 +229,19 @@ class TensorStore(tensorType: String = "none", debug: Boolean = false) ...@@ -213,13 +229,19 @@ class TensorStore(tensorType: String = "none", debug: Boolean = false)
// debug // debug
if (debug) { if (debug) {
when (io.vme_wr.cmd.fire()) { when(io.vme_wr.cmd.fire()) {
printf("[TensorStore] ysize:%x ycnt:%x raddr:%x waddr:%x len:%x rem:%x\n", ysize, ycnt, raddr_cur, waddr_cur, xlen, xrem) printf("[TensorStore] ysize:%x ycnt:%x raddr:%x waddr:%x len:%x rem:%x\n",
} ysize,
when (io.vme_wr.data.fire()) { ycnt,
raddr_cur,
waddr_cur,
xlen,
xrem)
}
when(io.vme_wr.data.fire()) {
printf("[TensorStore] data:%x\n", io.vme_wr.data.bits) printf("[TensorStore] data:%x\n", io.vme_wr.data.bits)
} }
when (io.vme_wr.ack) { when(io.vme_wr.ack) {
printf("[TensorStore] ack\n") printf("[TensorStore] ack\n")
} }
} }
......
...@@ -30,11 +30,14 @@ import vta.shell._ ...@@ -30,11 +30,14 @@ import vta.shell._
* weights (wgt), biases (acc), and outputs (out). This is used to avoid * weights (wgt), biases (acc), and outputs (out). This is used to avoid
* doing the same boring calculations over and over again. * doing the same boring calculations over and over again.
*/ */
class TensorParams(tensorType: String = "none")(implicit p: Parameters) extends Bundle { class TensorParams(tensorType: String = "none")(implicit p: Parameters)
val errorMsg = s"\n\n[VTA] [TensorParams] only inp, wgt, acc, and out supported\n\n" extends Bundle {
val errorMsg =
s"\n\n[VTA] [TensorParams] only inp, wgt, acc, and out supported\n\n"
require (tensorType == "inp" || tensorType == "wgt" require(tensorType == "inp" || tensorType == "wgt"
|| tensorType == "acc" || tensorType == "out", errorMsg) || tensorType == "acc" || tensorType == "out",
errorMsg)
val (tensorLength, tensorWidth, tensorElemBits) = val (tensorLength, tensorWidth, tensorElemBits) =
if (tensorType == "inp") if (tensorType == "inp")
...@@ -69,11 +72,12 @@ class TensorParams(tensorType: String = "none")(implicit p: Parameters) extends ...@@ -69,11 +72,12 @@ class TensorParams(tensorType: String = "none")(implicit p: Parameters) extends
* biases (acc), and outputs (out). * biases (acc), and outputs (out).
* *
*/ */
class TensorMaster(tensorType: String = "none") class TensorMaster(tensorType: String = "none")(implicit p: Parameters)
(implicit p: Parameters) extends TensorParams(tensorType) { extends TensorParams(tensorType) {
val rd = new Bundle { val rd = new Bundle {
val idx = ValidIO(UInt(memAddrBits.W)) val idx = ValidIO(UInt(memAddrBits.W))
val data = Flipped(ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W))))) val data = Flipped(
ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))))
} }
val wr = ValidIO(new Bundle { val wr = ValidIO(new Bundle {
val idx = UInt(memAddrBits.W) val idx = UInt(memAddrBits.W)
...@@ -86,7 +90,11 @@ class TensorMaster(tensorType: String = "none") ...@@ -86,7 +90,11 @@ class TensorMaster(tensorType: String = "none")
def tieoffWrite() { def tieoffWrite() {
wr.valid := false.B wr.valid := false.B
wr.bits.idx := 0.U wr.bits.idx := 0.U
wr.bits.data.foreach { b => b.foreach { c => c := 0.U } } wr.bits.data.foreach { b =>
b.foreach { c =>
c := 0.U
}
}
} }
override def cloneType = override def cloneType =
new TensorMaster(tensorType).asInstanceOf[this.type] new TensorMaster(tensorType).asInstanceOf[this.type]
...@@ -98,11 +106,12 @@ class TensorMaster(tensorType: String = "none") ...@@ -98,11 +106,12 @@ class TensorMaster(tensorType: String = "none")
* The TensorLoad unit uses this interface for receiving read and write requests from * The TensorLoad unit uses this interface for receiving read and write requests from
* the TensorGemm unit. * the TensorGemm unit.
*/ */
class TensorClient(tensorType: String = "none") class TensorClient(tensorType: String = "none")(implicit p: Parameters)
(implicit p: Parameters) extends TensorParams(tensorType) { extends TensorParams(tensorType) {
val rd = new Bundle { val rd = new Bundle {
val idx = Flipped(ValidIO(UInt(memAddrBits.W))) val idx = Flipped(ValidIO(UInt(memAddrBits.W)))
val data = ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))) val data = ValidIO(
Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W))))
} }
val wr = Flipped(ValidIO(new Bundle { val wr = Flipped(ValidIO(new Bundle {
val idx = UInt(memAddrBits.W) val idx = UInt(memAddrBits.W)
...@@ -110,7 +119,11 @@ class TensorClient(tensorType: String = "none") ...@@ -110,7 +119,11 @@ class TensorClient(tensorType: String = "none")
})) }))
def tieoffRead() { def tieoffRead() {
rd.data.valid := false.B rd.data.valid := false.B
rd.data.bits.foreach { b => b.foreach { c => c := 0.U } } rd.data.bits.foreach { b =>
b.foreach { c =>
c := 0.U
}
}
} }
override def cloneType = override def cloneType =
new TensorClient(tensorType).asInstanceOf[this.type] new TensorClient(tensorType).asInstanceOf[this.type]
...@@ -122,9 +135,10 @@ class TensorClient(tensorType: String = "none") ...@@ -122,9 +135,10 @@ class TensorClient(tensorType: String = "none")
* is based on the TensorMaster interface, which means this is an input. This interface * is based on the TensorMaster interface, which means this is an input. This interface
* is used on datapath only module such MatrixVectorCore or AluVector. * is used on datapath only module such MatrixVectorCore or AluVector.
*/ */
class TensorMasterData(tensorType: String = "none") class TensorMasterData(tensorType: String = "none")(implicit p: Parameters)
(implicit p: Parameters) extends TensorParams(tensorType) { extends TensorParams(tensorType) {
val data = Flipped(ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W))))) val data = Flipped(
ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))))
override def cloneType = override def cloneType =
new TensorMasterData(tensorType).asInstanceOf[this.type] new TensorMasterData(tensorType).asInstanceOf[this.type]
} }
...@@ -135,18 +149,22 @@ class TensorMasterData(tensorType: String = "none") ...@@ -135,18 +149,22 @@ class TensorMasterData(tensorType: String = "none")
* is based on the TensorClient interface, which means this is an output. This interface * is based on the TensorClient interface, which means this is an output. This interface
* is used on datapath only module such MatrixVectorCore or AluVector. * is used on datapath only module such MatrixVectorCore or AluVector.
*/ */
class TensorClientData(tensorType: String = "none") class TensorClientData(tensorType: String = "none")(implicit p: Parameters)
(implicit p: Parameters) extends TensorParams(tensorType) { extends TensorParams(tensorType) {
val data = ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))) val data = ValidIO(
Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W))))
override def cloneType = override def cloneType =
new TensorClientData(tensorType).asInstanceOf[this.type] new TensorClientData(tensorType).asInstanceOf[this.type]
} }
/** TensorPadCtrl. Zero-padding controller for TensorLoad. */ /** TensorPadCtrl. Zero-padding controller for TensorLoad. */
class TensorPadCtrl(padType: String = "none", sizeFactor: Int = 1) extends Module { class TensorPadCtrl(padType: String = "none", sizeFactor: Int = 1)
val errorMsg = s"\n\n\n[VTA-ERROR] only YPad0, YPad1, XPad0, or XPad1 supported\n\n\n" extends Module {
require (padType == "YPad0" || padType == "YPad1" val errorMsg =
|| padType == "XPad0" || padType == "XPad1", errorMsg) s"\n\n\n[VTA-ERROR] only YPad0, YPad1, XPad0, or XPad1 supported\n\n\n"
require(padType == "YPad0" || padType == "YPad1"
|| padType == "XPad0" || padType == "XPad1",
errorMsg)
val io = IO(new Bundle { val io = IO(new Bundle {
val start = Input(Bool()) val start = Input(Bool())
...@@ -180,33 +198,33 @@ class TensorPadCtrl(padType: String = "none", sizeFactor: Int = 1) extends Modul ...@@ -180,33 +198,33 @@ class TensorPadCtrl(padType: String = "none", sizeFactor: Int = 1) extends Modul
val sIdle :: sActive :: Nil = Enum(2) val sIdle :: sActive :: Nil = Enum(2)
val state = RegInit(sIdle) val state = RegInit(sIdle)
switch (state) { switch(state) {
is (sIdle) { is(sIdle) {
when (io.start) { when(io.start) {
state := sActive state := sActive
} }
} }
is (sActive) { is(sActive) {
when (ycnt === ymax && xcnt === xmax) { when(ycnt === ymax && xcnt === xmax) {
state := sIdle state := sIdle
} }
} }
} }
when (state === sIdle) { when(state === sIdle) {
xmax := xval xmax := xval
ymax := yval ymax := yval
} }
when (state === sIdle || xcnt === xmax) { when(state === sIdle || xcnt === xmax) {
xcnt := 0.U xcnt := 0.U
} .elsewhen (state === sActive) { }.elsewhen(state === sActive) {
xcnt := xcnt + 1.U xcnt := xcnt + 1.U
} }
when (state === sIdle || ymax === 0.U) { when(state === sIdle || ymax === 0.U) {
ycnt := 0.U ycnt := 0.U
} .elsewhen (state === sActive && xcnt === xmax) { }.elsewhen(state === sActive && xcnt === xmax) {
ycnt := ycnt + 1.U ycnt := ycnt + 1.U
} }
...@@ -214,7 +232,10 @@ class TensorPadCtrl(padType: String = "none", sizeFactor: Int = 1) extends Modul ...@@ -214,7 +232,10 @@ class TensorPadCtrl(padType: String = "none", sizeFactor: Int = 1) extends Modul
} }
/** TensorDataCtrl. Data controller for TensorLoad. */ /** TensorDataCtrl. Data controller for TensorLoad. */
class TensorDataCtrl(tensorType: String = "none", sizeFactor: Int = 1, strideFactor: Int = 1)(implicit p: Parameters) extends Module { class TensorDataCtrl(tensorType: String = "none",
sizeFactor: Int = 1,
strideFactor: Int = 1)(implicit p: Parameters)
extends Module {
val mp = p(ShellKey).memParams val mp = p(ShellKey).memParams
val io = IO(new Bundle { val io = IO(new Bundle {
val start = Input(Bool()) val start = Input(Bool())
...@@ -238,7 +259,7 @@ class TensorDataCtrl(tensorType: String = "none", sizeFactor: Int = 1, strideFac ...@@ -238,7 +259,7 @@ class TensorDataCtrl(tensorType: String = "none", sizeFactor: Int = 1, strideFac
val len = Reg(UInt(mp.lenBits.W)) val len = Reg(UInt(mp.lenBits.W))
val xmax_bytes = ((1 << mp.lenBits)*mp.dataBits/8).U val xmax_bytes = ((1 << mp.lenBits) * mp.dataBits / 8).U
val xcnt = Reg(UInt(mp.lenBits.W)) val xcnt = Reg(UInt(mp.lenBits.W))
val xrem = Reg(chiselTypeOf(dec.xsize)) val xrem = Reg(chiselTypeOf(dec.xsize))
val xsize = (dec.xsize << log2Ceil(sizeFactor)) - 1.U val xsize = (dec.xsize << log2Ceil(sizeFactor)) - 1.U
...@@ -251,33 +272,33 @@ class TensorDataCtrl(tensorType: String = "none", sizeFactor: Int = 1, strideFac ...@@ -251,33 +272,33 @@ class TensorDataCtrl(tensorType: String = "none", sizeFactor: Int = 1, strideFac
val split = xcnt === len & xrem =/= 0.U val split = xcnt === len & xrem =/= 0.U
when (io.start || (io.xupdate && stride)) { when(io.start || (io.xupdate && stride)) {
when (xsize < xmax) { when(xsize < xmax) {
len := xsize len := xsize
xrem := 0.U xrem := 0.U
} .otherwise { }.otherwise {
len := xmax - 1.U len := xmax - 1.U
xrem := xsize - xmax xrem := xsize - xmax
} }
} .elsewhen (io.xupdate && split) { }.elsewhen(io.xupdate && split) {
when (xrem < xmax) { when(xrem < xmax) {
len := xrem len := xrem
xrem := 0.U xrem := 0.U
} .otherwise { }.otherwise {
len := xmax - 1.U len := xmax - 1.U
xrem := xrem - xmax xrem := xrem - xmax
} }
} }
when (io.xinit) { when(io.xinit) {
xcnt := 0.U xcnt := 0.U
} .elsewhen (io.xupdate) { }.elsewhen(io.xupdate) {
xcnt := xcnt + 1.U xcnt := xcnt + 1.U
} }
when (io.start) { when(io.start) {
ycnt := 0.U ycnt := 0.U
} .elsewhen (io.yupdate && stride) { }.elsewhen(io.yupdate && stride) {
ycnt := ycnt + 1.U ycnt := ycnt + 1.U
} }
...@@ -291,13 +312,13 @@ class TensorDataCtrl(tensorType: String = "none", sizeFactor: Int = 1, strideFac ...@@ -291,13 +312,13 @@ class TensorDataCtrl(tensorType: String = "none", sizeFactor: Int = 1, strideFac
(p(CoreKey).batch * p(CoreKey).blockOut * p(CoreKey).accBits) / 8 (p(CoreKey).batch * p(CoreKey).blockOut * p(CoreKey).accBits) / 8
} }
when (io.start) { when(io.start) {
caddr := io.baddr | (maskOffset & (dec.dram_offset << log2Ceil(elemBytes))) caddr := io.baddr | (maskOffset & (dec.dram_offset << log2Ceil(elemBytes)))
baddr := io.baddr | (maskOffset & (dec.dram_offset << log2Ceil(elemBytes))) baddr := io.baddr | (maskOffset & (dec.dram_offset << log2Ceil(elemBytes)))
} .elsewhen (io.yupdate) { }.elsewhen(io.yupdate) {
when (split) { when(split) {
caddr := caddr + xmax_bytes caddr := caddr + xmax_bytes
} .elsewhen (stride) { }.elsewhen(stride) {
caddr := baddr + (dec.xstride << log2Ceil(strideFactor)) caddr := baddr + (dec.xstride << log2Ceil(strideFactor))
baddr := baddr + (dec.xstride << log2Ceil(strideFactor)) baddr := baddr + (dec.xstride << log2Ceil(strideFactor))
} }
......
...@@ -78,55 +78,56 @@ class VTAHostDPI extends BlackBox with HasBlackBoxResource { ...@@ -78,55 +78,56 @@ class VTAHostDPI extends BlackBox with HasBlackBoxResource {
* *
* Convert Host DPI to AXI for VTAShell * Convert Host DPI to AXI for VTAShell
*/ */
class VTAHostDPIToAXI(debug: Boolean = false)(implicit p: Parameters)
class VTAHostDPIToAXI(debug: Boolean = false)(implicit p: Parameters) extends Module { extends Module {
val io = IO(new Bundle { val io = IO(new Bundle {
val dpi = new VTAHostDPIClient val dpi = new VTAHostDPIClient
val axi = new AXILiteMaster(p(ShellKey).hostParams) val axi = new AXILiteMaster(p(ShellKey).hostParams)
}) })
val addr = RegInit(0.U.asTypeOf(chiselTypeOf(io.dpi.req.addr))) val addr = RegInit(0.U.asTypeOf(chiselTypeOf(io.dpi.req.addr)))
val data = RegInit(0.U.asTypeOf(chiselTypeOf(io.dpi.req.value))) val data = RegInit(0.U.asTypeOf(chiselTypeOf(io.dpi.req.value)))
val sIdle :: sReadAddress :: sReadData :: sWriteAddress :: sWriteData :: sWriteResponse :: Nil = Enum(6) val sIdle :: sReadAddress :: sReadData :: sWriteAddress :: sWriteData :: sWriteResponse :: Nil =
Enum(6)
val state = RegInit(sIdle) val state = RegInit(sIdle)
switch (state) { switch(state) {
is (sIdle) { is(sIdle) {
when (io.dpi.req.valid) { when(io.dpi.req.valid) {
when (io.dpi.req.opcode) { when(io.dpi.req.opcode) {
state := sWriteAddress state := sWriteAddress
} .otherwise { }.otherwise {
state := sReadAddress state := sReadAddress
} }
} }
} }
is (sReadAddress) { is(sReadAddress) {
when (io.axi.ar.ready) { when(io.axi.ar.ready) {
state := sReadData state := sReadData
} }
} }
is (sReadData) { is(sReadData) {
when (io.axi.r.valid) { when(io.axi.r.valid) {
state := sIdle state := sIdle
} }
} }
is (sWriteAddress) { is(sWriteAddress) {
when (io.axi.aw.ready) { when(io.axi.aw.ready) {
state := sWriteData state := sWriteData
} }
} }
is (sWriteData) { is(sWriteData) {
when (io.axi.w.ready) { when(io.axi.w.ready) {
state := sWriteResponse state := sWriteResponse
} }
} }
is (sWriteResponse) { is(sWriteResponse) {
when (io.axi.b.valid) { when(io.axi.b.valid) {
state := sIdle state := sIdle
} }
} }
} }
when (state === sIdle && io.dpi.req.valid) { when(state === sIdle && io.dpi.req.valid) {
addr := io.dpi.req.addr addr := io.dpi.req.addr
data := io.dpi.req.value data := io.dpi.req.value
} }
...@@ -147,9 +148,17 @@ class VTAHostDPIToAXI(debug: Boolean = false)(implicit p: Parameters) extends Mo ...@@ -147,9 +148,17 @@ class VTAHostDPIToAXI(debug: Boolean = false)(implicit p: Parameters) extends Mo
io.dpi.resp.bits := io.axi.r.bits.data io.dpi.resp.bits := io.axi.r.bits.data
if (debug) { if (debug) {
when (state === sWriteAddress && io.axi.aw.ready) { printf("[VTAHostDPIToAXI] [AW] addr:%x\n", addr) } when(state === sWriteAddress && io.axi.aw.ready) {
when (state === sReadAddress && io.axi.ar.ready) { printf("[VTAHostDPIToAXI] [AR] addr:%x\n", addr) } printf("[VTAHostDPIToAXI] [AW] addr:%x\n", addr)
when (io.axi.r.fire()) { printf("[VTAHostDPIToAXI] [R] value:%x\n", io.axi.r.bits.data) } }
when (io.axi.w.fire()) { printf("[VTAHostDPIToAXI] [W] value:%x\n", io.axi.w.bits.data) } when(state === sReadAddress && io.axi.ar.ready) {
printf("[VTAHostDPIToAXI] [AR] addr:%x\n", addr)
}
when(io.axi.r.fire()) {
printf("[VTAHostDPIToAXI] [R] value:%x\n", io.axi.r.bits.data)
}
when(io.axi.w.fire()) {
printf("[VTAHostDPIToAXI] [W] value:%x\n", io.axi.w.bits.data)
}
} }
} }
...@@ -75,7 +75,8 @@ class VTAMemDPI extends BlackBox with HasBlackBoxResource { ...@@ -75,7 +75,8 @@ class VTAMemDPI extends BlackBox with HasBlackBoxResource {
setResource("/verilog/VTAMemDPI.v") setResource("/verilog/VTAMemDPI.v")
} }
class VTAMemDPIToAXI(debug: Boolean = false)(implicit p: Parameters) extends Module { class VTAMemDPIToAXI(debug: Boolean = false)(implicit p: Parameters)
extends Module {
val io = IO(new Bundle { val io = IO(new Bundle {
val dpi = new VTAMemDPIMaster val dpi = new VTAMemDPIMaster
val axi = new AXIClient(p(ShellKey).memParams) val axi = new AXIClient(p(ShellKey).memParams)
...@@ -83,56 +84,57 @@ class VTAMemDPIToAXI(debug: Boolean = false)(implicit p: Parameters) extends Mod ...@@ -83,56 +84,57 @@ class VTAMemDPIToAXI(debug: Boolean = false)(implicit p: Parameters) extends Mod
val opcode = RegInit(false.B) val opcode = RegInit(false.B)
val len = RegInit(0.U.asTypeOf(chiselTypeOf(io.dpi.req.len))) val len = RegInit(0.U.asTypeOf(chiselTypeOf(io.dpi.req.len)))
val addr = RegInit(0.U.asTypeOf(chiselTypeOf(io.dpi.req.addr))) val addr = RegInit(0.U.asTypeOf(chiselTypeOf(io.dpi.req.addr)))
val sIdle :: sReadAddress :: sReadData :: sWriteAddress :: sWriteData :: sWriteResponse :: Nil = Enum(6) val sIdle :: sReadAddress :: sReadData :: sWriteAddress :: sWriteData :: sWriteResponse :: Nil =
Enum(6)
val state = RegInit(sIdle) val state = RegInit(sIdle)
switch (state) { switch(state) {
is (sIdle) { is(sIdle) {
when (io.axi.ar.valid) { when(io.axi.ar.valid) {
state := sReadAddress state := sReadAddress
} .elsewhen (io.axi.aw.valid) { }.elsewhen(io.axi.aw.valid) {
state := sWriteAddress state := sWriteAddress
} }
} }
is (sReadAddress) { is(sReadAddress) {
when (io.axi.ar.valid) { when(io.axi.ar.valid) {
state := sReadData state := sReadData
} }
} }
is (sReadData) { is(sReadData) {
when (io.axi.r.ready && io.dpi.rd.valid && len === 0.U) { when(io.axi.r.ready && io.dpi.rd.valid && len === 0.U) {
state := sIdle state := sIdle
} }
} }
is (sWriteAddress) { is(sWriteAddress) {
when (io.axi.aw.valid) { when(io.axi.aw.valid) {
state := sWriteData state := sWriteData
} }
} }
is (sWriteData) { is(sWriteData) {
when (io.axi.w.valid && io.axi.w.bits.last) { when(io.axi.w.valid && io.axi.w.bits.last) {
state := sWriteResponse state := sWriteResponse
} }
} }
is (sWriteResponse) { is(sWriteResponse) {
when (io.axi.b.ready) { when(io.axi.b.ready) {
state := sIdle state := sIdle
} }
} }
} }
when (state === sIdle) { when(state === sIdle) {
when (io.axi.ar.valid) { when(io.axi.ar.valid) {
opcode := false.B opcode := false.B
len := io.axi.ar.bits.len len := io.axi.ar.bits.len
addr := io.axi.ar.bits.addr addr := io.axi.ar.bits.addr
} .elsewhen (io.axi.aw.valid) { }.elsewhen(io.axi.aw.valid) {
opcode := true.B opcode := true.B
len := io.axi.aw.bits.len len := io.axi.aw.bits.len
addr := io.axi.aw.bits.addr addr := io.axi.aw.bits.addr
} }
} .elsewhen (state === sReadData) { }.elsewhen(state === sReadData) {
when (io.axi.r.ready && io.dpi.rd.valid && len =/= 0.U) { when(io.axi.r.ready && io.dpi.rd.valid && len =/= 0.U) {
len := len - 1.U len := len - 1.U
} }
} }
...@@ -163,9 +165,21 @@ class VTAMemDPIToAXI(debug: Boolean = false)(implicit p: Parameters) extends Mod ...@@ -163,9 +165,21 @@ class VTAMemDPIToAXI(debug: Boolean = false)(implicit p: Parameters) extends Mod
io.axi.b.bits.id := 0.U io.axi.b.bits.id := 0.U
if (debug) { if (debug) {
when (state === sReadAddress && io.axi.ar.valid) { printf("[VTAMemDPIToAXI] [AR] addr:%x len:%x\n", addr, len) } when(state === sReadAddress && io.axi.ar.valid) {
when (state === sWriteAddress && io.axi.aw.valid) { printf("[VTAMemDPIToAXI] [AW] addr:%x len:%x\n", addr, len) } printf("[VTAMemDPIToAXI] [AR] addr:%x len:%x\n", addr, len)
when (io.axi.r.fire()) { printf("[VTAMemDPIToAXI] [R] last:%x data:%x\n", io.axi.r.bits.last, io.axi.r.bits.data) } }
when (io.axi.w.fire()) { printf("[VTAMemDPIToAXI] [W] last:%x data:%x\n", io.axi.w.bits.last, io.axi.w.bits.data) } when(state === sWriteAddress && io.axi.aw.valid) {
printf("[VTAMemDPIToAXI] [AW] addr:%x len:%x\n", addr, len)
}
when(io.axi.r.fire()) {
printf("[VTAMemDPIToAXI] [R] last:%x data:%x\n",
io.axi.r.bits.last,
io.axi.r.bits.data)
}
when(io.axi.w.fire()) {
printf("[VTAMemDPIToAXI] [W] last:%x data:%x\n",
io.axi.w.bits.last,
io.axi.w.bits.data)
}
} }
} }
...@@ -30,12 +30,11 @@ case class AXIParams( ...@@ -30,12 +30,11 @@ case class AXIParams(
dataBits: Int = 64, dataBits: Int = 64,
lenBits: Int = 8, lenBits: Int = 8,
userBits: Int = 1 userBits: Int = 1
) ) {
{ require(addrBits > 0)
require (addrBits > 0) require(dataBits >= 8 && dataBits % 2 == 0)
require (dataBits >= 8 && dataBits % 2 == 0)
val strbBits = dataBits/8 val strbBits = dataBits / 8
val sizeBits = 3 val sizeBits = 3
val burstBits = 2 val burstBits = 2
val lockBits = 2 val lockBits = 2
...@@ -44,7 +43,7 @@ case class AXIParams( ...@@ -44,7 +43,7 @@ case class AXIParams(
val qosBits = 4 val qosBits = 4
val regionBits = 4 val regionBits = 4
val respBits = 2 val respBits = 2
val sizeConst = log2Ceil(dataBits/8) val sizeConst = log2Ceil(dataBits / 8)
val idConst = 0 val idConst = 0
val userConst = if (coherent) 1 else 0 val userConst = if (coherent) 1 else 0
val burstConst = 1 val burstConst = 1
......
...@@ -25,52 +25,59 @@ import vta.util.config._ ...@@ -25,52 +25,59 @@ import vta.util.config._
import vta.interface.axi._ import vta.interface.axi._
/** PynqConfig. Shell configuration for Pynq */ /** PynqConfig. Shell configuration for Pynq */
class PynqConfig extends Config((site, here, up) => { class PynqConfig
case ShellKey => ShellParams( extends Config((site, here, up) => {
hostParams = AXIParams( case ShellKey =>
coherent = false, ShellParams(
hostParams = AXIParams(coherent = false,
addrBits = 16, addrBits = 16,
dataBits = 32, dataBits = 32,
lenBits = 8, lenBits = 8,
userBits = 1), userBits = 1),
memParams = AXIParams( memParams = AXIParams(coherent = true,
coherent = true,
addrBits = 32, addrBits = 32,
dataBits = 64, dataBits = 64,
lenBits = 8, lenBits = 8,
userBits = 1), userBits = 1),
vcrParams = VCRParams(), vcrParams = VCRParams(),
vmeParams = VMEParams()) vmeParams = VMEParams()
}) )
})
/** F1Config. Shell configuration for F1 */ /** F1Config. Shell configuration for F1 */
class F1Config extends Config((site, here, up) => { class F1Config
case ShellKey => ShellParams( extends Config((site, here, up) => {
hostParams = AXIParams( case ShellKey =>
coherent = false, ShellParams(
hostParams = AXIParams(coherent = false,
addrBits = 16, addrBits = 16,
dataBits = 32, dataBits = 32,
lenBits = 8, lenBits = 8,
userBits = 1), userBits = 1),
memParams = AXIParams( memParams = AXIParams(coherent = false,
coherent = false,
addrBits = 64, addrBits = 64,
dataBits = 64, dataBits = 64,
lenBits = 8, lenBits = 8,
userBits = 1), userBits = 1),
vcrParams = VCRParams(), vcrParams = VCRParams(),
vmeParams = VMEParams()) vmeParams = VMEParams()
}) )
})
/** De10Config. Shell configuration for De10 */ /** De10Config. Shell configuration for De10 */
class De10Config extends Config((site, here, up) => { class De10Config
case ShellKey => ShellParams( extends Config((site, here, up) => {
hostParams = AXIParams( case ShellKey =>
addrBits = 16, dataBits = 32, idBits = 13, lenBits = 4), ShellParams(
hostParams =
AXIParams(addrBits = 16, dataBits = 32, idBits = 13, lenBits = 4),
memParams = AXIParams( memParams = AXIParams(
addrBits = 32, dataBits = 64, userBits = 5, addrBits = 32,
dataBits = 64,
userBits = 5,
lenBits = 4, // limit to 16 beats, instead of 256 beats in AXI4 lenBits = 4, // limit to 16 beats, instead of 256 beats in AXI4
coherent = true), coherent = true),
vcrParams = VCRParams(), vcrParams = VCRParams(),
vmeParams = VMEParams()) vmeParams = VMEParams()
}) )
})
...@@ -30,7 +30,7 @@ import vta.core._ ...@@ -30,7 +30,7 @@ import vta.core._
* system that can be used for simulation or real hardware. * system that can be used for simulation or real hardware.
*/ */
class IntelShell(implicit p: Parameters) extends Module { class IntelShell(implicit p: Parameters) extends Module {
val io = IO(new Bundle{ val io = IO(new Bundle {
val host = new AXIClient(p(ShellKey).hostParams) val host = new AXIClient(p(ShellKey).hostParams)
val mem = new AXIMaster(p(ShellKey).memParams) val mem = new AXIMaster(p(ShellKey).memParams)
}) })
......
...@@ -76,6 +76,7 @@ class VTASim(implicit p: Parameters) extends MultiIOModule { ...@@ -76,6 +76,7 @@ class VTASim(implicit p: Parameters) extends MultiIOModule {
sim.io.clock := clock sim.io.clock := clock
sim_wait := sim.io.dpi_wait sim_wait := sim.io.dpi_wait
} }
/** SimShell. /** SimShell.
* *
* The simulation shell instantiate the sim, host and memory DPI modules that * The simulation shell instantiate the sim, host and memory DPI modules that
......
...@@ -29,8 +29,7 @@ import vta.interface.axi._ ...@@ -29,8 +29,7 @@ import vta.interface.axi._
* *
* These parameters are used on VCR interfaces and modules. * These parameters are used on VCR interfaces and modules.
*/ */
case class VCRParams() case class VCRParams() {
{
val nCtrl = 1 val nCtrl = 1
val nECnt = 1 val nECnt = 1
val nVals = 1 val nVals = 1
...@@ -80,7 +79,7 @@ class VCRClient(implicit p: Parameters) extends VCRBase { ...@@ -80,7 +79,7 @@ class VCRClient(implicit p: Parameters) extends VCRBase {
* registers that could be used as event counters by the Core unit. * registers that could be used as event counters by the Core unit.
*/ */
class VCR(implicit p: Parameters) extends Module { class VCR(implicit p: Parameters) extends Module {
val io = IO(new Bundle{ val io = IO(new Bundle {
val host = new AXILiteClient(p(ShellKey).hostParams) val host = new AXILiteClient(p(ShellKey).hostParams)
val vcr = new VCRMaster val vcr = new VCRMaster
}) })
...@@ -101,7 +100,7 @@ class VCR(implicit p: Parameters) extends Module { ...@@ -101,7 +100,7 @@ class VCR(implicit p: Parameters) extends Module {
val rdata = RegInit(0.U(vp.regBits.W)) val rdata = RegInit(0.U(vp.regBits.W))
// registers // registers
val nPtrs = if (mp.addrBits == 32) vp.nPtrs else 2*vp.nPtrs val nPtrs = if (mp.addrBits == 32) vp.nPtrs else 2 * vp.nPtrs
val nTotal = vp.nCtrl + vp.nECnt + vp.nVals + nPtrs val nTotal = vp.nCtrl + vp.nECnt + vp.nVals + nPtrs
val reg = Seq.fill(nTotal)(RegInit(0.U(vp.regBits.W))) val reg = Seq.fill(nTotal)(RegInit(0.U(vp.regBits.W)))
...@@ -111,40 +110,39 @@ class VCR(implicit p: Parameters) extends Module { ...@@ -111,40 +110,39 @@ class VCR(implicit p: Parameters) extends Module {
val vo = eo + vp.nECnt val vo = eo + vp.nECnt
val po = vo + vp.nVals val po = vo + vp.nVals
switch (wstate) { switch(wstate) {
is (sWriteAddress) { is(sWriteAddress) {
when (io.host.aw.valid) { when(io.host.aw.valid) {
wstate := sWriteData wstate := sWriteData
} }
} }
is (sWriteData) { is(sWriteData) {
when (io.host.w.valid) { when(io.host.w.valid) {
wstate := sWriteResponse wstate := sWriteResponse
} }
} }
is (sWriteResponse) { is(sWriteResponse) {
when (io.host.b.ready) { when(io.host.b.ready) {
wstate := sWriteAddress wstate := sWriteAddress
} }
} }
} }
when (io.host.aw.fire()) { waddr := io.host.aw.bits.addr } when(io.host.aw.fire()) { waddr := io.host.aw.bits.addr }
io.host.aw.ready := wstate === sWriteAddress io.host.aw.ready := wstate === sWriteAddress
io.host.w.ready := wstate === sWriteData io.host.w.ready := wstate === sWriteData
io.host.b.valid := wstate === sWriteResponse io.host.b.valid := wstate === sWriteResponse
io.host.b.bits.resp := 0.U io.host.b.bits.resp := 0.U
switch(rstate) {
switch (rstate) { is(sReadAddress) {
is (sReadAddress) { when(io.host.ar.valid) {
when (io.host.ar.valid) {
rstate := sReadData rstate := sReadData
} }
} }
is (sReadData) { is(sReadData) {
when (io.host.r.ready) { when(io.host.r.ready) {
rstate := sReadAddress rstate := sReadAddress
} }
} }
...@@ -155,27 +153,27 @@ class VCR(implicit p: Parameters) extends Module { ...@@ -155,27 +153,27 @@ class VCR(implicit p: Parameters) extends Module {
io.host.r.bits.data := rdata io.host.r.bits.data := rdata
io.host.r.bits.resp := 0.U io.host.r.bits.resp := 0.U
when (io.vcr.finish) { when(io.vcr.finish) {
reg(0) := "b_10".U reg(0) := "b_10".U
} .elsewhen (io.host.w.fire() && addr(0).U === waddr) { }.elsewhen(io.host.w.fire() && addr(0).U === waddr) {
reg(0) := wdata reg(0) := wdata
} }
for (i <- 0 until vp.nECnt) { for (i <- 0 until vp.nECnt) {
when (io.vcr.ecnt(i).valid) { when(io.vcr.ecnt(i).valid) {
reg(eo + i) := io.vcr.ecnt(i).bits reg(eo + i) := io.vcr.ecnt(i).bits
} .elsewhen (io.host.w.fire() && addr(eo + i).U === waddr) { }.elsewhen(io.host.w.fire() && addr(eo + i).U === waddr) {
reg(eo + i) := wdata reg(eo + i) := wdata
} }
} }
for (i <- 0 until (vp.nVals + nPtrs)) { for (i <- 0 until (vp.nVals + nPtrs)) {
when (io.host.w.fire() && addr(vo + i).U === waddr) { when(io.host.w.fire() && addr(vo + i).U === waddr) {
reg(vo + i) := wdata reg(vo + i) := wdata
} }
} }
when (io.host.ar.fire()) { when(io.host.ar.fire()) {
rdata := MuxLookup(io.host.ar.bits.addr, 0.U, reg_map) rdata := MuxLookup(io.host.ar.bits.addr, 0.U, reg_map)
} }
...@@ -190,8 +188,8 @@ class VCR(implicit p: Parameters) extends Module { ...@@ -190,8 +188,8 @@ class VCR(implicit p: Parameters) extends Module {
io.vcr.ptrs(i) := reg(po + i) io.vcr.ptrs(i) := reg(po + i)
} }
} else { // 64-bits pointers } else { // 64-bits pointers
for (i <- 0 until (nPtrs/2)) { for (i <- 0 until (nPtrs / 2)) {
io.vcr.ptrs(i) := Cat(reg(po + 2*i + 1), reg(po + 2*i)) io.vcr.ptrs(i) := Cat(reg(po + 2 * i + 1), reg(po + 2 * i))
} }
} }
} }
...@@ -32,8 +32,11 @@ import vta.interface.axi._ ...@@ -32,8 +32,11 @@ import vta.interface.axi._
case class VMEParams() { case class VMEParams() {
val nReadClients: Int = 5 val nReadClients: Int = 5
val nWriteClients: Int = 1 val nWriteClients: Int = 1
require (nReadClients > 0, s"\n\n[VTA] [VMEParams] nReadClients must be larger than 0\n\n") require(nReadClients > 0,
require (nWriteClients == 1, s"\n\n[VTA] [VMEParams] nWriteClients must be 1, only one-write-client support atm\n\n") s"\n\n[VTA] [VMEParams] nReadClients must be larger than 0\n\n")
require(
nWriteClients == 1,
s"\n\n[VTA] [VMEParams] nWriteClients must be 1, only one-write-client support atm\n\n")
} }
/** VMEBase. Parametrize base class. */ /** VMEBase. Parametrize base class. */
...@@ -149,19 +152,19 @@ class VME(implicit p: Parameters) extends Module { ...@@ -149,19 +152,19 @@ class VME(implicit p: Parameters) extends Module {
val sReadIdle :: sReadAddr :: sReadData :: Nil = Enum(3) val sReadIdle :: sReadAddr :: sReadData :: Nil = Enum(3)
val rstate = RegInit(sReadIdle) val rstate = RegInit(sReadIdle)
switch (rstate) { switch(rstate) {
is (sReadIdle) { is(sReadIdle) {
when (rd_arb.io.out.valid) { when(rd_arb.io.out.valid) {
rstate := sReadAddr rstate := sReadAddr
} }
} }
is (sReadAddr) { is(sReadAddr) {
when (io.mem.ar.ready) { when(io.mem.ar.ready) {
rstate := sReadData rstate := sReadData
} }
} }
is (sReadData) { is(sReadData) {
when (io.mem.r.fire() && io.mem.r.bits.last) { when(io.mem.r.fire() && io.mem.r.bits.last) {
rstate := sReadIdle rstate := sReadIdle
} }
} }
...@@ -173,30 +176,34 @@ class VME(implicit p: Parameters) extends Module { ...@@ -173,30 +176,34 @@ class VME(implicit p: Parameters) extends Module {
val lenBits = p(ShellKey).memParams.lenBits val lenBits = p(ShellKey).memParams.lenBits
val wr_cnt = RegInit(0.U(lenBits.W)) val wr_cnt = RegInit(0.U(lenBits.W))
when (wstate === sWriteIdle) { when(wstate === sWriteIdle) {
wr_cnt := 0.U wr_cnt := 0.U
} .elsewhen (io.mem.w.fire()) { }.elsewhen(io.mem.w.fire()) {
wr_cnt := wr_cnt + 1.U wr_cnt := wr_cnt + 1.U
} }
switch (wstate) { switch(wstate) {
is (sWriteIdle) { is(sWriteIdle) {
when (io.vme.wr(0).cmd.valid) { when(io.vme.wr(0).cmd.valid) {
wstate := sWriteAddr wstate := sWriteAddr
} }
} }
is (sWriteAddr) { is(sWriteAddr) {
when (io.mem.aw.ready) { when(io.mem.aw.ready) {
wstate := sWriteData wstate := sWriteData
} }
} }
is (sWriteData) { is(sWriteData) {
when (io.vme.wr(0).data.valid && io.mem.w.ready && wr_cnt === io.vme.wr(0).cmd.bits.len) { when(
io.vme
.wr(0)
.data
.valid && io.mem.w.ready && wr_cnt === io.vme.wr(0).cmd.bits.len) {
wstate := sWriteResp wstate := sWriteResp
} }
} }
is (sWriteResp) { is(sWriteResp) {
when (io.mem.b.valid) { when(io.mem.b.valid) {
wstate := sWriteIdle wstate := sWriteIdle
} }
} }
...@@ -209,12 +216,12 @@ class VME(implicit p: Parameters) extends Module { ...@@ -209,12 +216,12 @@ class VME(implicit p: Parameters) extends Module {
val rd_addr = RegInit(0.U(addrBits.W)) val rd_addr = RegInit(0.U(addrBits.W))
val wr_addr = RegInit(0.U(addrBits.W)) val wr_addr = RegInit(0.U(addrBits.W))
when (rd_arb.io.out.fire()) { when(rd_arb.io.out.fire()) {
rd_len := rd_arb.io.out.bits.len rd_len := rd_arb.io.out.bits.len
rd_addr := rd_arb.io.out.bits.addr rd_addr := rd_arb.io.out.bits.addr
} }
when (io.vme.wr(0).cmd.fire()) { when(io.vme.wr(0).cmd.fire()) {
wr_len := io.vme.wr(0).cmd.bits.len wr_len := io.vme.wr(0).cmd.bits.len
wr_addr := io.vme.wr(0).cmd.bits.addr wr_addr := io.vme.wr(0).cmd.bits.addr
} }
......
...@@ -40,7 +40,7 @@ case object ShellKey extends Field[ShellParams] ...@@ -40,7 +40,7 @@ case object ShellKey extends Field[ShellParams]
* system that can be used for simulation or real hardware. * system that can be used for simulation or real hardware.
*/ */
class VTAShell(implicit p: Parameters) extends Module { class VTAShell(implicit p: Parameters) extends Module {
val io = IO(new Bundle{ val io = IO(new Bundle {
val host = new AXILiteClient(p(ShellKey).hostParams) val host = new AXILiteClient(p(ShellKey).hostParams)
val mem = new AXIMaster(p(ShellKey).memParams) val mem = new AXIMaster(p(ShellKey).memParams)
}) })
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
package vta.shell package vta.shell
import chisel3._ import chisel3._
import chisel3.experimental.{RawModule, withClockAndReset} import chisel3.experimental.{withClockAndReset, RawModule}
import vta.util.config._ import vta.util.config._
import vta.interface.axi._ import vta.interface.axi._
...@@ -39,7 +39,9 @@ class XilinxShell(implicit p: Parameters) extends RawModule { ...@@ -39,7 +39,9 @@ class XilinxShell(implicit p: Parameters) extends RawModule {
val m_axi_gmem = IO(new XilinxAXIMaster(mp)) val m_axi_gmem = IO(new XilinxAXIMaster(mp))
val s_axi_control = IO(new XilinxAXILiteClient(hp)) val s_axi_control = IO(new XilinxAXILiteClient(hp))
val shell = withClockAndReset (clock = ap_clk, reset = ~ap_rst_n) { Module(new VTAShell) } val shell = withClockAndReset(clock = ap_clk, reset = ~ap_rst_n) {
Module(new VTAShell)
}
// memory // memory
m_axi_gmem.AWVALID := shell.io.mem.aw.valid m_axi_gmem.AWVALID := shell.io.mem.aw.valid
......
...@@ -21,8 +21,7 @@ package vta.util.config ...@@ -21,8 +21,7 @@ package vta.util.config
// taken from https://github.com/vta.roject/rocket-chip // taken from https://github.com/vta.roject/rocket-chip
abstract class Field[T] private (val default: Option[T]) abstract class Field[T] private (val default: Option[T]) {
{
def this() = this(None) def this() = this(None)
def this(default: T) = this(Some(default)) def this(default: T) = this(Some(default))
} }
...@@ -31,42 +30,50 @@ abstract class View { ...@@ -31,42 +30,50 @@ abstract class View {
final def apply[T](pname: Field[T]): T = apply(pname, this) final def apply[T](pname: Field[T]): T = apply(pname, this)
final def apply[T](pname: Field[T], site: View): T = { final def apply[T](pname: Field[T], site: View): T = {
val out = find(pname, site) val out = find(pname, site)
require (out.isDefined, s"Key ${pname} is not defined in Parameters") require(out.isDefined, s"Key ${pname} is not defined in Parameters")
out.get out.get
} }
final def lift[T](pname: Field[T]): Option[T] = lift(pname, this) final def lift[T](pname: Field[T]): Option[T] = lift(pname, this)
final def lift[T](pname: Field[T], site: View): Option[T] = find(pname, site).map(_.asInstanceOf[T]) final def lift[T](pname: Field[T], site: View): Option[T] =
find(pname, site).map(_.asInstanceOf[T])
protected[config] def find[T](pname: Field[T], site: View): Option[T] protected[config] def find[T](pname: Field[T], site: View): Option[T]
} }
abstract class Parameters extends View { abstract class Parameters extends View {
final def ++ (x: Parameters): Parameters = final def ++(x: Parameters): Parameters =
new ChainParameters(this, x) new ChainParameters(this, x)
final def alter(f: (View, View, View) => PartialFunction[Any,Any]): Parameters = final def alter(
f: (View, View, View) => PartialFunction[Any, Any]): Parameters =
Parameters(f) ++ this Parameters(f) ++ this
final def alterPartial(f: PartialFunction[Any,Any]): Parameters = final def alterPartial(f: PartialFunction[Any, Any]): Parameters =
Parameters((_,_,_) => f) ++ this Parameters((_, _, _) => f) ++ this
final def alterMap(m: Map[Any,Any]): Parameters = final def alterMap(m: Map[Any, Any]): Parameters =
new MapParameters(m) ++ this new MapParameters(m) ++ this
protected[config] def chain[T](site: View, tail: View, pname: Field[T]): Option[T] protected[config] def chain[T](site: View,
protected[config] def find[T](pname: Field[T], site: View) = chain(site, new TerminalView, pname) tail: View,
pname: Field[T]): Option[T]
protected[config] def find[T](pname: Field[T], site: View) =
chain(site, new TerminalView, pname)
} }
object Parameters { object Parameters {
def empty: Parameters = new EmptyParameters def empty: Parameters = new EmptyParameters
def apply(f: (View, View, View) => PartialFunction[Any,Any]): Parameters = new PartialParameters(f) def apply(f: (View, View, View) => PartialFunction[Any, Any]): Parameters =
new PartialParameters(f)
} }
class Config(p: Parameters) extends Parameters { class Config(p: Parameters) extends Parameters {
def this(f: (View, View, View) => PartialFunction[Any,Any]) = this(Parameters(f)) def this(f: (View, View, View) => PartialFunction[Any, Any]) =
this(Parameters(f))
protected[config] def chain[T](site: View, tail: View, pname: Field[T]) = p.chain(site, tail, pname) protected[config] def chain[T](site: View, tail: View, pname: Field[T]) =
p.chain(site, tail, pname)
override def toString = this.getClass.getSimpleName override def toString = this.getClass.getSimpleName
def toInstance = this def toInstance = this
} }
...@@ -82,17 +89,21 @@ private class ChainView(head: Parameters, tail: View) extends View { ...@@ -82,17 +89,21 @@ private class ChainView(head: Parameters, tail: View) extends View {
} }
private class ChainParameters(x: Parameters, y: Parameters) extends Parameters { private class ChainParameters(x: Parameters, y: Parameters) extends Parameters {
def chain[T](site: View, tail: View, pname: Field[T]) = x.chain(site, new ChainView(y, tail), pname) def chain[T](site: View, tail: View, pname: Field[T]) =
x.chain(site, new ChainView(y, tail), pname)
} }
private class EmptyParameters extends Parameters { private class EmptyParameters extends Parameters {
def chain[T](site: View, tail: View, pname: Field[T]) = tail.find(pname, site) def chain[T](site: View, tail: View, pname: Field[T]) = tail.find(pname, site)
} }
private class PartialParameters(f: (View, View, View) => PartialFunction[Any,Any]) extends Parameters { private class PartialParameters(
f: (View, View, View) => PartialFunction[Any, Any])
extends Parameters {
protected[config] def chain[T](site: View, tail: View, pname: Field[T]) = { protected[config] def chain[T](site: View, tail: View, pname: Field[T]) = {
val g = f(site, this, tail) val g = f(site, this, tail)
if (g.isDefinedAt(pname)) Some(g.apply(pname).asInstanceOf[T]) else tail.find(pname, site) if (g.isDefinedAt(pname)) Some(g.apply(pname).asInstanceOf[T])
else tail.find(pname, site)
} }
} }
......
...@@ -23,18 +23,22 @@ package vta.util.genericbundle ...@@ -23,18 +23,22 @@ package vta.util.genericbundle
import chisel3._ import chisel3._
abstract class GenericParameterizedBundle[+T <: Object](val params: T) extends Bundle abstract class GenericParameterizedBundle[+T <: Object](val params: T)
{ extends Bundle {
override def cloneType = { override def cloneType = {
try { try {
this.getClass.getConstructors.head.newInstance(params).asInstanceOf[this.type] this.getClass.getConstructors.head
.newInstance(params)
.asInstanceOf[this.type]
} catch { } catch {
case e: java.lang.IllegalArgumentException => case e: java.lang.IllegalArgumentException =>
throw new Exception("Unable to use GenericParameterizedBundle.cloneType on " + throw new Exception(
"Unable to use GenericParameterizedBundle.cloneType on " +
this.getClass + ", probably because " + this.getClass + this.getClass + ", probably because " + this.getClass +
"() takes more than one argument. Consider overriding " + "() takes more than one argument. Consider overriding " +
"cloneType() on " + this.getClass, e) "cloneType() on " + this.getClass,
e
)
} }
} }
} }
...@@ -31,7 +31,6 @@ import vta.test._ ...@@ -31,7 +31,6 @@ import vta.test._
* These configurations are built in a mix/match form based on core * These configurations are built in a mix/match form based on core
* and shell configurations. * and shell configurations.
*/ */
class DefaultPynqConfig extends Config(new CoreConfig ++ new PynqConfig) class DefaultPynqConfig extends Config(new CoreConfig ++ new PynqConfig)
class DefaultF1Config extends Config(new CoreConfig ++ new F1Config) class DefaultF1Config extends Config(new CoreConfig ++ new F1Config)
class DefaultDe10Config extends Config(new CoreConfig ++ new De10Config) class DefaultDe10Config extends Config(new CoreConfig ++ new De10Config)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment