Commit 5fe61fd1 by Luis Vega Committed by Thierry Moreau

[VTA][Chisel] add scalafmt and format existing scala codebase (#3880)

* [VTA][Chisel] add scalafmt and format existing scala codebase

* change column width to 100

* add scalafmt conf file as a valid file type

* add asf header to scalafmt conf file and rerun formatter
parent f4a28c4b
......@@ -87,6 +87,7 @@ ALLOW_FILE_NAME = {
".clang-format",
".gitmodules",
"CODEOWNERS",
".scalafmt.conf",
}
# List of specific files allowed in relpath to <proj_root>
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
maxColumn = 100
rewrite.rules = [SortModifiers, SortImports]
......@@ -91,7 +91,10 @@ else
lib_path = $(build_dir)/$(LIBNAME).so
endif
default: lib
default: lint lib
lint:
sbt scalafmt
lib: $(lib_path)
$(lib_path): $(verilator_build_dir)/V$(TOP).cpp
......
......@@ -18,3 +18,4 @@
*/
logLevel := Level.Warn
addSbtPlugin("com.geirsson" % "sbt-scalafmt" % "1.5.1")
......@@ -41,7 +41,7 @@ case class AccelConfig() {
val nVals = 2
val nPtrs = 2
val regBits = 32
val ptrBits = 2*regBits
val ptrBits = 2 * regBits
}
class Accel extends Module {
......
......@@ -54,27 +54,27 @@ class Compute(implicit config: AccelConfig) extends Module {
val raddr = Reg(UInt(config.ptrBits.W))
val waddr = Reg(UInt(config.ptrBits.W))
switch (state) {
is (sIdle) {
when (io.launch) {
switch(state) {
is(sIdle) {
when(io.launch) {
state := sReadReq
}
}
is (sReadReq) {
is(sReadReq) {
state := sReadData
}
is (sReadData) {
when (io.mem.rd.valid) {
is(sReadData) {
when(io.mem.rd.valid) {
state := sWriteReq
}
}
is (sWriteReq) {
is(sWriteReq) {
state := sWriteData
}
is (sWriteData) {
when (cnt === (length - 1.U)) {
is(sWriteData) {
when(cnt === (length - 1.U)) {
state := sIdle
} .otherwise {
}.otherwise {
state := sReadReq
}
}
......@@ -83,9 +83,9 @@ class Compute(implicit config: AccelConfig) extends Module {
val last = state === sWriteData && cnt === (length - 1.U)
// cycle counter
when (state === sIdle) {
when(state === sIdle) {
cycles := 0.U
} .otherwise {
}.otherwise {
cycles := cycles + 1.U
}
......@@ -93,10 +93,10 @@ class Compute(implicit config: AccelConfig) extends Module {
io.ecnt(0).bits := cycles
// calculate next address
when (state === sIdle) {
when(state === sIdle) {
raddr := io.ptrs(0)
waddr := io.ptrs(1)
} .elsewhen (state === sWriteData) { // increment by 8-bytes
}.elsewhen(state === sWriteData) { // increment by 8-bytes
raddr := raddr + 8.U
waddr := waddr + 8.U
}
......@@ -108,7 +108,7 @@ class Compute(implicit config: AccelConfig) extends Module {
io.mem.req.addr := Mux(state === sReadReq, raddr, waddr)
// read
when (state === sReadData && io.mem.rd.valid) {
when(state === sReadData && io.mem.rd.valid) {
reg := io.mem.rd.bits + const
}
io.mem.rd.ready := state === sReadData
......@@ -118,9 +118,9 @@ class Compute(implicit config: AccelConfig) extends Module {
io.mem.wr.bits := reg
// count read/write
when (state === sIdle) {
when(state === sIdle) {
cnt := 0.U
} .elsewhen (state === sWriteData) {
}.elsewhen(state === sWriteData) {
cnt := cnt + 1.U
}
......
......@@ -59,52 +59,54 @@ class RegFile(implicit config: AccelConfig) extends Module {
val sIdle :: sRead :: Nil = Enum(2)
val state = RegInit(sIdle)
switch (state) {
is (sIdle) {
when (io.host.req.valid && !io.host.req.opcode) {
switch(state) {
is(sIdle) {
when(io.host.req.valid && !io.host.req.opcode) {
state := sRead
}
}
is (sRead) {
is(sRead) {
state := sIdle
}
}
io.host.req.deq := state === sIdle & io.host.req.valid
val nTotal = config.nCtrl + config.nECnt + config.nVals + (2*config.nPtrs)
val reg = Seq.fill(nTotal)(RegInit(0.U.asTypeOf(chiselTypeOf(io.host.req.value))))
val nTotal = config.nCtrl + config.nECnt + config.nVals + (2 * config.nPtrs)
val reg =
Seq.fill(nTotal)(RegInit(0.U.asTypeOf(chiselTypeOf(io.host.req.value))))
val addr = Seq.tabulate(nTotal)(_ * 4)
val reg_map = (addr zip reg) map { case (a, r) => a.U -> r }
val reg_map = (addr zip reg) map { case (a, r) => a.U -> r }
val eo = config.nCtrl
val vo = eo + config.nECnt
val po = vo + config.nVals
when (io.finish) {
when(io.finish) {
reg(0) := "b_10".U
} .elsewhen (state === sIdle && io.host.req.valid &&
io.host.req.opcode && addr(0).U === io.host.req.addr) {
}.elsewhen(state === sIdle && io.host.req.valid &&
io.host.req.opcode && addr(0).U === io.host.req.addr) {
reg(0) := io.host.req.value
}
for (i <- 0 until config.nECnt) {
when (io.ecnt(i).valid) {
when(io.ecnt(i).valid) {
reg(eo + i) := io.ecnt(i).bits
} .elsewhen (state === sIdle && io.host.req.valid &&
io.host.req.opcode && addr(eo + i).U === io.host.req.addr) {
}.elsewhen(state === sIdle && io.host.req.valid &&
io.host.req.opcode && addr(eo + i).U === io.host.req.addr) {
reg(eo + i) := io.host.req.value
}
}
for (i <- 0 until (config.nVals + (2*config.nPtrs))) {
when (state === sIdle && io.host.req.valid &&
io.host.req.opcode && addr(vo + i).U === io.host.req.addr) {
for (i <- 0 until (config.nVals + (2 * config.nPtrs))) {
when(
state === sIdle && io.host.req.valid &&
io.host.req.opcode && addr(vo + i).U === io.host.req.addr) {
reg(vo + i) := io.host.req.value
}
}
val rdata = RegInit(0.U.asTypeOf(chiselTypeOf(io.host.req.value)))
when (state === sIdle && io.host.req.valid && !io.host.req.opcode) {
when(state === sIdle && io.host.req.valid && !io.host.req.opcode) {
rdata := MuxLookup(io.host.req.addr, 0.U, reg_map)
}
......@@ -118,6 +120,6 @@ class RegFile(implicit config: AccelConfig) extends Module {
}
for (i <- 0 until config.nPtrs) {
io.ptrs(i) := Cat(reg(po + 2*i + 1), reg(po + 2*i))
io.ptrs(i) := Cat(reg(po + (2 * i) + 1), reg(po + (2 * i)))
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
maxColumn = 100
rewrite.rules = [SortModifiers, SortImports]
......@@ -102,7 +102,10 @@ else
lib_path = $(vta_dir)/$(BUILD_NAME)/$(VTA_LIBNAME).so
endif
default: lib
default: lint lib
lint:
sbt scalafmt
lib: $(lib_path)
......
......@@ -18,3 +18,4 @@
*/
logLevel := Level.Warn
addSbtPlugin("com.geirsson" % "sbt-scalafmt" % "1.5.1")
......@@ -49,7 +49,8 @@ class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module {
val sIdle :: sSync :: sExe :: Nil = Enum(3)
val state = RegInit(sIdle)
val s = Seq.tabulate(2)(_ => Module(new Semaphore(counterBits = 8, counterInitValue = 0)))
val s = Seq.tabulate(2)(_ =>
Module(new Semaphore(counterBits = 8, counterInitValue = 0)))
val loadUop = Module(new LoadUop)
val tensorAcc = Module(new TensorLoad(tensorType = "acc"))
......@@ -62,18 +63,20 @@ class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module {
val dec = Module(new ComputeDecode)
dec.io.inst := inst_q.io.deq.bits
val inst_type = Cat(dec.io.isFinish,
dec.io.isAlu,
dec.io.isGemm,
dec.io.isLoadAcc,
dec.io.isLoadUop).asUInt
val inst_type =
Cat(dec.io.isFinish,
dec.io.isAlu,
dec.io.isGemm,
dec.io.isLoadAcc,
dec.io.isLoadUop).asUInt
val sprev = inst_q.io.deq.valid & Mux(dec.io.pop_prev, s(0).io.sready, true.B)
val snext = inst_q.io.deq.valid & Mux(dec.io.pop_next, s(1).io.sready, true.B)
val start = snext & sprev
val done =
MuxLookup(inst_type,
false.B, // default
MuxLookup(
inst_type,
false.B, // default
Array(
"h_01".U -> loadUop.io.done,
"h_02".U -> tensorAcc.io.done,
......@@ -84,21 +87,21 @@ class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module {
)
// control
switch (state) {
is (sIdle) {
when (start) {
when (dec.io.isSync) {
switch(state) {
is(sIdle) {
when(start) {
when(dec.io.isSync) {
state := sSync
} .elsewhen (inst_type.orR) {
}.elsewhen(inst_type.orR) {
state := sExe
}
}
}
is (sSync) {
is(sSync) {
state := sIdle
}
is (sExe) {
when (done) {
is(sExe) {
when(done) {
state := sIdle
}
}
......@@ -109,22 +112,28 @@ class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module {
inst_q.io.deq.ready := (state === sExe & done) | (state === sSync)
// uop
loadUop.io.start := state === sIdle & start & dec.io.isLoadUop
loadUop.io.start := state === sIdle & start & dec.io.isLoadUop
loadUop.io.inst := inst_q.io.deq.bits
loadUop.io.baddr := io.uop_baddr
io.vme_rd(0) <> loadUop.io.vme_rd
loadUop.io.uop.idx <> Mux(dec.io.isGemm, tensorGemm.io.uop.idx, tensorAlu.io.uop.idx)
loadUop.io.uop.idx <> Mux(dec.io.isGemm,
tensorGemm.io.uop.idx,
tensorAlu.io.uop.idx)
// acc
tensorAcc.io.start := state === sIdle & start & dec.io.isLoadAcc
tensorAcc.io.inst := inst_q.io.deq.bits
tensorAcc.io.baddr := io.acc_baddr
tensorAcc.io.tensor.rd.idx <> Mux(dec.io.isGemm, tensorGemm.io.acc.rd.idx, tensorAlu.io.acc.rd.idx)
tensorAcc.io.tensor.wr <> Mux(dec.io.isGemm, tensorGemm.io.acc.wr, tensorAlu.io.acc.wr)
tensorAcc.io.tensor.rd.idx <> Mux(dec.io.isGemm,
tensorGemm.io.acc.rd.idx,
tensorAlu.io.acc.rd.idx)
tensorAcc.io.tensor.wr <> Mux(dec.io.isGemm,
tensorGemm.io.acc.wr,
tensorAlu.io.acc.wr)
io.vme_rd(1) <> tensorAcc.io.vme_rd
// gemm
tensorGemm.io.start := state === sIdle & start & dec.io.isGemm
tensorGemm.io.start := state === sIdle & start & dec.io.isGemm
tensorGemm.io.inst := inst_q.io.deq.bits
tensorGemm.io.uop.data.valid := loadUop.io.uop.data.valid & dec.io.isGemm
tensorGemm.io.uop.data.bits <> loadUop.io.uop.data.bits
......@@ -136,7 +145,7 @@ class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module {
tensorGemm.io.out.rd.data.bits <> io.out.rd.data.bits
// alu
tensorAlu.io.start := state === sIdle & start & dec.io.isAlu
tensorAlu.io.start := state === sIdle & start & dec.io.isAlu
tensorAlu.io.inst := inst_q.io.deq.bits
tensorAlu.io.uop.data.valid := loadUop.io.uop.data.valid & dec.io.isAlu
tensorAlu.io.uop.data.bits <> loadUop.io.uop.data.bits
......@@ -146,7 +155,9 @@ class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module {
tensorAlu.io.out.rd.data.bits <> io.out.rd.data.bits
// out
io.out.rd.idx <> Mux(dec.io.isGemm, tensorGemm.io.out.rd.idx, tensorAlu.io.out.rd.idx)
io.out.rd.idx <> Mux(dec.io.isGemm,
tensorGemm.io.out.rd.idx,
tensorAlu.io.out.rd.idx)
io.out.wr <> Mux(dec.io.isGemm, tensorGemm.io.out.wr, tensorAlu.io.out.wr)
// semaphore
......@@ -163,38 +174,45 @@ class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module {
// debug
if (debug) {
// start
when (state === sIdle && start) {
when (dec.io.isSync) {
when(state === sIdle && start) {
when(dec.io.isSync) {
printf("[Compute] start sync\n")
} .elsewhen (dec.io.isLoadUop) {
printf("[Compute] start load uop\n")
} .elsewhen (dec.io.isLoadAcc) {
printf("[Compute] start load acc\n")
} .elsewhen (dec.io.isGemm) {
printf("[Compute] start gemm\n")
} .elsewhen (dec.io.isAlu) {
printf("[Compute] start alu\n")
} .elsewhen (dec.io.isFinish) {
printf("[Compute] start finish\n")
}
}.elsewhen(dec.io.isLoadUop) {
printf("[Compute] start load uop\n")
}
.elsewhen(dec.io.isLoadAcc) {
printf("[Compute] start load acc\n")
}
.elsewhen(dec.io.isGemm) {
printf("[Compute] start gemm\n")
}
.elsewhen(dec.io.isAlu) {
printf("[Compute] start alu\n")
}
.elsewhen(dec.io.isFinish) {
printf("[Compute] start finish\n")
}
}
// done
when (state === sSync) {
when(state === sSync) {
printf("[Compute] done sync\n")
}
when (state === sExe) {
when (done) {
when (dec.io.isLoadUop) {
when(state === sExe) {
when(done) {
when(dec.io.isLoadUop) {
printf("[Compute] done load uop\n")
} .elsewhen (dec.io.isLoadAcc) {
printf("[Compute] done load acc\n")
} .elsewhen (dec.io.isGemm) {
printf("[Compute] done gemm\n")
} .elsewhen (dec.io.isAlu) {
printf("[Compute] done alu\n")
} .elsewhen (dec.io.isFinish) {
printf("[Compute] done finish\n")
}
}.elsewhen(dec.io.isLoadAcc) {
printf("[Compute] done load acc\n")
}
.elsewhen(dec.io.isGemm) {
printf("[Compute] done gemm\n")
}
.elsewhen(dec.io.isAlu) {
printf("[Compute] done alu\n")
}
.elsewhen(dec.io.isFinish) {
printf("[Compute] done finish\n")
}
}
}
}
......
......@@ -27,20 +27,23 @@ import vta.util.config._
* be eventually filled out with class configurations that can be
* mixed/matched with Shell configurations for different backends.
*/
class CoreConfig extends Config((site, here, up) => {
case CoreKey => CoreParams(
batch = 1,
blockOut = 16,
blockIn = 16,
inpBits = 8,
wgtBits = 8,
uopBits = 32,
accBits = 32,
outBits = 8,
uopMemDepth = 2048,
inpMemDepth = 2048,
wgtMemDepth = 1024,
accMemDepth = 2048,
outMemDepth = 2048,
instQueueEntries = 512)
})
class CoreConfig
extends Config((site, here, up) => {
case CoreKey =>
CoreParams(
batch = 1,
blockOut = 16,
blockIn = 16,
inpBits = 8,
wgtBits = 8,
uopBits = 32,
accBits = 32,
outBits = 8,
uopMemDepth = 2048,
inpMemDepth = 2048,
wgtMemDepth = 1024,
accMemDepth = 2048,
outMemDepth = 2048,
instQueueEntries = 512
)
})
......@@ -24,24 +24,24 @@ import vta.util.config._
import vta.shell._
/** Core parameters */
case class CoreParams (
batch: Int = 1,
blockOut: Int = 16,
blockIn: Int = 16,
inpBits: Int = 8,
wgtBits: Int = 8,
uopBits: Int = 32,
accBits: Int = 32,
outBits: Int = 8,
uopMemDepth: Int = 512,
inpMemDepth: Int = 512,
wgtMemDepth: Int = 512,
accMemDepth: Int = 512,
outMemDepth: Int = 512,
instQueueEntries: Int = 32
)
{
require (uopBits % 8 == 0, s"\n\n[VTA] [CoreParams] uopBits must be byte aligned\n\n")
case class CoreParams(
batch: Int = 1,
blockOut: Int = 16,
blockIn: Int = 16,
inpBits: Int = 8,
wgtBits: Int = 8,
uopBits: Int = 32,
accBits: Int = 32,
outBits: Int = 8,
uopMemDepth: Int = 512,
inpMemDepth: Int = 512,
wgtMemDepth: Int = 512,
accMemDepth: Int = 512,
outMemDepth: Int = 512,
instQueueEntries: Int = 32
) {
require(uopBits % 8 == 0,
s"\n\n[VTA] [CoreParams] uopBits must be byte aligned\n\n")
}
case object CoreKey extends Field[CoreParams]
......
......@@ -133,8 +133,9 @@ class FetchDecode extends Module {
val isStore = Output(Bool())
})
val csignals =
ListLookup(io.inst,
List(N, OP_X),
ListLookup(
io.inst,
List(N, OP_X),
Array(
LUOP -> List(Y, OP_G),
LWGT -> List(Y, OP_L),
......
......@@ -38,17 +38,18 @@ import vta.shell._
* If one would like to add an event counter, then the value of nECnt must be
* changed in VCRParams together with the corresponding counting logic here.
*/
class EventCounters(debug: Boolean = false)(implicit p: Parameters) extends Module {
class EventCounters(debug: Boolean = false)(implicit p: Parameters)
extends Module {
val vp = p(ShellKey).vcrParams
val io = IO(new Bundle{
val io = IO(new Bundle {
val launch = Input(Bool())
val finish = Input(Bool())
val ecnt = Vec(vp.nECnt, ValidIO(UInt(vp.regBits.W)))
})
val cycle_cnt = RegInit(0.U(vp.regBits.W))
when (io.launch && !io.finish) {
when(io.launch && !io.finish) {
cycle_cnt := cycle_cnt + 1.U
} .otherwise {
}.otherwise {
cycle_cnt := 0.U
}
io.ecnt(0).valid := io.finish
......
......@@ -67,69 +67,70 @@ class Fetch(debug: Boolean = false)(implicit p: Parameters) extends Module {
val xrem = Reg(chiselTypeOf(io.ins_count))
val xsize = (io.ins_count << 1.U) - 1.U
val xmax = (1 << mp.lenBits).U
val xmax_bytes = ((1 << mp.lenBits)*mp.dataBits/8).U
val xmax_bytes = ((1 << mp.lenBits) * mp.dataBits / 8).U
val sIdle :: sReadCmd :: sReadLSB :: sReadMSB :: sDrain :: Nil = Enum(5)
val state = RegInit(sIdle)
// control
switch (state) {
is (sIdle) {
when (pulse) {
switch(state) {
is(sIdle) {
when(pulse) {
state := sReadCmd
when (xsize < xmax) {
when(xsize < xmax) {
rlen := xsize
ilen := xsize >> 1.U
xrem := 0.U
} .otherwise {
}.otherwise {
rlen := xmax - 1.U
ilen := (xmax >> 1.U) - 1.U
xrem := xsize - xmax
}
}
}
is (sReadCmd) {
when (io.vme_rd.cmd.ready) {
is(sReadCmd) {
when(io.vme_rd.cmd.ready) {
state := sReadLSB
}
}
is (sReadLSB) {
when (io.vme_rd.data.valid) {
is(sReadLSB) {
when(io.vme_rd.data.valid) {
state := sReadMSB
}
}
is (sReadMSB) {
when (io.vme_rd.data.valid) {
when (inst_q.io.count === ilen) {
is(sReadMSB) {
when(io.vme_rd.data.valid) {
when(inst_q.io.count === ilen) {
state := sDrain
} .otherwise {
}.otherwise {
state := sReadLSB
}
}
}
is (sDrain) {
when (inst_q.io.count === 0.U) {
when (xrem === 0.U) {
is(sDrain) {
when(inst_q.io.count === 0.U) {
when(xrem === 0.U) {
state := sIdle
} .elsewhen (xrem < xmax) {
state := sReadCmd
rlen := xrem
ilen := xrem >> 1.U
xrem := 0.U
} .otherwise {
state := sReadCmd
rlen := xmax - 1.U
ilen := (xmax >> 1.U) - 1.U
xrem := xrem - xmax
}
}.elsewhen(xrem < xmax) {
state := sReadCmd
rlen := xrem
ilen := xrem >> 1.U
xrem := 0.U
}
.otherwise {
state := sReadCmd
rlen := xmax - 1.U
ilen := (xmax >> 1.U) - 1.U
xrem := xrem - xmax
}
}
}
}
// read instructions from dram
when (state === sIdle) {
when(state === sIdle) {
raddr := io.ins_baddr
} .elsewhen (state === sDrain && inst_q.io.count === 0.U && xrem =/= 0.U) {
}.elsewhen(state === sDrain && inst_q.io.count === 0.U && xrem =/= 0.U) {
raddr := raddr + xmax_bytes
}
......@@ -143,7 +144,7 @@ class Fetch(debug: Boolean = false)(implicit p: Parameters) extends Module {
val msb = io.vme_rd.data.bits
val inst = Cat(msb, lsb)
when (state === sReadLSB) { lsb := io.vme_rd.data.bits }
when(state === sReadLSB) { lsb := io.vme_rd.data.bits }
inst_q.io.enq.valid := io.vme_rd.data.valid & state === sReadMSB
inst_q.io.enq.bits := inst
......@@ -164,32 +165,30 @@ class Fetch(debug: Boolean = false)(implicit p: Parameters) extends Module {
val deq_sel = Cat(dec.io.isCompute, dec.io.isStore, dec.io.isLoad).asUInt
val deq_ready =
MuxLookup(deq_sel,
false.B, // default
Array(
"h_01".U -> io.inst.ld.ready,
"h_02".U -> io.inst.st.ready,
"h_04".U -> io.inst.co.ready
)
)
false.B, // default
Array(
"h_01".U -> io.inst.ld.ready,
"h_02".U -> io.inst.st.ready,
"h_04".U -> io.inst.co.ready
))
// dequeue instruction
inst_q.io.deq.ready := deq_ready & inst_q.io.deq.valid & state === sDrain
// debug
if (debug) {
when (state === sIdle && pulse) {
when(state === sIdle && pulse) {
printf("[Fetch] Launch\n")
}
// instruction
when (inst_q.io.deq.fire()) {
when (dec.io.isLoad) {
when(inst_q.io.deq.fire()) {
when(dec.io.isLoad) {
printf("[Fetch] [instruction decode] [L] %x\n", inst_q.io.deq.bits)
}
when (dec.io.isCompute) {
when(dec.io.isCompute) {
printf("[Fetch] [instruction decode] [C] %x\n", inst_q.io.deq.bits)
}
when (dec.io.isStore) {
when(dec.io.isStore) {
printf("[Fetch] [instruction decode] [S] %x\n", inst_q.io.deq.bits)
}
}
......
......@@ -26,47 +26,46 @@ import chisel3.util._
*
* These constants are used for decoding (parsing) fields on instructions.
*/
trait ISAConstants
{
val INST_BITS = 128
trait ISAConstants {
val INST_BITS = 128
val OP_BITS = 3
val OP_BITS = 3
val M_DEP_BITS = 4
val M_ID_BITS = 2
val M_SRAM_OFFSET_BITS = 16
val M_DRAM_OFFSET_BITS = 32
val M_SIZE_BITS = 16
val M_STRIDE_BITS = 16
val M_PAD_BITS = 4
val M_DEP_BITS = 4
val M_ID_BITS = 2
val M_SRAM_OFFSET_BITS = 16
val M_DRAM_OFFSET_BITS = 32
val M_SIZE_BITS = 16
val M_STRIDE_BITS = 16
val M_PAD_BITS = 4
val C_UOP_BGN_BITS = 13
val C_UOP_END_BITS = 14
val C_ITER_BITS = 14
val C_AIDX_BITS = 11
val C_IIDX_BITS = 11
val C_WIDX_BITS = 10
val C_ALU_DEC_BITS = 2 // FIXME: there should be a SHL and SHR instruction
val C_ALU_OP_BITS = 3
val C_ALU_IMM_BITS = 16
val C_UOP_BGN_BITS = 13
val C_UOP_END_BITS = 14
val C_ITER_BITS = 14
val C_AIDX_BITS = 11
val C_IIDX_BITS = 11
val C_WIDX_BITS = 10
val C_ALU_DEC_BITS = 2 // FIXME: there should be a SHL and SHR instruction
val C_ALU_OP_BITS = 3
val C_ALU_IMM_BITS = 16
val Y = true.B
val N = false.B
val Y = true.B
val N = false.B
val OP_L = 0.asUInt(OP_BITS.W)
val OP_S = 1.asUInt(OP_BITS.W)
val OP_G = 2.asUInt(OP_BITS.W)
val OP_F = 3.asUInt(OP_BITS.W)
val OP_A = 4.asUInt(OP_BITS.W)
val OP_X = 5.asUInt(OP_BITS.W)
val OP_L = 0.asUInt(OP_BITS.W)
val OP_S = 1.asUInt(OP_BITS.W)
val OP_G = 2.asUInt(OP_BITS.W)
val OP_F = 3.asUInt(OP_BITS.W)
val OP_A = 4.asUInt(OP_BITS.W)
val OP_X = 5.asUInt(OP_BITS.W)
val ALU_OP_NUM = 5
val ALU_OP = Enum(ALU_OP_NUM)
val ALU_OP_NUM = 5
val ALU_OP = Enum(ALU_OP_NUM)
val M_ID_U = 0.asUInt(M_ID_BITS.W)
val M_ID_W = 1.asUInt(M_ID_BITS.W)
val M_ID_I = 2.asUInt(M_ID_BITS.W)
val M_ID_A = 3.asUInt(M_ID_BITS.W)
val M_ID_U = 0.asUInt(M_ID_BITS.W)
val M_ID_W = 1.asUInt(M_ID_BITS.W)
val M_ID_I = 2.asUInt(M_ID_BITS.W)
val M_ID_A = 3.asUInt(M_ID_BITS.W)
}
/** ISA.
......@@ -79,15 +78,37 @@ trait ISAConstants
* TODO: Add VXOR to clear accumulator
*/
object ISA {
def LUOP = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????0_0????000")
def LWGT = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????0_1????000")
def LINP = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????1_0????000")
def LACC = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????1_1????000")
def SOUT = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????001")
def GEMM = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????010")
def VMIN = BitPat("b_????????_????????_??00????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100")
def VMAX = BitPat("b_????????_????????_??01????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100")
def VADD = BitPat("b_????????_????????_??10????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100")
def VSHX = BitPat("b_????????_????????_??11????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100")
def FNSH = BitPat("b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????011")
def LUOP =
BitPat(
"b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????0_0????000")
def LWGT =
BitPat(
"b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????0_1????000")
def LINP =
BitPat(
"b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????1_0????000")
def LACC =
BitPat(
"b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_???????1_1????000")
def SOUT =
BitPat(
"b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????001")
def GEMM =
BitPat(
"b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????010")
def VMIN =
BitPat(
"b_????????_????????_??00????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100")
def VMAX =
BitPat(
"b_????????_????????_??01????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100")
def VADD =
BitPat(
"b_????????_????????_??10????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100")
def VSHX =
BitPat(
"b_????????_????????_??11????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????100")
def FNSH =
BitPat(
"b_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_????????_?????011")
}
......@@ -54,27 +54,28 @@ class Load(debug: Boolean = false)(implicit p: Parameters) extends Module {
val tensorType = Seq("inp", "wgt")
val tensorDec = Seq(dec.io.isInput, dec.io.isWeight)
val tensorLoad = Seq.tabulate(2)(i => Module(new TensorLoad(tensorType = tensorType(i))))
val tensorLoad =
Seq.tabulate(2)(i => Module(new TensorLoad(tensorType = tensorType(i))))
val start = inst_q.io.deq.valid & Mux(dec.io.pop_next, s.io.sready, true.B)
val done = Mux(dec.io.isInput, tensorLoad(0).io.done, tensorLoad(1).io.done)
// control
switch (state) {
is (sIdle) {
when (start) {
when (dec.io.isSync) {
switch(state) {
is(sIdle) {
when(start) {
when(dec.io.isSync) {
state := sSync
} .elsewhen (dec.io.isInput || dec.io.isWeight) {
}.elsewhen(dec.io.isInput || dec.io.isWeight) {
state := sExe
}
}
}
is (sSync) {
is(sSync) {
state := sIdle
}
is (sExe) {
when (done) {
is(sExe) {
when(done) {
state := sIdle
}
}
......@@ -105,24 +106,25 @@ class Load(debug: Boolean = false)(implicit p: Parameters) extends Module {
// debug
if (debug) {
// start
when (state === sIdle && start) {
when (dec.io.isSync) {
when(state === sIdle && start) {
when(dec.io.isSync) {
printf("[Load] start sync\n")
} .elsewhen (dec.io.isInput) {
printf("[Load] start input\n")
} .elsewhen (dec.io.isWeight) {
printf("[Load] start weight\n")
}
}.elsewhen(dec.io.isInput) {
printf("[Load] start input\n")
}
.elsewhen(dec.io.isWeight) {
printf("[Load] start weight\n")
}
}
// done
when (state === sSync) {
when(state === sSync) {
printf("[Load] done sync\n")
}
when (state === sExe) {
when (done) {
when (dec.io.isInput) {
when(state === sExe) {
when(done) {
when(dec.io.isInput) {
printf("[Load] done input\n")
} .elsewhen (dec.io.isWeight) {
}.elsewhen(dec.io.isWeight) {
printf("[Load] done weight\n")
}
}
......
......@@ -77,9 +77,9 @@ class LoadUop(debug: Boolean = false)(implicit p: Parameters) extends Module {
val xcnt = Reg(chiselTypeOf(io.vme_rd.cmd.bits.len))
val xlen = Reg(chiselTypeOf(io.vme_rd.cmd.bits.len))
val xrem = Reg(chiselTypeOf(dec.xsize))
val xsize = (dec.xsize >> log2Ceil(numUop)) + dec.xsize(0) + (dec.sram_offset % 2.U) - 1.U
val xsize = (dec.xsize >> log2Ceil(numUop)) + dec.xsize(0) + (dec.sram_offset % 2.U) - 1.U
val xmax = (1 << mp.lenBits).U
val xmax_bytes = ((1 << mp.lenBits)*mp.dataBits/8).U
val xmax_bytes = ((1 << mp.lenBits) * mp.dataBits / 8).U
val offsetIsEven = (dec.sram_offset % 2.U) === 0.U
val sizeIsEven = (dec.xsize % 2.U) === 0.U
......@@ -88,38 +88,39 @@ class LoadUop(debug: Boolean = false)(implicit p: Parameters) extends Module {
val state = RegInit(sIdle)
// control
switch (state) {
is (sIdle) {
when (io.start) {
switch(state) {
is(sIdle) {
when(io.start) {
state := sReadCmd
when (xsize < xmax) {
when(xsize < xmax) {
xlen := xsize
xrem := 0.U
} .otherwise {
}.otherwise {
xlen := xmax - 1.U
xrem := xsize - xmax
}
}
}
is (sReadCmd) {
when (io.vme_rd.cmd.ready) {
is(sReadCmd) {
when(io.vme_rd.cmd.ready) {
state := sReadData
}
}
is (sReadData) {
when (io.vme_rd.data.valid) {
is(sReadData) {
when(io.vme_rd.data.valid) {
when(xcnt === xlen) {
when (xrem === 0.U) {
when(xrem === 0.U) {
state := sIdle
} .elsewhen (xrem < xmax) {
state := sReadCmd
xlen := xrem
xrem := 0.U
} .otherwise {
state := sReadCmd
xlen := xmax - 1.U
xrem := xrem - xmax
}
}.elsewhen(xrem < xmax) {
state := sReadCmd
xlen := xrem
xrem := 0.U
}
.otherwise {
state := sReadCmd
xlen := xmax - 1.U
xrem := xrem - xmax
}
}
}
}
......@@ -127,13 +128,14 @@ class LoadUop(debug: Boolean = false)(implicit p: Parameters) extends Module {
// read-from-dram
val maskOffset = VecInit(Seq.fill(M_DRAM_OFFSET_BITS)(true.B)).asUInt
when (state === sIdle) {
when (offsetIsEven) {
when(state === sIdle) {
when(offsetIsEven) {
raddr := io.baddr | (maskOffset & (dec.dram_offset << log2Ceil(uopBytes)))
} .otherwise {
raddr := (io.baddr | (maskOffset & (dec.dram_offset << log2Ceil(uopBytes)))) - uopBytes.U
}.otherwise {
raddr := (io.baddr | (maskOffset & (dec.dram_offset << log2Ceil(
uopBytes)))) - uopBytes.U
}
} .elsewhen (state === sReadData && xcnt === xlen && xrem =/= 0.U) {
}.elsewhen(state === sReadData && xcnt === xlen && xrem =/= 0.U) {
raddr := raddr + xmax_bytes
}
......@@ -143,16 +145,16 @@ class LoadUop(debug: Boolean = false)(implicit p: Parameters) extends Module {
io.vme_rd.data.ready := state === sReadData
when (state =/= sReadData) {
when(state =/= sReadData) {
xcnt := 0.U
} .elsewhen (io.vme_rd.data.fire()) {
}.elsewhen(io.vme_rd.data.fire()) {
xcnt := xcnt + 1.U
}
val waddr = Reg(UInt(log2Ceil(uopDepth).W))
when (state === sIdle) {
when(state === sIdle) {
waddr := dec.sram_offset >> log2Ceil(numUop)
} .elsewhen (io.vme_rd.data.fire()) {
}.elsewhen(io.vme_rd.data.fire()) {
waddr := waddr + 1.U
}
......@@ -160,36 +162,37 @@ class LoadUop(debug: Boolean = false)(implicit p: Parameters) extends Module {
val mem = SyncReadMem(uopDepth, chiselTypeOf(wdata))
val wmask = Reg(Vec(numUop, Bool()))
when (offsetIsEven) {
when (sizeIsEven) {
when(offsetIsEven) {
when(sizeIsEven) {
wmask := "b_11".U.asTypeOf(wmask)
} .elsewhen (io.vme_rd.cmd.fire()) {
when (dec.xsize === 1.U) {
wmask := "b_01".U.asTypeOf(wmask)
} .otherwise {
wmask := "b_11".U.asTypeOf(wmask)
}.elsewhen(io.vme_rd.cmd.fire()) {
when(dec.xsize === 1.U) {
wmask := "b_01".U.asTypeOf(wmask)
}.otherwise {
wmask := "b_11".U.asTypeOf(wmask)
}
}
} .elsewhen (io.vme_rd.data.fire()) {
when (xcnt === xlen - 1.U) {
wmask := "b_01".U.asTypeOf(wmask)
} .otherwise {
wmask := "b_11".U.asTypeOf(wmask)
.elsewhen(io.vme_rd.data.fire()) {
when(xcnt === xlen - 1.U) {
wmask := "b_01".U.asTypeOf(wmask)
}.otherwise {
wmask := "b_11".U.asTypeOf(wmask)
}
}
}
} .otherwise {
when (io.vme_rd.cmd.fire()) {
}.otherwise {
when(io.vme_rd.cmd.fire()) {
wmask := "b_10".U.asTypeOf(wmask)
} .elsewhen (io.vme_rd.data.fire()) {
when (sizeIsEven && xcnt === xlen - 1.U) {
}.elsewhen(io.vme_rd.data.fire()) {
when(sizeIsEven && xcnt === xlen - 1.U) {
wmask := "b_01".U.asTypeOf(wmask)
} .otherwise {
}.otherwise {
wmask := "b_11".U.asTypeOf(wmask)
}
}
}
wdata := io.vme_rd.data.bits.asTypeOf(wdata)
when (io.vme_rd.data.fire()) {
when(io.vme_rd.data.fire()) {
mem.write(waddr, wdata, wmask)
}
......@@ -209,7 +212,7 @@ class LoadUop(debug: Boolean = false)(implicit p: Parameters) extends Module {
// debug
if (debug) {
when (io.vme_rd.cmd.fire()) {
when(io.vme_rd.cmd.fire()) {
printf("[LoadUop] cmd addr:%x len:%x rem:%x\n", raddr, xlen, xrem)
}
}
......
......@@ -29,14 +29,17 @@ import chisel3.util._
* depending on the push and pop fields on instructions to prevent RAW and WAR
* hazards.
*/
class Semaphore(counterBits: Int = 1, counterInitValue: Int = 1) extends Module {
class Semaphore(counterBits: Int = 1, counterInitValue: Int = 1)
extends Module {
val io = IO(new Bundle {
val spost = Input(Bool())
val swait = Input(Bool())
val sready = Output(Bool())
})
val cnt = RegInit(counterInitValue.U(counterBits.W))
when (io.spost && !io.swait && cnt =/= ((1 << counterBits) - 1).asUInt) { cnt := cnt + 1.U }
when (!io.spost && io.swait && cnt =/= 0.U) { cnt := cnt - 1.U }
when(io.spost && !io.swait && cnt =/= ((1 << counterBits) - 1).asUInt) {
cnt := cnt + 1.U
}
when(!io.spost && io.swait && cnt =/= 0.U) { cnt := cnt - 1.U }
io.sready := cnt =/= 0.U
}
......@@ -55,21 +55,21 @@ class Store(debug: Boolean = false)(implicit p: Parameters) extends Module {
val done = tensorStore.io.done
// control
switch (state) {
is (sIdle) {
when (start) {
when (dec.io.isSync) {
switch(state) {
is(sIdle) {
when(start) {
when(dec.io.isSync) {
state := sSync
} .elsewhen (dec.io.isStore) {
}.elsewhen(dec.io.isStore) {
state := sExe
}
}
}
is (sSync) {
is(sSync) {
state := sIdle
}
is (sExe) {
when (done) {
is(sExe) {
when(done) {
state := sIdle
}
}
......@@ -94,19 +94,19 @@ class Store(debug: Boolean = false)(implicit p: Parameters) extends Module {
// debug
if (debug) {
// start
when (state === sIdle && start) {
when (dec.io.isSync) {
when(state === sIdle && start) {
when(dec.io.isSync) {
printf("[Store] start sync\n")
} .elsewhen (dec.io.isStore) {
}.elsewhen(dec.io.isStore) {
printf("[Store] start\n")
}
}
// done
when (state === sSync) {
when(state === sSync) {
printf("[Store] done sync\n")
}
when (state === sExe) {
when (done) {
when(state === sExe) {
when(done) {
printf("[Store] done\n")
}
}
......
......@@ -116,7 +116,8 @@ class TensorAlu(debug: Boolean = false)(implicit p: Parameters) extends Module {
val acc = new TensorMaster(tensorType = "acc")
val out = new TensorMaster(tensorType = "out")
})
val sIdle :: sReadUop :: sComputeIdx :: sReadTensorA :: sReadTensorB :: sExe :: Nil = Enum(6)
val sIdle :: sReadUop :: sComputeIdx :: sReadTensorA :: sReadTensorB :: sExe :: Nil =
Enum(6)
val state = RegInit(sIdle)
val alu = Module(new AluVector)
val dec = io.inst.asTypeOf(new AluDecode)
......@@ -132,81 +133,86 @@ class TensorAlu(debug: Boolean = false)(implicit p: Parameters) extends Module {
val src_i = Reg(chiselTypeOf(dec.uop_end))
val done =
state === sExe &
alu.io.out.data.valid &
(cnt_o === dec.lp_0 - 1.U) &
(cnt_i === dec.lp_1 - 1.U) &
(uop_idx === uop_end - 1.U)
switch (state) {
is (sIdle) {
when (io.start) {
alu.io.out.data.valid &
(cnt_o === dec.lp_0 - 1.U) &
(cnt_i === dec.lp_1 - 1.U) &
(uop_idx === uop_end - 1.U)
switch(state) {
is(sIdle) {
when(io.start) {
state := sReadUop
}
}
is (sReadUop) {
is(sReadUop) {
state := sComputeIdx
}
is (sComputeIdx) {
is(sComputeIdx) {
state := sReadTensorA
}
is (sReadTensorA) {
is(sReadTensorA) {
state := sReadTensorB
}
is (sReadTensorB) {
is(sReadTensorB) {
state := sExe
}
is (sExe) {
when (alu.io.out.data.valid) {
when ((cnt_o === dec.lp_0 - 1.U) &&
(cnt_i === dec.lp_1 - 1.U) &&
(uop_idx === uop_end - 1.U)) {
is(sExe) {
when(alu.io.out.data.valid) {
when(
(cnt_o === dec.lp_0 - 1.U) &&
(cnt_i === dec.lp_1 - 1.U) &&
(uop_idx === uop_end - 1.U)) {
state := sIdle
} .otherwise {
}.otherwise {
state := sReadUop
}
}
}
}
when (state === sIdle ||
(state === sExe &&
alu.io.out.data.valid &&
uop_idx === uop_end - 1.U)) {
when(
state === sIdle ||
(state === sExe &&
alu.io.out.data.valid &&
uop_idx === uop_end - 1.U)) {
uop_idx := dec.uop_begin
} .elsewhen (state === sExe && alu.io.out.data.valid) {
}.elsewhen(state === sExe && alu.io.out.data.valid) {
uop_idx := uop_idx + 1.U
}
when (state === sIdle) {
when(state === sIdle) {
cnt_o := 0.U
dst_o := 0.U
src_o := 0.U
} .elsewhen (state === sExe &&
alu.io.out.data.valid &&
uop_idx === uop_end - 1.U &&
cnt_i === dec.lp_1 - 1.U) {
}.elsewhen(
state === sExe &&
alu.io.out.data.valid &&
uop_idx === uop_end - 1.U &&
cnt_i === dec.lp_1 - 1.U) {
cnt_o := cnt_o + 1.U
dst_o := dst_o + dec.dst_0
src_o := src_o + dec.src_0
}
when (state === sIdle) {
when(state === sIdle) {
cnt_i := 0.U
dst_i := 0.U
src_i := 0.U
} .elsewhen (state === sReadUop && cnt_i === dec.lp_1) {
cnt_i := 0.U
dst_i := dst_o
src_i := src_o
} .elsewhen (state === sExe &&
alu.io.out.data.valid &&
uop_idx === uop_end - 1.U) {
cnt_i := cnt_i + 1.U
dst_i := dst_i + dec.dst_1
src_i := src_i + dec.src_1
}
}.elsewhen(state === sReadUop && cnt_i === dec.lp_1) {
cnt_i := 0.U
dst_i := dst_o
src_i := src_o
}
.elsewhen(
state === sExe &&
alu.io.out.data.valid &&
uop_idx === uop_end - 1.U) {
cnt_i := cnt_i + 1.U
dst_i := dst_i + dec.dst_1
src_i := src_i + dec.src_1
}
when (state === sComputeIdx && io.uop.data.valid) {
when(state === sComputeIdx && io.uop.data.valid) {
uop_dst := io.uop.data.bits.u0 + dst_i
uop_src := io.uop.data.bits.u1 + src_i
}
......@@ -222,17 +228,25 @@ class TensorAlu(debug: Boolean = false)(implicit p: Parameters) extends Module {
// imm
val tensorImm = Wire(new TensorClientData(tensorType = "acc"))
tensorImm.data.valid := state === sReadTensorB
tensorImm.data.bits.foreach { b => b.foreach { c => c := dec.alu_imm } }
tensorImm.data.bits.foreach { b =>
b.foreach { c =>
c := dec.alu_imm
}
}
// alu
val isSHR = dec.alu_op === ALU_OP(3)
val neg_shift = isSHR & dec.alu_imm(C_ALU_IMM_BITS-1)
val neg_shift = isSHR & dec.alu_imm(C_ALU_IMM_BITS - 1)
val fixme_alu_op = Cat(neg_shift, Mux(neg_shift, 0.U, dec.alu_op))
alu.io.opcode := fixme_alu_op
alu.io.acc_a.data.valid := io.acc.rd.data.valid & state === sReadTensorB
alu.io.acc_a.data.bits <> io.acc.rd.data.bits
alu.io.acc_b.data.valid := Mux(dec.alu_use_imm, tensorImm.data.valid, io.acc.rd.data.valid & state === sExe)
alu.io.acc_b.data.bits <> Mux(dec.alu_use_imm, tensorImm.data.bits, io.acc.rd.data.bits)
alu.io.acc_b.data.valid := Mux(dec.alu_use_imm,
tensorImm.data.valid,
io.acc.rd.data.valid & state === sExe)
alu.io.acc_b.data.bits <> Mux(dec.alu_use_imm,
tensorImm.data.bits,
io.acc.rd.data.bits)
// acc_o
io.acc.wr.valid := alu.io.acc_y.data.valid
......@@ -249,47 +263,51 @@ class TensorAlu(debug: Boolean = false)(implicit p: Parameters) extends Module {
if (debug) {
when (state === sReadUop) {
when(state === sReadUop) {
printf("[TensorAlu] [uop] idx:%x\n", uop_idx)
}
when (state === sReadTensorA) {
when(state === sReadTensorA) {
printf("[TensorAlu] [uop] dst:%x src:%x\n", uop_dst, uop_src)
}
when (state === sIdle && io.start) {
when(state === sIdle && io.start) {
printf(p"[TensorAlu] decode:$dec\n")
}
alu.io.acc_a.data.bits.foreach { tensor =>
tensor.zipWithIndex.foreach { case(elem, i) =>
when (alu.io.acc_a.data.valid) {
printf("[TensorAlu] [a] i:%x val:%x\n", i.U, elem)
}
tensor.zipWithIndex.foreach {
case (elem, i) =>
when(alu.io.acc_a.data.valid) {
printf("[TensorAlu] [a] i:%x val:%x\n", i.U, elem)
}
}
}
alu.io.acc_b.data.bits.foreach { tensor =>
tensor.zipWithIndex.foreach { case(elem, i) =>
when (alu.io.acc_b.data.valid) {
printf("[TensorAlu] [b] i:%x val:%x\n", i.U, elem)
}
tensor.zipWithIndex.foreach {
case (elem, i) =>
when(alu.io.acc_b.data.valid) {
printf("[TensorAlu] [b] i:%x val:%x\n", i.U, elem)
}
}
}
alu.io.acc_y.data.bits.foreach { tensor =>
tensor.zipWithIndex.foreach { case(elem, i) =>
when (alu.io.acc_y.data.valid) {
printf("[TensorAlu] [y] i:%x val:%x\n", i.U, elem)
}
tensor.zipWithIndex.foreach {
case (elem, i) =>
when(alu.io.acc_y.data.valid) {
printf("[TensorAlu] [y] i:%x val:%x\n", i.U, elem)
}
}
}
alu.io.out.data.bits.foreach { tensor =>
tensor.zipWithIndex.foreach { case(elem, i) =>
when (alu.io.out.data.valid) {
printf("[TensorAlu] [out] i:%x val:%x\n", i.U, elem)
}
tensor.zipWithIndex.foreach {
case (elem, i) =>
when(alu.io.out.data.valid) {
printf("[TensorAlu] [out] i:%x val:%x\n", i.U, elem)
}
}
}
}
......
......@@ -62,8 +62,10 @@ class PipeAdder(aBits: Int = 8, bBits: Int = 8) extends Module {
}
/** Pipelined DotProduct based on MAC and PipeAdder */
class DotProduct(aBits: Int = 8, bBits: Int = 8, size: Int = 16) extends Module {
val errorMsg = s"\n\n[VTA] [DotProduct] size must be greater than 4 and a power of 2\n\n"
class DotProduct(aBits: Int = 8, bBits: Int = 8, size: Int = 16)
extends Module {
val errorMsg =
s"\n\n[VTA] [DotProduct] size must be greater than 4 and a power of 2\n\n"
require(size >= 2 && isPow2(size), errorMsg)
val b = aBits + bBits
val outBits = b + log2Ceil(size) + 1
......@@ -72,12 +74,15 @@ class DotProduct(aBits: Int = 8, bBits: Int = 8, size: Int = 16) extends Module
val b = Input(Vec(size, SInt(bBits.W)))
val y = Output(SInt(outBits.W))
})
val s = Seq.tabulate(log2Ceil(size + 1))(i => pow(2, log2Ceil(size) - i).toInt) // # of total layers
val s = Seq.tabulate(log2Ceil(size + 1))(i =>
pow(2, log2Ceil(size) - i).toInt) // # of total layers
val p = log2Ceil(size / 2) + 1 // # of adder layers
val m = Seq.fill(s(0))(Module(new MAC(aBits, bBits, cBits = 1))) // # of total vector pairs
val a = Seq.tabulate(p)(i =>
Seq.fill(s(i + 1))(Module(new PipeAdder(aBits = (b + i + 1), bBits = (b + i + 1))))
) // # adders within each layer
val a = Seq.tabulate(p)(
i =>
Seq.fill(s(i + 1))(Module(new PipeAdder(
aBits = (b + i + 1),
bBits = (b + i + 1))))) // # adders within each layer
// Vector MACs
for (i <- 0 until s(0)) {
......@@ -111,7 +116,7 @@ class MatrixVectorMultiplication(implicit p: Parameters) extends Module {
val inpBits = p(CoreKey).inpBits
val wgtBits = p(CoreKey).wgtBits
val outBits = p(CoreKey).outBits
val io = IO(new Bundle{
val io = IO(new Bundle {
val reset = Input(Bool()) // FIXME: reset should be replaced by a load-acc instr
val inp = new TensorMasterData(tensorType = "inp")
val wgt = new TensorMasterData(tensorType = "wgt")
......@@ -119,8 +124,10 @@ class MatrixVectorMultiplication(implicit p: Parameters) extends Module {
val acc_o = new TensorClientData(tensorType = "acc")
val out = new TensorClientData(tensorType = "out")
})
val dot = Seq.fill(size)(Module(new DotProduct(aBits = inpBits, bBits = wgtBits, size)))
val acc = Seq.fill(size)(Module(new Pipe(UInt(accBits.W), latency = log2Ceil(size) + 1)))
val dot = Seq.fill(size)(
Module(new DotProduct(aBits = inpBits, bBits = wgtBits, size)))
val acc = Seq.fill(size)(
Module(new Pipe(UInt(accBits.W), latency = log2Ceil(size) + 1)))
val add = Seq.fill(size)(Wire(SInt(accBits.W)))
val vld = Wire(Vec(size, Bool()))
......@@ -149,7 +156,8 @@ class MatrixVectorMultiplication(implicit p: Parameters) extends Module {
* Also, the TensorGemm uses the reset field in the Gemm instruction to
* clear or zero-out the acc-scratchpad locations based on the micro-ops.
*/
class TensorGemm(debug: Boolean = false)(implicit p: Parameters) extends Module {
class TensorGemm(debug: Boolean = false)(implicit p: Parameters)
extends Module {
val io = IO(new Bundle {
val start = Input(Bool())
val done = Output(Bool())
......@@ -160,7 +168,8 @@ class TensorGemm(debug: Boolean = false)(implicit p: Parameters) extends Module
val acc = new TensorMaster(tensorType = "acc")
val out = new TensorMaster(tensorType = "out")
})
val sIdle :: sReadUop :: sComputeIdx :: sReadTensor :: sExe :: sWait :: Nil = Enum(6)
val sIdle :: sReadUop :: sComputeIdx :: sReadTensor :: sExe :: sWait :: Nil =
Enum(6)
val state = RegInit(sIdle)
val mvc = Module(new MatrixVectorMultiplication)
val dec = io.inst.asTypeOf(new GemmDecode)
......@@ -181,99 +190,103 @@ class TensorGemm(debug: Boolean = false)(implicit p: Parameters) extends Module
val inflight = Reg(UInt(pBits.W))
val wrpipe = Module(new Pipe(chiselTypeOf(dec.uop_end), latency = pBits))
val done = inflight === 0.U &
((state === sExe &
cnt_o === dec.lp_0 - 1.U &
cnt_i === dec.lp_1 - 1.U &
uop_idx === uop_end - 1.U &
inflight === 0.U) |
state === sWait)
switch (state) {
is (sIdle) {
when (io.start) {
((state === sExe &
cnt_o === dec.lp_0 - 1.U &
cnt_i === dec.lp_1 - 1.U &
uop_idx === uop_end - 1.U &
inflight === 0.U) |
state === sWait)
switch(state) {
is(sIdle) {
when(io.start) {
state := sReadUop
}
}
is (sReadUop) {
is(sReadUop) {
state := sComputeIdx
}
is (sComputeIdx) {
is(sComputeIdx) {
state := sReadTensor
}
is (sReadTensor) {
is(sReadTensor) {
state := sExe
}
is (sExe) {
when ((cnt_o === dec.lp_0 - 1.U) &&
(cnt_i === dec.lp_1 - 1.U) &&
(uop_idx === uop_end - 1.U)) {
when (inflight =/= 0.U) {
is(sExe) {
when(
(cnt_o === dec.lp_0 - 1.U) &&
(cnt_i === dec.lp_1 - 1.U) &&
(uop_idx === uop_end - 1.U)) {
when(inflight =/= 0.U) {
state := sWait
} .otherwise {
}.otherwise {
state := sIdle
}
} .otherwise {
}.otherwise {
state := sReadUop
}
}
is (sWait) {
when (inflight === 0.U) {
is(sWait) {
when(inflight === 0.U) {
state := sIdle
}
}
}
when (state === sIdle) {
when(state === sIdle) {
inflight := 0.U
} .elsewhen (!dec.reset) {
when (state === sReadTensor) { // issue a tensor
}.elsewhen(!dec.reset) {
when(state === sReadTensor) { // issue a tensor
inflight := inflight + 1.U
} .elsewhen (mvc.io.acc_o.data.valid) { // commit a tensor
}.elsewhen(mvc.io.acc_o.data.valid) { // commit a tensor
inflight := inflight - 1.U
}
}
when (state === sIdle ||
(state === sExe &&
uop_idx === uop_end - 1.U)) {
when(
state === sIdle ||
(state === sExe &&
uop_idx === uop_end - 1.U)) {
uop_idx := dec.uop_begin
} .elsewhen (state === sExe) {
}.elsewhen(state === sExe) {
uop_idx := uop_idx + 1.U
}
when (state === sIdle) {
when(state === sIdle) {
cnt_o := 0.U
acc_o := 0.U
inp_o := 0.U
wgt_o := 0.U
} .elsewhen (state === sExe &&
uop_idx === uop_end - 1.U &&
cnt_i === dec.lp_1 - 1.U) {
}.elsewhen(
state === sExe &&
uop_idx === uop_end - 1.U &&
cnt_i === dec.lp_1 - 1.U) {
cnt_o := cnt_o + 1.U
acc_o := acc_o + dec.acc_0
inp_o := inp_o + dec.inp_0
wgt_o := wgt_o + dec.wgt_0
}
when (state === sIdle) {
when(state === sIdle) {
cnt_i := 0.U
acc_i := 0.U
inp_i := 0.U
wgt_i := 0.U
} .elsewhen (state === sReadUop && cnt_i === dec.lp_1) {
cnt_i := 0.U
acc_i := acc_o
inp_i := inp_o
wgt_i := wgt_o
} .elsewhen (state === sExe &&
uop_idx === uop_end - 1.U) {
cnt_i := cnt_i + 1.U
acc_i := acc_i + dec.acc_1
inp_i := inp_i + dec.inp_1
wgt_i := wgt_i + dec.wgt_1
}
}.elsewhen(state === sReadUop && cnt_i === dec.lp_1) {
cnt_i := 0.U
acc_i := acc_o
inp_i := inp_o
wgt_i := wgt_o
}
.elsewhen(state === sExe &&
uop_idx === uop_end - 1.U) {
cnt_i := cnt_i + 1.U
acc_i := acc_i + dec.acc_1
inp_i := inp_i + dec.inp_1
wgt_i := wgt_i + dec.wgt_1
}
when (state === sComputeIdx && io.uop.data.valid) {
when(state === sComputeIdx && io.uop.data.valid) {
uop_acc := io.uop.data.bits.u0 + acc_i
uop_inp := io.uop.data.bits.u1 + inp_i
uop_wgt := io.uop.data.bits.u2 + wgt_i
......@@ -307,7 +320,9 @@ class TensorGemm(debug: Boolean = false)(implicit p: Parameters) extends Module
mvc.io.acc_i.data <> io.acc.rd.data
// acc_o
io.acc.wr.valid := mvc.io.acc_o.data.valid & Mux(dec.reset, true.B, wrpipe.io.deq.valid)
io.acc.wr.valid := mvc.io.acc_o.data.valid & Mux(dec.reset,
true.B,
wrpipe.io.deq.valid)
io.acc.wr.bits.idx := Mux(dec.reset, uop_acc, wrpipe.io.deq.bits)
io.acc.wr.bits.data <> mvc.io.acc_o.data.bits
......@@ -320,47 +335,55 @@ class TensorGemm(debug: Boolean = false)(implicit p: Parameters) extends Module
io.done := done
if (debug) {
when (state === sReadUop && ~dec.reset) {
when(state === sReadUop && ~dec.reset) {
printf("[TensorGemm] [uop] idx:%x\n", uop_idx)
}
when (state === sReadTensor && ~dec.reset) {
printf("[TensorGemm] [uop] acc:%x inp:%x wgt:%x\n", uop_acc, uop_inp, uop_wgt)
when(state === sReadTensor && ~dec.reset) {
printf("[TensorGemm] [uop] acc:%x inp:%x wgt:%x\n",
uop_acc,
uop_inp,
uop_wgt)
}
io.inp.rd.data.bits.zipWithIndex.foreach { case(r, i) =>
when (io.inp.rd.data.valid && ~dec.reset) {
printf("[TensorGemm] [inp] i:%x val:%x\n", i.U, r.asUInt)
}
io.inp.rd.data.bits.zipWithIndex.foreach {
case (r, i) =>
when(io.inp.rd.data.valid && ~dec.reset) {
printf("[TensorGemm] [inp] i:%x val:%x\n", i.U, r.asUInt)
}
}
io.wgt.rd.data.bits.zipWithIndex.foreach { case(r, i) =>
when (io.wgt.rd.data.valid && ~dec.reset) {
printf("[TensorGemm] [wgt] i:%x val:%x\n", i.U, r.asUInt)
}
io.wgt.rd.data.bits.zipWithIndex.foreach {
case (r, i) =>
when(io.wgt.rd.data.valid && ~dec.reset) {
printf("[TensorGemm] [wgt] i:%x val:%x\n", i.U, r.asUInt)
}
}
io.acc.rd.data.bits.foreach { tensor =>
tensor.zipWithIndex.foreach { case(elem, i) =>
when (io.acc.rd.data.valid && ~dec.reset) {
printf("[TensorGemm] [acc_i] i:%x val:%x\n", i.U, elem)
}
tensor.zipWithIndex.foreach {
case (elem, i) =>
when(io.acc.rd.data.valid && ~dec.reset) {
printf("[TensorGemm] [acc_i] i:%x val:%x\n", i.U, elem)
}
}
}
mvc.io.acc_o.data.bits.foreach { tensor =>
tensor.zipWithIndex.foreach { case(elem, i) =>
when (mvc.io.acc_o.data.valid && ~dec.reset) {
printf("[TensorGemm] [acc_o] i:%x val:%x\n", i.U, elem)
}
tensor.zipWithIndex.foreach {
case (elem, i) =>
when(mvc.io.acc_o.data.valid && ~dec.reset) {
printf("[TensorGemm] [acc_o] i:%x val:%x\n", i.U, elem)
}
}
}
mvc.io.out.data.bits.foreach { tensor =>
tensor.zipWithIndex.foreach { case(elem, i) =>
when (mvc.io.out.data.valid && ~dec.reset) {
printf("[TensorGemm] [out] i:%x val:%x\n", i.U, elem)
}
tensor.zipWithIndex.foreach {
case (elem, i) =>
when(mvc.io.out.data.valid && ~dec.reset) {
printf("[TensorGemm] [out] i:%x val:%x\n", i.U, elem)
}
}
}
}
......
......@@ -32,8 +32,9 @@ import vta.shell._
* managed by TensorPadCtrl. The TensorDataCtrl is in charge of
* handling the way tensors are stored on the scratchpads.
*/
class TensorLoad(tensorType: String = "none", debug: Boolean = false)
(implicit p: Parameters) extends Module {
class TensorLoad(tensorType: String = "none", debug: Boolean = false)(
implicit p: Parameters)
extends Module {
val tp = new TensorParams(tensorType)
val mp = p(ShellKey).memParams
val io = IO(new Bundle {
......@@ -48,7 +49,8 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false)
val strideFactor = tp.tensorLength * tp.tensorWidth
val dec = io.inst.asTypeOf(new MemDecode)
val dataCtrl = Module(new TensorDataCtrl(tensorType, sizeFactor, strideFactor))
val dataCtrl = Module(
new TensorDataCtrl(tensorType, sizeFactor, strideFactor))
val dataCtrlDone = RegInit(false.B)
val yPadCtrl0 = Module(new TensorPadCtrl(padType = "YPad0", sizeFactor))
val yPadCtrl1 = Module(new TensorPadCtrl(padType = "YPad1", sizeFactor))
......@@ -58,81 +60,85 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false)
val tag = Reg(UInt(log2Ceil(tp.numMemBlock).W))
val set = Reg(UInt(log2Ceil(tp.tensorLength).W))
val sIdle :: sYPad0 :: sXPad0 :: sReadCmd :: sReadData :: sXPad1 :: sYPad1 :: Nil = Enum(7)
val sIdle :: sYPad0 :: sXPad0 :: sReadCmd :: sReadData :: sXPad1 :: sYPad1 :: Nil =
Enum(7)
val state = RegInit(sIdle)
// control
switch (state) {
is (sIdle) {
when (io.start) {
when (dec.ypad_0 =/= 0.U) {
switch(state) {
is(sIdle) {
when(io.start) {
when(dec.ypad_0 =/= 0.U) {
state := sYPad0
} .elsewhen (dec.xpad_0 =/= 0.U) {
state := sXPad0
} .otherwise {
state := sReadCmd
}
}.elsewhen(dec.xpad_0 =/= 0.U) {
state := sXPad0
}
.otherwise {
state := sReadCmd
}
}
}
is (sYPad0) {
when (yPadCtrl0.io.done) {
when (dec.xpad_0 =/= 0.U) {
is(sYPad0) {
when(yPadCtrl0.io.done) {
when(dec.xpad_0 =/= 0.U) {
state := sXPad0
} .otherwise {
}.otherwise {
state := sReadCmd
}
}
}
is (sXPad0) {
when (xPadCtrl0.io.done) {
is(sXPad0) {
when(xPadCtrl0.io.done) {
state := sReadCmd
}
}
is (sReadCmd) {
when (io.vme_rd.cmd.ready) {
is(sReadCmd) {
when(io.vme_rd.cmd.ready) {
state := sReadData
}
}
is (sReadData) {
when (io.vme_rd.data.valid) {
when (dataCtrl.io.done) {
when (dec.xpad_1 =/= 0.U) {
is(sReadData) {
when(io.vme_rd.data.valid) {
when(dataCtrl.io.done) {
when(dec.xpad_1 =/= 0.U) {
state := sXPad1
} .elsewhen (dec.ypad_1 =/= 0.U) {
state := sYPad1
} .otherwise {
state := sIdle
}
} .elsewhen (dataCtrl.io.stride || dataCtrl.io.split) {
when (dec.xpad_1 =/= 0.U) {
}.elsewhen(dec.ypad_1 =/= 0.U) {
state := sYPad1
}
.otherwise {
state := sIdle
}
}.elsewhen(dataCtrl.io.stride || dataCtrl.io.split) {
when(dec.xpad_1 =/= 0.U) {
state := sXPad1
} .elsewhen (dec.xpad_0 =/= 0.U) {
state := sXPad0
} .otherwise {
}.elsewhen(dec.xpad_0 =/= 0.U) {
state := sXPad0
}
.otherwise {
state := sReadCmd
}
}
}
}
}
is (sXPad1) {
when (xPadCtrl1.io.done) {
when (dataCtrlDone) {
when (dec.ypad_1 =/= 0.U) {
is(sXPad1) {
when(xPadCtrl1.io.done) {
when(dataCtrlDone) {
when(dec.ypad_1 =/= 0.U) {
state := sYPad1
} .otherwise {
}.otherwise {
state := sIdle
}
} .otherwise {
when (dec.xpad_0 =/= 0.U) {
}.otherwise {
when(dec.xpad_0 =/= 0.U) {
state := sXPad0
} .otherwise {
}.otherwise {
state := sReadCmd
}
}
}
}
is (sYPad1) {
when (yPadCtrl1.io.done && dataCtrlDone) {
is(sYPad1) {
when(yPadCtrl1.io.done && dataCtrlDone) {
state := sIdle
}
}
......@@ -146,9 +152,9 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false)
dataCtrl.io.xupdate := io.vme_rd.data.fire()
dataCtrl.io.yupdate := io.vme_rd.data.fire()
when (state === sIdle) {
when(state === sIdle) {
dataCtrlDone := false.B
} .elsewhen (io.vme_rd.data.fire() && dataCtrl.io.done) {
}.elsewhen(io.vme_rd.data.fire() && dataCtrl.io.done) {
dataCtrlDone := true.B
}
......@@ -156,18 +162,19 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false)
yPadCtrl0.io.start := dec.ypad_0 =/= 0.U & state === sIdle & io.start
yPadCtrl1.io.start := dec.ypad_1 =/= 0.U &
((io.vme_rd.data.fire() & dataCtrl.io.done & dec.xpad_1 === 0.U) |
(state === sXPad1 & xPadCtrl1.io.done & dataCtrlDone))
((io.vme_rd.data.fire() & dataCtrl.io.done & dec.xpad_1 === 0.U) |
(state === sXPad1 & xPadCtrl1.io.done & dataCtrlDone))
xPadCtrl0.io.start := dec.xpad_0 =/= 0.U &
((state === sIdle & io.start) |
(state === sYPad0 & yPadCtrl0.io.done) |
(io.vme_rd.data.fire() & ~dataCtrlDone & (dataCtrl.io.stride | dataCtrl.io.split) & dec.xpad_1 === 0.U) |
(state === sXPad1 & xPadCtrl1.io.done & ~dataCtrlDone))
((state === sIdle & io.start) |
(state === sYPad0 & yPadCtrl0.io.done) |
(io.vme_rd.data
.fire() & ~dataCtrlDone & (dataCtrl.io.stride | dataCtrl.io.split) & dec.xpad_1 === 0.U) |
(state === sXPad1 & xPadCtrl1.io.done & ~dataCtrlDone))
xPadCtrl1.io.start := dec.xpad_1 =/= 0.U & io.vme_rd.data.fire() &
((dataCtrl.io.done) |
(~dataCtrl.io.done & (dataCtrl.io.stride | dataCtrl.io.split) & dec.xpad_1 =/= 0.U))
((dataCtrl.io.done) |
(~dataCtrl.io.done & (dataCtrl.io.stride | dataCtrl.io.split) & dec.xpad_1 =/= 0.U))
yPadCtrl0.io.inst := io.inst
yPadCtrl1.io.inst := io.inst
......@@ -183,39 +190,49 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false)
// write-to-sram
val isZeroPad = state === sYPad0 |
state === sXPad0 |
state === sXPad1 |
state === sYPad1
state === sXPad0 |
state === sXPad1 |
state === sYPad1
when (state === sIdle || state === sReadCmd || tag === (tp.numMemBlock - 1).U) {
when(state === sIdle || state === sReadCmd || tag === (tp.numMemBlock - 1).U) {
tag := 0.U
} .elsewhen (io.vme_rd.data.fire() || isZeroPad) {
}.elsewhen(io.vme_rd.data.fire() || isZeroPad) {
tag := tag + 1.U
}
when (state === sIdle || dataCtrlDone || (set === (tp.tensorLength - 1).U && tag === (tp.numMemBlock - 1).U)) {
when(
state === sIdle || dataCtrlDone || (set === (tp.tensorLength - 1).U && tag === (tp.numMemBlock - 1).U)) {
set := 0.U
} .elsewhen ((io.vme_rd.data.fire() || isZeroPad) && tag === (tp.numMemBlock - 1).U) {
}.elsewhen(
(io.vme_rd.data.fire() || isZeroPad) && tag === (tp.numMemBlock - 1).U) {
set := set + 1.U
}
val waddr_cur = Reg(UInt(tp.memAddrBits.W))
val waddr_nxt = Reg(UInt(tp.memAddrBits.W))
when (state === sIdle) {
when(state === sIdle) {
waddr_cur := dec.sram_offset
waddr_nxt := dec.sram_offset
} .elsewhen ((io.vme_rd.data.fire() || isZeroPad) && set === (tp.tensorLength - 1).U && tag === (tp.numMemBlock - 1).U) {
waddr_cur := waddr_cur + 1.U
} .elsewhen (dataCtrl.io.stride) {
waddr_cur := waddr_nxt + dec.xsize
waddr_nxt := waddr_nxt + dec.xsize
}
}.elsewhen((io.vme_rd.data
.fire() || isZeroPad) && set === (tp.tensorLength - 1).U && tag === (tp.numMemBlock - 1).U) {
waddr_cur := waddr_cur + 1.U
}
.elsewhen(dataCtrl.io.stride) {
waddr_cur := waddr_nxt + dec.xsize
waddr_nxt := waddr_nxt + dec.xsize
}
val tensorFile = Seq.fill(tp.tensorLength) { SyncReadMem(tp.memDepth, Vec(tp.numMemBlock, UInt(tp.memBlockBits.W))) }
val tensorFile = Seq.fill(tp.tensorLength) {
SyncReadMem(tp.memDepth, Vec(tp.numMemBlock, UInt(tp.memBlockBits.W)))
}
val wmask = Seq.fill(tp.tensorLength) { Wire(Vec(tp.numMemBlock, Bool())) }
val wdata = Seq.fill(tp.tensorLength) { Wire(Vec(tp.numMemBlock, UInt(tp.memBlockBits.W))) }
val wdata = Seq.fill(tp.tensorLength) {
Wire(Vec(tp.numMemBlock, UInt(tp.memBlockBits.W)))
}
val no_mask = Wire(Vec(tp.numMemBlock, Bool()))
no_mask.foreach { m => m := true.B }
no_mask.foreach { m =>
m := true.B
}
for (i <- 0 until tp.tensorLength) {
for (j <- 0 until tp.numMemBlock) {
......@@ -223,11 +240,14 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false)
wdata(i)(j) := Mux(isZeroPad, 0.U, io.vme_rd.data.bits)
}
val tdata = io.tensor.wr.bits.data(i).asUInt.asTypeOf(wdata(i))
val muxWen = Mux(state === sIdle, io.tensor.wr.valid, (io.vme_rd.data.fire() | isZeroPad) & set === i.U)
val muxWen =
Mux(state === sIdle,
io.tensor.wr.valid,
(io.vme_rd.data.fire() | isZeroPad) & set === i.U)
val muxWaddr = Mux(state === sIdle, io.tensor.wr.bits.idx, waddr_cur)
val muxWdata = Mux(state === sIdle, tdata, wdata(i))
val muxWmask = Mux(state === sIdle, no_mask, wmask(i))
when (muxWen) {
when(muxWen) {
tensorFile(i).write(muxWaddr, muxWdata, muxWmask)
}
}
......@@ -236,13 +256,16 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false)
val rvalid = RegNext(io.tensor.rd.idx.valid)
io.tensor.rd.data.valid := rvalid
val rdata = tensorFile.map(_.read(io.tensor.rd.idx.bits, io.tensor.rd.idx.valid))
rdata.zipWithIndex.foreach { case(r, i) =>
io.tensor.rd.data.bits(i) := r.asUInt.asTypeOf(io.tensor.rd.data.bits(i))
val rdata =
tensorFile.map(_.read(io.tensor.rd.idx.bits, io.tensor.rd.idx.valid))
rdata.zipWithIndex.foreach {
case (r, i) =>
io.tensor.rd.data.bits(i) := r.asUInt.asTypeOf(io.tensor.rd.data.bits(i))
}
// done
val done_no_pad = io.vme_rd.data.fire() & dataCtrl.io.done & dec.xpad_1 === 0.U & dec.ypad_1 === 0.U
val done_no_pad = io.vme_rd.data
.fire() & dataCtrl.io.done & dec.xpad_1 === 0.U & dec.ypad_1 === 0.U
val done_x_pad = state === sXPad1 & xPadCtrl1.io.done & dataCtrlDone & dec.ypad_1 === 0.U
val done_y_pad = state === sYPad1 & dataCtrlDone & yPadCtrl1.io.done
io.done := done_no_pad | done_x_pad | done_y_pad
......@@ -250,28 +273,34 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false)
// debug
if (debug) {
if (tensorType == "inp") {
when (io.vme_rd.cmd.fire()) {
printf("[TensorLoad] [inp] cmd addr:%x len:%x\n", dataCtrl.io.addr, dataCtrl.io.len)
when(io.vme_rd.cmd.fire()) {
printf("[TensorLoad] [inp] cmd addr:%x len:%x\n",
dataCtrl.io.addr,
dataCtrl.io.len)
}
when (state === sYPad0) {
when(state === sYPad0) {
printf("[TensorLoad] [inp] sYPad0\n")
}
when (state === sYPad1) {
when(state === sYPad1) {
printf("[TensorLoad] [inp] sYPad1\n")
}
when (state === sXPad0) {
when(state === sXPad0) {
printf("[TensorLoad] [inp] sXPad0\n")
}
when (state === sXPad1) {
when(state === sXPad1) {
printf("[TensorLoad] [inp] sXPad1\n")
}
} else if (tensorType == "wgt") {
when (io.vme_rd.cmd.fire()) {
printf("[TensorLoad] [wgt] cmd addr:%x len:%x\n", dataCtrl.io.addr, dataCtrl.io.len)
when(io.vme_rd.cmd.fire()) {
printf("[TensorLoad] [wgt] cmd addr:%x len:%x\n",
dataCtrl.io.addr,
dataCtrl.io.len)
}
} else if (tensorType == "acc") {
when (io.vme_rd.cmd.fire()) {
printf("[TensorLoad] [acc] cmd addr:%x len:%x\n", dataCtrl.io.addr, dataCtrl.io.len)
when(io.vme_rd.cmd.fire()) {
printf("[TensorLoad] [acc] cmd addr:%x len:%x\n",
dataCtrl.io.addr,
dataCtrl.io.len)
}
}
}
......
......@@ -28,8 +28,9 @@ import vta.shell._
*
* Store 1D and 2D tensors from out-scratchpad (SRAM) to main memory (DRAM).
*/
class TensorStore(tensorType: String = "none", debug: Boolean = false)
(implicit p: Parameters) extends Module {
class TensorStore(tensorType: String = "none", debug: Boolean = false)(
implicit p: Parameters)
extends Module {
val tp = new TensorParams(tensorType)
val mp = p(ShellKey).memParams
val io = IO(new Bundle {
......@@ -53,9 +54,9 @@ class TensorStore(tensorType: String = "none", debug: Boolean = false)
val xcnt = Reg(chiselTypeOf(io.vme_wr.cmd.bits.len))
val xlen = Reg(chiselTypeOf(io.vme_wr.cmd.bits.len))
val xrem = Reg(chiselTypeOf(dec.xsize))
val xsize = (dec.xsize << log2Ceil(tensorLength*numMemBlock)) - 1.U
val xsize = (dec.xsize << log2Ceil(tensorLength * numMemBlock)) - 1.U
val xmax = (1 << mp.lenBits).U
val xmax_bytes = ((1 << mp.lenBits)*mp.dataBits/8).U
val xmax_bytes = ((1 << mp.lenBits) * mp.dataBits / 8).U
val ycnt = Reg(chiselTypeOf(dec.ysize))
val ysize = dec.ysize
val tag = Reg(UInt(8.W))
......@@ -65,132 +66,147 @@ class TensorStore(tensorType: String = "none", debug: Boolean = false)
val state = RegInit(sIdle)
// control
switch (state) {
is (sIdle) {
when (io.start) {
switch(state) {
is(sIdle) {
when(io.start) {
state := sWriteCmd
when (xsize < xmax) {
when(xsize < xmax) {
xlen := xsize
xrem := 0.U
} .otherwise {
}.otherwise {
xlen := xmax - 1.U
xrem := xsize - xmax
}
}
}
is (sWriteCmd) {
when (io.vme_wr.cmd.ready) {
is(sWriteCmd) {
when(io.vme_wr.cmd.ready) {
state := sWriteData
}
}
is (sWriteData) {
when (io.vme_wr.data.ready) {
when (xcnt === xlen) {
is(sWriteData) {
when(io.vme_wr.data.ready) {
when(xcnt === xlen) {
state := sWriteAck
} .elsewhen (tag === (numMemBlock - 1).U) {
}.elsewhen(tag === (numMemBlock - 1).U) {
state := sReadMem
}
}
}
is (sReadMem) {
is(sReadMem) {
state := sWriteData
}
is (sWriteAck) {
when (io.vme_wr.ack) {
when (xrem === 0.U) {
when (ycnt === ysize - 1.U) {
is(sWriteAck) {
when(io.vme_wr.ack) {
when(xrem === 0.U) {
when(ycnt === ysize - 1.U) {
state := sIdle
} .otherwise {
}.otherwise {
state := sWriteCmd
when (xsize < xmax) {
when(xsize < xmax) {
xlen := xsize
xrem := 0.U
} .otherwise {
}.otherwise {
xlen := xmax - 1.U
xrem := xsize - xmax
}
}
} .elsewhen (xrem < xmax) {
state := sWriteCmd
xlen := xrem
xrem := 0.U
} .otherwise {
state := sWriteCmd
xlen := xmax - 1.U
xrem := xrem - xmax
}
}.elsewhen(xrem < xmax) {
state := sWriteCmd
xlen := xrem
xrem := 0.U
}
.otherwise {
state := sWriteCmd
xlen := xmax - 1.U
xrem := xrem - xmax
}
}
}
}
// write-to-sram
val tensorFile = Seq.fill(tensorLength) { SyncReadMem(memDepth, Vec(numMemBlock, UInt(memBlockBits.W))) }
val tensorFile = Seq.fill(tensorLength) {
SyncReadMem(memDepth, Vec(numMemBlock, UInt(memBlockBits.W)))
}
val wdata_t = Wire(Vec(numMemBlock, UInt(memBlockBits.W)))
val no_mask = Wire(Vec(numMemBlock, Bool()))
wdata_t := DontCare
no_mask.foreach { m => m := true.B }
no_mask.foreach { m =>
m := true.B
}
for (i <- 0 until tensorLength) {
val inWrData = io.tensor.wr.bits.data(i).asUInt.asTypeOf(wdata_t)
when (io.tensor.wr.valid) {
when(io.tensor.wr.valid) {
tensorFile(i).write(io.tensor.wr.bits.idx, inWrData, no_mask)
}
}
// read-from-sram
val stride = state === sWriteAck &
io.vme_wr.ack &
xcnt === xlen + 1.U &
xrem === 0.U &
ycnt =/= ysize - 1.U
io.vme_wr.ack &
xcnt === xlen + 1.U &
xrem === 0.U &
ycnt =/= ysize - 1.U
when (state === sIdle) {
when(state === sIdle) {
ycnt := 0.U
} .elsewhen (stride) {
}.elsewhen(stride) {
ycnt := ycnt + 1.U
}
when (state === sWriteCmd || tag === (numMemBlock - 1).U) {
when(state === sWriteCmd || tag === (numMemBlock - 1).U) {
tag := 0.U
} .elsewhen (io.vme_wr.data.fire()) {
}.elsewhen(io.vme_wr.data.fire()) {
tag := tag + 1.U
}
when (state === sWriteCmd || (set === (tensorLength - 1).U && tag === (numMemBlock - 1).U)) {
when(
state === sWriteCmd || (set === (tensorLength - 1).U && tag === (numMemBlock - 1).U)) {
set := 0.U
} .elsewhen (io.vme_wr.data.fire() && tag === (numMemBlock - 1).U) {
}.elsewhen(io.vme_wr.data.fire() && tag === (numMemBlock - 1).U) {
set := set + 1.U
}
val raddr_cur = Reg(UInt(tp.memAddrBits.W))
val raddr_nxt = Reg(UInt(tp.memAddrBits.W))
when (state === sIdle) {
when(state === sIdle) {
raddr_cur := dec.sram_offset
raddr_nxt := dec.sram_offset
} .elsewhen (io.vme_wr.data.fire() && set === (tensorLength - 1).U && tag === (numMemBlock - 1).U) {
raddr_cur := raddr_cur + 1.U
} .elsewhen (stride) {
raddr_cur := raddr_nxt + dec.xsize
raddr_nxt := raddr_nxt + dec.xsize
}
}.elsewhen(io.vme_wr.data
.fire() && set === (tensorLength - 1).U && tag === (numMemBlock - 1).U) {
raddr_cur := raddr_cur + 1.U
}
.elsewhen(stride) {
raddr_cur := raddr_nxt + dec.xsize
raddr_nxt := raddr_nxt + dec.xsize
}
val tread = Seq.tabulate(tensorLength) { i => i.U ->
tensorFile(i).read(raddr_cur, state === sWriteCmd | state === sReadMem) }
val tread = Seq.tabulate(tensorLength) { i =>
i.U ->
tensorFile(i).read(raddr_cur, state === sWriteCmd | state === sReadMem)
}
val mdata = MuxLookup(set, 0.U.asTypeOf(chiselTypeOf(wdata_t)), tread)
// write-to-dram
val maskOffset = VecInit(Seq.fill(M_DRAM_OFFSET_BITS)(true.B)).asUInt
val elemBytes = (p(CoreKey).batch * p(CoreKey).blockOut * p(CoreKey).outBits) / 8
when (state === sIdle) {
waddr_cur := io.baddr | (maskOffset & (dec.dram_offset << log2Ceil(elemBytes)))
waddr_nxt := io.baddr | (maskOffset & (dec.dram_offset << log2Ceil(elemBytes)))
} .elsewhen (state === sWriteAck && io.vme_wr.ack && xrem =/= 0.U) {
waddr_cur := waddr_cur + xmax_bytes
} .elsewhen (stride) {
waddr_cur := waddr_nxt + (dec.xstride << log2Ceil(tensorLength*tensorWidth))
waddr_nxt := waddr_nxt + (dec.xstride << log2Ceil(tensorLength*tensorWidth))
}
when(state === sIdle) {
waddr_cur := io.baddr | (maskOffset & (dec.dram_offset << log2Ceil(
elemBytes)))
waddr_nxt := io.baddr | (maskOffset & (dec.dram_offset << log2Ceil(
elemBytes)))
}.elsewhen(state === sWriteAck && io.vme_wr.ack && xrem =/= 0.U) {
waddr_cur := waddr_cur + xmax_bytes
}
.elsewhen(stride) {
waddr_cur := waddr_nxt + (dec.xstride << log2Ceil(
tensorLength * tensorWidth))
waddr_nxt := waddr_nxt + (dec.xstride << log2Ceil(
tensorLength * tensorWidth))
}
io.vme_wr.cmd.valid := state === sWriteCmd
io.vme_wr.cmd.bits.addr := waddr_cur
......@@ -199,9 +215,9 @@ class TensorStore(tensorType: String = "none", debug: Boolean = false)
io.vme_wr.data.valid := state === sWriteData
io.vme_wr.data.bits := mdata(tag)
when (state === sWriteCmd) {
when(state === sWriteCmd) {
xcnt := 0.U
} .elsewhen (io.vme_wr.data.fire()) {
}.elsewhen(io.vme_wr.data.fire()) {
xcnt := xcnt + 1.U
}
......@@ -213,13 +229,19 @@ class TensorStore(tensorType: String = "none", debug: Boolean = false)
// debug
if (debug) {
when (io.vme_wr.cmd.fire()) {
printf("[TensorStore] ysize:%x ycnt:%x raddr:%x waddr:%x len:%x rem:%x\n", ysize, ycnt, raddr_cur, waddr_cur, xlen, xrem)
when(io.vme_wr.cmd.fire()) {
printf("[TensorStore] ysize:%x ycnt:%x raddr:%x waddr:%x len:%x rem:%x\n",
ysize,
ycnt,
raddr_cur,
waddr_cur,
xlen,
xrem)
}
when (io.vme_wr.data.fire()) {
when(io.vme_wr.data.fire()) {
printf("[TensorStore] data:%x\n", io.vme_wr.data.bits)
}
when (io.vme_wr.ack) {
when(io.vme_wr.ack) {
printf("[TensorStore] ack\n")
}
}
......
......@@ -30,11 +30,14 @@ import vta.shell._
* weights (wgt), biases (acc), and outputs (out). This is used to avoid
* doing the same boring calculations over and over again.
*/
class TensorParams(tensorType: String = "none")(implicit p: Parameters) extends Bundle {
val errorMsg = s"\n\n[VTA] [TensorParams] only inp, wgt, acc, and out supported\n\n"
class TensorParams(tensorType: String = "none")(implicit p: Parameters)
extends Bundle {
val errorMsg =
s"\n\n[VTA] [TensorParams] only inp, wgt, acc, and out supported\n\n"
require (tensorType == "inp" || tensorType == "wgt"
|| tensorType == "acc" || tensorType == "out", errorMsg)
require(tensorType == "inp" || tensorType == "wgt"
|| tensorType == "acc" || tensorType == "out",
errorMsg)
val (tensorLength, tensorWidth, tensorElemBits) =
if (tensorType == "inp")
......@@ -69,25 +72,30 @@ class TensorParams(tensorType: String = "none")(implicit p: Parameters) extends
* biases (acc), and outputs (out).
*
*/
class TensorMaster(tensorType: String = "none")
(implicit p: Parameters) extends TensorParams(tensorType) {
val rd = new Bundle {
val idx = ValidIO(UInt(memAddrBits.W))
val data = Flipped(ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))))
}
val wr = ValidIO(new Bundle {
val idx = UInt(memAddrBits.W)
val data = Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))
})
def tieoffRead() {
rd.idx.valid := false.B
rd.idx.bits := 0.U
}
def tieoffWrite() {
wr.valid := false.B
wr.bits.idx := 0.U
wr.bits.data.foreach { b => b.foreach { c => c := 0.U } }
class TensorMaster(tensorType: String = "none")(implicit p: Parameters)
extends TensorParams(tensorType) {
val rd = new Bundle {
val idx = ValidIO(UInt(memAddrBits.W))
val data = Flipped(
ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))))
}
val wr = ValidIO(new Bundle {
val idx = UInt(memAddrBits.W)
val data = Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))
})
def tieoffRead() {
rd.idx.valid := false.B
rd.idx.bits := 0.U
}
def tieoffWrite() {
wr.valid := false.B
wr.bits.idx := 0.U
wr.bits.data.foreach { b =>
b.foreach { c =>
c := 0.U
}
}
}
override def cloneType =
new TensorMaster(tensorType).asInstanceOf[this.type]
}
......@@ -98,20 +106,25 @@ class TensorMaster(tensorType: String = "none")
* The TensorLoad unit uses this interface for receiving read and write requests from
* the TensorGemm unit.
*/
class TensorClient(tensorType: String = "none")
(implicit p: Parameters) extends TensorParams(tensorType) {
val rd = new Bundle {
val idx = Flipped(ValidIO(UInt(memAddrBits.W)))
val data = ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W))))
}
val wr = Flipped(ValidIO(new Bundle {
val idx = UInt(memAddrBits.W)
val data = Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))
}))
def tieoffRead() {
rd.data.valid := false.B
rd.data.bits.foreach { b => b.foreach { c => c := 0.U } }
class TensorClient(tensorType: String = "none")(implicit p: Parameters)
extends TensorParams(tensorType) {
val rd = new Bundle {
val idx = Flipped(ValidIO(UInt(memAddrBits.W)))
val data = ValidIO(
Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W))))
}
val wr = Flipped(ValidIO(new Bundle {
val idx = UInt(memAddrBits.W)
val data = Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))
}))
def tieoffRead() {
rd.data.valid := false.B
rd.data.bits.foreach { b =>
b.foreach { c =>
c := 0.U
}
}
}
override def cloneType =
new TensorClient(tensorType).asInstanceOf[this.type]
}
......@@ -122,9 +135,10 @@ class TensorClient(tensorType: String = "none")
* is based on the TensorMaster interface, which means this is an input. This interface
* is used on datapath only module such MatrixVectorCore or AluVector.
*/
class TensorMasterData(tensorType: String = "none")
(implicit p: Parameters) extends TensorParams(tensorType) {
val data = Flipped(ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))))
class TensorMasterData(tensorType: String = "none")(implicit p: Parameters)
extends TensorParams(tensorType) {
val data = Flipped(
ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))))
override def cloneType =
new TensorMasterData(tensorType).asInstanceOf[this.type]
}
......@@ -135,18 +149,22 @@ class TensorMasterData(tensorType: String = "none")
* is based on the TensorClient interface, which means this is an output. This interface
* is used on datapath only module such MatrixVectorCore or AluVector.
*/
class TensorClientData(tensorType: String = "none")
(implicit p: Parameters) extends TensorParams(tensorType) {
val data = ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W))))
class TensorClientData(tensorType: String = "none")(implicit p: Parameters)
extends TensorParams(tensorType) {
val data = ValidIO(
Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W))))
override def cloneType =
new TensorClientData(tensorType).asInstanceOf[this.type]
}
/** TensorPadCtrl. Zero-padding controller for TensorLoad. */
class TensorPadCtrl(padType: String = "none", sizeFactor: Int = 1) extends Module {
val errorMsg = s"\n\n\n[VTA-ERROR] only YPad0, YPad1, XPad0, or XPad1 supported\n\n\n"
require (padType == "YPad0" || padType == "YPad1"
|| padType == "XPad0" || padType == "XPad1", errorMsg)
class TensorPadCtrl(padType: String = "none", sizeFactor: Int = 1)
extends Module {
val errorMsg =
s"\n\n\n[VTA-ERROR] only YPad0, YPad1, XPad0, or XPad1 supported\n\n\n"
require(padType == "YPad0" || padType == "YPad1"
|| padType == "XPad0" || padType == "XPad1",
errorMsg)
val io = IO(new Bundle {
val start = Input(Bool())
......@@ -180,33 +198,33 @@ class TensorPadCtrl(padType: String = "none", sizeFactor: Int = 1) extends Modul
val sIdle :: sActive :: Nil = Enum(2)
val state = RegInit(sIdle)
switch (state) {
is (sIdle) {
when (io.start) {
switch(state) {
is(sIdle) {
when(io.start) {
state := sActive
}
}
is (sActive) {
when (ycnt === ymax && xcnt === xmax) {
is(sActive) {
when(ycnt === ymax && xcnt === xmax) {
state := sIdle
}
}
}
when (state === sIdle) {
when(state === sIdle) {
xmax := xval
ymax := yval
}
when (state === sIdle || xcnt === xmax) {
when(state === sIdle || xcnt === xmax) {
xcnt := 0.U
} .elsewhen (state === sActive) {
}.elsewhen(state === sActive) {
xcnt := xcnt + 1.U
}
when (state === sIdle || ymax === 0.U) {
when(state === sIdle || ymax === 0.U) {
ycnt := 0.U
} .elsewhen (state === sActive && xcnt === xmax) {
}.elsewhen(state === sActive && xcnt === xmax) {
ycnt := ycnt + 1.U
}
......@@ -214,7 +232,10 @@ class TensorPadCtrl(padType: String = "none", sizeFactor: Int = 1) extends Modul
}
/** TensorDataCtrl. Data controller for TensorLoad. */
class TensorDataCtrl(tensorType: String = "none", sizeFactor: Int = 1, strideFactor: Int = 1)(implicit p: Parameters) extends Module {
class TensorDataCtrl(tensorType: String = "none",
sizeFactor: Int = 1,
strideFactor: Int = 1)(implicit p: Parameters)
extends Module {
val mp = p(ShellKey).memParams
val io = IO(new Bundle {
val start = Input(Bool())
......@@ -238,7 +259,7 @@ class TensorDataCtrl(tensorType: String = "none", sizeFactor: Int = 1, strideFac
val len = Reg(UInt(mp.lenBits.W))
val xmax_bytes = ((1 << mp.lenBits)*mp.dataBits/8).U
val xmax_bytes = ((1 << mp.lenBits) * mp.dataBits / 8).U
val xcnt = Reg(UInt(mp.lenBits.W))
val xrem = Reg(chiselTypeOf(dec.xsize))
val xsize = (dec.xsize << log2Ceil(sizeFactor)) - 1.U
......@@ -246,38 +267,38 @@ class TensorDataCtrl(tensorType: String = "none", sizeFactor: Int = 1, strideFac
val ycnt = Reg(chiselTypeOf(dec.ysize))
val stride = xcnt === len &
xrem === 0.U &
ycnt =/= dec.ysize - 1.U
xrem === 0.U &
ycnt =/= dec.ysize - 1.U
val split = xcnt === len & xrem =/= 0.U
when (io.start || (io.xupdate && stride)) {
when (xsize < xmax) {
when(io.start || (io.xupdate && stride)) {
when(xsize < xmax) {
len := xsize
xrem := 0.U
} .otherwise {
}.otherwise {
len := xmax - 1.U
xrem := xsize - xmax
}
} .elsewhen (io.xupdate && split) {
when (xrem < xmax) {
}.elsewhen(io.xupdate && split) {
when(xrem < xmax) {
len := xrem
xrem := 0.U
} .otherwise {
}.otherwise {
len := xmax - 1.U
xrem := xrem - xmax
}
}
when (io.xinit) {
when(io.xinit) {
xcnt := 0.U
} .elsewhen (io.xupdate) {
}.elsewhen(io.xupdate) {
xcnt := xcnt + 1.U
}
when (io.start) {
when(io.start) {
ycnt := 0.U
} .elsewhen (io.yupdate && stride) {
}.elsewhen(io.yupdate && stride) {
ycnt := ycnt + 1.U
}
......@@ -291,13 +312,13 @@ class TensorDataCtrl(tensorType: String = "none", sizeFactor: Int = 1, strideFac
(p(CoreKey).batch * p(CoreKey).blockOut * p(CoreKey).accBits) / 8
}
when (io.start) {
when(io.start) {
caddr := io.baddr | (maskOffset & (dec.dram_offset << log2Ceil(elemBytes)))
baddr := io.baddr | (maskOffset & (dec.dram_offset << log2Ceil(elemBytes)))
} .elsewhen (io.yupdate) {
when (split) {
}.elsewhen(io.yupdate) {
when(split) {
caddr := caddr + xmax_bytes
} .elsewhen (stride) {
}.elsewhen(stride) {
caddr := baddr + (dec.xstride << log2Ceil(strideFactor))
baddr := baddr + (dec.xstride << log2Ceil(strideFactor))
}
......@@ -309,6 +330,6 @@ class TensorDataCtrl(tensorType: String = "none", sizeFactor: Int = 1, strideFac
io.addr := caddr
io.len := len
io.done := xcnt === len &
xrem === 0.U &
ycnt === dec.ysize - 1.U
xrem === 0.U &
ycnt === dec.ysize - 1.U
}
......@@ -78,55 +78,56 @@ class VTAHostDPI extends BlackBox with HasBlackBoxResource {
*
* Convert Host DPI to AXI for VTAShell
*/
class VTAHostDPIToAXI(debug: Boolean = false)(implicit p: Parameters) extends Module {
class VTAHostDPIToAXI(debug: Boolean = false)(implicit p: Parameters)
extends Module {
val io = IO(new Bundle {
val dpi = new VTAHostDPIClient
val axi = new AXILiteMaster(p(ShellKey).hostParams)
})
val addr = RegInit(0.U.asTypeOf(chiselTypeOf(io.dpi.req.addr)))
val data = RegInit(0.U.asTypeOf(chiselTypeOf(io.dpi.req.value)))
val sIdle :: sReadAddress :: sReadData :: sWriteAddress :: sWriteData :: sWriteResponse :: Nil = Enum(6)
val sIdle :: sReadAddress :: sReadData :: sWriteAddress :: sWriteData :: sWriteResponse :: Nil =
Enum(6)
val state = RegInit(sIdle)
switch (state) {
is (sIdle) {
when (io.dpi.req.valid) {
when (io.dpi.req.opcode) {
switch(state) {
is(sIdle) {
when(io.dpi.req.valid) {
when(io.dpi.req.opcode) {
state := sWriteAddress
} .otherwise {
}.otherwise {
state := sReadAddress
}
}
}
is (sReadAddress) {
when (io.axi.ar.ready) {
is(sReadAddress) {
when(io.axi.ar.ready) {
state := sReadData
}
}
is (sReadData) {
when (io.axi.r.valid) {
is(sReadData) {
when(io.axi.r.valid) {
state := sIdle
}
}
is (sWriteAddress) {
when (io.axi.aw.ready) {
is(sWriteAddress) {
when(io.axi.aw.ready) {
state := sWriteData
}
}
is (sWriteData) {
when (io.axi.w.ready) {
is(sWriteData) {
when(io.axi.w.ready) {
state := sWriteResponse
}
}
is (sWriteResponse) {
when (io.axi.b.valid) {
is(sWriteResponse) {
when(io.axi.b.valid) {
state := sIdle
}
}
}
when (state === sIdle && io.dpi.req.valid) {
when(state === sIdle && io.dpi.req.valid) {
addr := io.dpi.req.addr
data := io.dpi.req.value
}
......@@ -147,9 +148,17 @@ class VTAHostDPIToAXI(debug: Boolean = false)(implicit p: Parameters) extends Mo
io.dpi.resp.bits := io.axi.r.bits.data
if (debug) {
when (state === sWriteAddress && io.axi.aw.ready) { printf("[VTAHostDPIToAXI] [AW] addr:%x\n", addr) }
when (state === sReadAddress && io.axi.ar.ready) { printf("[VTAHostDPIToAXI] [AR] addr:%x\n", addr) }
when (io.axi.r.fire()) { printf("[VTAHostDPIToAXI] [R] value:%x\n", io.axi.r.bits.data) }
when (io.axi.w.fire()) { printf("[VTAHostDPIToAXI] [W] value:%x\n", io.axi.w.bits.data) }
when(state === sWriteAddress && io.axi.aw.ready) {
printf("[VTAHostDPIToAXI] [AW] addr:%x\n", addr)
}
when(state === sReadAddress && io.axi.ar.ready) {
printf("[VTAHostDPIToAXI] [AR] addr:%x\n", addr)
}
when(io.axi.r.fire()) {
printf("[VTAHostDPIToAXI] [R] value:%x\n", io.axi.r.bits.data)
}
when(io.axi.w.fire()) {
printf("[VTAHostDPIToAXI] [W] value:%x\n", io.axi.w.bits.data)
}
}
}
......@@ -75,7 +75,8 @@ class VTAMemDPI extends BlackBox with HasBlackBoxResource {
setResource("/verilog/VTAMemDPI.v")
}
class VTAMemDPIToAXI(debug: Boolean = false)(implicit p: Parameters) extends Module {
class VTAMemDPIToAXI(debug: Boolean = false)(implicit p: Parameters)
extends Module {
val io = IO(new Bundle {
val dpi = new VTAMemDPIMaster
val axi = new AXIClient(p(ShellKey).memParams)
......@@ -83,56 +84,57 @@ class VTAMemDPIToAXI(debug: Boolean = false)(implicit p: Parameters) extends Mod
val opcode = RegInit(false.B)
val len = RegInit(0.U.asTypeOf(chiselTypeOf(io.dpi.req.len)))
val addr = RegInit(0.U.asTypeOf(chiselTypeOf(io.dpi.req.addr)))
val sIdle :: sReadAddress :: sReadData :: sWriteAddress :: sWriteData :: sWriteResponse :: Nil = Enum(6)
val sIdle :: sReadAddress :: sReadData :: sWriteAddress :: sWriteData :: sWriteResponse :: Nil =
Enum(6)
val state = RegInit(sIdle)
switch (state) {
is (sIdle) {
when (io.axi.ar.valid) {
switch(state) {
is(sIdle) {
when(io.axi.ar.valid) {
state := sReadAddress
} .elsewhen (io.axi.aw.valid) {
}.elsewhen(io.axi.aw.valid) {
state := sWriteAddress
}
}
is (sReadAddress) {
when (io.axi.ar.valid) {
is(sReadAddress) {
when(io.axi.ar.valid) {
state := sReadData
}
}
is (sReadData) {
when (io.axi.r.ready && io.dpi.rd.valid && len === 0.U) {
is(sReadData) {
when(io.axi.r.ready && io.dpi.rd.valid && len === 0.U) {
state := sIdle
}
}
is (sWriteAddress) {
when (io.axi.aw.valid) {
is(sWriteAddress) {
when(io.axi.aw.valid) {
state := sWriteData
}
}
is (sWriteData) {
when (io.axi.w.valid && io.axi.w.bits.last) {
is(sWriteData) {
when(io.axi.w.valid && io.axi.w.bits.last) {
state := sWriteResponse
}
}
is (sWriteResponse) {
when (io.axi.b.ready) {
is(sWriteResponse) {
when(io.axi.b.ready) {
state := sIdle
}
}
}
when (state === sIdle) {
when (io.axi.ar.valid) {
when(state === sIdle) {
when(io.axi.ar.valid) {
opcode := false.B
len := io.axi.ar.bits.len
addr := io.axi.ar.bits.addr
} .elsewhen (io.axi.aw.valid) {
}.elsewhen(io.axi.aw.valid) {
opcode := true.B
len := io.axi.aw.bits.len
addr := io.axi.aw.bits.addr
}
} .elsewhen (state === sReadData) {
when (io.axi.r.ready && io.dpi.rd.valid && len =/= 0.U) {
}.elsewhen(state === sReadData) {
when(io.axi.r.ready && io.dpi.rd.valid && len =/= 0.U) {
len := len - 1.U
}
}
......@@ -163,9 +165,21 @@ class VTAMemDPIToAXI(debug: Boolean = false)(implicit p: Parameters) extends Mod
io.axi.b.bits.id := 0.U
if (debug) {
when (state === sReadAddress && io.axi.ar.valid) { printf("[VTAMemDPIToAXI] [AR] addr:%x len:%x\n", addr, len) }
when (state === sWriteAddress && io.axi.aw.valid) { printf("[VTAMemDPIToAXI] [AW] addr:%x len:%x\n", addr, len) }
when (io.axi.r.fire()) { printf("[VTAMemDPIToAXI] [R] last:%x data:%x\n", io.axi.r.bits.last, io.axi.r.bits.data) }
when (io.axi.w.fire()) { printf("[VTAMemDPIToAXI] [W] last:%x data:%x\n", io.axi.w.bits.last, io.axi.w.bits.data) }
when(state === sReadAddress && io.axi.ar.valid) {
printf("[VTAMemDPIToAXI] [AR] addr:%x len:%x\n", addr, len)
}
when(state === sWriteAddress && io.axi.aw.valid) {
printf("[VTAMemDPIToAXI] [AW] addr:%x len:%x\n", addr, len)
}
when(io.axi.r.fire()) {
printf("[VTAMemDPIToAXI] [R] last:%x data:%x\n",
io.axi.r.bits.last,
io.axi.r.bits.data)
}
when(io.axi.w.fire()) {
printf("[VTAMemDPIToAXI] [W] last:%x data:%x\n",
io.axi.w.bits.last,
io.axi.w.bits.data)
}
}
}
......@@ -24,18 +24,17 @@ import chisel3.util._
import vta.util.genericbundle._
case class AXIParams(
coherent: Boolean = false,
idBits: Int = 1,
addrBits: Int = 32,
dataBits: Int = 64,
lenBits: Int = 8,
userBits: Int = 1
)
{
require (addrBits > 0)
require (dataBits >= 8 && dataBits % 2 == 0)
coherent: Boolean = false,
idBits: Int = 1,
addrBits: Int = 32,
dataBits: Int = 64,
lenBits: Int = 8,
userBits: Int = 1
) {
require(addrBits > 0)
require(dataBits >= 8 && dataBits % 2 == 0)
val strbBits = dataBits/8
val strbBits = dataBits / 8
val sizeBits = 3
val burstBits = 2
val lockBits = 2
......@@ -44,7 +43,7 @@ case class AXIParams(
val qosBits = 4
val regionBits = 4
val respBits = 2
val sizeConst = log2Ceil(dataBits/8)
val sizeConst = log2Ceil(dataBits / 8)
val idConst = 0
val userConst = if (coherent) 1 else 0
val burstConst = 1
......@@ -56,7 +55,7 @@ case class AXIParams(
}
abstract class AXIBase(params: AXIParams)
extends GenericParameterizedBundle(params)
extends GenericParameterizedBundle(params)
// AXILite
......
......@@ -25,52 +25,59 @@ import vta.util.config._
import vta.interface.axi._
/** PynqConfig. Shell configuration for Pynq */
class PynqConfig extends Config((site, here, up) => {
case ShellKey => ShellParams(
hostParams = AXIParams(
coherent = false,
addrBits = 16,
dataBits = 32,
lenBits = 8,
userBits = 1),
memParams = AXIParams(
coherent = true,
addrBits = 32,
dataBits = 64,
lenBits = 8,
userBits = 1),
vcrParams = VCRParams(),
vmeParams = VMEParams())
})
class PynqConfig
extends Config((site, here, up) => {
case ShellKey =>
ShellParams(
hostParams = AXIParams(coherent = false,
addrBits = 16,
dataBits = 32,
lenBits = 8,
userBits = 1),
memParams = AXIParams(coherent = true,
addrBits = 32,
dataBits = 64,
lenBits = 8,
userBits = 1),
vcrParams = VCRParams(),
vmeParams = VMEParams()
)
})
/** F1Config. Shell configuration for F1 */
class F1Config extends Config((site, here, up) => {
case ShellKey => ShellParams(
hostParams = AXIParams(
coherent = false,
addrBits = 16,
dataBits = 32,
lenBits = 8,
userBits = 1),
memParams = AXIParams(
coherent = false,
addrBits = 64,
dataBits = 64,
lenBits = 8,
userBits = 1),
vcrParams = VCRParams(),
vmeParams = VMEParams())
})
class F1Config
extends Config((site, here, up) => {
case ShellKey =>
ShellParams(
hostParams = AXIParams(coherent = false,
addrBits = 16,
dataBits = 32,
lenBits = 8,
userBits = 1),
memParams = AXIParams(coherent = false,
addrBits = 64,
dataBits = 64,
lenBits = 8,
userBits = 1),
vcrParams = VCRParams(),
vmeParams = VMEParams()
)
})
/** De10Config. Shell configuration for De10 */
class De10Config extends Config((site, here, up) => {
case ShellKey => ShellParams(
hostParams = AXIParams(
addrBits = 16, dataBits = 32, idBits = 13, lenBits = 4),
memParams = AXIParams(
addrBits = 32, dataBits = 64, userBits = 5,
lenBits = 4, // limit to 16 beats, instead of 256 beats in AXI4
coherent = true),
vcrParams = VCRParams(),
vmeParams = VMEParams())
})
class De10Config
extends Config((site, here, up) => {
case ShellKey =>
ShellParams(
hostParams =
AXIParams(addrBits = 16, dataBits = 32, idBits = 13, lenBits = 4),
memParams = AXIParams(
addrBits = 32,
dataBits = 64,
userBits = 5,
lenBits = 4, // limit to 16 beats, instead of 256 beats in AXI4
coherent = true),
vcrParams = VCRParams(),
vmeParams = VMEParams()
)
})
......@@ -30,7 +30,7 @@ import vta.core._
* system that can be used for simulation or real hardware.
*/
class IntelShell(implicit p: Parameters) extends Module {
val io = IO(new Bundle{
val io = IO(new Bundle {
val host = new AXIClient(p(ShellKey).hostParams)
val mem = new AXIMaster(p(ShellKey).memParams)
})
......
......@@ -76,6 +76,7 @@ class VTASim(implicit p: Parameters) extends MultiIOModule {
sim.io.clock := clock
sim_wait := sim.io.dpi_wait
}
/** SimShell.
*
* The simulation shell instantiate the sim, host and memory DPI modules that
......
......@@ -29,8 +29,7 @@ import vta.interface.axi._
*
* These parameters are used on VCR interfaces and modules.
*/
case class VCRParams()
{
case class VCRParams() {
val nCtrl = 1
val nECnt = 1
val nVals = 1
......@@ -40,7 +39,7 @@ case class VCRParams()
/** VCRBase. Parametrize base class. */
abstract class VCRBase(implicit p: Parameters)
extends GenericParameterizedBundle(p)
extends GenericParameterizedBundle(p)
/** VCRMaster.
*
......@@ -80,7 +79,7 @@ class VCRClient(implicit p: Parameters) extends VCRBase {
* registers that could be used as event counters by the Core unit.
*/
class VCR(implicit p: Parameters) extends Module {
val io = IO(new Bundle{
val io = IO(new Bundle {
val host = new AXILiteClient(p(ShellKey).hostParams)
val vcr = new VCRMaster
})
......@@ -101,50 +100,49 @@ class VCR(implicit p: Parameters) extends Module {
val rdata = RegInit(0.U(vp.regBits.W))
// registers
val nPtrs = if (mp.addrBits == 32) vp.nPtrs else 2*vp.nPtrs
val nPtrs = if (mp.addrBits == 32) vp.nPtrs else 2 * vp.nPtrs
val nTotal = vp.nCtrl + vp.nECnt + vp.nVals + nPtrs
val reg = Seq.fill(nTotal)(RegInit(0.U(vp.regBits.W)))
val addr = Seq.tabulate(nTotal)(_ * 4)
val reg_map = (addr zip reg) map { case (a, r) => a.U -> r }
val reg_map = (addr zip reg) map { case (a, r) => a.U -> r }
val eo = vp.nCtrl
val vo = eo + vp.nECnt
val po = vo + vp.nVals
switch (wstate) {
is (sWriteAddress) {
when (io.host.aw.valid) {
switch(wstate) {
is(sWriteAddress) {
when(io.host.aw.valid) {
wstate := sWriteData
}
}
is (sWriteData) {
when (io.host.w.valid) {
is(sWriteData) {
when(io.host.w.valid) {
wstate := sWriteResponse
}
}
is (sWriteResponse) {
when (io.host.b.ready) {
is(sWriteResponse) {
when(io.host.b.ready) {
wstate := sWriteAddress
}
}
}
when (io.host.aw.fire()) { waddr := io.host.aw.bits.addr }
when(io.host.aw.fire()) { waddr := io.host.aw.bits.addr }
io.host.aw.ready := wstate === sWriteAddress
io.host.w.ready := wstate === sWriteData
io.host.b.valid := wstate === sWriteResponse
io.host.b.bits.resp := 0.U
switch (rstate) {
is (sReadAddress) {
when (io.host.ar.valid) {
switch(rstate) {
is(sReadAddress) {
when(io.host.ar.valid) {
rstate := sReadData
}
}
is (sReadData) {
when (io.host.r.ready) {
is(sReadData) {
when(io.host.r.ready) {
rstate := sReadAddress
}
}
......@@ -155,27 +153,27 @@ class VCR(implicit p: Parameters) extends Module {
io.host.r.bits.data := rdata
io.host.r.bits.resp := 0.U
when (io.vcr.finish) {
when(io.vcr.finish) {
reg(0) := "b_10".U
} .elsewhen (io.host.w.fire() && addr(0).U === waddr) {
}.elsewhen(io.host.w.fire() && addr(0).U === waddr) {
reg(0) := wdata
}
for (i <- 0 until vp.nECnt) {
when (io.vcr.ecnt(i).valid) {
when(io.vcr.ecnt(i).valid) {
reg(eo + i) := io.vcr.ecnt(i).bits
} .elsewhen (io.host.w.fire() && addr(eo + i).U === waddr) {
}.elsewhen(io.host.w.fire() && addr(eo + i).U === waddr) {
reg(eo + i) := wdata
}
}
for (i <- 0 until (vp.nVals + nPtrs)) {
when (io.host.w.fire() && addr(vo + i).U === waddr) {
when(io.host.w.fire() && addr(vo + i).U === waddr) {
reg(vo + i) := wdata
}
}
when (io.host.ar.fire()) {
when(io.host.ar.fire()) {
rdata := MuxLookup(io.host.ar.bits.addr, 0.U, reg_map)
}
......@@ -185,13 +183,13 @@ class VCR(implicit p: Parameters) extends Module {
io.vcr.vals(i) := reg(vo + i)
}
if (mp.addrBits == 32) { // 32-bit pointers
if (mp.addrBits == 32) { // 32-bit pointers
for (i <- 0 until nPtrs) {
io.vcr.ptrs(i) := reg(po + i)
}
} else { // 64-bits pointers
for (i <- 0 until (nPtrs/2)) {
io.vcr.ptrs(i) := Cat(reg(po + 2*i + 1), reg(po + 2*i))
} else { // 64-bits pointers
for (i <- 0 until (nPtrs / 2)) {
io.vcr.ptrs(i) := Cat(reg(po + 2 * i + 1), reg(po + 2 * i))
}
}
}
......@@ -32,13 +32,16 @@ import vta.interface.axi._
case class VMEParams() {
val nReadClients: Int = 5
val nWriteClients: Int = 1
require (nReadClients > 0, s"\n\n[VTA] [VMEParams] nReadClients must be larger than 0\n\n")
require (nWriteClients == 1, s"\n\n[VTA] [VMEParams] nWriteClients must be 1, only one-write-client support atm\n\n")
require(nReadClients > 0,
s"\n\n[VTA] [VMEParams] nReadClients must be larger than 0\n\n")
require(
nWriteClients == 1,
s"\n\n[VTA] [VMEParams] nWriteClients must be 1, only one-write-client support atm\n\n")
}
/** VMEBase. Parametrize base class. */
abstract class VMEBase(implicit p: Parameters)
extends GenericParameterizedBundle(p)
extends GenericParameterizedBundle(p)
/** VMECmd.
*
......@@ -149,19 +152,19 @@ class VME(implicit p: Parameters) extends Module {
val sReadIdle :: sReadAddr :: sReadData :: Nil = Enum(3)
val rstate = RegInit(sReadIdle)
switch (rstate) {
is (sReadIdle) {
when (rd_arb.io.out.valid) {
switch(rstate) {
is(sReadIdle) {
when(rd_arb.io.out.valid) {
rstate := sReadAddr
}
}
is (sReadAddr) {
when (io.mem.ar.ready) {
is(sReadAddr) {
when(io.mem.ar.ready) {
rstate := sReadData
}
}
is (sReadData) {
when (io.mem.r.fire() && io.mem.r.bits.last) {
is(sReadData) {
when(io.mem.r.fire() && io.mem.r.bits.last) {
rstate := sReadIdle
}
}
......@@ -173,30 +176,34 @@ class VME(implicit p: Parameters) extends Module {
val lenBits = p(ShellKey).memParams.lenBits
val wr_cnt = RegInit(0.U(lenBits.W))
when (wstate === sWriteIdle) {
when(wstate === sWriteIdle) {
wr_cnt := 0.U
} .elsewhen (io.mem.w.fire()) {
}.elsewhen(io.mem.w.fire()) {
wr_cnt := wr_cnt + 1.U
}
switch (wstate) {
is (sWriteIdle) {
when (io.vme.wr(0).cmd.valid) {
switch(wstate) {
is(sWriteIdle) {
when(io.vme.wr(0).cmd.valid) {
wstate := sWriteAddr
}
}
is (sWriteAddr) {
when (io.mem.aw.ready) {
is(sWriteAddr) {
when(io.mem.aw.ready) {
wstate := sWriteData
}
}
is (sWriteData) {
when (io.vme.wr(0).data.valid && io.mem.w.ready && wr_cnt === io.vme.wr(0).cmd.bits.len) {
is(sWriteData) {
when(
io.vme
.wr(0)
.data
.valid && io.mem.w.ready && wr_cnt === io.vme.wr(0).cmd.bits.len) {
wstate := sWriteResp
}
}
is (sWriteResp) {
when (io.mem.b.valid) {
is(sWriteResp) {
when(io.mem.b.valid) {
wstate := sWriteIdle
}
}
......@@ -209,12 +216,12 @@ class VME(implicit p: Parameters) extends Module {
val rd_addr = RegInit(0.U(addrBits.W))
val wr_addr = RegInit(0.U(addrBits.W))
when (rd_arb.io.out.fire()) {
when(rd_arb.io.out.fire()) {
rd_len := rd_arb.io.out.bits.len
rd_addr := rd_arb.io.out.bits.addr
}
when (io.vme.wr(0).cmd.fire()) {
when(io.vme.wr(0).cmd.fire()) {
wr_len := io.vme.wr(0).cmd.bits.len
wr_addr := io.vme.wr(0).cmd.bits.addr
}
......@@ -230,7 +237,7 @@ class VME(implicit p: Parameters) extends Module {
io.vme.wr(0).cmd.ready := wstate === sWriteIdle
io.vme.wr(0).ack := io.mem.b.fire()
io.vme.wr(0).data.ready := wstate === sWriteData & io.mem.w.ready
io.vme.wr(0).data.ready := wstate === sWriteData & io.mem.w.ready
// mem
io.mem.aw.valid := wstate === sWriteAddr
......
......@@ -26,10 +26,10 @@ import vta.core._
/** Shell parameters. */
case class ShellParams(
hostParams: AXIParams,
memParams: AXIParams,
vcrParams: VCRParams,
vmeParams: VMEParams
hostParams: AXIParams,
memParams: AXIParams,
vcrParams: VCRParams,
vmeParams: VMEParams
)
case object ShellKey extends Field[ShellParams]
......@@ -40,7 +40,7 @@ case object ShellKey extends Field[ShellParams]
* system that can be used for simulation or real hardware.
*/
class VTAShell(implicit p: Parameters) extends Module {
val io = IO(new Bundle{
val io = IO(new Bundle {
val host = new AXILiteClient(p(ShellKey).hostParams)
val mem = new AXIMaster(p(ShellKey).memParams)
})
......
......@@ -20,7 +20,7 @@
package vta.shell
import chisel3._
import chisel3.experimental.{RawModule, withClockAndReset}
import chisel3.experimental.{withClockAndReset, RawModule}
import vta.util.config._
import vta.interface.axi._
......@@ -39,7 +39,9 @@ class XilinxShell(implicit p: Parameters) extends RawModule {
val m_axi_gmem = IO(new XilinxAXIMaster(mp))
val s_axi_control = IO(new XilinxAXILiteClient(hp))
val shell = withClockAndReset (clock = ap_clk, reset = ~ap_rst_n) { Module(new VTAShell) }
val shell = withClockAndReset(clock = ap_clk, reset = ~ap_rst_n) {
Module(new VTAShell)
}
// memory
m_axi_gmem.AWVALID := shell.io.mem.aw.valid
......
......@@ -21,8 +21,7 @@ package vta.util.config
// taken from https://github.com/vta.roject/rocket-chip
abstract class Field[T] private (val default: Option[T])
{
abstract class Field[T] private (val default: Option[T]) {
def this() = this(None)
def this(default: T) = this(Some(default))
}
......@@ -31,42 +30,50 @@ abstract class View {
final def apply[T](pname: Field[T]): T = apply(pname, this)
final def apply[T](pname: Field[T], site: View): T = {
val out = find(pname, site)
require (out.isDefined, s"Key ${pname} is not defined in Parameters")
require(out.isDefined, s"Key ${pname} is not defined in Parameters")
out.get
}
final def lift[T](pname: Field[T]): Option[T] = lift(pname, this)
final def lift[T](pname: Field[T], site: View): Option[T] = find(pname, site).map(_.asInstanceOf[T])
final def lift[T](pname: Field[T], site: View): Option[T] =
find(pname, site).map(_.asInstanceOf[T])
protected[config] def find[T](pname: Field[T], site: View): Option[T]
}
abstract class Parameters extends View {
final def ++ (x: Parameters): Parameters =
final def ++(x: Parameters): Parameters =
new ChainParameters(this, x)
final def alter(f: (View, View, View) => PartialFunction[Any,Any]): Parameters =
final def alter(
f: (View, View, View) => PartialFunction[Any, Any]): Parameters =
Parameters(f) ++ this
final def alterPartial(f: PartialFunction[Any,Any]): Parameters =
Parameters((_,_,_) => f) ++ this
final def alterPartial(f: PartialFunction[Any, Any]): Parameters =
Parameters((_, _, _) => f) ++ this
final def alterMap(m: Map[Any,Any]): Parameters =
final def alterMap(m: Map[Any, Any]): Parameters =
new MapParameters(m) ++ this
protected[config] def chain[T](site: View, tail: View, pname: Field[T]): Option[T]
protected[config] def find[T](pname: Field[T], site: View) = chain(site, new TerminalView, pname)
protected[config] def chain[T](site: View,
tail: View,
pname: Field[T]): Option[T]
protected[config] def find[T](pname: Field[T], site: View) =
chain(site, new TerminalView, pname)
}
object Parameters {
def empty: Parameters = new EmptyParameters
def apply(f: (View, View, View) => PartialFunction[Any,Any]): Parameters = new PartialParameters(f)
def apply(f: (View, View, View) => PartialFunction[Any, Any]): Parameters =
new PartialParameters(f)
}
class Config(p: Parameters) extends Parameters {
def this(f: (View, View, View) => PartialFunction[Any,Any]) = this(Parameters(f))
def this(f: (View, View, View) => PartialFunction[Any, Any]) =
this(Parameters(f))
protected[config] def chain[T](site: View, tail: View, pname: Field[T]) = p.chain(site, tail, pname)
protected[config] def chain[T](site: View, tail: View, pname: Field[T]) =
p.chain(site, tail, pname)
override def toString = this.getClass.getSimpleName
def toInstance = this
}
......@@ -82,17 +89,21 @@ private class ChainView(head: Parameters, tail: View) extends View {
}
private class ChainParameters(x: Parameters, y: Parameters) extends Parameters {
def chain[T](site: View, tail: View, pname: Field[T]) = x.chain(site, new ChainView(y, tail), pname)
def chain[T](site: View, tail: View, pname: Field[T]) =
x.chain(site, new ChainView(y, tail), pname)
}
private class EmptyParameters extends Parameters {
def chain[T](site: View, tail: View, pname: Field[T]) = tail.find(pname, site)
}
private class PartialParameters(f: (View, View, View) => PartialFunction[Any,Any]) extends Parameters {
private class PartialParameters(
f: (View, View, View) => PartialFunction[Any, Any])
extends Parameters {
protected[config] def chain[T](site: View, tail: View, pname: Field[T]) = {
val g = f(site, this, tail)
if (g.isDefinedAt(pname)) Some(g.apply(pname).asInstanceOf[T]) else tail.find(pname, site)
if (g.isDefinedAt(pname)) Some(g.apply(pname).asInstanceOf[T])
else tail.find(pname, site)
}
}
......
......@@ -23,18 +23,22 @@ package vta.util.genericbundle
import chisel3._
abstract class GenericParameterizedBundle[+T <: Object](val params: T) extends Bundle
{
abstract class GenericParameterizedBundle[+T <: Object](val params: T)
extends Bundle {
override def cloneType = {
try {
this.getClass.getConstructors.head.newInstance(params).asInstanceOf[this.type]
this.getClass.getConstructors.head
.newInstance(params)
.asInstanceOf[this.type]
} catch {
case e: java.lang.IllegalArgumentException =>
throw new Exception("Unable to use GenericParameterizedBundle.cloneType on " +
this.getClass + ", probably because " + this.getClass +
"() takes more than one argument. Consider overriding " +
"cloneType() on " + this.getClass, e)
throw new Exception(
"Unable to use GenericParameterizedBundle.cloneType on " +
this.getClass + ", probably because " + this.getClass +
"() takes more than one argument. Consider overriding " +
"cloneType() on " + this.getClass,
e
)
}
}
}
......@@ -31,7 +31,6 @@ import vta.test._
* These configurations are built in a mix/match form based on core
* and shell configurations.
*/
class DefaultPynqConfig extends Config(new CoreConfig ++ new PynqConfig)
class DefaultF1Config extends Config(new CoreConfig ++ new F1Config)
class DefaultDe10Config extends Config(new CoreConfig ++ new De10Config)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment