Commit 7bf2ff23 by Luis Vega Committed by Thierry Moreau

[VTA] add support to event counters (#3347)

* add support to event counters in VTA

* fix comment

* fix event-counter interface parameter

* no longer needed

* add sim back

* add docs to event counters

* fix docs

* add more details about event counting

* make dpi-module docs more accurate
parent c3cc3bf5
...@@ -54,16 +54,13 @@ class Device { ...@@ -54,16 +54,13 @@ class Device {
private: private:
void Launch(uint32_t c, uint32_t length, void* inp, void* out) { void Launch(uint32_t c, uint32_t length, void* inp, void* out) {
dpi_->Launch(wait_cycles_); dpi_->Launch(wait_cycles_);
// set counter to zero
dpi_->WriteReg(0x04, 0);
dpi_->WriteReg(0x08, c); dpi_->WriteReg(0x08, c);
dpi_->WriteReg(0x0c, length); dpi_->WriteReg(0x0c, length);
dpi_->WriteReg(0x10, get_half_addr(inp, false)); dpi_->WriteReg(0x10, get_half_addr(inp, false));
dpi_->WriteReg(0x14, get_half_addr(inp, true)); dpi_->WriteReg(0x14, get_half_addr(inp, true));
dpi_->WriteReg(0x18, get_half_addr(out, false)); dpi_->WriteReg(0x18, get_half_addr(out, false));
dpi_->WriteReg(0x1c, get_half_addr(out, true)); dpi_->WriteReg(0x1c, get_half_addr(out, true));
// launch dpi_->WriteReg(0x00, 0x1); // launch
dpi_->WriteReg(0x00, 0x1);
} }
uint32_t WaitForCompletion() { uint32_t WaitForCompletion() {
......
...@@ -64,6 +64,7 @@ class Core(implicit p: Parameters) extends Module { ...@@ -64,6 +64,7 @@ class Core(implicit p: Parameters) extends Module {
val load = Module(new Load) val load = Module(new Load)
val compute = Module(new Compute) val compute = Module(new Compute)
val store = Module(new Store) val store = Module(new Store)
val ecounters = Module(new EventCounters)
// Read(rd) and write(wr) from/to memory (i.e. DRAM) // Read(rd) and write(wr) from/to memory (i.e. DRAM)
io.vme.rd(0) <> fetch.io.vme_rd io.vme.rd(0) <> fetch.io.vme_rd
...@@ -103,6 +104,11 @@ class Core(implicit p: Parameters) extends Module { ...@@ -103,6 +104,11 @@ class Core(implicit p: Parameters) extends Module {
store.io.out_baddr := io.vcr.ptrs(5) store.io.out_baddr := io.vcr.ptrs(5)
store.io.out <> compute.io.out store.io.out <> compute.io.out
// Event counters
ecounters.io.launch := io.vcr.launch
ecounters.io.finish := compute.io.finish
io.vcr.ecnt <> ecounters.io.ecnt
// Finish instruction is executed and asserts the VCR finish flag // Finish instruction is executed and asserts the VCR finish flag
val finish = RegNext(compute.io.finish) val finish = RegNext(compute.io.finish)
io.vcr.finish := finish io.vcr.finish := finish
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package vta.core
import chisel3._
import chisel3.util._
import vta.util.config._
import vta.shell._
/** EventCounters.
*
* This unit contains all the event counting logic. One common event tracked in
* hardware is the number of clock cycles taken to achieve certain task. We
* can count the total number of clock cycles spent in a VTA run by checking
* launch and finish signals.
*
* The event counter value is passed to the VCR module via the ecnt port, so
* they can be accessed by the host. The number of event counters (nECnt) is
* defined in the Shell VCR module as a parameter, see VCRParams.
*
* If one would like to add an event counter, then the value of nECnt must be
* changed in VCRParams together with the corresponding counting logic here.
*/
class EventCounters(debug: Boolean = false)(implicit p: Parameters) extends Module {
val vp = p(ShellKey).vcrParams
val io = IO(new Bundle{
val launch = Input(Bool())
val finish = Input(Bool())
val ecnt = Vec(vp.nECnt, ValidIO(UInt(vp.regBits.W)))
})
val cycle_cnt = RegInit(0.U(vp.regBits.W))
when (io.launch && !io.finish) {
cycle_cnt := cycle_cnt + 1.U
} .otherwise {
cycle_cnt := 0.U
}
io.ecnt(0).valid := io.finish
io.ecnt(0).bits := cycle_cnt
}
...@@ -23,8 +23,6 @@ import chisel3._ ...@@ -23,8 +23,6 @@ import chisel3._
import chisel3.util._ import chisel3.util._
import vta.util.config._ import vta.util.config._
import vta.util.genericbundle._ import vta.util.genericbundle._
import scala.collection.mutable.ListBuffer
import scala.collection.mutable.LinkedHashMap
import vta.interface.axi._ import vta.interface.axi._
/** VCR parameters. /** VCR parameters.
...@@ -33,14 +31,11 @@ import vta.interface.axi._ ...@@ -33,14 +31,11 @@ import vta.interface.axi._
*/ */
case class VCRParams() case class VCRParams()
{ {
val nValsReg: Int = 1 val nCtrl = 1
val nPtrsReg: Int = 6 val nECnt = 1
val regBits: Int = 32 val nVals = 1
val nCtrlReg: Int = 4 val nPtrs = 6
val ctrlBaseAddr: Int = 0 val regBits = 32
require (nValsReg > 0)
require (nPtrsReg > 0)
} }
/** VCRBase. Parametrize base class. */ /** VCRBase. Parametrize base class. */
...@@ -57,9 +52,9 @@ class VCRMaster(implicit p: Parameters) extends VCRBase { ...@@ -57,9 +52,9 @@ class VCRMaster(implicit p: Parameters) extends VCRBase {
val mp = p(ShellKey).memParams val mp = p(ShellKey).memParams
val launch = Output(Bool()) val launch = Output(Bool())
val finish = Input(Bool()) val finish = Input(Bool())
val irq = Output(Bool()) val ecnt = Vec(vp.nECnt, Flipped(ValidIO(UInt(vp.regBits.W))))
val ptrs = Output(Vec(vp.nPtrsReg, UInt(mp.addrBits.W))) val vals = Output(Vec(vp.nVals, UInt(vp.regBits.W)))
val vals = Output(Vec(vp.nValsReg, UInt(vp.regBits.W))) val ptrs = Output(Vec(vp.nPtrs, UInt(mp.addrBits.W)))
} }
/** VCRClient. /** VCRClient.
...@@ -72,9 +67,9 @@ class VCRClient(implicit p: Parameters) extends VCRBase { ...@@ -72,9 +67,9 @@ class VCRClient(implicit p: Parameters) extends VCRBase {
val mp = p(ShellKey).memParams val mp = p(ShellKey).memParams
val launch = Input(Bool()) val launch = Input(Bool())
val finish = Output(Bool()) val finish = Output(Bool())
val irq = Input(Bool()) val ecnt = Vec(vp.nECnt, ValidIO(UInt(vp.regBits.W)))
val ptrs = Input(Vec(vp.nPtrsReg, UInt(mp.addrBits.W))) val vals = Input(Vec(vp.nVals, UInt(vp.regBits.W)))
val vals = Input(Vec(vp.nValsReg, UInt(vp.regBits.W))) val ptrs = Input(Vec(vp.nPtrs, UInt(mp.addrBits.W)))
} }
/** VTA Control Registers (VCR). /** VTA Control Registers (VCR).
...@@ -97,10 +92,23 @@ class VCR(implicit p: Parameters) extends Module { ...@@ -97,10 +92,23 @@ class VCR(implicit p: Parameters) extends Module {
// Write control (AW, W, B) // Write control (AW, W, B)
val waddr = RegInit("h_ffff".U(hp.addrBits.W)) // init with invalid address val waddr = RegInit("h_ffff".U(hp.addrBits.W)) // init with invalid address
val wdata = io.host.w.bits.data val wdata = io.host.w.bits.data
val wstrb = io.host.w.bits.strb
val wmask = Cat(Fill(8, wstrb(3)), Fill(8, wstrb(2)), Fill(8, wstrb(1)), Fill(8, wstrb(0)))
val sWriteAddress :: sWriteData :: sWriteResponse :: Nil = Enum(3) val sWriteAddress :: sWriteData :: sWriteResponse :: Nil = Enum(3)
val wstate = RegInit(sWriteAddress) val wstate = RegInit(sWriteAddress)
// read control (AR, R)
val sReadAddress :: sReadData :: Nil = Enum(2)
val rstate = RegInit(sReadAddress)
val rdata = RegInit(0.U(vp.regBits.W))
// registers
val nTotal = vp.nCtrl + vp.nECnt + vp.nVals + (2*vp.nPtrs)
val reg = Seq.fill(nTotal)(RegInit(0.U(vp.regBits.W)))
val addr = Seq.tabulate(nTotal)(_ * 4)
val reg_map = (addr zip reg) map { case (a, r) => a.U -> r }
val eo = vp.nCtrl
val vo = eo + vp.nECnt
val po = vo + vp.nVals
switch (wstate) { switch (wstate) {
is (sWriteAddress) { is (sWriteAddress) {
when (io.host.aw.valid) { when (io.host.aw.valid) {
...@@ -124,11 +132,8 @@ class VCR(implicit p: Parameters) extends Module { ...@@ -124,11 +132,8 @@ class VCR(implicit p: Parameters) extends Module {
io.host.aw.ready := wstate === sWriteAddress io.host.aw.ready := wstate === sWriteAddress
io.host.w.ready := wstate === sWriteData io.host.w.ready := wstate === sWriteData
io.host.b.valid := wstate === sWriteResponse io.host.b.valid := wstate === sWriteResponse
io.host.b.bits.resp := "h_0".U io.host.b.bits.resp := 0.U
// read control (AR, R)
val sReadAddress :: sReadData :: Nil = Enum(2)
val rstate = RegInit(sReadAddress)
switch (rstate) { switch (rstate) {
is (sReadAddress) { is (sReadAddress) {
...@@ -145,98 +150,40 @@ class VCR(implicit p: Parameters) extends Module { ...@@ -145,98 +150,40 @@ class VCR(implicit p: Parameters) extends Module {
io.host.ar.ready := rstate === sReadAddress io.host.ar.ready := rstate === sReadAddress
io.host.r.valid := rstate === sReadData io.host.r.valid := rstate === sReadData
io.host.r.bits.data := rdata
io.host.r.bits.resp := 0.U
val nPtrsReg = vp.nPtrsReg
val nValsReg = vp.nValsReg
val regBits = vp.regBits
val ptrsBits = mp.addrBits
val nCtrlReg = vp.nCtrlReg
val rStride = regBits/8
val pStride = ptrsBits/8
val ctrlBaseAddr = vp.ctrlBaseAddr
val valsBaseAddr = ctrlBaseAddr + nCtrlReg*rStride
val ptrsBaseAddr = valsBaseAddr + nValsReg*rStride
val ctrlAddr = Seq.tabulate(nCtrlReg)(i => i*rStride + ctrlBaseAddr)
val valsAddr = Seq.tabulate(nValsReg)(i => i*rStride + valsBaseAddr)
val ptrsAddr = new ListBuffer[Int]()
for (i <- 0 until nPtrsReg) {
ptrsAddr += i*pStride + ptrsBaseAddr
if (ptrsBits == 64) {
ptrsAddr += i*pStride + rStride + ptrsBaseAddr
}
}
// AP register
val c0 = RegInit(VecInit(Seq.fill(regBits)(false.B)))
// ap start
when (io.host.w.fire() && waddr === ctrlAddr(0).asUInt && wstrb(0) && wdata(0)) {
c0(0) := true.B
} .elsewhen (io.vcr.finish) {
c0(0) := false.B
}
// ap done = finish
when (io.vcr.finish) { when (io.vcr.finish) {
c0(1) := true.B reg(0) := "b_10".U
} .elsewhen (io.host.ar.fire() && io.host.ar.bits.addr === ctrlAddr(0).asUInt) { } .elsewhen (io.host.w.fire() && addr(0).U === waddr) {
c0(1) := false.B reg(0) := wdata
} }
val c1 = 0.U for (i <- 0 until vp.nECnt) {
val c2 = 0.U when (io.vcr.ecnt(i).valid) {
val c3 = 0.U reg(eo + i) := io.vcr.ecnt(i).bits
} .elsewhen (io.host.w.fire() && addr(eo + i).U === waddr) {
val ctrlRegList = List(c0, c1, c2, c3) reg(eo + i) := wdata
io.vcr.launch := c0(0)
// interrupts not supported atm
io.vcr.irq := false.B
// Write pointer and value registers
val pvAddr = valsAddr ++ ptrsAddr
val pvNumReg = if (ptrsBits == 64) nValsReg + nPtrsReg*2 else nValsReg + nPtrsReg
val pvReg = RegInit(VecInit(Seq.fill(pvNumReg)(0.U(regBits.W))))
val pvRegList = new ListBuffer[UInt]()
for (i <- 0 until pvNumReg) {
when (io.host.w.fire() && (waddr === pvAddr(i).U)) {
pvReg(i) := (wdata & wmask) | (pvReg(i) & ~wmask)
} }
pvRegList += pvReg(i)
}
for (i <- 0 until nValsReg) {
io.vcr.vals(i) := pvReg(i)
} }
for (i <- 0 until nPtrsReg) { for (i <- 0 until (vp.nVals + (2*vp.nPtrs))) {
if (ptrsBits == 64) { when (io.host.w.fire() && addr(vo + i).U === waddr) {
io.vcr.ptrs(i) := Cat(pvReg(nValsReg + i*2 + 1), pvReg(nValsReg + i*2)) reg(vo + i) := wdata
} else {
io.vcr.ptrs(i) := pvReg(nValsReg + i)
} }
} }
// Read pointer and value registers when (io.host.ar.fire()) {
val mapAddr = ctrlAddr ++ valsAddr ++ ptrsAddr rdata := MuxLookup(io.host.ar.bits.addr, 0.U, reg_map)
val mapRegList = ctrlRegList ++ pvRegList }
val rdata = RegInit(0.U(regBits.W))
val rmap = LinkedHashMap[Int,UInt]()
val totalReg = mapRegList.length
for (i <- 0 until totalReg) { rmap += mapAddr(i) -> mapRegList(i).asUInt }
val decodeAddr = rmap map { case (k, _) => k -> (io.host.ar.bits.addr === k.asUInt) } io.vcr.launch := reg(0)(0)
when (io.host.ar.fire()) { for (i <- 0 until vp.nVals) {
rdata := Mux1H(for ((k, v) <- rmap) yield decodeAddr(k) -> v) io.vcr.vals(i) := reg(vo + i)
} }
io.host.r.bits.resp := 0.U for (i <- 0 until vp.nPtrs) {
io.host.r.bits.data := rdata io.vcr.ptrs(i) := Cat(reg(po + 2*i + 1), reg(po + 2*i))
}
} }
...@@ -35,7 +35,7 @@ namespace dpi { ...@@ -35,7 +35,7 @@ namespace dpi {
class DPIModuleNode : public tvm::runtime::ModuleNode { class DPIModuleNode : public tvm::runtime::ModuleNode {
public: public:
/*! /*!
* \brief Launch accelerator until it finishes or reach max_cycles * \brief Launch hardware simulation until accelerator finishes or reach max_cycles
* \param max_cycles The maximum of cycles to wait * \param max_cycles The maximum of cycles to wait
*/ */
virtual void Launch(uint64_t max_cycles) = 0; virtual void Launch(uint64_t max_cycles) = 0;
...@@ -53,7 +53,7 @@ class DPIModuleNode : public tvm::runtime::ModuleNode { ...@@ -53,7 +53,7 @@ class DPIModuleNode : public tvm::runtime::ModuleNode {
*/ */
virtual uint32_t ReadReg(int addr) = 0; virtual uint32_t ReadReg(int addr) = 0;
/*! \brief Kill or Exit() the accelerator */ /*! \brief Finish hardware simulation */
virtual void Finish() = 0; virtual void Finish() = 0;
static tvm::runtime::Module Load(std::string dll_name); static tvm::runtime::Module Load(std::string dll_name);
......
...@@ -74,5 +74,14 @@ def tsim_init(hw_lib): ...@@ -74,5 +74,14 @@ def tsim_init(hw_lib):
m = tvm.module.load(lib, "vta-tsim") m = tvm.module.load(lib, "vta-tsim")
f(m) f(m)
def tsim_cycles():
"""Get tsim clock cycles
Returns
-------
stats : int
tsim clock cycles
"""
return tvm.get_global_func("tvm.vta.tsim.cycles")()
LIBS = _load_lib() LIBS = _load_lib()
...@@ -28,6 +28,17 @@ namespace tsim { ...@@ -28,6 +28,17 @@ namespace tsim {
using vta::dpi::DPIModuleNode; using vta::dpi::DPIModuleNode;
using tvm::runtime::Module; using tvm::runtime::Module;
class Profiler {
public:
/*! \brief cycle counter */
uint64_t cycle_count{0};
static Profiler* Global() {
static Profiler inst;
return &inst;
}
};
class DPILoader { class DPILoader {
public: public:
void Init(Module module) { void Init(Module module) {
...@@ -50,6 +61,7 @@ class Device { ...@@ -50,6 +61,7 @@ class Device {
public: public:
Device() { Device() {
dpi_ = DPILoader::Global(); dpi_ = DPILoader::Global();
prof_ = Profiler::Global();
} }
int Run(vta_phy_addr_t insn_phy_addr, int Run(vta_phy_addr_t insn_phy_addr,
...@@ -89,19 +101,21 @@ class Device { ...@@ -89,19 +101,21 @@ class Device {
uint32_t wait_cycles) { uint32_t wait_cycles) {
// launch simulation thread // launch simulation thread
dev_->Launch(wait_cycles); dev_->Launch(wait_cycles);
dev_->WriteReg(0x10, insn_count); // set counter to zero
dev_->WriteReg(0x14, insn_phy_addr); dev_->WriteReg(0x04, 0);
dev_->WriteReg(0x18, insn_phy_addr >> 32); dev_->WriteReg(0x08, insn_count);
dev_->WriteReg(0x0c, insn_phy_addr);
dev_->WriteReg(0x10, insn_phy_addr >> 32);
dev_->WriteReg(0x14, 0);
dev_->WriteReg(0x18, uop_phy_addr >> 32);
dev_->WriteReg(0x1c, 0); dev_->WriteReg(0x1c, 0);
dev_->WriteReg(0x20, uop_phy_addr >> 32); dev_->WriteReg(0x20, inp_phy_addr >> 32);
dev_->WriteReg(0x24, 0); dev_->WriteReg(0x24, 0);
dev_->WriteReg(0x28, inp_phy_addr >> 32); dev_->WriteReg(0x28, wgt_phy_addr >> 32);
dev_->WriteReg(0x2c, 0); dev_->WriteReg(0x2c, 0);
dev_->WriteReg(0x30, wgt_phy_addr >> 32); dev_->WriteReg(0x30, acc_phy_addr >> 32);
dev_->WriteReg(0x34, 0); dev_->WriteReg(0x34, 0);
dev_->WriteReg(0x38, acc_phy_addr >> 32); dev_->WriteReg(0x38, out_phy_addr >> 32);
dev_->WriteReg(0x3c, 0);
dev_->WriteReg(0x40, out_phy_addr >> 32);
// start // start
dev_->WriteReg(0x00, 0x1); dev_->WriteReg(0x00, 0x1);
} }
...@@ -113,9 +127,14 @@ class Device { ...@@ -113,9 +127,14 @@ class Device {
val &= 0x2; val &= 0x2;
if (val == 0x2) break; // finish if (val == 0x2) break; // finish
} }
prof_->cycle_count = dev_->ReadReg(0x04);
} }
// Profiler
Profiler* prof_;
// DPI loader
DPILoader* dpi_; DPILoader* dpi_;
// DPI Module
DPIModuleNode* dev_; DPIModuleNode* dev_;
}; };
...@@ -128,6 +147,11 @@ TVM_REGISTER_GLOBAL("tvm.vta.tsim.init") ...@@ -128,6 +147,11 @@ TVM_REGISTER_GLOBAL("tvm.vta.tsim.init")
DPILoader::Global()->Init(m); DPILoader::Global()->Init(m);
}); });
TVM_REGISTER_GLOBAL("tvm.vta.tsim.cycles")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = static_cast<int>(Profiler::Global()->cycle_count);
});
} // namespace tsim } // namespace tsim
} // namespace vta } // namespace vta
......
...@@ -73,8 +73,12 @@ def test_save_load_out(): ...@@ -73,8 +73,12 @@ def test_save_load_out():
simulator.tsim_init("libvta_hw") simulator.tsim_init("libvta_hw")
f(x_nd, y_nd) f(x_nd, y_nd)
np.testing.assert_equal(y_np, y_nd.asnumpy()) np.testing.assert_equal(y_np, y_nd.asnumpy())
if env.TARGET == "tsim":
print("Load/store test took {} clock cycles".format(simulator.tsim_cycles()))
vta.testing.run(_run) vta.testing.run(_run)
...@@ -135,8 +139,12 @@ def test_padded_load(): ...@@ -135,8 +139,12 @@ def test_padded_load():
simulator.tsim_init("libvta_hw") simulator.tsim_init("libvta_hw")
f(x_nd, y_nd) f(x_nd, y_nd)
np.testing.assert_equal(y_np, y_nd.asnumpy()) np.testing.assert_equal(y_np, y_nd.asnumpy())
if env.TARGET == "tsim":
print("Padded load test took {} clock cycles".format(simulator.tsim_cycles()))
vta.testing.run(_run) vta.testing.run(_run)
...@@ -180,7 +188,7 @@ def test_gemm(): ...@@ -180,7 +188,7 @@ def test_gemm():
if not remote: if not remote:
return return
def verify(s): def verify(s, name=None):
mod = vta.build(s, [x, w, y], "ext_dev", env.target_host) mod = vta.build(s, [x, w, y], "ext_dev", env.target_host)
temp = util.tempdir() temp = util.tempdir()
mod.save(temp.relpath("gemm.o")) mod.save(temp.relpath("gemm.o"))
...@@ -217,6 +225,9 @@ def test_gemm(): ...@@ -217,6 +225,9 @@ def test_gemm():
np.testing.assert_equal(y_np, y_nd.asnumpy()) np.testing.assert_equal(y_np, y_nd.asnumpy())
if env.TARGET == "tsim":
print("GEMM schedule:{} test took {} clock cycles".format(name, simulator.tsim_cycles()))
def test_schedule1(): def test_schedule1():
# default schedule with no smt # default schedule with no smt
s = tvm.create_schedule(y.op) s = tvm.create_schedule(y.op)
...@@ -245,7 +256,7 @@ def test_gemm(): ...@@ -245,7 +256,7 @@ def test_gemm():
s[y_gem].op.axis[3], s[y_gem].op.axis[3],
ki) ki)
s[y_gem].tensorize(s[y_gem].op.axis[2], env.gemm) s[y_gem].tensorize(s[y_gem].op.axis[2], env.gemm)
verify(s) verify(s, name="default")
def test_smt(): def test_smt():
# test smt schedule # test smt schedule
...@@ -279,7 +290,7 @@ def test_gemm(): ...@@ -279,7 +290,7 @@ def test_gemm():
s[w_buf].compute_at(s[y_gem], ko) s[w_buf].compute_at(s[y_gem], ko)
s[w_buf].pragma(s[w_buf].op.axis[0], env.dma_copy) s[w_buf].pragma(s[w_buf].op.axis[0], env.dma_copy)
s[y].pragma(abo2, env.dma_copy) s[y].pragma(abo2, env.dma_copy)
verify(s) verify(s, name="smt")
test_schedule1() test_schedule1()
test_smt() test_smt()
...@@ -288,7 +299,7 @@ def test_gemm(): ...@@ -288,7 +299,7 @@ def test_gemm():
def test_alu(): def test_alu():
def _run(env, remote): def _run(env, remote):
def check_alu(tvm_op, np_op=None, use_imm=False): def check_alu(tvm_op, np_op=None, use_imm=False, test_name=None):
"""Test ALU""" """Test ALU"""
m = 8 m = 8
n = 8 n = 8
...@@ -371,14 +382,18 @@ def test_alu(): ...@@ -371,14 +382,18 @@ def test_alu():
else: else:
b_nd = tvm.nd.array(b_np, ctx) b_nd = tvm.nd.array(b_np, ctx)
f(a_nd, b_nd, res_nd) f(a_nd, b_nd, res_nd)
np.testing.assert_equal(res_np, res_nd.asnumpy()) np.testing.assert_equal(res_np, res_nd.asnumpy())
check_alu(lambda x, y: x << y, np.left_shift, use_imm=True) if env.TARGET == "tsim":
check_alu(tvm.max, np.maximum, use_imm=True) print("ALU {} imm:{} test took {} clock cycles".format(test_name, use_imm, simulator.tsim_cycles()))
check_alu(tvm.max, np.maximum)
check_alu(lambda x, y: x + y, use_imm=True) check_alu(lambda x, y: x << y, np.left_shift, use_imm=True, test_name="SHL")
check_alu(lambda x, y: x + y) check_alu(tvm.max, np.maximum, use_imm=True, test_name="MAX")
check_alu(lambda x, y: x >> y, np.right_shift, use_imm=True) check_alu(tvm.max, np.maximum, test_name="MAX")
check_alu(lambda x, y: x + y, use_imm=True, test_name="ADD")
check_alu(lambda x, y: x + y, test_name="ADD")
check_alu(lambda x, y: x >> y, np.right_shift, use_imm=True, test_name="SHR")
vta.testing.run(_run) vta.testing.run(_run)
...@@ -440,8 +455,12 @@ def test_relu(): ...@@ -440,8 +455,12 @@ def test_relu():
simulator.tsim_init("libvta_hw") simulator.tsim_init("libvta_hw")
f(a_nd, res_nd) f(a_nd, res_nd)
np.testing.assert_equal(res_np, res_nd.asnumpy()) np.testing.assert_equal(res_np, res_nd.asnumpy())
if env.TARGET == "tsim":
print("Relu test took {} clock cycles".format(simulator.tsim_cycles()))
vta.testing.run(_run) vta.testing.run(_run)
...@@ -503,8 +522,12 @@ def test_shift_and_scale(): ...@@ -503,8 +522,12 @@ def test_shift_and_scale():
simulator.tsim_init("libvta_hw") simulator.tsim_init("libvta_hw")
f(a_nd, res_nd) f(a_nd, res_nd)
np.testing.assert_equal(res_np, res_nd.asnumpy()) np.testing.assert_equal(res_np, res_nd.asnumpy())
if env.TARGET == "tsim":
print("Shift/scale test took {} clock cycles".format(simulator.tsim_cycles()))
vta.testing.run(_run) vta.testing.run(_run)
...@@ -521,17 +544,10 @@ def test_runtime_array(): ...@@ -521,17 +544,10 @@ def test_runtime_array():
if __name__ == "__main__": if __name__ == "__main__":
print("Array test")
test_runtime_array() test_runtime_array()
print("Load/store test")
test_save_load_out() test_save_load_out()
print("Padded load test")
test_padded_load() test_padded_load()
print("GEMM test")
test_gemm() test_gemm()
print("ALU test")
test_alu() test_alu()
print("Relu test")
test_relu() test_relu()
print("Shift and scale")
test_shift_and_scale() test_shift_and_scale()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment