Commit 7bf2ff23 by Luis Vega Committed by Thierry Moreau

[VTA] add support to event counters (#3347)

* add support to event counters in VTA

* fix comment

* fix event-counter interface parameter

* no longer needed

* add sim back

* add docs to event counters

* fix docs

* add more details about event counting

* make dpi-module docs more accurate
parent c3cc3bf5
......@@ -54,16 +54,13 @@ class Device {
private:
void Launch(uint32_t c, uint32_t length, void* inp, void* out) {
dpi_->Launch(wait_cycles_);
// set counter to zero
dpi_->WriteReg(0x04, 0);
dpi_->WriteReg(0x08, c);
dpi_->WriteReg(0x0c, length);
dpi_->WriteReg(0x10, get_half_addr(inp, false));
dpi_->WriteReg(0x14, get_half_addr(inp, true));
dpi_->WriteReg(0x18, get_half_addr(out, false));
dpi_->WriteReg(0x1c, get_half_addr(out, true));
// launch
dpi_->WriteReg(0x00, 0x1);
dpi_->WriteReg(0x00, 0x1); // launch
}
uint32_t WaitForCompletion() {
......
......@@ -64,6 +64,7 @@ class Core(implicit p: Parameters) extends Module {
val load = Module(new Load)
val compute = Module(new Compute)
val store = Module(new Store)
val ecounters = Module(new EventCounters)
// Read(rd) and write(wr) from/to memory (i.e. DRAM)
io.vme.rd(0) <> fetch.io.vme_rd
......@@ -103,6 +104,11 @@ class Core(implicit p: Parameters) extends Module {
store.io.out_baddr := io.vcr.ptrs(5)
store.io.out <> compute.io.out
// Event counters
ecounters.io.launch := io.vcr.launch
ecounters.io.finish := compute.io.finish
io.vcr.ecnt <> ecounters.io.ecnt
// Finish instruction is executed and asserts the VCR finish flag
val finish = RegNext(compute.io.finish)
io.vcr.finish := finish
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package vta.core
import chisel3._
import chisel3.util._
import vta.util.config._
import vta.shell._
/** EventCounters.
*
* This unit contains all the event counting logic. One common event tracked in
* hardware is the number of clock cycles taken to achieve certain task. We
* can count the total number of clock cycles spent in a VTA run by checking
* launch and finish signals.
*
* The event counter value is passed to the VCR module via the ecnt port, so
* they can be accessed by the host. The number of event counters (nECnt) is
* defined in the Shell VCR module as a parameter, see VCRParams.
*
* If one would like to add an event counter, then the value of nECnt must be
* changed in VCRParams together with the corresponding counting logic here.
*/
class EventCounters(debug: Boolean = false)(implicit p: Parameters) extends Module {
val vp = p(ShellKey).vcrParams
val io = IO(new Bundle{
val launch = Input(Bool())
val finish = Input(Bool())
val ecnt = Vec(vp.nECnt, ValidIO(UInt(vp.regBits.W)))
})
val cycle_cnt = RegInit(0.U(vp.regBits.W))
when (io.launch && !io.finish) {
cycle_cnt := cycle_cnt + 1.U
} .otherwise {
cycle_cnt := 0.U
}
io.ecnt(0).valid := io.finish
io.ecnt(0).bits := cycle_cnt
}
......@@ -23,8 +23,6 @@ import chisel3._
import chisel3.util._
import vta.util.config._
import vta.util.genericbundle._
import scala.collection.mutable.ListBuffer
import scala.collection.mutable.LinkedHashMap
import vta.interface.axi._
/** VCR parameters.
......@@ -33,14 +31,11 @@ import vta.interface.axi._
*/
case class VCRParams()
{
val nValsReg: Int = 1
val nPtrsReg: Int = 6
val regBits: Int = 32
val nCtrlReg: Int = 4
val ctrlBaseAddr: Int = 0
require (nValsReg > 0)
require (nPtrsReg > 0)
val nCtrl = 1
val nECnt = 1
val nVals = 1
val nPtrs = 6
val regBits = 32
}
/** VCRBase. Parametrize base class. */
......@@ -57,9 +52,9 @@ class VCRMaster(implicit p: Parameters) extends VCRBase {
val mp = p(ShellKey).memParams
val launch = Output(Bool())
val finish = Input(Bool())
val irq = Output(Bool())
val ptrs = Output(Vec(vp.nPtrsReg, UInt(mp.addrBits.W)))
val vals = Output(Vec(vp.nValsReg, UInt(vp.regBits.W)))
val ecnt = Vec(vp.nECnt, Flipped(ValidIO(UInt(vp.regBits.W))))
val vals = Output(Vec(vp.nVals, UInt(vp.regBits.W)))
val ptrs = Output(Vec(vp.nPtrs, UInt(mp.addrBits.W)))
}
/** VCRClient.
......@@ -72,9 +67,9 @@ class VCRClient(implicit p: Parameters) extends VCRBase {
val mp = p(ShellKey).memParams
val launch = Input(Bool())
val finish = Output(Bool())
val irq = Input(Bool())
val ptrs = Input(Vec(vp.nPtrsReg, UInt(mp.addrBits.W)))
val vals = Input(Vec(vp.nValsReg, UInt(vp.regBits.W)))
val ecnt = Vec(vp.nECnt, ValidIO(UInt(vp.regBits.W)))
val vals = Input(Vec(vp.nVals, UInt(vp.regBits.W)))
val ptrs = Input(Vec(vp.nPtrs, UInt(mp.addrBits.W)))
}
/** VTA Control Registers (VCR).
......@@ -97,10 +92,23 @@ class VCR(implicit p: Parameters) extends Module {
// Write control (AW, W, B)
val waddr = RegInit("h_ffff".U(hp.addrBits.W)) // init with invalid address
val wdata = io.host.w.bits.data
val wstrb = io.host.w.bits.strb
val wmask = Cat(Fill(8, wstrb(3)), Fill(8, wstrb(2)), Fill(8, wstrb(1)), Fill(8, wstrb(0)))
val sWriteAddress :: sWriteData :: sWriteResponse :: Nil = Enum(3)
val wstate = RegInit(sWriteAddress)
// read control (AR, R)
val sReadAddress :: sReadData :: Nil = Enum(2)
val rstate = RegInit(sReadAddress)
val rdata = RegInit(0.U(vp.regBits.W))
// registers
val nTotal = vp.nCtrl + vp.nECnt + vp.nVals + (2*vp.nPtrs)
val reg = Seq.fill(nTotal)(RegInit(0.U(vp.regBits.W)))
val addr = Seq.tabulate(nTotal)(_ * 4)
val reg_map = (addr zip reg) map { case (a, r) => a.U -> r }
val eo = vp.nCtrl
val vo = eo + vp.nECnt
val po = vo + vp.nVals
switch (wstate) {
is (sWriteAddress) {
when (io.host.aw.valid) {
......@@ -124,11 +132,8 @@ class VCR(implicit p: Parameters) extends Module {
io.host.aw.ready := wstate === sWriteAddress
io.host.w.ready := wstate === sWriteData
io.host.b.valid := wstate === sWriteResponse
io.host.b.bits.resp := "h_0".U
io.host.b.bits.resp := 0.U
// read control (AR, R)
val sReadAddress :: sReadData :: Nil = Enum(2)
val rstate = RegInit(sReadAddress)
switch (rstate) {
is (sReadAddress) {
......@@ -145,98 +150,40 @@ class VCR(implicit p: Parameters) extends Module {
io.host.ar.ready := rstate === sReadAddress
io.host.r.valid := rstate === sReadData
io.host.r.bits.data := rdata
io.host.r.bits.resp := 0.U
val nPtrsReg = vp.nPtrsReg
val nValsReg = vp.nValsReg
val regBits = vp.regBits
val ptrsBits = mp.addrBits
val nCtrlReg = vp.nCtrlReg
val rStride = regBits/8
val pStride = ptrsBits/8
val ctrlBaseAddr = vp.ctrlBaseAddr
val valsBaseAddr = ctrlBaseAddr + nCtrlReg*rStride
val ptrsBaseAddr = valsBaseAddr + nValsReg*rStride
val ctrlAddr = Seq.tabulate(nCtrlReg)(i => i*rStride + ctrlBaseAddr)
val valsAddr = Seq.tabulate(nValsReg)(i => i*rStride + valsBaseAddr)
val ptrsAddr = new ListBuffer[Int]()
for (i <- 0 until nPtrsReg) {
ptrsAddr += i*pStride + ptrsBaseAddr
if (ptrsBits == 64) {
ptrsAddr += i*pStride + rStride + ptrsBaseAddr
}
}
// AP register
val c0 = RegInit(VecInit(Seq.fill(regBits)(false.B)))
// ap start
when (io.host.w.fire() && waddr === ctrlAddr(0).asUInt && wstrb(0) && wdata(0)) {
c0(0) := true.B
} .elsewhen (io.vcr.finish) {
c0(0) := false.B
}
// ap done = finish
when (io.vcr.finish) {
c0(1) := true.B
} .elsewhen (io.host.ar.fire() && io.host.ar.bits.addr === ctrlAddr(0).asUInt) {
c0(1) := false.B
reg(0) := "b_10".U
} .elsewhen (io.host.w.fire() && addr(0).U === waddr) {
reg(0) := wdata
}
val c1 = 0.U
val c2 = 0.U
val c3 = 0.U
val ctrlRegList = List(c0, c1, c2, c3)
io.vcr.launch := c0(0)
// interrupts not supported atm
io.vcr.irq := false.B
// Write pointer and value registers
val pvAddr = valsAddr ++ ptrsAddr
val pvNumReg = if (ptrsBits == 64) nValsReg + nPtrsReg*2 else nValsReg + nPtrsReg
val pvReg = RegInit(VecInit(Seq.fill(pvNumReg)(0.U(regBits.W))))
val pvRegList = new ListBuffer[UInt]()
for (i <- 0 until pvNumReg) {
when (io.host.w.fire() && (waddr === pvAddr(i).U)) {
pvReg(i) := (wdata & wmask) | (pvReg(i) & ~wmask)
for (i <- 0 until vp.nECnt) {
when (io.vcr.ecnt(i).valid) {
reg(eo + i) := io.vcr.ecnt(i).bits
} .elsewhen (io.host.w.fire() && addr(eo + i).U === waddr) {
reg(eo + i) := wdata
}
pvRegList += pvReg(i)
}
for (i <- 0 until nValsReg) {
io.vcr.vals(i) := pvReg(i)
}
for (i <- 0 until nPtrsReg) {
if (ptrsBits == 64) {
io.vcr.ptrs(i) := Cat(pvReg(nValsReg + i*2 + 1), pvReg(nValsReg + i*2))
} else {
io.vcr.ptrs(i) := pvReg(nValsReg + i)
for (i <- 0 until (vp.nVals + (2*vp.nPtrs))) {
when (io.host.w.fire() && addr(vo + i).U === waddr) {
reg(vo + i) := wdata
}
}
// Read pointer and value registers
val mapAddr = ctrlAddr ++ valsAddr ++ ptrsAddr
val mapRegList = ctrlRegList ++ pvRegList
val rdata = RegInit(0.U(regBits.W))
val rmap = LinkedHashMap[Int,UInt]()
val totalReg = mapRegList.length
for (i <- 0 until totalReg) { rmap += mapAddr(i) -> mapRegList(i).asUInt }
when (io.host.ar.fire()) {
rdata := MuxLookup(io.host.ar.bits.addr, 0.U, reg_map)
}
val decodeAddr = rmap map { case (k, _) => k -> (io.host.ar.bits.addr === k.asUInt) }
io.vcr.launch := reg(0)(0)
when (io.host.ar.fire()) {
rdata := Mux1H(for ((k, v) <- rmap) yield decodeAddr(k) -> v)
for (i <- 0 until vp.nVals) {
io.vcr.vals(i) := reg(vo + i)
}
io.host.r.bits.resp := 0.U
io.host.r.bits.data := rdata
for (i <- 0 until vp.nPtrs) {
io.vcr.ptrs(i) := Cat(reg(po + 2*i + 1), reg(po + 2*i))
}
}
......@@ -35,7 +35,7 @@ namespace dpi {
class DPIModuleNode : public tvm::runtime::ModuleNode {
public:
/*!
* \brief Launch accelerator until it finishes or reach max_cycles
* \brief Launch hardware simulation until accelerator finishes or reach max_cycles
* \param max_cycles The maximum of cycles to wait
*/
virtual void Launch(uint64_t max_cycles) = 0;
......@@ -53,7 +53,7 @@ class DPIModuleNode : public tvm::runtime::ModuleNode {
*/
virtual uint32_t ReadReg(int addr) = 0;
/*! \brief Kill or Exit() the accelerator */
/*! \brief Finish hardware simulation */
virtual void Finish() = 0;
static tvm::runtime::Module Load(std::string dll_name);
......
......@@ -74,5 +74,14 @@ def tsim_init(hw_lib):
m = tvm.module.load(lib, "vta-tsim")
f(m)
def tsim_cycles():
"""Get tsim clock cycles
Returns
-------
stats : int
tsim clock cycles
"""
return tvm.get_global_func("tvm.vta.tsim.cycles")()
LIBS = _load_lib()
......@@ -28,6 +28,17 @@ namespace tsim {
using vta::dpi::DPIModuleNode;
using tvm::runtime::Module;
class Profiler {
public:
/*! \brief cycle counter */
uint64_t cycle_count{0};
static Profiler* Global() {
static Profiler inst;
return &inst;
}
};
class DPILoader {
public:
void Init(Module module) {
......@@ -50,6 +61,7 @@ class Device {
public:
Device() {
dpi_ = DPILoader::Global();
prof_ = Profiler::Global();
}
int Run(vta_phy_addr_t insn_phy_addr,
......@@ -89,19 +101,21 @@ class Device {
uint32_t wait_cycles) {
// launch simulation thread
dev_->Launch(wait_cycles);
dev_->WriteReg(0x10, insn_count);
dev_->WriteReg(0x14, insn_phy_addr);
dev_->WriteReg(0x18, insn_phy_addr >> 32);
// set counter to zero
dev_->WriteReg(0x04, 0);
dev_->WriteReg(0x08, insn_count);
dev_->WriteReg(0x0c, insn_phy_addr);
dev_->WriteReg(0x10, insn_phy_addr >> 32);
dev_->WriteReg(0x14, 0);
dev_->WriteReg(0x18, uop_phy_addr >> 32);
dev_->WriteReg(0x1c, 0);
dev_->WriteReg(0x20, uop_phy_addr >> 32);
dev_->WriteReg(0x20, inp_phy_addr >> 32);
dev_->WriteReg(0x24, 0);
dev_->WriteReg(0x28, inp_phy_addr >> 32);
dev_->WriteReg(0x28, wgt_phy_addr >> 32);
dev_->WriteReg(0x2c, 0);
dev_->WriteReg(0x30, wgt_phy_addr >> 32);
dev_->WriteReg(0x30, acc_phy_addr >> 32);
dev_->WriteReg(0x34, 0);
dev_->WriteReg(0x38, acc_phy_addr >> 32);
dev_->WriteReg(0x3c, 0);
dev_->WriteReg(0x40, out_phy_addr >> 32);
dev_->WriteReg(0x38, out_phy_addr >> 32);
// start
dev_->WriteReg(0x00, 0x1);
}
......@@ -113,9 +127,14 @@ class Device {
val &= 0x2;
if (val == 0x2) break; // finish
}
prof_->cycle_count = dev_->ReadReg(0x04);
}
// Profiler
Profiler* prof_;
// DPI loader
DPILoader* dpi_;
// DPI Module
DPIModuleNode* dev_;
};
......@@ -128,6 +147,11 @@ TVM_REGISTER_GLOBAL("tvm.vta.tsim.init")
DPILoader::Global()->Init(m);
});
TVM_REGISTER_GLOBAL("tvm.vta.tsim.cycles")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = static_cast<int>(Profiler::Global()->cycle_count);
});
} // namespace tsim
} // namespace vta
......
......@@ -73,8 +73,12 @@ def test_save_load_out():
simulator.tsim_init("libvta_hw")
f(x_nd, y_nd)
np.testing.assert_equal(y_np, y_nd.asnumpy())
if env.TARGET == "tsim":
print("Load/store test took {} clock cycles".format(simulator.tsim_cycles()))
vta.testing.run(_run)
......@@ -135,8 +139,12 @@ def test_padded_load():
simulator.tsim_init("libvta_hw")
f(x_nd, y_nd)
np.testing.assert_equal(y_np, y_nd.asnumpy())
if env.TARGET == "tsim":
print("Padded load test took {} clock cycles".format(simulator.tsim_cycles()))
vta.testing.run(_run)
......@@ -180,7 +188,7 @@ def test_gemm():
if not remote:
return
def verify(s):
def verify(s, name=None):
mod = vta.build(s, [x, w, y], "ext_dev", env.target_host)
temp = util.tempdir()
mod.save(temp.relpath("gemm.o"))
......@@ -217,6 +225,9 @@ def test_gemm():
np.testing.assert_equal(y_np, y_nd.asnumpy())
if env.TARGET == "tsim":
print("GEMM schedule:{} test took {} clock cycles".format(name, simulator.tsim_cycles()))
def test_schedule1():
# default schedule with no smt
s = tvm.create_schedule(y.op)
......@@ -245,7 +256,7 @@ def test_gemm():
s[y_gem].op.axis[3],
ki)
s[y_gem].tensorize(s[y_gem].op.axis[2], env.gemm)
verify(s)
verify(s, name="default")
def test_smt():
# test smt schedule
......@@ -279,7 +290,7 @@ def test_gemm():
s[w_buf].compute_at(s[y_gem], ko)
s[w_buf].pragma(s[w_buf].op.axis[0], env.dma_copy)
s[y].pragma(abo2, env.dma_copy)
verify(s)
verify(s, name="smt")
test_schedule1()
test_smt()
......@@ -288,7 +299,7 @@ def test_gemm():
def test_alu():
def _run(env, remote):
def check_alu(tvm_op, np_op=None, use_imm=False):
def check_alu(tvm_op, np_op=None, use_imm=False, test_name=None):
"""Test ALU"""
m = 8
n = 8
......@@ -371,14 +382,18 @@ def test_alu():
else:
b_nd = tvm.nd.array(b_np, ctx)
f(a_nd, b_nd, res_nd)
np.testing.assert_equal(res_np, res_nd.asnumpy())
check_alu(lambda x, y: x << y, np.left_shift, use_imm=True)
check_alu(tvm.max, np.maximum, use_imm=True)
check_alu(tvm.max, np.maximum)
check_alu(lambda x, y: x + y, use_imm=True)
check_alu(lambda x, y: x + y)
check_alu(lambda x, y: x >> y, np.right_shift, use_imm=True)
if env.TARGET == "tsim":
print("ALU {} imm:{} test took {} clock cycles".format(test_name, use_imm, simulator.tsim_cycles()))
check_alu(lambda x, y: x << y, np.left_shift, use_imm=True, test_name="SHL")
check_alu(tvm.max, np.maximum, use_imm=True, test_name="MAX")
check_alu(tvm.max, np.maximum, test_name="MAX")
check_alu(lambda x, y: x + y, use_imm=True, test_name="ADD")
check_alu(lambda x, y: x + y, test_name="ADD")
check_alu(lambda x, y: x >> y, np.right_shift, use_imm=True, test_name="SHR")
vta.testing.run(_run)
......@@ -440,8 +455,12 @@ def test_relu():
simulator.tsim_init("libvta_hw")
f(a_nd, res_nd)
np.testing.assert_equal(res_np, res_nd.asnumpy())
if env.TARGET == "tsim":
print("Relu test took {} clock cycles".format(simulator.tsim_cycles()))
vta.testing.run(_run)
......@@ -503,8 +522,12 @@ def test_shift_and_scale():
simulator.tsim_init("libvta_hw")
f(a_nd, res_nd)
np.testing.assert_equal(res_np, res_nd.asnumpy())
if env.TARGET == "tsim":
print("Shift/scale test took {} clock cycles".format(simulator.tsim_cycles()))
vta.testing.run(_run)
......@@ -521,17 +544,10 @@ def test_runtime_array():
if __name__ == "__main__":
print("Array test")
test_runtime_array()
print("Load/store test")
test_save_load_out()
print("Padded load test")
test_padded_load()
print("GEMM test")
test_gemm()
print("ALU test")
test_alu()
print("Relu test")
test_relu()
print("Shift and scale")
test_shift_and_scale()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment