Unverified Commit 6a36fb40 by Liangfu Chen Committed by GitHub

[VTA][Chisel] Change Scala Linter scalafmt => scalastyle (#4998)

* scalafmt => scalastyle

Change-Id: Ifc590e7cb63585f35dfdc9efcf3c6287b1afb1dd

* scalafmt => scalastyle

Change-Id: I8aff2632dadda05d2896e28bdaf6f780a160a15a

* add indentation constraint

Change-Id: Ibeb00c11a5718ea47322ea2b82e757828af8af91

* trigger ci again
parent 62424611
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
maxColumn = 100
rewrite.rules = [SortModifiers, SortImports]
...@@ -94,7 +94,8 @@ endif ...@@ -94,7 +94,8 @@ endif
default: lint lib default: lint lib
lint: lint:
sbt scalafmt cp $(vta_dir)/hardware/chisel/scalastyle-config.xml .
sbt scalastyle
lib: $(lib_path) lib: $(lib_path)
$(lib_path): $(verilator_build_dir)/V$(TOP).cpp $(lib_path): $(verilator_build_dir)/V$(TOP).cpp
......
...@@ -18,4 +18,4 @@ ...@@ -18,4 +18,4 @@
*/ */
logLevel := Level.Warn logLevel := Level.Warn
addSbtPlugin("com.geirsson" % "sbt-scalafmt" % "1.5.1") addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "1.0.0")
...@@ -23,18 +23,18 @@ import chisel3._ ...@@ -23,18 +23,18 @@ import chisel3._
import vta.dpi._ import vta.dpi._
/** Add-by-one accelerator. /** Add-by-one accelerator.
* *
* ___________ ___________ * ___________ ___________
* | | | | * | | | |
* | HostDPI | <--> | RegFile | <->| * | HostDPI | <--> | RegFile | <->|
* |_________| |_________| | * |_________| |_________| |
* | * |
* ___________ ___________ | * ___________ ___________ |
* | | | | | * | | | | |
* | MemDPI | <--> | Compute | <->| * | MemDPI | <--> | Compute | <->|
* |_________| |_________| * |_________| |_________|
* *
*/ */
case class AccelConfig() { case class AccelConfig() {
val nCtrl = 1 val nCtrl = 1
val nECnt = 1 val nECnt = 1
......
...@@ -24,17 +24,17 @@ import chisel3.util._ ...@@ -24,17 +24,17 @@ import chisel3.util._
import vta.dpi._ import vta.dpi._
/** Compute /** Compute
* *
* Add-by-one procedure: * Add-by-one procedure:
* *
* 1. Wait for launch to be asserted * 1. Wait for launch to be asserted
* 2. Issue a read request for 8-byte value at inp_baddr address * 2. Issue a read request for 8-byte value at inp_baddr address
* 3. Wait for the value * 3. Wait for the value
* 4. Issue a write request for 8-byte value at out_baddr address * 4. Issue a write request for 8-byte value at out_baddr address
* 5. Increment read-address and write-address for next value * 5. Increment read-address and write-address for next value
* 6. Check if counter (cnt) is equal to length to assert finish, * 6. Check if counter (cnt) is equal to length to assert finish,
* otherwise go to step 2. * otherwise go to step 2.
*/ */
class Compute(implicit config: AccelConfig) extends Module { class Compute(implicit config: AccelConfig) extends Module {
val io = IO(new Bundle { val io = IO(new Bundle {
val launch = Input(Bool()) val launch = Input(Bool())
......
...@@ -24,29 +24,29 @@ import chisel3.util._ ...@@ -24,29 +24,29 @@ import chisel3.util._
import vta.dpi._ import vta.dpi._
/** Register File. /** Register File.
* *
* Six 32-bit register file. * Six 32-bit register file.
* *
* ------------------------------- * -------------------------------
* Register description | addr * Register description | addr
* -------------------------|----- * -------------------------|-----
* Control status register | 0x00 * Control status register | 0x00
* Cycle counter | 0x04 * Cycle counter | 0x04
* Constant value | 0x08 * Constant value | 0x08
* Vector length | 0x0c * Vector length | 0x0c
* Input pointer lsb | 0x10 * Input pointer lsb | 0x10
* Input pointer msb | 0x14 * Input pointer msb | 0x14
* Output pointer lsb | 0x18 * Output pointer lsb | 0x18
* Output pointer msb | 0x1c * Output pointer msb | 0x1c
* ------------------------------- * -------------------------------
*
* ------------------------------ * ------------------------------
* Control status register | bit * Control status register | bit
* ------------------------------ * ------------------------------
* Launch | 0 * Launch | 0
* Finish | 1 * Finish | 1
* ------------------------------ * ------------------------------
*/ */
class RegFile(implicit config: AccelConfig) extends Module { class RegFile(implicit config: AccelConfig) extends Module {
val io = IO(new Bundle { val io = IO(new Bundle {
val launch = Output(Bool()) val launch = Output(Bool())
...@@ -98,9 +98,8 @@ class RegFile(implicit config: AccelConfig) extends Module { ...@@ -98,9 +98,8 @@ class RegFile(implicit config: AccelConfig) extends Module {
} }
for (i <- 0 until (config.nVals + (2 * config.nPtrs))) { for (i <- 0 until (config.nVals + (2 * config.nPtrs))) {
when( when(state === sIdle && io.host.req.valid &&
state === sIdle && io.host.req.valid && io.host.req.opcode && addr(vo + i).U === io.host.req.addr) {
io.host.req.opcode && addr(vo + i).U === io.host.req.addr) {
reg(vo + i) := io.host.req.value reg(vo + i) := io.host.req.value
} }
} }
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
maxColumn = 100
rewrite.rules = [SortModifiers, SortImports]
...@@ -112,7 +112,7 @@ endif ...@@ -112,7 +112,7 @@ endif
default: lint lib default: lint lib
lint: lint:
sbt scalafmt --test sbt scalastyle
lib: $(lib_path) lib: $(lib_path)
......
...@@ -18,4 +18,4 @@ ...@@ -18,4 +18,4 @@
*/ */
logLevel := Level.Warn logLevel := Level.Warn
addSbtPlugin("com.geirsson" % "sbt-scalafmt" % "1.5.1") addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "1.0.0")
<scalastyle>
<name>Scalastyle standard configuration</name>
<check level="error" class="org.scalastyle.file.FileTabChecker" enabled="true"></check>
<check level="error" class="org.scalastyle.file.FileLengthChecker" enabled="true">
<parameters>
<parameter name="maxFileLength"><![CDATA[800]]></parameter>
</parameters>
</check>
<check level="error" class="org.scalastyle.file.HeaderMatchesChecker" enabled="true">
<parameters>
<parameter name="header"><![CDATA[/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
]]></parameter>
</parameters>
</check>
<check level="error" class="org.scalastyle.scalariform.SpacesAfterPlusChecker" enabled="true"></check>
<check level="error" class="org.scalastyle.file.WhitespaceEndOfLineChecker" enabled="true"></check>
<check level="error" class="org.scalastyle.scalariform.SpacesBeforePlusChecker" enabled="true"></check>
<check level="error" class="org.scalastyle.file.FileLineLengthChecker" enabled="true">
<parameters>
<parameter name="maxLineLength"><![CDATA[120]]></parameter>
<parameter name="tabSize"><![CDATA[2]]></parameter>
</parameters>
</check>
<check level="error" class="org.scalastyle.scalariform.ClassNamesChecker" enabled="true">
<parameters>
<parameter name="regex"><![CDATA[[A-Z][A-Za-z]*]]></parameter>
</parameters>
</check>
<check level="error" class="org.scalastyle.scalariform.ObjectNamesChecker" enabled="true">
<parameters>
<parameter name="regex"><![CDATA[[A-Z][A-Za-z]*]]></parameter>
</parameters>
</check>
<check level="error" class="org.scalastyle.scalariform.PackageObjectNamesChecker" enabled="true">
<parameters>
<parameter name="regex"><![CDATA[^[a-z][A-Za-z]*$]]></parameter>
</parameters>
</check>
<check level="error" class="org.scalastyle.scalariform.EqualsHashCodeChecker" enabled="true"></check>
<check level="error" class="org.scalastyle.scalariform.IllegalImportsChecker" enabled="true">
<parameters>
<parameter name="illegalImports"><![CDATA[sun._,java.awt._]]></parameter>
</parameters>
</check>
<check level="error" class="org.scalastyle.scalariform.ParameterNumberChecker" enabled="true">
<parameters>
<parameter name="maxParameters"><![CDATA[8]]></parameter>
</parameters>
</check>
<check level="error" class="org.scalastyle.scalariform.MagicNumberChecker" enabled="false">
<parameters>
<parameter name="ignore"><![CDATA[-1,0,1,2,3,4,8,16,32,64,128]]></parameter>
</parameters>
</check>
<check level="error" class="org.scalastyle.scalariform.NoWhitespaceBeforeLeftBracketChecker" enabled="true"></check>
<check level="error" class="org.scalastyle.scalariform.NoWhitespaceAfterLeftBracketChecker" enabled="true"></check>
<check level="error" class="org.scalastyle.scalariform.ReturnChecker" enabled="true"></check>
<check level="error" class="org.scalastyle.scalariform.NullChecker" enabled="true"></check>
<check level="error" class="org.scalastyle.scalariform.NoCloneChecker" enabled="true"></check>
<check level="error" class="org.scalastyle.scalariform.NoFinalizeChecker" enabled="true"></check>
<check level="error" class="org.scalastyle.scalariform.CovariantEqualsChecker" enabled="true"></check>
<check level="error" class="org.scalastyle.scalariform.StructuralTypeChecker" enabled="true"></check>
<check level="error" class="org.scalastyle.file.RegexChecker" enabled="true">
<parameters>
<parameter name="regex"><![CDATA[println]]></parameter>
</parameters>
</check>
<check level="error" class="org.scalastyle.scalariform.NumberOfTypesChecker" enabled="true">
<parameters>
<parameter name="maxTypes"><![CDATA[30]]></parameter>
</parameters>
</check>
<check level="error" class="org.scalastyle.scalariform.CyclomaticComplexityChecker" enabled="true">
<parameters>
<parameter name="maximum"><![CDATA[10]]></parameter>
</parameters>
</check>
<check level="error" class="org.scalastyle.scalariform.UppercaseLChecker" enabled="true"></check>
<check level="error" class="org.scalastyle.scalariform.SimplifyBooleanExpressionChecker" enabled="true"></check>
<check level="error" class="org.scalastyle.scalariform.IfBraceChecker" enabled="false">
<parameters>
<parameter name="singleLineAllowed"><![CDATA[true]]></parameter>
<parameter name="doubleLineAllowed"><![CDATA[false]]></parameter>
</parameters>
</check>
<check level="error" class="org.scalastyle.scalariform.MethodLengthChecker" enabled="true">
<parameters>
<parameter name="maxLength"><![CDATA[50]]></parameter>
</parameters>
</check>
<check level="error" class="org.scalastyle.scalariform.MethodNamesChecker" enabled="false">
<parameters>
<parameter name="regex"><![CDATA[^[a-z][A-Za-z0-9]*$]]></parameter>
</parameters>
</check>
<check level="error" class="org.scalastyle.scalariform.NumberOfMethodsInTypeChecker" enabled="true">
<parameters>
<parameter name="maxMethods"><![CDATA[30]]></parameter>
</parameters>
</check>
<check level="error" class="org.scalastyle.scalariform.PublicMethodsHaveTypeChecker" enabled="false"></check>
<check level="error" class="org.scalastyle.file.NewLineAtEofChecker" enabled="true"></check>
<check level="error" class="org.scalastyle.file.NoNewLineAtEofChecker" enabled="false"></check>
<check level="error" class="org.scalastyle.file.IndentationChecker" enabled="true">
<parameters>
<parameter name="tabSize">2</parameter>
<parameter name="methodParamIndentSize">2</parameter>
<parameter name="classParamIndentSize">4</parameter>
</parameters>
</check>
</scalastyle>
...@@ -25,13 +25,13 @@ import vta.util.config._ ...@@ -25,13 +25,13 @@ import vta.util.config._
import vta.shell._ import vta.shell._
/** Compute. /** Compute.
* *
* The compute unit is in charge of the following: * The compute unit is in charge of the following:
* - Loading micro-ops from memory (loadUop module) * - Loading micro-ops from memory (loadUop module)
* - Loading biases (acc) from memory (tensorAcc module) * - Loading biases (acc) from memory (tensorAcc module)
* - Compute ALU instructions (tensorAlu module) * - Compute ALU instructions (tensorAlu module)
* - Compute GEMM instructions (tensorGemm module) * - Compute GEMM instructions (tensorGemm module)
*/ */
class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module { class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module {
val mp = p(ShellKey).memParams val mp = p(ShellKey).memParams
val io = IO(new Bundle { val io = IO(new Bundle {
...@@ -65,10 +65,10 @@ class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module { ...@@ -65,10 +65,10 @@ class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module {
val inst_type = val inst_type =
Cat(dec.io.isFinish, Cat(dec.io.isFinish,
dec.io.isAlu, dec.io.isAlu,
dec.io.isGemm, dec.io.isGemm,
dec.io.isLoadAcc, dec.io.isLoadAcc,
dec.io.isLoadUop).asUInt dec.io.isLoadUop).asUInt
val sprev = inst_q.io.deq.valid & Mux(dec.io.pop_prev, s(0).io.sready, true.B) val sprev = inst_q.io.deq.valid & Mux(dec.io.pop_prev, s(0).io.sready, true.B)
val snext = inst_q.io.deq.valid & Mux(dec.io.pop_next, s(1).io.sready, true.B) val snext = inst_q.io.deq.valid & Mux(dec.io.pop_next, s(1).io.sready, true.B)
...@@ -116,20 +116,14 @@ class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module { ...@@ -116,20 +116,14 @@ class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module {
loadUop.io.inst := inst_q.io.deq.bits loadUop.io.inst := inst_q.io.deq.bits
loadUop.io.baddr := io.uop_baddr loadUop.io.baddr := io.uop_baddr
io.vme_rd(0) <> loadUop.io.vme_rd io.vme_rd(0) <> loadUop.io.vme_rd
loadUop.io.uop.idx <> Mux(dec.io.isGemm, loadUop.io.uop.idx <> Mux(dec.io.isGemm, tensorGemm.io.uop.idx, tensorAlu.io.uop.idx)
tensorGemm.io.uop.idx,
tensorAlu.io.uop.idx)
// acc // acc
tensorAcc.io.start := state === sIdle & start & dec.io.isLoadAcc tensorAcc.io.start := state === sIdle & start & dec.io.isLoadAcc
tensorAcc.io.inst := inst_q.io.deq.bits tensorAcc.io.inst := inst_q.io.deq.bits
tensorAcc.io.baddr := io.acc_baddr tensorAcc.io.baddr := io.acc_baddr
tensorAcc.io.tensor.rd.idx <> Mux(dec.io.isGemm, tensorAcc.io.tensor.rd.idx <> Mux(dec.io.isGemm, tensorGemm.io.acc.rd.idx, tensorAlu.io.acc.rd.idx)
tensorGemm.io.acc.rd.idx, tensorAcc.io.tensor.wr <> Mux(dec.io.isGemm, tensorGemm.io.acc.wr, tensorAlu.io.acc.wr)
tensorAlu.io.acc.rd.idx)
tensorAcc.io.tensor.wr <> Mux(dec.io.isGemm,
tensorGemm.io.acc.wr,
tensorAlu.io.acc.wr)
io.vme_rd(1) <> tensorAcc.io.vme_rd io.vme_rd(1) <> tensorAcc.io.vme_rd
// gemm // gemm
...@@ -156,8 +150,8 @@ class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module { ...@@ -156,8 +150,8 @@ class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module {
// out // out
io.out.rd.idx <> Mux(dec.io.isGemm, io.out.rd.idx <> Mux(dec.io.isGemm,
tensorGemm.io.out.rd.idx, tensorGemm.io.out.rd.idx,
tensorAlu.io.out.rd.idx) tensorAlu.io.out.rd.idx)
io.out.wr <> Mux(dec.io.isGemm, tensorGemm.io.out.wr, tensorAlu.io.out.wr) io.out.wr <> Mux(dec.io.isGemm, tensorGemm.io.out.wr, tensorAlu.io.out.wr)
// semaphore // semaphore
...@@ -178,20 +172,16 @@ class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module { ...@@ -178,20 +172,16 @@ class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module {
when(dec.io.isSync) { when(dec.io.isSync) {
printf("[Compute] start sync\n") printf("[Compute] start sync\n")
}.elsewhen(dec.io.isLoadUop) { }.elsewhen(dec.io.isLoadUop) {
printf("[Compute] start load uop\n") printf("[Compute] start load uop\n")
} }.elsewhen(dec.io.isLoadAcc) {
.elsewhen(dec.io.isLoadAcc) { printf("[Compute] start load acc\n")
printf("[Compute] start load acc\n") }.elsewhen(dec.io.isGemm) {
} printf("[Compute] start gemm\n")
.elsewhen(dec.io.isGemm) { }.elsewhen(dec.io.isAlu) {
printf("[Compute] start gemm\n") printf("[Compute] start alu\n")
} }.elsewhen(dec.io.isFinish) {
.elsewhen(dec.io.isAlu) { printf("[Compute] start finish\n")
printf("[Compute] start alu\n") }
}
.elsewhen(dec.io.isFinish) {
printf("[Compute] start finish\n")
}
} }
// done // done
when(state === sSync) { when(state === sSync) {
...@@ -202,17 +192,14 @@ class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module { ...@@ -202,17 +192,14 @@ class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module {
when(dec.io.isLoadUop) { when(dec.io.isLoadUop) {
printf("[Compute] done load uop\n") printf("[Compute] done load uop\n")
}.elsewhen(dec.io.isLoadAcc) { }.elsewhen(dec.io.isLoadAcc) {
printf("[Compute] done load acc\n") printf("[Compute] done load acc\n")
} }.elsewhen(dec.io.isGemm) {
.elsewhen(dec.io.isGemm) { printf("[Compute] done gemm\n")
printf("[Compute] done gemm\n") }.elsewhen(dec.io.isAlu) {
} printf("[Compute] done alu\n")
.elsewhen(dec.io.isAlu) { }.elsewhen(dec.io.isFinish) {
printf("[Compute] done alu\n") printf("[Compute] done finish\n")
} }
.elsewhen(dec.io.isFinish) {
printf("[Compute] done finish\n")
}
} }
} }
} }
......
...@@ -22,28 +22,27 @@ package vta.core ...@@ -22,28 +22,27 @@ package vta.core
import vta.util.config._ import vta.util.config._
/** CoreConfig. /** CoreConfig.
* *
* This is one supported configuration for VTA. This file will * This is one supported configuration for VTA. This file will
* be eventually filled out with class configurations that can be * be eventually filled out with class configurations that can be
* mixed/matched with Shell configurations for different backends. * mixed/matched with Shell configurations for different backends.
*/ */
class CoreConfig class CoreConfig extends Config((site, here, up) => {
extends Config((site, here, up) => { case CoreKey =>
case CoreKey => CoreParams(
CoreParams( batch = 1,
batch = 1, blockOut = 16,
blockOut = 16, blockIn = 16,
blockIn = 16, inpBits = 8,
inpBits = 8, wgtBits = 8,
wgtBits = 8, uopBits = 32,
uopBits = 32, accBits = 32,
accBits = 32, outBits = 8,
outBits = 8, uopMemDepth = 2048,
uopMemDepth = 2048, inpMemDepth = 2048,
inpMemDepth = 2048, wgtMemDepth = 1024,
wgtMemDepth = 1024, accMemDepth = 2048,
accMemDepth = 2048, outMemDepth = 2048,
outMemDepth = 2048, instQueueEntries = 512
instQueueEntries = 512 )
) })
})
...@@ -41,23 +41,23 @@ case class CoreParams( ...@@ -41,23 +41,23 @@ case class CoreParams(
instQueueEntries: Int = 32 instQueueEntries: Int = 32
) { ) {
require(uopBits % 8 == 0, require(uopBits % 8 == 0,
s"\n\n[VTA] [CoreParams] uopBits must be byte aligned\n\n") s"\n\n[VTA] [CoreParams] uopBits must be byte aligned\n\n")
} }
case object CoreKey extends Field[CoreParams] case object CoreKey extends Field[CoreParams]
/** Core. /** Core.
* *
* The core defines the current VTA architecture by connecting memory and * The core defines the current VTA architecture by connecting memory and
* compute modules together such as load/store and compute. Most of the * compute modules together such as load/store and compute. Most of the
* connections in the core are bulk (<>), and we should try to keep it this * connections in the core are bulk (<>), and we should try to keep it this
* way, because it is easier to understand what is going on. * way, because it is easier to understand what is going on.
* *
* Also, the core must be instantiated by a shell using the * Also, the core must be instantiated by a shell using the
* VTA Control Register (VCR) and the VTA Memory Engine (VME) interfaces. * VTA Control Register (VCR) and the VTA Memory Engine (VME) interfaces.
* More info about these interfaces and modules can be found in the shell * More info about these interfaces and modules can be found in the shell
* directory. * directory.
*/ */
class Core(implicit p: Parameters) extends Module { class Core(implicit p: Parameters) extends Module {
val io = IO(new Bundle { val io = IO(new Bundle {
val vcr = new VCRClient val vcr = new VCRClient
......
...@@ -25,16 +25,16 @@ import chisel3.util._ ...@@ -25,16 +25,16 @@ import chisel3.util._
import ISA._ import ISA._
/** MemDecode. /** MemDecode.
* *
* Decode memory instructions with a Bundle. This is similar to an union, * Decode memory instructions with a Bundle. This is similar to an union,
* therefore order matters when declaring fields. These are the instructions * therefore order matters when declaring fields. These are the instructions
* decoded with this bundle: * decoded with this bundle:
* - LUOP * - LUOP
* - LWGT * - LWGT
* - LINP * - LINP
* - LACC * - LACC
* - SOUT * - SOUT
*/ */
class MemDecode extends Bundle { class MemDecode extends Bundle {
val xpad_1 = UInt(M_PAD_BITS.W) val xpad_1 = UInt(M_PAD_BITS.W)
val xpad_0 = UInt(M_PAD_BITS.W) val xpad_0 = UInt(M_PAD_BITS.W)
...@@ -55,10 +55,10 @@ class MemDecode extends Bundle { ...@@ -55,10 +55,10 @@ class MemDecode extends Bundle {
} }
/** GemmDecode. /** GemmDecode.
* *
* Decode GEMM instruction with a Bundle. This is similar to an union, * Decode GEMM instruction with a Bundle. This is similar to an union,
* therefore order matters when declaring fields. * therefore order matters when declaring fields.
*/ */
class GemmDecode extends Bundle { class GemmDecode extends Bundle {
val wgt_1 = UInt(C_WIDX_BITS.W) val wgt_1 = UInt(C_WIDX_BITS.W)
val wgt_0 = UInt(C_WIDX_BITS.W) val wgt_0 = UInt(C_WIDX_BITS.W)
...@@ -80,15 +80,15 @@ class GemmDecode extends Bundle { ...@@ -80,15 +80,15 @@ class GemmDecode extends Bundle {
} }
/** AluDecode. /** AluDecode.
* *
* Decode ALU instructions with a Bundle. This is similar to an union, * Decode ALU instructions with a Bundle. This is similar to an union,
* therefore order matters when declaring fields. These are the instructions * therefore order matters when declaring fields. These are the instructions
* decoded with this bundle: * decoded with this bundle:
* - VMIN * - VMIN
* - VMAX * - VMAX
* - VADD * - VADD
* - VSHX * - VSHX
*/ */
class AluDecode extends Bundle { class AluDecode extends Bundle {
val empty_1 = Bool() val empty_1 = Bool()
val alu_imm = UInt(C_ALU_IMM_BITS.W) val alu_imm = UInt(C_ALU_IMM_BITS.W)
...@@ -112,9 +112,9 @@ class AluDecode extends Bundle { ...@@ -112,9 +112,9 @@ class AluDecode extends Bundle {
} }
/** UopDecode. /** UopDecode.
* *
* Decode micro-ops (uops). * Decode micro-ops (uops).
*/ */
class UopDecode extends Bundle { class UopDecode extends Bundle {
val u2 = UInt(10.W) val u2 = UInt(10.W)
val u1 = UInt(11.W) val u1 = UInt(11.W)
...@@ -122,9 +122,9 @@ class UopDecode extends Bundle { ...@@ -122,9 +122,9 @@ class UopDecode extends Bundle {
} }
/** FetchDecode. /** FetchDecode.
* *
* Partial decoding for dispatching instructions to Load, Compute, and Store. * Partial decoding for dispatching instructions to Load, Compute, and Store.
*/ */
class FetchDecode extends Module { class FetchDecode extends Module {
val io = IO(new Bundle { val io = IO(new Bundle {
val inst = Input(UInt(INST_BITS.W)) val inst = Input(UInt(INST_BITS.W))
...@@ -159,9 +159,9 @@ class FetchDecode extends Module { ...@@ -159,9 +159,9 @@ class FetchDecode extends Module {
} }
/** LoadDecode. /** LoadDecode.
* *
* Decode dependencies, type and sync for Load module. * Decode dependencies, type and sync for Load module.
*/ */
class LoadDecode extends Module { class LoadDecode extends Module {
val io = IO(new Bundle { val io = IO(new Bundle {
val inst = Input(UInt(INST_BITS.W)) val inst = Input(UInt(INST_BITS.W))
...@@ -180,9 +180,9 @@ class LoadDecode extends Module { ...@@ -180,9 +180,9 @@ class LoadDecode extends Module {
} }
/** ComputeDecode. /** ComputeDecode.
* *
* Decode dependencies, type and sync for Compute module. * Decode dependencies, type and sync for Compute module.
*/ */
class ComputeDecode extends Module { class ComputeDecode extends Module {
val io = IO(new Bundle { val io = IO(new Bundle {
val inst = Input(UInt(INST_BITS.W)) val inst = Input(UInt(INST_BITS.W))
...@@ -211,9 +211,9 @@ class ComputeDecode extends Module { ...@@ -211,9 +211,9 @@ class ComputeDecode extends Module {
} }
/** StoreDecode. /** StoreDecode.
* *
* Decode dependencies, type and sync for Store module. * Decode dependencies, type and sync for Store module.
*/ */
class StoreDecode extends Module { class StoreDecode extends Module {
val io = IO(new Bundle { val io = IO(new Bundle {
val inst = Input(UInt(INST_BITS.W)) val inst = Input(UInt(INST_BITS.W))
......
...@@ -25,21 +25,20 @@ import vta.util.config._ ...@@ -25,21 +25,20 @@ import vta.util.config._
import vta.shell._ import vta.shell._
/** EventCounters. /** EventCounters.
* *
* This unit contains all the event counting logic. One common event tracked in * This unit contains all the event counting logic. One common event tracked in
* hardware is the number of clock cycles taken to achieve certain task. We * hardware is the number of clock cycles taken to achieve certain task. We
* can count the total number of clock cycles spent in a VTA run by checking * can count the total number of clock cycles spent in a VTA run by checking
* launch and finish signals. * launch and finish signals.
* *
* The event counter value is passed to the VCR module via the ecnt port, so * The event counter value is passed to the VCR module via the ecnt port, so
* they can be accessed by the host. The number of event counters (nECnt) is * they can be accessed by the host. The number of event counters (nECnt) is
* defined in the Shell VCR module as a parameter, see VCRParams. * defined in the Shell VCR module as a parameter, see VCRParams.
* *
* If one would like to add an event counter, then the value of nECnt must be * If one would like to add an event counter, then the value of nECnt must be
* changed in VCRParams together with the corresponding counting logic here. * changed in VCRParams together with the corresponding counting logic here.
*/ */
class EventCounters(debug: Boolean = false)(implicit p: Parameters) class EventCounters(debug: Boolean = false)(implicit p: Parameters) extends Module {
extends Module {
val vp = p(ShellKey).vcrParams val vp = p(ShellKey).vcrParams
val io = IO(new Bundle { val io = IO(new Bundle {
val launch = Input(Bool()) val launch = Input(Bool())
......
...@@ -25,20 +25,20 @@ import vta.util.config._ ...@@ -25,20 +25,20 @@ import vta.util.config._
import vta.shell._ import vta.shell._
/** Fetch. /** Fetch.
* *
* The fetch unit reads instructions (tasks) from memory (i.e. DRAM), using the * The fetch unit reads instructions (tasks) from memory (i.e. DRAM), using the
* VTA Memory Engine (VME), and push them into an instruction queue called * VTA Memory Engine (VME), and push them into an instruction queue called
* inst_q. Once the instruction queue is full, instructions are dispatched to * inst_q. Once the instruction queue is full, instructions are dispatched to
* the Load, Compute and Store module queues based on the instruction opcode. * the Load, Compute and Store module queues based on the instruction opcode.
* After draining the queue, the fetch unit checks if there are more instructions * After draining the queue, the fetch unit checks if there are more instructions
* via the ins_count register which is written by the host. * via the ins_count register which is written by the host.
* *
* Additionally, instructions are read into two chunks (see sReadLSB and sReadMSB) * Additionally, instructions are read into two chunks (see sReadLSB and sReadMSB)
* because we are using a DRAM payload of 8-bytes or half of a VTA instruction. * because we are using a DRAM payload of 8-bytes or half of a VTA instruction.
* This should be configurable for larger payloads, i.e. 64-bytes, which can load * This should be configurable for larger payloads, i.e. 64-bytes, which can load
* more than one instruction at the time. Finally, the instruction queue is * more than one instruction at the time. Finally, the instruction queue is
* sized (entries_q), depending on the maximum burst allowed in the memory. * sized (entries_q), depending on the maximum burst allowed in the memory.
*/ */
class Fetch(debug: Boolean = false)(implicit p: Parameters) extends Module { class Fetch(debug: Boolean = false)(implicit p: Parameters) extends Module {
val vp = p(ShellKey).vcrParams val vp = p(ShellKey).vcrParams
val mp = p(ShellKey).memParams val mp = p(ShellKey).memParams
...@@ -112,17 +112,16 @@ class Fetch(debug: Boolean = false)(implicit p: Parameters) extends Module { ...@@ -112,17 +112,16 @@ class Fetch(debug: Boolean = false)(implicit p: Parameters) extends Module {
when(xrem === 0.U) { when(xrem === 0.U) {
state := sIdle state := sIdle
}.elsewhen(xrem < xmax) { }.elsewhen(xrem < xmax) {
state := sReadCmd state := sReadCmd
rlen := xrem rlen := xrem
ilen := xrem >> 1.U ilen := xrem >> 1.U
xrem := 0.U xrem := 0.U
} }.otherwise {
.otherwise { state := sReadCmd
state := sReadCmd rlen := xmax - 1.U
rlen := xmax - 1.U ilen := (xmax >> 1.U) - 1.U
ilen := (xmax >> 1.U) - 1.U xrem := xrem - xmax
xrem := xrem - xmax }
}
} }
} }
} }
...@@ -165,12 +164,12 @@ class Fetch(debug: Boolean = false)(implicit p: Parameters) extends Module { ...@@ -165,12 +164,12 @@ class Fetch(debug: Boolean = false)(implicit p: Parameters) extends Module {
val deq_sel = Cat(dec.io.isCompute, dec.io.isStore, dec.io.isLoad).asUInt val deq_sel = Cat(dec.io.isCompute, dec.io.isStore, dec.io.isLoad).asUInt
val deq_ready = val deq_ready =
MuxLookup(deq_sel, MuxLookup(deq_sel,
false.B, // default false.B, // default
Array( Array(
"h_01".U -> io.inst.ld.ready, "h_01".U -> io.inst.ld.ready,
"h_02".U -> io.inst.st.ready, "h_02".U -> io.inst.st.ready,
"h_04".U -> io.inst.co.ready "h_04".U -> io.inst.co.ready
)) ))
// dequeue instruction // dequeue instruction
inst_q.io.deq.ready := deq_ready & inst_q.io.deq.valid & state === sDrain inst_q.io.deq.ready := deq_ready & inst_q.io.deq.valid & state === sDrain
......
...@@ -24,9 +24,9 @@ import chisel3.util._ ...@@ -24,9 +24,9 @@ import chisel3.util._
import scala.collection.mutable.HashMap import scala.collection.mutable.HashMap
/** ISAConstants. /** ISAConstants.
* *
* These constants are used for decoding (parsing) fields on instructions. * These constants are used for decoding (parsing) fields on instructions.
*/ */
trait ISAConstants { trait ISAConstants {
val INST_BITS = 128 val INST_BITS = 128
...@@ -70,13 +70,13 @@ trait ISAConstants { ...@@ -70,13 +70,13 @@ trait ISAConstants {
} }
/** ISA. /** ISA.
* *
* This is the VTA task ISA * This is the VTA task ISA
* *
* TODO: Add VXOR to clear accumulator * TODO: Add VXOR to clear accumulator
* TODO: Use ISA object for decoding as well * TODO: Use ISA object for decoding as well
* TODO: Eventually deprecate ISAConstants * TODO: Eventually deprecate ISAConstants
*/ */
object ISA { object ISA {
private val xLen = 128 private val xLen = 128
private val depBits = 4 private val depBits = 4
...@@ -86,19 +86,19 @@ object ISA { ...@@ -86,19 +86,19 @@ object ISA {
private val taskId: HashMap[String, String] = private val taskId: HashMap[String, String] =
HashMap(("load", "000"), HashMap(("load", "000"),
("store", "001"), ("store", "001"),
("gemm", "010"), ("gemm", "010"),
("finish", "011"), ("finish", "011"),
("alu", "100")) ("alu", "100"))
private val memId: HashMap[String, String] = private val memId: HashMap[String, String] =
HashMap(("uop", "00"), ("wgt", "01"), ("inp", "10"), ("acc", "11")) HashMap(("uop", "00"), ("wgt", "01"), ("inp", "10"), ("acc", "11"))
private val aluId: HashMap[String, String] = private val aluId: HashMap[String, String] =
HashMap(("minpool", "00"), HashMap(("minpool", "00"),
("maxpool", "01"), ("maxpool", "01"),
("add", "10"), ("add", "10"),
("shift", "11")) ("shift", "11"))
private def dontCare(bits: Int): String = "?" * bits private def dontCare(bits: Int): String = "?" * bits
......
...@@ -25,12 +25,12 @@ import vta.util.config._ ...@@ -25,12 +25,12 @@ import vta.util.config._
import vta.shell._ import vta.shell._
/** Load. /** Load.
* *
* Load inputs and weights from memory (DRAM) into scratchpads (SRAMs). * Load inputs and weights from memory (DRAM) into scratchpads (SRAMs).
* This module instantiate the TensorLoad unit which is in charge of * This module instantiate the TensorLoad unit which is in charge of
* loading 1D and 2D tensors to scratchpads, so it can be used by * loading 1D and 2D tensors to scratchpads, so it can be used by
* other modules such as Compute. * other modules such as Compute.
*/ */
class Load(debug: Boolean = false)(implicit p: Parameters) extends Module { class Load(debug: Boolean = false)(implicit p: Parameters) extends Module {
val mp = p(ShellKey).memParams val mp = p(ShellKey).memParams
val io = IO(new Bundle { val io = IO(new Bundle {
...@@ -110,11 +110,10 @@ class Load(debug: Boolean = false)(implicit p: Parameters) extends Module { ...@@ -110,11 +110,10 @@ class Load(debug: Boolean = false)(implicit p: Parameters) extends Module {
when(dec.io.isSync) { when(dec.io.isSync) {
printf("[Load] start sync\n") printf("[Load] start sync\n")
}.elsewhen(dec.io.isInput) { }.elsewhen(dec.io.isInput) {
printf("[Load] start input\n") printf("[Load] start input\n")
} }.elsewhen(dec.io.isWeight) {
.elsewhen(dec.io.isWeight) { printf("[Load] start weight\n")
printf("[Load] start weight\n") }
}
} }
// done // done
when(state === sSync) { when(state === sSync) {
......
...@@ -25,11 +25,11 @@ import vta.util.config._ ...@@ -25,11 +25,11 @@ import vta.util.config._
import vta.shell._ import vta.shell._
/** UopMaster. /** UopMaster.
* *
* Uop interface used by a master module, i.e. TensorAlu or TensorGemm, * Uop interface used by a master module, i.e. TensorAlu or TensorGemm,
* to request a micro-op (uop) from the uop-scratchpad. The index (idx) is * to request a micro-op (uop) from the uop-scratchpad. The index (idx) is
* used as an address to find the uop in the uop-scratchpad. * used as an address to find the uop in the uop-scratchpad.
*/ */
class UopMaster(implicit p: Parameters) extends Bundle { class UopMaster(implicit p: Parameters) extends Bundle {
val addrBits = log2Ceil(p(CoreKey).uopMemDepth) val addrBits = log2Ceil(p(CoreKey).uopMemDepth)
val idx = ValidIO(UInt(addrBits.W)) val idx = ValidIO(UInt(addrBits.W))
...@@ -38,11 +38,11 @@ class UopMaster(implicit p: Parameters) extends Bundle { ...@@ -38,11 +38,11 @@ class UopMaster(implicit p: Parameters) extends Bundle {
} }
/** UopClient. /** UopClient.
* *
* Uop interface used by a client module, i.e. LoadUop, to receive * Uop interface used by a client module, i.e. LoadUop, to receive
* a request from a master module, i.e. TensorAlu or TensorGemm. * a request from a master module, i.e. TensorAlu or TensorGemm.
* The index (idx) is used as an address to find the uop in the uop-scratchpad. * The index (idx) is used as an address to find the uop in the uop-scratchpad.
*/ */
class UopClient(implicit p: Parameters) extends Bundle { class UopClient(implicit p: Parameters) extends Bundle {
val addrBits = log2Ceil(p(CoreKey).uopMemDepth) val addrBits = log2Ceil(p(CoreKey).uopMemDepth)
val idx = Flipped(ValidIO(UInt(addrBits.W))) val idx = Flipped(ValidIO(UInt(addrBits.W)))
...@@ -51,12 +51,12 @@ class UopClient(implicit p: Parameters) extends Bundle { ...@@ -51,12 +51,12 @@ class UopClient(implicit p: Parameters) extends Bundle {
} }
/** LoadUop. /** LoadUop.
* *
* Load micro-ops (uops) from memory, i.e. DRAM, and store them in the * Load micro-ops (uops) from memory, i.e. DRAM, and store them in the
* uop-scratchpad. Currently, micro-ops are 32-bit wide and loaded in * uop-scratchpad. Currently, micro-ops are 32-bit wide and loaded in
* group of 2 given the fact that the DRAM payload is 8-bytes. This module * group of 2 given the fact that the DRAM payload is 8-bytes. This module
* should be modified later on to support different DRAM sizes efficiently. * should be modified later on to support different DRAM sizes efficiently.
*/ */
class LoadUop(debug: Boolean = false)(implicit p: Parameters) extends Module { class LoadUop(debug: Boolean = false)(implicit p: Parameters) extends Module {
val mp = p(ShellKey).memParams val mp = p(ShellKey).memParams
val io = IO(new Bundle { val io = IO(new Bundle {
...@@ -113,15 +113,14 @@ class LoadUop(debug: Boolean = false)(implicit p: Parameters) extends Module { ...@@ -113,15 +113,14 @@ class LoadUop(debug: Boolean = false)(implicit p: Parameters) extends Module {
when(xrem === 0.U) { when(xrem === 0.U) {
state := sIdle state := sIdle
}.elsewhen(xrem < xmax) { }.elsewhen(xrem < xmax) {
state := sReadCmd state := sReadCmd
xlen := xrem xlen := xrem
xrem := 0.U xrem := 0.U
} }.otherwise {
.otherwise { state := sReadCmd
state := sReadCmd xlen := xmax - 1.U
xlen := xmax - 1.U xrem := xrem - xmax
xrem := xrem - xmax }
}
} }
} }
} }
...@@ -166,19 +165,18 @@ class LoadUop(debug: Boolean = false)(implicit p: Parameters) extends Module { ...@@ -166,19 +165,18 @@ class LoadUop(debug: Boolean = false)(implicit p: Parameters) extends Module {
when(sizeIsEven) { when(sizeIsEven) {
wmask := "b_11".U.asTypeOf(wmask) wmask := "b_11".U.asTypeOf(wmask)
}.elsewhen(io.vme_rd.cmd.fire()) { }.elsewhen(io.vme_rd.cmd.fire()) {
when(dec.xsize === 1.U) { when(dec.xsize === 1.U) {
wmask := "b_01".U.asTypeOf(wmask) wmask := "b_01".U.asTypeOf(wmask)
}.otherwise { }.otherwise {
wmask := "b_11".U.asTypeOf(wmask) wmask := "b_11".U.asTypeOf(wmask)
}
} }
.elsewhen(io.vme_rd.data.fire()) { }.elsewhen(io.vme_rd.data.fire()) {
when((xcnt === xlen - 1.U) && (xrem === 0.U)) { when((xcnt === xlen - 1.U) && (xrem === 0.U)) {
wmask := "b_01".U.asTypeOf(wmask) wmask := "b_01".U.asTypeOf(wmask)
}.otherwise { }.otherwise {
wmask := "b_11".U.asTypeOf(wmask) wmask := "b_11".U.asTypeOf(wmask)
}
} }
}
}.otherwise { }.otherwise {
when(io.vme_rd.cmd.fire()) { when(io.vme_rd.cmd.fire()) {
wmask := "b_10".U.asTypeOf(wmask) wmask := "b_10".U.asTypeOf(wmask)
......
...@@ -23,14 +23,13 @@ import chisel3._ ...@@ -23,14 +23,13 @@ import chisel3._
import chisel3.util._ import chisel3.util._
/** Semaphore. /** Semaphore.
* *
* This semaphore is used instead of push/pop fifo, used in the initial * This semaphore is used instead of push/pop fifo, used in the initial
* version of VTA. This semaphore is incremented (spost) or decremented (swait) * version of VTA. This semaphore is incremented (spost) or decremented (swait)
* depending on the push and pop fields on instructions to prevent RAW and WAR * depending on the push and pop fields on instructions to prevent RAW and WAR
* hazards. * hazards.
*/ */
class Semaphore(counterBits: Int = 1, counterInitValue: Int = 1) class Semaphore(counterBits: Int = 1, counterInitValue: Int = 1) extends Module {
extends Module {
val io = IO(new Bundle { val io = IO(new Bundle {
val spost = Input(Bool()) val spost = Input(Bool())
val swait = Input(Bool()) val swait = Input(Bool())
......
...@@ -25,11 +25,11 @@ import vta.util.config._ ...@@ -25,11 +25,11 @@ import vta.util.config._
import vta.shell._ import vta.shell._
/** Store. /** Store.
* *
* Store results back to memory (DRAM) from scratchpads (SRAMs). * Store results back to memory (DRAM) from scratchpads (SRAMs).
* This module instantiate the TensorStore unit which is in charge * This module instantiate the TensorStore unit which is in charge
* of storing 1D and 2D tensors to main memory. * of storing 1D and 2D tensors to main memory.
*/ */
class Store(debug: Boolean = false)(implicit p: Parameters) extends Module { class Store(debug: Boolean = false)(implicit p: Parameters) extends Module {
val mp = p(ShellKey).memParams val mp = p(ShellKey).memParams
val io = IO(new Bundle { val io = IO(new Bundle {
......
...@@ -39,11 +39,8 @@ class Alu(implicit p: Parameters) extends Module { ...@@ -39,11 +39,8 @@ class Alu(implicit p: Parameters) extends Module {
val m = ~ub(width - 1, 0) + 1.U val m = ~ub(width - 1, 0) + 1.U
val n = ub(width - 1, 0) val n = ub(width - 1, 0)
val fop = Seq(Mux(io.a < io.b, io.a, io.b), val fop = Seq(Mux(io.a < io.b, io.a, io.b), Mux(io.a < io.b, io.b, io.a),
Mux(io.a < io.b, io.b, io.a), io.a + io.b, io.a >> n, io.a << m)
io.a + io.b,
io.a >> n,
io.a << m)
val opmux = Seq.tabulate(ALU_OP_NUM)(i => ALU_OP(i) -> fop(i)) val opmux = Seq.tabulate(ALU_OP_NUM)(i => ALU_OP(i) -> fop(i))
io.y := MuxLookup(io.opcode, io.a, opmux) io.y := MuxLookup(io.opcode, io.a, opmux)
...@@ -101,12 +98,12 @@ class AluVector(implicit p: Parameters) extends Module { ...@@ -101,12 +98,12 @@ class AluVector(implicit p: Parameters) extends Module {
} }
/** TensorAlu. /** TensorAlu.
* *
* This unit instantiate the ALU vector unit (AluVector) and go over the * This unit instantiate the ALU vector unit (AluVector) and go over the
* micro-ops (uops) which are used to read the source operands (vectors) * micro-ops (uops) which are used to read the source operands (vectors)
* from the acc-scratchpad and then they are written back the same * from the acc-scratchpad and then they are written back the same
* acc-scratchpad. * acc-scratchpad.
*/ */
class TensorAlu(debug: Boolean = false)(implicit p: Parameters) extends Module { class TensorAlu(debug: Boolean = false)(implicit p: Parameters) extends Module {
val aluBits = p(CoreKey).accBits val aluBits = p(CoreKey).accBits
val io = IO(new Bundle { val io = IO(new Bundle {
...@@ -200,18 +197,14 @@ class TensorAlu(debug: Boolean = false)(implicit p: Parameters) extends Module { ...@@ -200,18 +197,14 @@ class TensorAlu(debug: Boolean = false)(implicit p: Parameters) extends Module {
dst_i := 0.U dst_i := 0.U
src_i := 0.U src_i := 0.U
}.elsewhen(state === sReadUop && cnt_i === dec.lp_1) { }.elsewhen(state === sReadUop && cnt_i === dec.lp_1) {
cnt_i := 0.U cnt_i := 0.U
dst_i := dst_o dst_i := dst_o
src_i := src_o src_i := src_o
} }.elsewhen(state === sExe && alu.io.out.data.valid && uop_idx === uop_end - 1.U) {
.elsewhen( cnt_i := cnt_i + 1.U
state === sExe && dst_i := dst_i + dec.dst_1
alu.io.out.data.valid && src_i := src_i + dec.src_1
uop_idx === uop_end - 1.U) { }
cnt_i := cnt_i + 1.U
dst_i := dst_i + dec.dst_1
src_i := src_i + dec.src_1
}
when(state === sComputeIdx && io.uop.data.valid) { when(state === sComputeIdx && io.uop.data.valid) {
uop_dst := io.uop.data.bits.u0 + dst_i uop_dst := io.uop.data.bits.u0 + dst_i
...@@ -232,7 +225,7 @@ class TensorAlu(debug: Boolean = false)(implicit p: Parameters) extends Module { ...@@ -232,7 +225,7 @@ class TensorAlu(debug: Boolean = false)(implicit p: Parameters) extends Module {
tensorImm.data.bits.foreach { b => tensorImm.data.bits.foreach { b =>
b.foreach { c => b.foreach { c =>
c := Mux(dec.alu_imm(C_ALU_IMM_BITS - 1), c := Mux(dec.alu_imm(C_ALU_IMM_BITS - 1),
Cat(-1.S((aluBits - C_ALU_IMM_BITS).W), dec.alu_imm), dec.alu_imm) Cat(-1.S((aluBits - C_ALU_IMM_BITS).W), dec.alu_imm), dec.alu_imm)
} }
} }
...@@ -244,11 +237,11 @@ class TensorAlu(debug: Boolean = false)(implicit p: Parameters) extends Module { ...@@ -244,11 +237,11 @@ class TensorAlu(debug: Boolean = false)(implicit p: Parameters) extends Module {
alu.io.acc_a.data.valid := io.acc.rd.data.valid & state === sReadTensorB alu.io.acc_a.data.valid := io.acc.rd.data.valid & state === sReadTensorB
alu.io.acc_a.data.bits <> io.acc.rd.data.bits alu.io.acc_a.data.bits <> io.acc.rd.data.bits
alu.io.acc_b.data.valid := Mux(dec.alu_use_imm, alu.io.acc_b.data.valid := Mux(dec.alu_use_imm,
tensorImm.data.valid, tensorImm.data.valid,
io.acc.rd.data.valid & state === sExe) io.acc.rd.data.valid & state === sExe)
alu.io.acc_b.data.bits <> Mux(dec.alu_use_imm, alu.io.acc_b.data.bits <> Mux(dec.alu_use_imm,
tensorImm.data.bits, tensorImm.data.bits,
io.acc.rd.data.bits) io.acc.rd.data.bits)
// acc_o // acc_o
io.acc.wr.valid := alu.io.acc_y.data.valid io.acc.wr.valid := alu.io.acc_y.data.valid
......
...@@ -47,9 +47,9 @@ class MAC(aBits: Int = 8, bBits: Int = 8, cBits: Int = 16) extends Module { ...@@ -47,9 +47,9 @@ class MAC(aBits: Int = 8, bBits: Int = 8, cBits: Int = 16) extends Module {
} }
/** PipeAdder /** PipeAdder
* *
* This unit loads input bits into register and performs addition in the next cycle * This unit loads input bits into register and performs addition in the next cycle
*/ */
class PipeAdder(aBits: Int = 8, bBits: Int = 8) extends Module { class PipeAdder(aBits: Int = 8, bBits: Int = 8) extends Module {
val outBits = Math.max(aBits, bBits) + 1 val outBits = Math.max(aBits, bBits) + 1
val io = IO(new Bundle { val io = IO(new Bundle {
...@@ -65,10 +65,10 @@ class PipeAdder(aBits: Int = 8, bBits: Int = 8) extends Module { ...@@ -65,10 +65,10 @@ class PipeAdder(aBits: Int = 8, bBits: Int = 8) extends Module {
} }
/** Adder /** Adder
* *
* This unit wires input bits to an adder directly. * This unit wires input bits to an adder directly.
* The output comes out of combinational logic without waiting for another cycle. * The output comes out of combinational logic without waiting for another cycle.
*/ */
class Adder(aBits: Int = 8, bBits: Int = 8) extends Module { class Adder(aBits: Int = 8, bBits: Int = 8) extends Module {
val outBits = Math.max(aBits, bBits) + 1 val outBits = Math.max(aBits, bBits) + 1
val io = IO(new Bundle { val io = IO(new Bundle {
...@@ -86,8 +86,7 @@ class Adder(aBits: Int = 8, bBits: Int = 8) extends Module { ...@@ -86,8 +86,7 @@ class Adder(aBits: Int = 8, bBits: Int = 8) extends Module {
} }
/** Pipelined DotProduct based on MAC and PipeAdder */ /** Pipelined DotProduct based on MAC and PipeAdder */
class DotProduct(aBits: Int = 8, bBits: Int = 8, size: Int = 16) class DotProduct(aBits: Int = 8, bBits: Int = 8, size: Int = 16) extends Module {
extends Module {
val errorMsg = val errorMsg =
s"\n\n[VTA] [DotProduct] size must be greater than 4 and a power of 2\n\n" s"\n\n[VTA] [DotProduct] size must be greater than 4 and a power of 2\n\n"
require(size >= 2 && isPow2(size), errorMsg) require(size >= 2 && isPow2(size), errorMsg)
...@@ -175,16 +174,15 @@ class MatrixVectorMultiplication(implicit p: Parameters) extends Module { ...@@ -175,16 +174,15 @@ class MatrixVectorMultiplication(implicit p: Parameters) extends Module {
} }
/** TensorGemm. /** TensorGemm.
* *
* This unit instantiate the MatrixVectorMultiplication and go over the * This unit instantiate the MatrixVectorMultiplication and go over the
* micro-ops (uops) which are used to read inputs, weights and biases, * micro-ops (uops) which are used to read inputs, weights and biases,
* and writes results back to the acc and out scratchpads. * and writes results back to the acc and out scratchpads.
* *
* Also, the TensorGemm uses the reset field in the Gemm instruction to * Also, the TensorGemm uses the reset field in the Gemm instruction to
* clear or zero-out the acc-scratchpad locations based on the micro-ops. * clear or zero-out the acc-scratchpad locations based on the micro-ops.
*/ */
class TensorGemm(debug: Boolean = false)(implicit p: Parameters) class TensorGemm(debug: Boolean = false)(implicit p: Parameters) extends Module {
extends Module {
val io = IO(new Bundle { val io = IO(new Bundle {
val start = Input(Bool()) val start = Input(Bool())
val done = Output(Bool()) val done = Output(Bool())
...@@ -268,11 +266,10 @@ class TensorGemm(debug: Boolean = false)(implicit p: Parameters) ...@@ -268,11 +266,10 @@ class TensorGemm(debug: Boolean = false)(implicit p: Parameters)
when((state === sReadTensor) && mvc.io.acc_o.data.valid) { // issue & commit when((state === sReadTensor) && mvc.io.acc_o.data.valid) { // issue & commit
inflight := inflight inflight := inflight
}.elsewhen(state === sReadTensor) { // issue a tensor }.elsewhen(state === sReadTensor) { // issue a tensor
inflight := inflight + 1.U inflight := inflight + 1.U
} }.elsewhen(mvc.io.acc_o.data.valid) { // commit a tensor
.elsewhen(mvc.io.acc_o.data.valid) { // commit a tensor inflight := inflight - 1.U
inflight := inflight - 1.U }
}
} }
when( when(
...@@ -305,17 +302,16 @@ class TensorGemm(debug: Boolean = false)(implicit p: Parameters) ...@@ -305,17 +302,16 @@ class TensorGemm(debug: Boolean = false)(implicit p: Parameters)
inp_i := 0.U inp_i := 0.U
wgt_i := 0.U wgt_i := 0.U
}.elsewhen(state === sReadUop && cnt_i === dec.lp_1) { }.elsewhen(state === sReadUop && cnt_i === dec.lp_1) {
cnt_i := 0.U cnt_i := 0.U
acc_i := acc_o acc_i := acc_o
inp_i := inp_o inp_i := inp_o
wgt_i := wgt_o wgt_i := wgt_o
} }.elsewhen(state === sExe && uop_idx === uop_end - 1.U) {
.elsewhen(state === sExe && uop_idx === uop_end - 1.U) { cnt_i := cnt_i + 1.U
cnt_i := cnt_i + 1.U acc_i := acc_i + dec.acc_1
acc_i := acc_i + dec.acc_1 inp_i := inp_i + dec.inp_1
inp_i := inp_i + dec.inp_1 wgt_i := wgt_i + dec.wgt_1
wgt_i := wgt_i + dec.wgt_1 }
}
when(state === sComputeIdx && io.uop.data.valid) { when(state === sComputeIdx && io.uop.data.valid) {
uop_acc := io.uop.data.bits.u0 + acc_i uop_acc := io.uop.data.bits.u0 + acc_i
...@@ -351,9 +347,8 @@ class TensorGemm(debug: Boolean = false)(implicit p: Parameters) ...@@ -351,9 +347,8 @@ class TensorGemm(debug: Boolean = false)(implicit p: Parameters)
mvc.io.acc_i.data <> io.acc.rd.data mvc.io.acc_i.data <> io.acc.rd.data
// acc_o // acc_o
io.acc.wr.valid := mvc.io.acc_o.data.valid & Mux(dec.reset, io.acc.wr.valid := mvc.io.acc_o.data.valid &
true.B, Mux(dec.reset, true.B, wrpipe.io.deq.valid)
wrpipe.io.deq.valid)
io.acc.wr.bits.idx := Mux(dec.reset, uop_acc, wrpipe.io.deq.bits) io.acc.wr.bits.idx := Mux(dec.reset, uop_acc, wrpipe.io.deq.bits)
io.acc.wr.bits.data <> mvc.io.acc_o.data.bits io.acc.wr.bits.data <> mvc.io.acc_o.data.bits
...@@ -371,10 +366,7 @@ class TensorGemm(debug: Boolean = false)(implicit p: Parameters) ...@@ -371,10 +366,7 @@ class TensorGemm(debug: Boolean = false)(implicit p: Parameters)
} }
when(state === sReadTensor && ~dec.reset) { when(state === sReadTensor && ~dec.reset) {
printf("[TensorGemm] [uop] acc:%x inp:%x wgt:%x\n", printf("[TensorGemm] [uop] acc:%x inp:%x wgt:%x\n", uop_acc, uop_inp, uop_wgt)
uop_acc,
uop_inp,
uop_wgt)
} }
io.inp.rd.data.bits.zipWithIndex.foreach { io.inp.rd.data.bits.zipWithIndex.foreach {
......
...@@ -25,13 +25,13 @@ import vta.util.config._ ...@@ -25,13 +25,13 @@ import vta.util.config._
import vta.shell._ import vta.shell._
/** TensorLoad. /** TensorLoad.
* *
* Load 1D and 2D tensors from main memory (DRAM) to input/weight * Load 1D and 2D tensors from main memory (DRAM) to input/weight
* scratchpads (SRAM). Also, there is support for zero padding, while * scratchpads (SRAM). Also, there is support for zero padding, while
* doing the load. Zero-padding works on the y and x axis, and it is * doing the load. Zero-padding works on the y and x axis, and it is
* managed by TensorPadCtrl. The TensorDataCtrl is in charge of * managed by TensorPadCtrl. The TensorDataCtrl is in charge of
* handling the way tensors are stored on the scratchpads. * handling the way tensors are stored on the scratchpads.
*/ */
class TensorLoad(tensorType: String = "none", debug: Boolean = false)( class TensorLoad(tensorType: String = "none", debug: Boolean = false)(
implicit p: Parameters) implicit p: Parameters)
extends Module { extends Module {
...@@ -71,11 +71,10 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false)( ...@@ -71,11 +71,10 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false)(
when(dec.ypad_0 =/= 0.U) { when(dec.ypad_0 =/= 0.U) {
state := sYPad0 state := sYPad0
}.elsewhen(dec.xpad_0 =/= 0.U) { }.elsewhen(dec.xpad_0 =/= 0.U) {
state := sXPad0 state := sXPad0
} }.otherwise {
.otherwise { state := sReadCmd
state := sReadCmd }
}
} }
} }
is(sYPad0) { is(sYPad0) {
...@@ -213,13 +212,12 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false)( ...@@ -213,13 +212,12 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false)(
waddr_cur := dec.sram_offset waddr_cur := dec.sram_offset
waddr_nxt := dec.sram_offset waddr_nxt := dec.sram_offset
}.elsewhen((io.vme_rd.data }.elsewhen((io.vme_rd.data
.fire() || isZeroPad) && set === (tp.tensorLength - 1).U && tag === (tp.numMemBlock - 1).U) { .fire() || isZeroPad) && set === (tp.tensorLength - 1).U && tag === (tp.numMemBlock - 1).U) {
waddr_cur := waddr_cur + 1.U waddr_cur := waddr_cur + 1.U
} }.elsewhen(dataCtrl.io.stride) {
.elsewhen(dataCtrl.io.stride) { waddr_cur := waddr_nxt + dec.xsize
waddr_cur := waddr_nxt + dec.xsize waddr_nxt := waddr_nxt + dec.xsize
waddr_nxt := waddr_nxt + dec.xsize }
}
val tensorFile = Seq.fill(tp.tensorLength) { val tensorFile = Seq.fill(tp.tensorLength) {
SyncReadMem(tp.memDepth, Vec(tp.numMemBlock, UInt(tp.memBlockBits.W))) SyncReadMem(tp.memDepth, Vec(tp.numMemBlock, UInt(tp.memBlockBits.W)))
...@@ -241,8 +239,8 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false)( ...@@ -241,8 +239,8 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false)(
val tdata = io.tensor.wr.bits.data(i).asUInt.asTypeOf(wdata(i)) val tdata = io.tensor.wr.bits.data(i).asUInt.asTypeOf(wdata(i))
val muxWen = val muxWen =
Mux(state === sIdle, Mux(state === sIdle,
io.tensor.wr.valid, io.tensor.wr.valid,
(io.vme_rd.data.fire() | isZeroPad) & set === i.U) (io.vme_rd.data.fire() | isZeroPad) & set === i.U)
val muxWaddr = Mux(state === sIdle, io.tensor.wr.bits.idx, waddr_cur) val muxWaddr = Mux(state === sIdle, io.tensor.wr.bits.idx, waddr_cur)
val muxWdata = Mux(state === sIdle, tdata, wdata(i)) val muxWdata = Mux(state === sIdle, tdata, wdata(i))
val muxWmask = Mux(state === sIdle, no_mask, wmask(i)) val muxWmask = Mux(state === sIdle, no_mask, wmask(i))
...@@ -274,8 +272,8 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false)( ...@@ -274,8 +272,8 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false)(
if (tensorType == "inp") { if (tensorType == "inp") {
when(io.vme_rd.cmd.fire()) { when(io.vme_rd.cmd.fire()) {
printf("[TensorLoad] [inp] cmd addr:%x len:%x\n", printf("[TensorLoad] [inp] cmd addr:%x len:%x\n",
dataCtrl.io.addr, dataCtrl.io.addr,
dataCtrl.io.len) dataCtrl.io.len)
} }
when(state === sYPad0) { when(state === sYPad0) {
printf("[TensorLoad] [inp] sYPad0\n") printf("[TensorLoad] [inp] sYPad0\n")
...@@ -292,14 +290,14 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false)( ...@@ -292,14 +290,14 @@ class TensorLoad(tensorType: String = "none", debug: Boolean = false)(
} else if (tensorType == "wgt") { } else if (tensorType == "wgt") {
when(io.vme_rd.cmd.fire()) { when(io.vme_rd.cmd.fire()) {
printf("[TensorLoad] [wgt] cmd addr:%x len:%x\n", printf("[TensorLoad] [wgt] cmd addr:%x len:%x\n",
dataCtrl.io.addr, dataCtrl.io.addr,
dataCtrl.io.len) dataCtrl.io.len)
} }
} else if (tensorType == "acc") { } else if (tensorType == "acc") {
when(io.vme_rd.cmd.fire()) { when(io.vme_rd.cmd.fire()) {
printf("[TensorLoad] [acc] cmd addr:%x len:%x\n", printf("[TensorLoad] [acc] cmd addr:%x len:%x\n",
dataCtrl.io.addr, dataCtrl.io.addr,
dataCtrl.io.len) dataCtrl.io.len)
} }
} }
} }
......
...@@ -25,9 +25,9 @@ import vta.util.config._ ...@@ -25,9 +25,9 @@ import vta.util.config._
import vta.shell._ import vta.shell._
/** TensorStore. /** TensorStore.
* *
* Store 1D and 2D tensors from out-scratchpad (SRAM) to main memory (DRAM). * Store 1D and 2D tensors from out-scratchpad (SRAM) to main memory (DRAM).
*/ */
class TensorStore(tensorType: String = "none", debug: Boolean = false)( class TensorStore(tensorType: String = "none", debug: Boolean = false)(
implicit p: Parameters) implicit p: Parameters)
extends Module { extends Module {
...@@ -112,15 +112,14 @@ class TensorStore(tensorType: String = "none", debug: Boolean = false)( ...@@ -112,15 +112,14 @@ class TensorStore(tensorType: String = "none", debug: Boolean = false)(
} }
} }
}.elsewhen(xrem < xmax) { }.elsewhen(xrem < xmax) {
state := sWriteCmd state := sWriteCmd
xlen := xrem xlen := xrem
xrem := 0.U xrem := 0.U
} }.otherwise {
.otherwise { state := sWriteCmd
state := sWriteCmd xlen := xmax - 1.U
xlen := xmax - 1.U xrem := xrem - xmax
xrem := xrem - xmax }
}
} }
} }
} }
...@@ -176,13 +175,12 @@ class TensorStore(tensorType: String = "none", debug: Boolean = false)( ...@@ -176,13 +175,12 @@ class TensorStore(tensorType: String = "none", debug: Boolean = false)(
raddr_cur := dec.sram_offset raddr_cur := dec.sram_offset
raddr_nxt := dec.sram_offset raddr_nxt := dec.sram_offset
}.elsewhen(io.vme_wr.data }.elsewhen(io.vme_wr.data
.fire() && set === (tensorLength - 1).U && tag === (numMemBlock - 1).U) { .fire() && set === (tensorLength - 1).U && tag === (numMemBlock - 1).U) {
raddr_cur := raddr_cur + 1.U raddr_cur := raddr_cur + 1.U
} }.elsewhen(stride) {
.elsewhen(stride) { raddr_cur := raddr_nxt + dec.xsize
raddr_cur := raddr_nxt + dec.xsize raddr_nxt := raddr_nxt + dec.xsize
raddr_nxt := raddr_nxt + dec.xsize }
}
val tread = Seq.tabulate(tensorLength) { i => val tread = Seq.tabulate(tensorLength) { i =>
i.U -> i.U ->
...@@ -199,14 +197,11 @@ class TensorStore(tensorType: String = "none", debug: Boolean = false)( ...@@ -199,14 +197,11 @@ class TensorStore(tensorType: String = "none", debug: Boolean = false)(
waddr_nxt := io.baddr | (maskOffset & (dec.dram_offset << log2Ceil( waddr_nxt := io.baddr | (maskOffset & (dec.dram_offset << log2Ceil(
elemBytes))) elemBytes)))
}.elsewhen(state === sWriteAck && io.vme_wr.ack && xrem =/= 0.U) { }.elsewhen(state === sWriteAck && io.vme_wr.ack && xrem =/= 0.U) {
waddr_cur := waddr_cur + xmax_bytes waddr_cur := waddr_cur + xmax_bytes
} }.elsewhen(stride) {
.elsewhen(stride) { waddr_cur := waddr_nxt + (dec.xstride << log2Ceil(tensorLength * tensorWidth))
waddr_cur := waddr_nxt + (dec.xstride << log2Ceil( waddr_nxt := waddr_nxt + (dec.xstride << log2Ceil(tensorLength * tensorWidth))
tensorLength * tensorWidth)) }
waddr_nxt := waddr_nxt + (dec.xstride << log2Ceil(
tensorLength * tensorWidth))
}
io.vme_wr.cmd.valid := state === sWriteCmd io.vme_wr.cmd.valid := state === sWriteCmd
io.vme_wr.cmd.bits.addr := waddr_cur io.vme_wr.cmd.bits.addr := waddr_cur
...@@ -231,12 +226,7 @@ class TensorStore(tensorType: String = "none", debug: Boolean = false)( ...@@ -231,12 +226,7 @@ class TensorStore(tensorType: String = "none", debug: Boolean = false)(
if (debug) { if (debug) {
when(io.vme_wr.cmd.fire()) { when(io.vme_wr.cmd.fire()) {
printf("[TensorStore] ysize:%x ycnt:%x raddr:%x waddr:%x len:%x rem:%x\n", printf("[TensorStore] ysize:%x ycnt:%x raddr:%x waddr:%x len:%x rem:%x\n",
ysize, ysize, ycnt, raddr_cur, waddr_cur, xlen, xrem)
ycnt,
raddr_cur,
waddr_cur,
xlen,
xrem)
} }
when(io.vme_wr.data.fire()) { when(io.vme_wr.data.fire()) {
printf("[TensorStore] data:%x\n", io.vme_wr.data.bits) printf("[TensorStore] data:%x\n", io.vme_wr.data.bits)
......
...@@ -25,19 +25,18 @@ import vta.util.config._ ...@@ -25,19 +25,18 @@ import vta.util.config._
import vta.shell._ import vta.shell._
/** TensorParams. /** TensorParams.
* *
* This Bundle derives parameters for each tensorType, including inputs (inp), * This Bundle derives parameters for each tensorType, including inputs (inp),
* weights (wgt), biases (acc), and outputs (out). This is used to avoid * weights (wgt), biases (acc), and outputs (out). This is used to avoid
* doing the same boring calculations over and over again. * doing the same boring calculations over and over again.
*/ */
class TensorParams(tensorType: String = "none")(implicit p: Parameters) class TensorParams(tensorType: String = "none")(implicit p: Parameters) extends Bundle {
extends Bundle {
val errorMsg = val errorMsg =
s"\n\n[VTA] [TensorParams] only inp, wgt, acc, and out supported\n\n" s"\n\n[VTA] [TensorParams] only inp, wgt, acc, and out supported\n\n"
require(tensorType == "inp" || tensorType == "wgt" require(tensorType == "inp" || tensorType == "wgt"
|| tensorType == "acc" || tensorType == "out", || tensorType == "acc" || tensorType == "out",
errorMsg) errorMsg)
val (tensorLength, tensorWidth, tensorElemBits) = val (tensorLength, tensorWidth, tensorElemBits) =
if (tensorType == "inp") if (tensorType == "inp")
...@@ -66,14 +65,14 @@ class TensorParams(tensorType: String = "none")(implicit p: Parameters) ...@@ -66,14 +65,14 @@ class TensorParams(tensorType: String = "none")(implicit p: Parameters)
} }
/** TensorMaster. /** TensorMaster.
* *
* This interface issue read and write tensor-requests to scratchpads. For example, * This interface issue read and write tensor-requests to scratchpads. For example,
* The TensorGemm unit uses this interface for managing the inputs (inp), weights (wgt), * The TensorGemm unit uses this interface for managing the inputs (inp), weights (wgt),
* biases (acc), and outputs (out). * biases (acc), and outputs (out).
* *
*/ */
class TensorMaster(tensorType: String = "none")(implicit p: Parameters) class TensorMaster(tensorType: String = "none")
extends TensorParams(tensorType) { (implicit p: Parameters) extends TensorParams(tensorType) {
val rd = new Bundle { val rd = new Bundle {
val idx = ValidIO(UInt(memAddrBits.W)) val idx = ValidIO(UInt(memAddrBits.W))
val data = Flipped( val data = Flipped(
...@@ -101,13 +100,13 @@ class TensorMaster(tensorType: String = "none")(implicit p: Parameters) ...@@ -101,13 +100,13 @@ class TensorMaster(tensorType: String = "none")(implicit p: Parameters)
} }
/** TensorClient. /** TensorClient.
* *
* This interface receives read and write tensor-requests to scratchpads. For example, * This interface receives read and write tensor-requests to scratchpads. For example,
* The TensorLoad unit uses this interface for receiving read and write requests from * The TensorLoad unit uses this interface for receiving read and write requests from
* the TensorGemm unit. * the TensorGemm unit.
*/ */
class TensorClient(tensorType: String = "none")(implicit p: Parameters) class TensorClient(tensorType: String = "none")
extends TensorParams(tensorType) { (implicit p: Parameters) extends TensorParams(tensorType) {
val rd = new Bundle { val rd = new Bundle {
val idx = Flipped(ValidIO(UInt(memAddrBits.W))) val idx = Flipped(ValidIO(UInt(memAddrBits.W)))
val data = ValidIO( val data = ValidIO(
...@@ -130,13 +129,13 @@ class TensorClient(tensorType: String = "none")(implicit p: Parameters) ...@@ -130,13 +129,13 @@ class TensorClient(tensorType: String = "none")(implicit p: Parameters)
} }
/** TensorMasterData. /** TensorMasterData.
* *
* This interface is only used for datapath only purposes and the direction convention * This interface is only used for datapath only purposes and the direction convention
* is based on the TensorMaster interface, which means this is an input. This interface * is based on the TensorMaster interface, which means this is an input. This interface
* is used on datapath only module such MatrixVectorCore or AluVector. * is used on datapath only module such MatrixVectorCore or AluVector.
*/ */
class TensorMasterData(tensorType: String = "none")(implicit p: Parameters) class TensorMasterData(tensorType: String = "none")
extends TensorParams(tensorType) { (implicit p: Parameters) extends TensorParams(tensorType) {
val data = Flipped( val data = Flipped(
ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W))))) ValidIO(Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))))
override def cloneType = override def cloneType =
...@@ -144,13 +143,13 @@ class TensorMasterData(tensorType: String = "none")(implicit p: Parameters) ...@@ -144,13 +143,13 @@ class TensorMasterData(tensorType: String = "none")(implicit p: Parameters)
} }
/** TensorClientData. /** TensorClientData.
* *
* This interface is only used for datapath only purposes and the direction convention * This interface is only used for datapath only purposes and the direction convention
* is based on the TensorClient interface, which means this is an output. This interface * is based on the TensorClient interface, which means this is an output. This interface
* is used on datapath only module such MatrixVectorCore or AluVector. * is used on datapath only module such MatrixVectorCore or AluVector.
*/ */
class TensorClientData(tensorType: String = "none")(implicit p: Parameters) class TensorClientData(tensorType: String = "none")
extends TensorParams(tensorType) { (implicit p: Parameters) extends TensorParams(tensorType) {
val data = ValidIO( val data = ValidIO(
Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W)))) Vec(tensorLength, Vec(tensorWidth, UInt(tensorElemBits.W))))
override def cloneType = override def cloneType =
...@@ -158,13 +157,12 @@ class TensorClientData(tensorType: String = "none")(implicit p: Parameters) ...@@ -158,13 +157,12 @@ class TensorClientData(tensorType: String = "none")(implicit p: Parameters)
} }
/** TensorPadCtrl. Zero-padding controller for TensorLoad. */ /** TensorPadCtrl. Zero-padding controller for TensorLoad. */
class TensorPadCtrl(padType: String = "none", sizeFactor: Int = 1) class TensorPadCtrl(padType: String = "none", sizeFactor: Int = 1) extends Module {
extends Module {
val errorMsg = val errorMsg =
s"\n\n\n[VTA-ERROR] only YPad0, YPad1, XPad0, or XPad1 supported\n\n\n" s"\n\n\n[VTA-ERROR] only YPad0, YPad1, XPad0, or XPad1 supported\n\n\n"
require(padType == "YPad0" || padType == "YPad1" require(padType == "YPad0" || padType == "YPad1"
|| padType == "XPad0" || padType == "XPad1", || padType == "XPad0" || padType == "XPad1",
errorMsg) errorMsg)
val io = IO(new Bundle { val io = IO(new Bundle {
val start = Input(Bool()) val start = Input(Bool())
...@@ -233,9 +231,7 @@ class TensorPadCtrl(padType: String = "none", sizeFactor: Int = 1) ...@@ -233,9 +231,7 @@ class TensorPadCtrl(padType: String = "none", sizeFactor: Int = 1)
/** TensorDataCtrl. Data controller for TensorLoad. */ /** TensorDataCtrl. Data controller for TensorLoad. */
class TensorDataCtrl(tensorType: String = "none", class TensorDataCtrl(tensorType: String = "none",
sizeFactor: Int = 1, sizeFactor: Int = 1, strideFactor: Int = 1)(implicit p: Parameters) extends Module {
strideFactor: Int = 1)(implicit p: Parameters)
extends Module {
val mp = p(ShellKey).memParams val mp = p(ShellKey).memParams
val io = IO(new Bundle { val io = IO(new Bundle {
val start = Input(Bool()) val start = Input(Bool())
......
...@@ -32,9 +32,9 @@ trait VTAHostDPIParams { ...@@ -32,9 +32,9 @@ trait VTAHostDPIParams {
} }
/** Host master interface. /** Host master interface.
* *
* This interface is tipically used by the Host * This interface is tipically used by the Host
*/ */
class VTAHostDPIMaster extends Bundle with VTAHostDPIParams { class VTAHostDPIMaster extends Bundle with VTAHostDPIParams {
val req = new Bundle { val req = new Bundle {
val valid = Output(Bool()) val valid = Output(Bool())
...@@ -47,9 +47,9 @@ class VTAHostDPIMaster extends Bundle with VTAHostDPIParams { ...@@ -47,9 +47,9 @@ class VTAHostDPIMaster extends Bundle with VTAHostDPIParams {
} }
/** Host client interface. /** Host client interface.
* *
* This interface is tipically used by the Accelerator * This interface is tipically used by the Accelerator
*/ */
class VTAHostDPIClient extends Bundle with VTAHostDPIParams { class VTAHostDPIClient extends Bundle with VTAHostDPIParams {
val req = new Bundle { val req = new Bundle {
val valid = Input(Bool()) val valid = Input(Bool())
...@@ -62,9 +62,9 @@ class VTAHostDPIClient extends Bundle with VTAHostDPIParams { ...@@ -62,9 +62,9 @@ class VTAHostDPIClient extends Bundle with VTAHostDPIParams {
} }
/** Host DPI module. /** Host DPI module.
* *
* Wrapper for Host Verilog DPI module. * Wrapper for Host Verilog DPI module.
*/ */
class VTAHostDPI extends BlackBox with HasBlackBoxResource { class VTAHostDPI extends BlackBox with HasBlackBoxResource {
val io = IO(new Bundle { val io = IO(new Bundle {
val clock = Input(Clock()) val clock = Input(Clock())
...@@ -75,11 +75,10 @@ class VTAHostDPI extends BlackBox with HasBlackBoxResource { ...@@ -75,11 +75,10 @@ class VTAHostDPI extends BlackBox with HasBlackBoxResource {
} }
/** Host DPI to AXI Converter. /** Host DPI to AXI Converter.
* *
* Convert Host DPI to AXI for VTAShell * Convert Host DPI to AXI for VTAShell
*/ */
class VTAHostDPIToAXI(debug: Boolean = false)(implicit p: Parameters) class VTAHostDPIToAXI(debug: Boolean = false)(implicit p: Parameters) extends Module {
extends Module {
val io = IO(new Bundle { val io = IO(new Bundle {
val dpi = new VTAHostDPIClient val dpi = new VTAHostDPIClient
val axi = new AXILiteMaster(p(ShellKey).hostParams) val axi = new AXILiteMaster(p(ShellKey).hostParams)
......
...@@ -33,9 +33,9 @@ trait VTAMemDPIParams { ...@@ -33,9 +33,9 @@ trait VTAMemDPIParams {
} }
/** Memory master interface. /** Memory master interface.
* *
* This interface is tipically used by the Accelerator * This interface is tipically used by the Accelerator
*/ */
class VTAMemDPIMaster extends Bundle with VTAMemDPIParams { class VTAMemDPIMaster extends Bundle with VTAMemDPIParams {
val req = new Bundle { val req = new Bundle {
val valid = Output(Bool()) val valid = Output(Bool())
...@@ -48,9 +48,9 @@ class VTAMemDPIMaster extends Bundle with VTAMemDPIParams { ...@@ -48,9 +48,9 @@ class VTAMemDPIMaster extends Bundle with VTAMemDPIParams {
} }
/** Memory client interface. /** Memory client interface.
* *
* This interface is tipically used by the Host * This interface is tipically used by the Host
*/ */
class VTAMemDPIClient extends Bundle with VTAMemDPIParams { class VTAMemDPIClient extends Bundle with VTAMemDPIParams {
val req = new Bundle { val req = new Bundle {
val valid = Input(Bool()) val valid = Input(Bool())
...@@ -63,9 +63,9 @@ class VTAMemDPIClient extends Bundle with VTAMemDPIParams { ...@@ -63,9 +63,9 @@ class VTAMemDPIClient extends Bundle with VTAMemDPIParams {
} }
/** Memory DPI module. /** Memory DPI module.
* *
* Wrapper for Memory Verilog DPI module. * Wrapper for Memory Verilog DPI module.
*/ */
class VTAMemDPI extends BlackBox with HasBlackBoxResource { class VTAMemDPI extends BlackBox with HasBlackBoxResource {
val io = IO(new Bundle { val io = IO(new Bundle {
val clock = Input(Clock()) val clock = Input(Clock())
...@@ -75,8 +75,7 @@ class VTAMemDPI extends BlackBox with HasBlackBoxResource { ...@@ -75,8 +75,7 @@ class VTAMemDPI extends BlackBox with HasBlackBoxResource {
setResource("/verilog/VTAMemDPI.v") setResource("/verilog/VTAMemDPI.v")
} }
class VTAMemDPIToAXI(debug: Boolean = false)(implicit p: Parameters) class VTAMemDPIToAXI(debug: Boolean = false)(implicit p: Parameters) extends Module {
extends Module {
val io = IO(new Bundle { val io = IO(new Bundle {
val dpi = new VTAMemDPIMaster val dpi = new VTAMemDPIMaster
val axi = new AXIClient(p(ShellKey).memParams) val axi = new AXIClient(p(ShellKey).memParams)
...@@ -173,13 +172,13 @@ class VTAMemDPIToAXI(debug: Boolean = false)(implicit p: Parameters) ...@@ -173,13 +172,13 @@ class VTAMemDPIToAXI(debug: Boolean = false)(implicit p: Parameters)
} }
when(io.axi.r.fire()) { when(io.axi.r.fire()) {
printf("[VTAMemDPIToAXI] [R] last:%x data:%x\n", printf("[VTAMemDPIToAXI] [R] last:%x data:%x\n",
io.axi.r.bits.last, io.axi.r.bits.last,
io.axi.r.bits.data) io.axi.r.bits.data)
} }
when(io.axi.w.fire()) { when(io.axi.w.fire()) {
printf("[VTAMemDPIToAXI] [W] last:%x data:%x\n", printf("[VTAMemDPIToAXI] [W] last:%x data:%x\n",
io.axi.w.bits.last, io.axi.w.bits.last,
io.axi.w.bits.data) io.axi.w.bits.data)
} }
} }
} }
...@@ -26,9 +26,9 @@ import vta.interface.axi._ ...@@ -26,9 +26,9 @@ import vta.interface.axi._
import vta.shell._ import vta.shell._
/** Sim DPI module. /** Sim DPI module.
* *
* Wrapper for Sim Verilog DPI module. * Wrapper for Sim Verilog DPI module.
*/ */
class VTASimDPI extends BlackBox with HasBlackBoxResource { class VTASimDPI extends BlackBox with HasBlackBoxResource {
val io = IO(new Bundle { val io = IO(new Bundle {
val clock = Input(Clock()) val clock = Input(Clock())
......
...@@ -55,7 +55,7 @@ case class AXIParams( ...@@ -55,7 +55,7 @@ case class AXIParams(
} }
abstract class AXIBase(params: AXIParams) abstract class AXIBase(params: AXIParams)
extends GenericParameterizedBundle(params) extends GenericParameterizedBundle(params)
// AXILite // AXILite
......
...@@ -25,59 +25,56 @@ import vta.util.config._ ...@@ -25,59 +25,56 @@ import vta.util.config._
import vta.interface.axi._ import vta.interface.axi._
/** PynqConfig. Shell configuration for Pynq */ /** PynqConfig. Shell configuration for Pynq */
class PynqConfig class PynqConfig extends Config((site, here, up) => {
extends Config((site, here, up) => { case ShellKey =>
case ShellKey => ShellParams(
ShellParams( hostParams = AXIParams(coherent = false,
hostParams = AXIParams(coherent = false, addrBits = 16,
addrBits = 16, dataBits = 32,
dataBits = 32, lenBits = 8,
lenBits = 8, userBits = 1),
userBits = 1), memParams = AXIParams(coherent = true,
memParams = AXIParams(coherent = true, addrBits = 32,
addrBits = 32, dataBits = 64,
dataBits = 64, lenBits = 8,
lenBits = 8, userBits = 1),
userBits = 1), vcrParams = VCRParams(),
vcrParams = VCRParams(), vmeParams = VMEParams()
vmeParams = VMEParams() )
) })
})
/** F1Config. Shell configuration for F1 */ /** F1Config. Shell configuration for F1 */
class F1Config class F1Config extends Config((site, here, up) => {
extends Config((site, here, up) => { case ShellKey =>
case ShellKey => ShellParams(
ShellParams( hostParams = AXIParams(coherent = false,
hostParams = AXIParams(coherent = false, addrBits = 16,
addrBits = 16, dataBits = 32,
dataBits = 32, lenBits = 8,
lenBits = 8, userBits = 1),
userBits = 1), memParams = AXIParams(coherent = false,
memParams = AXIParams(coherent = false, addrBits = 64,
addrBits = 64, dataBits = 64,
dataBits = 64, lenBits = 8,
lenBits = 8, userBits = 1),
userBits = 1), vcrParams = VCRParams(),
vcrParams = VCRParams(), vmeParams = VMEParams()
vmeParams = VMEParams() )
) })
})
/** De10Config. Shell configuration for De10 */ /** De10Config. Shell configuration for De10 */
class De10Config class De10Config extends Config((site, here, up) => {
extends Config((site, here, up) => { case ShellKey =>
case ShellKey => ShellParams(
ShellParams( hostParams =
hostParams = AXIParams(addrBits = 16, dataBits = 32, idBits = 13, lenBits = 4),
AXIParams(addrBits = 16, dataBits = 32, idBits = 13, lenBits = 4), memParams = AXIParams(
memParams = AXIParams( addrBits = 32,
addrBits = 32, dataBits = 64,
dataBits = 64, userBits = 5,
userBits = 5, lenBits = 4, // limit to 16 beats, instead of 256 beats in AXI4
lenBits = 4, // limit to 16 beats, instead of 256 beats in AXI4 coherent = true),
coherent = true), vcrParams = VCRParams(),
vcrParams = VCRParams(), vmeParams = VMEParams()
vmeParams = VMEParams() )
) })
})
...@@ -25,10 +25,10 @@ import vta.interface.axi._ ...@@ -25,10 +25,10 @@ import vta.interface.axi._
import vta.core._ import vta.core._
/** IntelShell. /** IntelShell.
* *
* The IntelShell is based on a VME, VCR and core. This creates a complete VTA * The IntelShell is based on a VME, VCR and core. This creates a complete VTA
* system that can be used for simulation or real hardware. * system that can be used for simulation or real hardware.
*/ */
class IntelShell(implicit p: Parameters) extends Module { class IntelShell(implicit p: Parameters) extends Module {
val io = IO(new Bundle { val io = IO(new Bundle {
val host = new AXIClient(p(ShellKey).hostParams) val host = new AXIClient(p(ShellKey).hostParams)
......
...@@ -27,11 +27,11 @@ import vta.shell._ ...@@ -27,11 +27,11 @@ import vta.shell._
import vta.dpi._ import vta.dpi._
/** VTAHost. /** VTAHost.
* *
* This module translate the DPI protocol into AXI. This is a simulation only * This module translate the DPI protocol into AXI. This is a simulation only
* module and used to test host-to-VTA communication. This module should be updated * module and used to test host-to-VTA communication. This module should be updated
* for testing hosts using a different bus protocol, other than AXI. * for testing hosts using a different bus protocol, other than AXI.
*/ */
class VTAHost(implicit p: Parameters) extends Module { class VTAHost(implicit p: Parameters) extends Module {
val io = IO(new Bundle { val io = IO(new Bundle {
val axi = new AXILiteMaster(p(ShellKey).hostParams) val axi = new AXILiteMaster(p(ShellKey).hostParams)
...@@ -45,11 +45,11 @@ class VTAHost(implicit p: Parameters) extends Module { ...@@ -45,11 +45,11 @@ class VTAHost(implicit p: Parameters) extends Module {
} }
/** VTAMem. /** VTAMem.
* *
* This module translate the DPI protocol into AXI. This is a simulation only * This module translate the DPI protocol into AXI. This is a simulation only
* module and used to test VTA-to-memory communication. This module should be updated * module and used to test VTA-to-memory communication. This module should be updated
* for testing memories using a different bus protocol, other than AXI. * for testing memories using a different bus protocol, other than AXI.
*/ */
class VTAMem(implicit p: Parameters) extends Module { class VTAMem(implicit p: Parameters) extends Module {
val io = IO(new Bundle { val io = IO(new Bundle {
val axi = new AXIClient(p(ShellKey).memParams) val axi = new AXIClient(p(ShellKey).memParams)
...@@ -63,12 +63,12 @@ class VTAMem(implicit p: Parameters) extends Module { ...@@ -63,12 +63,12 @@ class VTAMem(implicit p: Parameters) extends Module {
} }
/** VTASim. /** VTASim.
* *
* This module is used to handle hardware simulation thread, such as halting * This module is used to handle hardware simulation thread, such as halting
* or terminating the simulation thread. The sim_wait port is used to halt * or terminating the simulation thread. The sim_wait port is used to halt
* the simulation thread when it is asserted and resume it when it is * the simulation thread when it is asserted and resume it when it is
* de-asserted. * de-asserted.
*/ */
class VTASim(implicit p: Parameters) extends MultiIOModule { class VTASim(implicit p: Parameters) extends MultiIOModule {
val sim_wait = IO(Output(Bool())) val sim_wait = IO(Output(Bool()))
val sim = Module(new VTASimDPI) val sim = Module(new VTASimDPI)
...@@ -78,11 +78,11 @@ class VTASim(implicit p: Parameters) extends MultiIOModule { ...@@ -78,11 +78,11 @@ class VTASim(implicit p: Parameters) extends MultiIOModule {
} }
/** SimShell. /** SimShell.
* *
* The simulation shell instantiate the sim, host and memory DPI modules that * The simulation shell instantiate the sim, host and memory DPI modules that
* are connected to the VTAShell. An extra clock, sim_clock, is used to eval * are connected to the VTAShell. An extra clock, sim_clock, is used to eval
* the VTASim DPI function when the main simulation clock is on halt state. * the VTASim DPI function when the main simulation clock is on halt state.
*/ */
class SimShell(implicit p: Parameters) extends MultiIOModule { class SimShell(implicit p: Parameters) extends MultiIOModule {
val mem = IO(new AXIClient(p(ShellKey).memParams)) val mem = IO(new AXIClient(p(ShellKey).memParams))
val host = IO(new AXILiteMaster(p(ShellKey).hostParams)) val host = IO(new AXILiteMaster(p(ShellKey).hostParams))
......
...@@ -26,9 +26,9 @@ import vta.util.genericbundle._ ...@@ -26,9 +26,9 @@ import vta.util.genericbundle._
import vta.interface.axi._ import vta.interface.axi._
/** VCR parameters. /** VCR parameters.
* *
* These parameters are used on VCR interfaces and modules. * These parameters are used on VCR interfaces and modules.
*/ */
case class VCRParams() { case class VCRParams() {
val nCtrl = 1 val nCtrl = 1
val nECnt = 1 val nECnt = 1
...@@ -38,14 +38,13 @@ case class VCRParams() { ...@@ -38,14 +38,13 @@ case class VCRParams() {
} }
/** VCRBase. Parametrize base class. */ /** VCRBase. Parametrize base class. */
abstract class VCRBase(implicit p: Parameters) abstract class VCRBase(implicit p: Parameters) extends GenericParameterizedBundle(p)
extends GenericParameterizedBundle(p)
/** VCRMaster. /** VCRMaster.
* *
* This is the master interface used by VCR in the VTAShell to control * This is the master interface used by VCR in the VTAShell to control
* the Core unit. * the Core unit.
*/ */
class VCRMaster(implicit p: Parameters) extends VCRBase { class VCRMaster(implicit p: Parameters) extends VCRBase {
val vp = p(ShellKey).vcrParams val vp = p(ShellKey).vcrParams
val mp = p(ShellKey).memParams val mp = p(ShellKey).memParams
...@@ -57,10 +56,10 @@ class VCRMaster(implicit p: Parameters) extends VCRBase { ...@@ -57,10 +56,10 @@ class VCRMaster(implicit p: Parameters) extends VCRBase {
} }
/** VCRClient. /** VCRClient.
* *
* This is the client interface used by the Core module to communicate * This is the client interface used by the Core module to communicate
* to the VCR in the VTAShell. * to the VCR in the VTAShell.
*/ */
class VCRClient(implicit p: Parameters) extends VCRBase { class VCRClient(implicit p: Parameters) extends VCRBase {
val vp = p(ShellKey).vcrParams val vp = p(ShellKey).vcrParams
val mp = p(ShellKey).memParams val mp = p(ShellKey).memParams
...@@ -72,12 +71,12 @@ class VCRClient(implicit p: Parameters) extends VCRBase { ...@@ -72,12 +71,12 @@ class VCRClient(implicit p: Parameters) extends VCRBase {
} }
/** VTA Control Registers (VCR). /** VTA Control Registers (VCR).
* *
* This unit provides control registers (32 and 64 bits) to be used by a control' * This unit provides control registers (32 and 64 bits) to be used by a control'
* unit, typically a host processor. These registers are read-only by the core * unit, typically a host processor. These registers are read-only by the core
* at the moment but this will likely change once we add support to general purpose * at the moment but this will likely change once we add support to general purpose
* registers that could be used as event counters by the Core unit. * registers that could be used as event counters by the Core unit.
*/ */
class VCR(implicit p: Parameters) extends Module { class VCR(implicit p: Parameters) extends Module {
val io = IO(new Bundle { val io = IO(new Bundle {
val host = new AXILiteClient(p(ShellKey).hostParams) val host = new AXILiteClient(p(ShellKey).hostParams)
......
...@@ -26,27 +26,26 @@ import vta.util.genericbundle._ ...@@ -26,27 +26,26 @@ import vta.util.genericbundle._
import vta.interface.axi._ import vta.interface.axi._
/** VME parameters. /** VME parameters.
* *
* These parameters are used on VME interfaces and modules. * These parameters are used on VME interfaces and modules.
*/ */
case class VMEParams() { case class VMEParams() {
val nReadClients: Int = 5 val nReadClients: Int = 5
val nWriteClients: Int = 1 val nWriteClients: Int = 1
require(nReadClients > 0, require(nReadClients > 0,
s"\n\n[VTA] [VMEParams] nReadClients must be larger than 0\n\n") s"\n\n[VTA] [VMEParams] nReadClients must be larger than 0\n\n")
require( require(
nWriteClients == 1, nWriteClients == 1,
s"\n\n[VTA] [VMEParams] nWriteClients must be 1, only one-write-client support atm\n\n") s"\n\n[VTA] [VMEParams] nWriteClients must be 1, only one-write-client support atm\n\n")
} }
/** VMEBase. Parametrize base class. */ /** VMEBase. Parametrize base class. */
abstract class VMEBase(implicit p: Parameters) abstract class VMEBase(implicit p: Parameters) extends GenericParameterizedBundle(p)
extends GenericParameterizedBundle(p)
/** VMECmd. /** VMECmd.
* *
* This interface is used for creating write and read requests to memory. * This interface is used for creating write and read requests to memory.
*/ */
class VMECmd(implicit p: Parameters) extends VMEBase { class VMECmd(implicit p: Parameters) extends VMEBase {
val addrBits = p(ShellKey).memParams.addrBits val addrBits = p(ShellKey).memParams.addrBits
val lenBits = p(ShellKey).memParams.lenBits val lenBits = p(ShellKey).memParams.lenBits
...@@ -55,10 +54,10 @@ class VMECmd(implicit p: Parameters) extends VMEBase { ...@@ -55,10 +54,10 @@ class VMECmd(implicit p: Parameters) extends VMEBase {
} }
/** VMEReadMaster. /** VMEReadMaster.
* *
* This interface is used by modules inside the core to generate read requests * This interface is used by modules inside the core to generate read requests
* and receive responses from VME. * and receive responses from VME.
*/ */
class VMEReadMaster(implicit p: Parameters) extends Bundle { class VMEReadMaster(implicit p: Parameters) extends Bundle {
val dataBits = p(ShellKey).memParams.dataBits val dataBits = p(ShellKey).memParams.dataBits
val cmd = Decoupled(new VMECmd) val cmd = Decoupled(new VMECmd)
...@@ -68,10 +67,10 @@ class VMEReadMaster(implicit p: Parameters) extends Bundle { ...@@ -68,10 +67,10 @@ class VMEReadMaster(implicit p: Parameters) extends Bundle {
} }
/** VMEReadClient. /** VMEReadClient.
* *
* This interface is used by the VME to receive read requests and generate * This interface is used by the VME to receive read requests and generate
* responses to modules inside the core. * responses to modules inside the core.
*/ */
class VMEReadClient(implicit p: Parameters) extends Bundle { class VMEReadClient(implicit p: Parameters) extends Bundle {
val dataBits = p(ShellKey).memParams.dataBits val dataBits = p(ShellKey).memParams.dataBits
val cmd = Flipped(Decoupled(new VMECmd)) val cmd = Flipped(Decoupled(new VMECmd))
...@@ -81,10 +80,10 @@ class VMEReadClient(implicit p: Parameters) extends Bundle { ...@@ -81,10 +80,10 @@ class VMEReadClient(implicit p: Parameters) extends Bundle {
} }
/** VMEWriteMaster. /** VMEWriteMaster.
* *
* This interface is used by modules inside the core to generate write requests * This interface is used by modules inside the core to generate write requests
* to the VME. * to the VME.
*/ */
class VMEWriteMaster(implicit p: Parameters) extends Bundle { class VMEWriteMaster(implicit p: Parameters) extends Bundle {
val dataBits = p(ShellKey).memParams.dataBits val dataBits = p(ShellKey).memParams.dataBits
val cmd = Decoupled(new VMECmd) val cmd = Decoupled(new VMECmd)
...@@ -95,10 +94,10 @@ class VMEWriteMaster(implicit p: Parameters) extends Bundle { ...@@ -95,10 +94,10 @@ class VMEWriteMaster(implicit p: Parameters) extends Bundle {
} }
/** VMEWriteClient. /** VMEWriteClient.
* *
* This interface is used by the VME to handle write requests from modules inside * This interface is used by the VME to handle write requests from modules inside
* the core. * the core.
*/ */
class VMEWriteClient(implicit p: Parameters) extends Bundle { class VMEWriteClient(implicit p: Parameters) extends Bundle {
val dataBits = p(ShellKey).memParams.dataBits val dataBits = p(ShellKey).memParams.dataBits
val cmd = Flipped(Decoupled(new VMECmd)) val cmd = Flipped(Decoupled(new VMECmd))
...@@ -109,10 +108,10 @@ class VMEWriteClient(implicit p: Parameters) extends Bundle { ...@@ -109,10 +108,10 @@ class VMEWriteClient(implicit p: Parameters) extends Bundle {
} }
/** VMEMaster. /** VMEMaster.
* *
* Pack nRd number of VMEReadMaster interfaces and nWr number of VMEWriteMaster * Pack nRd number of VMEReadMaster interfaces and nWr number of VMEWriteMaster
* interfaces. * interfaces.
*/ */
class VMEMaster(implicit p: Parameters) extends Bundle { class VMEMaster(implicit p: Parameters) extends Bundle {
val nRd = p(ShellKey).vmeParams.nReadClients val nRd = p(ShellKey).vmeParams.nReadClients
val nWr = p(ShellKey).vmeParams.nWriteClients val nWr = p(ShellKey).vmeParams.nWriteClients
...@@ -121,10 +120,10 @@ class VMEMaster(implicit p: Parameters) extends Bundle { ...@@ -121,10 +120,10 @@ class VMEMaster(implicit p: Parameters) extends Bundle {
} }
/** VMEClient. /** VMEClient.
* *
* Pack nRd number of VMEReadClient interfaces and nWr number of VMEWriteClient * Pack nRd number of VMEReadClient interfaces and nWr number of VMEWriteClient
* interfaces. * interfaces.
*/ */
class VMEClient(implicit p: Parameters) extends Bundle { class VMEClient(implicit p: Parameters) extends Bundle {
val nRd = p(ShellKey).vmeParams.nReadClients val nRd = p(ShellKey).vmeParams.nReadClients
val nWr = p(ShellKey).vmeParams.nWriteClients val nWr = p(ShellKey).vmeParams.nWriteClients
...@@ -133,10 +132,10 @@ class VMEClient(implicit p: Parameters) extends Bundle { ...@@ -133,10 +132,10 @@ class VMEClient(implicit p: Parameters) extends Bundle {
} }
/** VTA Memory Engine (VME). /** VTA Memory Engine (VME).
* *
* This unit multiplexes the memory controller interface for the Core. Currently, * This unit multiplexes the memory controller interface for the Core. Currently,
* it supports single-writer and multiple-reader mode and it is also based on AXI. * it supports single-writer and multiple-reader mode and it is also based on AXI.
*/ */
class VME(implicit p: Parameters) extends Module { class VME(implicit p: Parameters) extends Module {
val io = IO(new Bundle { val io = IO(new Bundle {
val mem = new AXIMaster(p(ShellKey).memParams) val mem = new AXIMaster(p(ShellKey).memParams)
......
...@@ -35,10 +35,10 @@ case class ShellParams( ...@@ -35,10 +35,10 @@ case class ShellParams(
case object ShellKey extends Field[ShellParams] case object ShellKey extends Field[ShellParams]
/** VTAShell. /** VTAShell.
* *
* The VTAShell is based on a VME, VCR and core. This creates a complete VTA * The VTAShell is based on a VME, VCR and core. This creates a complete VTA
* system that can be used for simulation or real hardware. * system that can be used for simulation or real hardware.
*/ */
class VTAShell(implicit p: Parameters) extends Module { class VTAShell(implicit p: Parameters) extends Module {
val io = IO(new Bundle { val io = IO(new Bundle {
val host = new AXILiteClient(p(ShellKey).hostParams) val host = new AXILiteClient(p(ShellKey).hostParams)
......
...@@ -25,10 +25,10 @@ import vta.util.config._ ...@@ -25,10 +25,10 @@ import vta.util.config._
import vta.interface.axi._ import vta.interface.axi._
/** XilinxShell. /** XilinxShell.
* *
* This is a wrapper shell mostly used to match Xilinx convention naming, * This is a wrapper shell mostly used to match Xilinx convention naming,
* therefore we can pack VTA as an IP for IPI based flows. * therefore we can pack VTA as an IP for IPI based flows.
*/ */
class XilinxShell(implicit p: Parameters) extends RawModule { class XilinxShell(implicit p: Parameters) extends RawModule {
val hp = p(ShellKey).hostParams val hp = p(ShellKey).hostParams
......
...@@ -46,7 +46,7 @@ abstract class Parameters extends View { ...@@ -46,7 +46,7 @@ abstract class Parameters extends View {
new ChainParameters(this, x) new ChainParameters(this, x)
final def alter( final def alter(
f: (View, View, View) => PartialFunction[Any, Any]): Parameters = f: (View, View, View) => PartialFunction[Any, Any]): Parameters =
Parameters(f) ++ this Parameters(f) ++ this
final def alterPartial(f: PartialFunction[Any, Any]): Parameters = final def alterPartial(f: PartialFunction[Any, Any]): Parameters =
...@@ -56,8 +56,8 @@ abstract class Parameters extends View { ...@@ -56,8 +56,8 @@ abstract class Parameters extends View {
new MapParameters(m) ++ this new MapParameters(m) ++ this
protected[config] def chain[T](site: View, protected[config] def chain[T](site: View,
tail: View, tail: View,
pname: Field[T]): Option[T] pname: Field[T]): Option[T]
protected[config] def find[T](pname: Field[T], site: View) = protected[config] def find[T](pname: Field[T], site: View) =
chain(site, new TerminalView, pname) chain(site, new TerminalView, pname)
} }
......
...@@ -23,8 +23,8 @@ package vta.util.genericbundle ...@@ -23,8 +23,8 @@ package vta.util.genericbundle
import chisel3._ import chisel3._
abstract class GenericParameterizedBundle[+T <: Object](val params: T) abstract class GenericParameterizedBundle[+T <: Object]
extends Bundle { (val params: T) extends Bundle {
override def cloneType = { override def cloneType = {
try { try {
this.getClass.getConstructors.head this.getClass.getConstructors.head
......
...@@ -26,11 +26,11 @@ import vta.core._ ...@@ -26,11 +26,11 @@ import vta.core._
import vta.test._ import vta.test._
/** VTA. /** VTA.
* *
* This file contains all the configurations supported by VTA. * This file contains all the configurations supported by VTA.
* These configurations are built in a mix/match form based on core * These configurations are built in a mix/match form based on core
* and shell configurations. * and shell configurations.
*/ */
class DefaultPynqConfig extends Config(new CoreConfig ++ new PynqConfig) class DefaultPynqConfig extends Config(new CoreConfig ++ new PynqConfig)
class DefaultF1Config extends Config(new CoreConfig ++ new F1Config) class DefaultF1Config extends Config(new CoreConfig ++ new F1Config)
class DefaultDe10Config extends Config(new CoreConfig ++ new De10Config) class DefaultDe10Config extends Config(new CoreConfig ++ new De10Config)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment