Commit a31dd162 by Luis Vega Committed by Thierry Moreau

[VTA] TSIM improvements and fixes (#3505)

* add tsim init function

* add sim device

* test wait and resume

* launch simulation thread from DPILoader

* add VTASimDPI module to handle all simulation related stuff

* test tsim init

* move exit to simdpi module

* update vta driver

* add chisel DPI module

* get back simshell

* update vta to support dpi sim

* update unittests

* add tsim to integration-conv2d test

* run resnet on tsim

* remove max-cycles

* match tsim counters with sim counters

* use env in simulator to switch between sim and tsim

* update unittest

* rollback conv2d test

* update resnet

* add stats to matrix multiply

* add stats

* print stats after assert

* update other tests

* add stats to gemm

* add return and remove unused libs

* add missing arg

* return lib

* update comments for linter

* add more comments to VTASimDPI module

* remove trailing spaces

* remove trailing spaces
parent 2a7aebe5
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
package test package test
import chisel3._ import chisel3._
import chisel3.experimental.MultiIOModule
import vta.dpi._ import vta.dpi._
import accel._ import accel._
...@@ -28,19 +29,23 @@ import accel._ ...@@ -28,19 +29,23 @@ import accel._
* Instantiate Host and Memory DPI modules. * Instantiate Host and Memory DPI modules.
* *
*/ */
class VTASimShell extends Module { class VTASimShell extends MultiIOModule {
val io = IO(new Bundle { val host = IO(new VTAHostDPIMaster)
val host = new VTAHostDPIMaster val mem = IO(new VTAMemDPIClient)
val mem = new VTAMemDPIClient val sim_clock = IO(Input(Clock()))
}) val sim_wait = IO(Output(Bool()))
val host = Module(new VTAHostDPI) val mod_sim = Module(new VTASimDPI)
val mem = Module(new VTAMemDPI) val mod_host = Module(new VTAHostDPI)
mem.io.dpi <> io.mem val mod_mem = Module(new VTAMemDPI)
mem.io.reset := reset mod_mem.io.clock := clock
mem.io.clock := clock mod_mem.io.reset := reset
io.host <> host.io.dpi mod_mem.io.dpi <> mem
host.io.reset := reset mod_host.io.clock := clock
host.io.clock := clock mod_host.io.reset := reset
host <> mod_host.io.dpi
mod_sim.io.clock := sim_clock
mod_sim.io.reset := reset
sim_wait := mod_sim.io.dpi_wait
} }
/** Test accelerator. /** Test accelerator.
...@@ -48,12 +53,15 @@ class VTASimShell extends Module { ...@@ -48,12 +53,15 @@ class VTASimShell extends Module {
* Instantiate and connect the simulation-shell and the accelerator. * Instantiate and connect the simulation-shell and the accelerator.
* *
*/ */
class TestAccel extends Module { class TestAccel extends MultiIOModule {
val io = IO(new Bundle {}) val sim_clock = IO(Input(Clock()))
val sim_wait = IO(Output(Bool()))
val sim_shell = Module(new VTASimShell) val sim_shell = Module(new VTASimShell)
val vta_accel = Module(new Accel) val vta_accel = Module(new Accel)
vta_accel.io.host <> sim_shell.io.host sim_shell.sim_clock := sim_clock
sim_shell.io.mem <> vta_accel.io.mem sim_wait := sim_shell.sim_wait
sim_shell.mem <> vta_accel.io.mem
vta_accel.io.host <> sim_shell.host
} }
/** Generate TestAccel as top module */ /** Generate TestAccel as top module */
......
...@@ -25,7 +25,9 @@ ...@@ -25,7 +25,9 @@
module TestAccel module TestAccel
( (
input clock, input clock,
input reset input reset,
input sim_clock,
output sim_wait
); );
localparam HOST_ADDR_BITS = 8; localparam HOST_ADDR_BITS = 8;
...@@ -53,6 +55,14 @@ module TestAccel ...@@ -53,6 +55,14 @@ module TestAccel
logic [MEM_DATA_BITS-1:0] mem_rd_bits; logic [MEM_DATA_BITS-1:0] mem_rd_bits;
logic mem_rd_ready; logic mem_rd_ready;
VTASimDPI sim
(
.clock (sim_clock),
.reset (reset),
.dpi_wait (sim_wait)
);
VTAHostDPI host VTAHostDPI host
( (
.clock (clock), .clock (clock),
...@@ -114,4 +124,5 @@ module TestAccel ...@@ -114,4 +124,5 @@ module TestAccel
.mem_rd_bits (mem_rd_bits), .mem_rd_bits (mem_rd_bits),
.mem_rd_ready (mem_rd_ready) .mem_rd_ready (mem_rd_ready)
); );
endmodule endmodule
...@@ -20,7 +20,22 @@ import ctypes ...@@ -20,7 +20,22 @@ import ctypes
import os.path as osp import os.path as osp
from sys import platform from sys import platform
def driver(hw_backend): def get_ext():
return ".dylib" if platform == "darwin" else ".so"
def load_dll(dll):
try:
return [ctypes.CDLL(dll, ctypes.RTLD_GLOBAL)]
except OSError:
return []
def load_sw():
cur_path = osp.dirname(osp.abspath(osp.expanduser(__file__)))
sw_libname = "libsw" + get_ext()
sw_lib = osp.join(cur_path, "..", "build", sw_libname)
load_dll(sw_lib)
def init(hw_backend):
"""Init hardware and software shared library for accelerator """Init hardware and software shared library for accelerator
Parameters Parameters
...@@ -29,23 +44,15 @@ def driver(hw_backend): ...@@ -29,23 +44,15 @@ def driver(hw_backend):
Hardware backend can be verilog or chisel Hardware backend can be verilog or chisel
""" """
_ext = ".dylib" if platform == "darwin" else ".so" cur_path = osp.dirname(osp.abspath(osp.expanduser(__file__)))
_hw_libname = "libhw" + _ext hw_libname = "libhw" + get_ext()
_sw_libname = "libsw" + _ext
_cur_path = osp.dirname(osp.abspath(osp.expanduser(__file__)))
if hw_backend in ("verilog", "chisel"): if hw_backend in ("verilog", "chisel"):
_hw_lib = osp.join(_cur_path, "..", "hardware", hw_backend, "build", _hw_libname) hw_lib = osp.join(cur_path, "..", "hardware", hw_backend, "build", hw_libname)
_sw_lib = osp.join(_cur_path, "..", "build", _sw_libname) m = tvm.module.load(hw_lib, "vta-tsim")
load_sw()
def load_dll(dll): f = tvm.get_global_func("tvm.vta.tsim.init")
try: f(m)
return [ctypes.CDLL(dll, ctypes.RTLD_GLOBAL)]
except OSError: def load_module():
return [] load_sw()
return tvm.get_global_func("tvm.vta.driver")
def run(a, b, c):
load_dll(_sw_lib)
f = tvm.get_global_func("tvm.vta.driver")
m = tvm.module.load(_hw_lib, "vta-tsim")
return f(m, a, b, c)
return run
...@@ -35,25 +35,56 @@ uint32_t get_half_addr(void *p, bool upper) { ...@@ -35,25 +35,56 @@ uint32_t get_half_addr(void *p, bool upper) {
using vta::dpi::DPIModuleNode; using vta::dpi::DPIModuleNode;
using tvm::runtime::Module; using tvm::runtime::Module;
class DPILoader {
public:
~DPILoader() {
dpi_->SimResume();
dpi_->SimFinish();
}
void Init(Module module) {
mod_ = module;
dpi_ = this->Get();
dpi_->SimLaunch();
dpi_->SimWait();
}
DPIModuleNode* Get() {
return static_cast<DPIModuleNode*>(mod_.operator->());
}
static DPILoader* Global() {
static DPILoader inst;
return &inst;
}
// TVM module
Module mod_;
// DPI Module
DPIModuleNode* dpi_{nullptr};
};
class Device { class Device {
public: public:
Device(Module module) Device() {
: module_(module) { loader_ = DPILoader::Global();
dpi_ = static_cast<DPIModuleNode*>(
module.operator->());
} }
uint32_t Run(uint32_t c, uint32_t length, void* inp, void* out) { uint32_t Run(uint32_t c, uint32_t length, void* inp, void* out) {
uint32_t cycles; uint32_t cycles;
this->Init();
this->Launch(c, length, inp, out); this->Launch(c, length, inp, out);
cycles = this->WaitForCompletion(); cycles = this->WaitForCompletion();
dpi_->Finish();
return cycles; return cycles;
} }
private: private:
void Init() {
dpi_ = loader_->Get();
dpi_->SimResume();
}
void Launch(uint32_t c, uint32_t length, void* inp, void* out) { void Launch(uint32_t c, uint32_t length, void* inp, void* out) {
dpi_->Launch(wait_cycles_);
dpi_->WriteReg(0x08, c); dpi_->WriteReg(0x08, c);
dpi_->WriteReg(0x0c, length); dpi_->WriteReg(0x0c, length);
dpi_->WriteReg(0x10, get_half_addr(inp, false)); dpi_->WriteReg(0x10, get_half_addr(inp, false));
...@@ -70,24 +101,33 @@ class Device { ...@@ -70,24 +101,33 @@ class Device {
if (val == 2) break; // finish if (val == 2) break; // finish
} }
val = dpi_->ReadReg(0x04); val = dpi_->ReadReg(0x04);
dpi_->SimWait();
return val; return val;
} }
// wait cycles
uint32_t wait_cycles_{100000000}; uint32_t wait_cycles_{100000000};
DPIModuleNode* dpi_; // DPI loader
Module module_; DPILoader* loader_{nullptr};
// DPI Module
DPIModuleNode* dpi_{nullptr};
}; };
using tvm::runtime::TVMRetValue; using tvm::runtime::TVMRetValue;
using tvm::runtime::TVMArgs; using tvm::runtime::TVMArgs;
TVM_REGISTER_GLOBAL("tvm.vta.tsim.init")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Module m = args[0];
DPILoader::Global()->Init(m);
});
TVM_REGISTER_GLOBAL("tvm.vta.driver") TVM_REGISTER_GLOBAL("tvm.vta.driver")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
Module dev_mod = args[0]; DLTensor* A = args[0];
DLTensor* A = args[1]; DLTensor* B = args[1];
DLTensor* B = args[2]; Device dev_;
Device dev_(dev_mod); uint32_t cycles = dev_.Run(static_cast<int>(args[2]), A->shape[0], A->data, B->data);
uint32_t cycles = dev_.Run(static_cast<int>(args[3]), A->shape[0], A->data, B->data);
*rv = static_cast<int>(cycles); *rv = static_cast<int>(cycles);
}); });
......
...@@ -26,12 +26,13 @@ def test_accel(): ...@@ -26,12 +26,13 @@ def test_accel():
ctx = tvm.cpu(0) ctx = tvm.cpu(0)
a = tvm.nd.array(np.random.randint(rmax, size=n).astype("uint64"), ctx) a = tvm.nd.array(np.random.randint(rmax, size=n).astype("uint64"), ctx)
b = tvm.nd.array(np.zeros(n).astype("uint64"), ctx) b = tvm.nd.array(np.zeros(n).astype("uint64"), ctx)
f = tsim.driver("chisel") f = tsim.load_module()
cycles = f(a, b, c) cycles = f(a, b, c)
msg = "cycles:{0:4} n:{1:2} c:{2:2}".format(cycles, n, c) msg = "cycles:{0:4} n:{1:2} c:{2:2}".format(cycles, n, c)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + c, err_msg = "[FAIL] " + msg) np.testing.assert_equal(b.asnumpy(), a.asnumpy() + c, err_msg = "[FAIL] " + msg)
print("[PASS] " + msg) print("[PASS] " + msg)
if __name__ == "__main__": if __name__ == "__main__":
tsim.init("chisel")
for i in range(10): for i in range(10):
test_accel() test_accel()
...@@ -26,12 +26,13 @@ def test_accel(): ...@@ -26,12 +26,13 @@ def test_accel():
ctx = tvm.cpu(0) ctx = tvm.cpu(0)
a = tvm.nd.array(np.random.randint(rmax, size=n).astype("uint64"), ctx) a = tvm.nd.array(np.random.randint(rmax, size=n).astype("uint64"), ctx)
b = tvm.nd.array(np.zeros(n).astype("uint64"), ctx) b = tvm.nd.array(np.zeros(n).astype("uint64"), ctx)
f = tsim.driver("verilog") f = tsim.load_module()
cycles = f(a, b, c) cycles = f(a, b, c)
msg = "cycles:{0:4} n:{1:2} c:{2:2}".format(cycles, n, c) msg = "cycles:{0:4} n:{1:2} c:{2:2}".format(cycles, n, c)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + c, err_msg = "[FAIL] " + msg) np.testing.assert_equal(b.asnumpy(), a.asnumpy() + c, err_msg = "[FAIL] " + msg)
print("[PASS] " + msg) print("[PASS] " + msg)
if __name__ == "__main__": if __name__ == "__main__":
tsim.init("verilog")
for i in range(10): for i in range(10):
test_accel() test_accel()
...@@ -35,7 +35,6 @@ module VTAHostDPI # ...@@ -35,7 +35,6 @@ module VTAHostDPI #
import "DPI-C" function void VTAHostDPI import "DPI-C" function void VTAHostDPI
( (
output byte unsigned exit,
output byte unsigned req_valid, output byte unsigned req_valid,
output byte unsigned req_opcode, output byte unsigned req_opcode,
output byte unsigned req_addr, output byte unsigned req_addr,
...@@ -50,7 +49,6 @@ module VTAHostDPI # ...@@ -50,7 +49,6 @@ module VTAHostDPI #
typedef logic [31:0] dpi32_t; typedef logic [31:0] dpi32_t;
dpi1_t __reset; dpi1_t __reset;
dpi8_t __exit;
dpi8_t __req_valid; dpi8_t __req_valid;
dpi8_t __req_opcode; dpi8_t __req_opcode;
dpi8_t __req_addr; dpi8_t __req_addr;
...@@ -80,7 +78,6 @@ module VTAHostDPI # ...@@ -80,7 +78,6 @@ module VTAHostDPI #
// evaluate DPI function // evaluate DPI function
always_ff @(posedge clock) begin always_ff @(posedge clock) begin
if (reset | __reset) begin if (reset | __reset) begin
__exit = 0;
__req_valid = 0; __req_valid = 0;
__req_opcode = 0; __req_opcode = 0;
__req_addr = 0; __req_addr = 0;
...@@ -88,7 +85,6 @@ module VTAHostDPI # ...@@ -88,7 +85,6 @@ module VTAHostDPI #
end end
else begin else begin
VTAHostDPI( VTAHostDPI(
__exit,
__req_valid, __req_valid,
__req_opcode, __req_opcode,
__req_addr, __req_addr,
...@@ -99,21 +95,4 @@ module VTAHostDPI # ...@@ -99,21 +95,4 @@ module VTAHostDPI #
end end
end end
logic [63:0] cycles;
always_ff @(posedge clock) begin
if (reset | __reset) begin
cycles <= 'd0;
end
else begin
cycles <= cycles + 1'b1;
end
end
always_ff @(posedge clock) begin
if (__exit == 'd1) begin
$finish;
end
end
endmodule endmodule
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module VTASimDPI
(
input clock,
input reset,
output logic dpi_wait
);
import "DPI-C" function void VTASimDPI
(
output byte unsigned sim_wait,
output byte unsigned sim_exit
);
typedef logic dpi1_t;
typedef logic [7:0] dpi8_t;
dpi1_t __reset;
dpi8_t __wait;
dpi8_t __exit;
// reset
always_ff @(posedge clock) begin
__reset <= reset;
end
// evaluate DPI function
always_ff @(posedge clock) begin
if (reset | __reset) begin
__wait = 0;
__exit = 0;
end
else begin
VTASimDPI(
__wait,
__exit);
end
end
logic wait_reg;
always_ff @(posedge clock) begin
if (reset | __reset) begin
wait_reg <= 1'b0;
end else if (__wait == 1) begin
wait_reg <= 1'b1;
end else begin
wait_reg <= 1'b0;
end
end
assign dpi_wait = wait_reg;
always_ff @(posedge clock) begin
if (__exit == 1) begin
$finish;
end
end
endmodule
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package vta.dpi
import chisel3._
import chisel3.util._
import vta.util.config._
import vta.interface.axi._
import vta.shell._
/** Sim DPI module.
*
* Wrapper for Sim Verilog DPI module.
*/
class VTASimDPI extends BlackBox with HasBlackBoxResource {
val io = IO(new Bundle {
val clock = Input(Clock())
val reset = Input(Bool())
val dpi_wait = Output(Bool())
})
setResource("/verilog/VTASimDPI.v")
}
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
package vta.shell package vta.shell
import chisel3._ import chisel3._
import chisel3.experimental.MultiIOModule
import vta.util.config._ import vta.util.config._
import vta.interface.axi._ import vta.interface.axi._
import vta.shell._ import vta.shell._
...@@ -61,18 +62,37 @@ class VTAMem(implicit p: Parameters) extends Module { ...@@ -61,18 +62,37 @@ class VTAMem(implicit p: Parameters) extends Module {
mem_axi.io.axi <> io.axi mem_axi.io.axi <> io.axi
} }
/** VTASim.
*
* This module is used to handle hardware simulation thread, such as halting
* or terminating the simulation thread. The sim_wait port is used to halt
* the simulation thread when it is asserted and resume it when it is
* de-asserted.
*/
class VTASim(implicit p: Parameters) extends MultiIOModule {
val sim_wait = IO(Output(Bool()))
val sim = Module(new VTASimDPI)
sim.io.reset := reset
sim.io.clock := clock
sim_wait := sim.io.dpi_wait
}
/** SimShell. /** SimShell.
* *
* The simulation shell instantiate a host and memory simulation modules and it is * The simulation shell instantiate the sim, host and memory DPI modules that
* intended to be connected to the VTAShell. * are connected to the VTAShell. An extra clock, sim_clock, is used to eval
* the VTASim DPI function when the main simulation clock is on halt state.
*/ */
class SimShell(implicit p: Parameters) extends Module { class SimShell(implicit p: Parameters) extends MultiIOModule {
val io = IO(new Bundle { val mem = IO(new AXIClient(p(ShellKey).memParams))
val mem = new AXIClient(p(ShellKey).memParams) val host = IO(new AXILiteMaster(p(ShellKey).hostParams))
val host = new AXILiteMaster(p(ShellKey).hostParams) val sim_clock = IO(Input(Clock()))
}) val sim_wait = IO(Output(Bool()))
val host = Module(new VTAHost) val mod_sim = Module(new VTASim)
val mem = Module(new VTAMem) val mod_host = Module(new VTAHost)
io.mem <> mem.io.axi val mod_mem = Module(new VTAMem)
io.host <> host.io.axi mem <> mod_mem.io.axi
host <> mod_host.io.axi
mod_sim.reset := reset
mod_sim.clock := sim_clock
sim_wait := mod_sim.sim_wait
} }
...@@ -20,14 +20,18 @@ ...@@ -20,14 +20,18 @@
package vta.test package vta.test
import chisel3._ import chisel3._
import chisel3.experimental.MultiIOModule
import vta.util.config._ import vta.util.config._
import vta.shell._ import vta.shell._
/** Test. This generates a testbench file for simulation */ /** Test. This generates a testbench file for simulation */
class Test(implicit p: Parameters) extends Module { class Test(implicit p: Parameters) extends MultiIOModule {
val io = IO(new Bundle {}) val sim_clock = IO(Input(Clock()))
val sim_wait = IO(Output(Bool()))
val sim_shell = Module(new SimShell) val sim_shell = Module(new SimShell)
val vta_shell = Module(new VTAShell) val vta_shell = Module(new VTAShell)
vta_shell.io.host <> sim_shell.io.host sim_shell.sim_clock := sim_clock
sim_shell.io.mem <> vta_shell.io.mem sim_wait := sim_shell.sim_wait
sim_shell.mem <> vta_shell.io.mem
vta_shell.io.host <> sim_shell.host
} }
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
* under the License. * under the License.
*/ */
#include <chrono>
#include <thread>
#include <vta/dpi/tsim.h> #include <vta/dpi/tsim.h>
#if VM_TRACE #if VM_TRACE
...@@ -29,11 +31,17 @@ ...@@ -29,11 +31,17 @@
#endif #endif
static VTAContextHandle _ctx = nullptr; static VTAContextHandle _ctx = nullptr;
static VTAMemDPIFunc _mem_dpi = nullptr; static VTASimDPIFunc _sim_dpi = nullptr;
static VTAHostDPIFunc _host_dpi = nullptr; static VTAHostDPIFunc _host_dpi = nullptr;
static VTAMemDPIFunc _mem_dpi = nullptr;
void VTASimDPI(dpi8_t* wait,
dpi8_t* exit) {
assert(_sim_dpi != nullptr);
(*_sim_dpi)(_ctx, wait, exit);
}
void VTAHostDPI(dpi8_t* exit, void VTAHostDPI(dpi8_t* req_valid,
dpi8_t* req_valid,
dpi8_t* req_opcode, dpi8_t* req_opcode,
dpi8_t* req_addr, dpi8_t* req_addr,
dpi32_t* req_value, dpi32_t* req_value,
...@@ -41,7 +49,7 @@ void VTAHostDPI(dpi8_t* exit, ...@@ -41,7 +49,7 @@ void VTAHostDPI(dpi8_t* exit,
dpi8_t resp_valid, dpi8_t resp_valid,
dpi32_t resp_value) { dpi32_t resp_value) {
assert(_host_dpi != nullptr); assert(_host_dpi != nullptr);
(*_host_dpi)(_ctx, exit, req_valid, req_opcode, (*_host_dpi)(_ctx, req_valid, req_opcode,
req_addr, req_value, req_deq, req_addr, req_value, req_deq,
resp_valid, resp_value); resp_valid, resp_value);
} }
...@@ -63,9 +71,11 @@ void VTAMemDPI(dpi8_t req_valid, ...@@ -63,9 +71,11 @@ void VTAMemDPI(dpi8_t req_valid,
} }
void VTADPIInit(VTAContextHandle handle, void VTADPIInit(VTAContextHandle handle,
VTASimDPIFunc sim_dpi,
VTAHostDPIFunc host_dpi, VTAHostDPIFunc host_dpi,
VTAMemDPIFunc mem_dpi) { VTAMemDPIFunc mem_dpi) {
_ctx = handle; _ctx = handle;
_sim_dpi = sim_dpi;
_host_dpi = host_dpi; _host_dpi = host_dpi;
_mem_dpi = mem_dpi; _mem_dpi = mem_dpi;
} }
...@@ -77,7 +87,7 @@ void vl_finish(const char* filename, int linenum, const char* hier) { ...@@ -77,7 +87,7 @@ void vl_finish(const char* filename, int linenum, const char* hier) {
Verilated::gotFinish(true); Verilated::gotFinish(true);
} }
int VTADPISim(uint64_t max_cycles) { int VTADPISim() {
uint64_t trace_count = 0; uint64_t trace_count = 0;
Verilated::flushCall(); Verilated::flushCall();
Verilated::gotFinish(false); Verilated::gotFinish(false);
...@@ -115,13 +125,15 @@ int VTADPISim(uint64_t max_cycles) { ...@@ -115,13 +125,15 @@ int VTADPISim(uint64_t max_cycles) {
top->reset = 0; top->reset = 0;
// start simulation // start simulation
while (!Verilated::gotFinish() && trace_count < max_cycles) { while (!Verilated::gotFinish()) {
top->sim_clock = 0;
top->clock = 0; top->clock = 0;
top->eval(); top->eval();
#if VM_TRACE #if VM_TRACE
if (trace_count >= start) if (trace_count >= start)
tfp->dump(static_cast<vluint64_t>(trace_count * 2)); tfp->dump(static_cast<vluint64_t>(trace_count * 2));
#endif #endif
top->sim_clock = 1;
top->clock = 1; top->clock = 1;
top->eval(); top->eval();
#if VM_TRACE #if VM_TRACE
...@@ -129,6 +141,14 @@ int VTADPISim(uint64_t max_cycles) { ...@@ -129,6 +141,14 @@ int VTADPISim(uint64_t max_cycles) {
tfp->dump(static_cast<vluint64_t>(trace_count * 2 + 1)); tfp->dump(static_cast<vluint64_t>(trace_count * 2 + 1));
#endif #endif
trace_count++; trace_count++;
while (top->sim_wait) {
top->clock = 0;
std::this_thread::sleep_for(std::chrono::milliseconds(100));
top->sim_clock = 0;
top->eval();
top->sim_clock = 1;
top->eval();
}
} }
#if VM_TRACE #if VM_TRACE
......
...@@ -34,11 +34,17 @@ namespace dpi { ...@@ -34,11 +34,17 @@ namespace dpi {
*/ */
class DPIModuleNode : public tvm::runtime::ModuleNode { class DPIModuleNode : public tvm::runtime::ModuleNode {
public: public:
/*! /*! \brief Launch hardware simulation */
* \brief Launch hardware simulation until accelerator finishes or reach max_cycles virtual void SimLaunch() = 0;
* \param max_cycles The maximum of cycles to wait
*/ /*! \brief Halt hardware simulation */
virtual void Launch(uint64_t max_cycles) = 0; virtual void SimWait() = 0;
/*! \brief Resume hardware simulation */
virtual void SimResume() = 0;
/*! \brief Finish hardware simulation */
virtual void SimFinish() = 0;
/*! /*!
* \brief Write an accelerator register * \brief Write an accelerator register
...@@ -53,13 +59,9 @@ class DPIModuleNode : public tvm::runtime::ModuleNode { ...@@ -53,13 +59,9 @@ class DPIModuleNode : public tvm::runtime::ModuleNode {
*/ */
virtual uint32_t ReadReg(int addr) = 0; virtual uint32_t ReadReg(int addr) = 0;
/*! \brief Finish hardware simulation */
virtual void Finish() = 0;
static tvm::runtime::Module Load(std::string dll_name); static tvm::runtime::Module Load(std::string dll_name);
}; };
} // namespace dpi } // namespace dpi
} // namespace vta } // namespace vta
#endif // VTA_DPI_MODULE_H_ #endif // VTA_DPI_MODULE_H_
...@@ -36,9 +36,13 @@ typedef unsigned long long dpi64_t; // NOLINT(*) ...@@ -36,9 +36,13 @@ typedef unsigned long long dpi64_t; // NOLINT(*)
/*! \brief the context handle */ /*! \brief the context handle */
typedef void* VTAContextHandle; typedef void* VTAContextHandle;
typedef void (*VTASimDPIFunc)(
VTAContextHandle self,
dpi8_t* wait,
dpi8_t* exit);
/*! /*!
* \brief Host DPI callback function that is invoked in VTAHostDPI.v every clock cycle * \brief Host DPI callback function that is invoked in VTAHostDPI.v every clock cycle
* \param exit Host kill simulation
* \param req_valid Host has a valid request for read or write a register in Accel * \param req_valid Host has a valid request for read or write a register in Accel
* \param req_opcode Host request type, opcode=0 for read and opcode=1 for write * \param req_opcode Host request type, opcode=0 for read and opcode=1 for write
* \param req_addr Host request register address * \param req_addr Host request register address
...@@ -50,7 +54,6 @@ typedef void* VTAContextHandle; ...@@ -50,7 +54,6 @@ typedef void* VTAContextHandle;
*/ */
typedef void (*VTAHostDPIFunc)( typedef void (*VTAHostDPIFunc)(
VTAContextHandle self, VTAContextHandle self,
dpi8_t* exit,
dpi8_t* req_valid, dpi8_t* req_valid,
dpi8_t* req_opcode, dpi8_t* req_opcode,
dpi8_t* req_addr, dpi8_t* req_addr,
...@@ -84,28 +87,28 @@ typedef void (*VTAMemDPIFunc)( ...@@ -84,28 +87,28 @@ typedef void (*VTAMemDPIFunc)(
/*! \brief The type of VTADPIInit function pointer */ /*! \brief The type of VTADPIInit function pointer */
typedef void (*VTADPIInitFunc)(VTAContextHandle handle, typedef void (*VTADPIInitFunc)(VTAContextHandle handle,
VTASimDPIFunc sim_dpi,
VTAHostDPIFunc host_dpi, VTAHostDPIFunc host_dpi,
VTAMemDPIFunc mem_dpi); VTAMemDPIFunc mem_dpi);
/*! \brief The type of VTADPISim function pointer */ /*! \brief The type of VTADPISim function pointer */
typedef int (*VTADPISimFunc)(uint64_t max_cycles); typedef int (*VTADPISimFunc)();
/*! /*!
* \brief Set Host and Memory DPI functions * \brief Set Host and Memory DPI functions
* \param handle DPI Context handle * \param handle DPI Context handle
* \param sim_dpi Sim DPI function
* \param host_dpi Host DPI function * \param host_dpi Host DPI function
* \param mem_dpi Memory DPI function * \param mem_dpi Memory DPI function
*/ */
TVM_DLL void VTADPIInit(VTAContextHandle handle, TVM_DLL void VTADPIInit(VTAContextHandle handle,
VTASimDPIFunc sim_dpi,
VTAHostDPIFunc host_dpi, VTAHostDPIFunc host_dpi,
VTAMemDPIFunc mem_dpi); VTAMemDPIFunc mem_dpi);
/*! /*! \brief VTA hardware simulation thread */
* \brief Instantiate VTA design and generate clock/reset TVM_DLL int VTADPISim();
* \param max_cycles The maximum number of simulation cycles
*/
TVM_DLL int VTADPISim(uint64_t max_cycles);
#ifdef __cplusplus #ifdef __cplusplus
} }
......
...@@ -42,7 +42,7 @@ def server_start(): ...@@ -42,7 +42,7 @@ def server_start():
curr_path = os.path.dirname( curr_path = os.path.dirname(
os.path.abspath(os.path.expanduser(__file__))) os.path.abspath(os.path.expanduser(__file__)))
proj_root = os.path.abspath(os.path.join(curr_path, "../../../../")) proj_root = os.path.abspath(os.path.join(curr_path, "../../../../"))
dll_path = find_libvta()[0] dll_path = find_libvta("libvta")[0]
cfg_path = os.path.abspath(os.path.join(proj_root, "build/vta_config.json")) cfg_path = os.path.abspath(os.path.join(proj_root, "build/vta_config.json"))
runtime_dll = [] runtime_dll = []
_load_module = tvm.get_global_func("tvm.rpc.server.load_module") _load_module = tvm.get_global_func("tvm.rpc.server.load_module")
......
...@@ -19,21 +19,48 @@ from __future__ import absolute_import ...@@ -19,21 +19,48 @@ from __future__ import absolute_import
import sys import sys
import os import os
def _get_lib_name(): def _get_lib_name(lib_name):
"""Get lib name with extension
Returns
-------
lib_name_ext : str
Name of VTA shared library with extension
Parameters
------------
lib_name : str
Name of VTA shared library
"""
if sys.platform.startswith('win32'): if sys.platform.startswith('win32'):
return "vta.dll" return lib_name + ".dll"
if sys.platform.startswith('darwin'): if sys.platform.startswith('darwin'):
return "libvta.dylib" return lib_name + ".dylib"
return "libvta.so" return lib_name + ".so"
def find_libvta(lib_vta, optional=False):
"""Find VTA library
Returns
-------
lib_found : str
Library path
Parameters
------------
lib_vta : str
Name of VTA shared library
def find_libvta(optional=False): optional : bool
"""Find VTA library""" Enable error check
"""
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
lib_search = [curr_path] lib_search = [curr_path]
lib_search += [os.path.join(curr_path, "..", "..", "build",)]
lib_search += [os.path.join(curr_path, "..", "..", "..", "build",)] lib_search += [os.path.join(curr_path, "..", "..", "..", "build",)]
lib_search += [os.path.join(curr_path, "..", "..", "..", "build", "Release")] lib_search += [os.path.join(curr_path, "..", "..", "..", "build", "Release")]
lib_name = _get_lib_name() lib_name = _get_lib_name(lib_vta)
lib_path = [os.path.join(x, lib_name) for x in lib_search] lib_path = [os.path.join(x, lib_name) for x in lib_search]
lib_found = [x for x in lib_path if os.path.exists(x)] lib_found = [x for x in lib_path if os.path.exists(x)]
if not lib_found and not optional: if not lib_found and not optional:
......
...@@ -17,22 +17,34 @@ ...@@ -17,22 +17,34 @@
"""Utilities to start simulator.""" """Utilities to start simulator."""
import ctypes import ctypes
import json import json
import sys
import os
import tvm import tvm
from ..environment import get_env
from ..libinfo import find_libvta from ..libinfo import find_libvta
def _load_lib():
"""Load local library, assuming they are simulator.""" def _load_sw():
lib_path = find_libvta(optional=True) """Load software library, assuming they are simulator."""
if not lib_path: lib_sw = find_libvta("libvta", optional=True)
if not lib_sw:
return [] return []
try: try:
return [ctypes.CDLL(lib_path[0], ctypes.RTLD_GLOBAL)] return [ctypes.CDLL(lib_sw[0], ctypes.RTLD_GLOBAL)]
except OSError: except OSError:
return [] return []
def _load_all():
"""Load hardware library for tsim."""
lib = _load_sw()
env = get_env()
if env.TARGET == "tsim":
lib = find_libvta("libvta_hw", optional=True)
f = tvm.get_global_func("vta.tsim.init")
m = tvm.module.load(lib[0], "vta-tsim")
f(m)
return lib
def enabled(): def enabled():
"""Check if simulator is enabled.""" """Check if simulator is enabled."""
f = tvm.get_global_func("vta.simulator.profiler_clear", True) f = tvm.get_global_func("vta.simulator.profiler_clear", True)
...@@ -40,49 +52,31 @@ def enabled(): ...@@ -40,49 +52,31 @@ def enabled():
def clear_stats(): def clear_stats():
"""Clear profiler statistics""" """Clear profiler statistics."""
f = tvm.get_global_func("vta.simulator.profiler_clear", True) env = get_env()
if env.TARGET == "sim":
f = tvm.get_global_func("vta.simulator.profiler_clear", True)
else:
f = tvm.get_global_func("vta.tsim.profiler_clear", True)
if f: if f:
f() f()
def stats(): def stats():
"""Clear profiler statistics """Get profiler statistics
Returns Returns
------- -------
stats : dict stats : dict
Current profiler statistics Current profiler statistics
""" """
x = tvm.get_global_func("vta.simulator.profiler_status")() env = get_env()
if env.TARGET == "sim":
x = tvm.get_global_func("vta.simulator.profiler_status")()
else:
x = tvm.get_global_func("vta.tsim.profiler_status")()
return json.loads(x) return json.loads(x)
def tsim_init(hw_lib):
"""Init hardware shared library for TSIM
Parameters
------------
hw_lib : str
Name of hardware shared library
"""
cur_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
vta_build_path = os.path.join(cur_path, "..", "..", "..", "build")
if not hw_lib.endswith(("dylib", "so")):
hw_lib += ".dylib" if sys.platform == "darwin" else ".so"
lib = os.path.join(vta_build_path, hw_lib)
f = tvm.get_global_func("tvm.vta.tsim.init")
m = tvm.module.load(lib, "vta-tsim")
f(m)
def tsim_cycles():
"""Get tsim clock cycles
Returns
-------
stats : int
tsim clock cycles
"""
return tvm.get_global_func("tvm.vta.tsim.cycles")()
# debug flag to skip execution. # debug flag to skip execution.
DEBUG_SKIP_EXEC = 1 DEBUG_SKIP_EXEC = 1
...@@ -97,4 +91,4 @@ def debug_mode(flag): ...@@ -97,4 +91,4 @@ def debug_mode(flag):
tvm.get_global_func("vta.simulator.profiler_debug_mode")(flag) tvm.get_global_func("vta.simulator.profiler_debug_mode")(flag)
LIBS = _load_lib() LIBS = _load_all()
...@@ -22,6 +22,7 @@ from tvm import rpc, autotvm ...@@ -22,6 +22,7 @@ from tvm import rpc, autotvm
from ..environment import get_env from ..environment import get_env
from . import simulator from . import simulator
def run(run_func): def run(run_func):
"""Run test function on all available env. """Run test function on all available env.
......
...@@ -86,17 +86,28 @@ class ThreadSafeQueue { ...@@ -86,17 +86,28 @@ class ThreadSafeQueue {
std::condition_variable cond_; std::condition_variable cond_;
}; };
class SimDevice {
public:
void Wait();
void Resume();
void Exit();
bool GetWaitStatus();
bool GetExitStatus();
private:
bool wait_{false};
bool exit_{false};
mutable std::mutex mutex_;
};
class HostDevice { class HostDevice {
public: public:
void PushRequest(uint8_t opcode, uint8_t addr, uint32_t value); void PushRequest(uint8_t opcode, uint8_t addr, uint32_t value);
bool TryPopRequest(HostRequest* r, bool pop); bool TryPopRequest(HostRequest* r, bool pop);
void PushResponse(uint32_t value); void PushResponse(uint32_t value);
void WaitPopResponse(HostResponse* r); void WaitPopResponse(HostResponse* r);
void Exit();
uint8_t GetExitStatus();
private: private:
uint8_t exit_{0};
mutable std::mutex mutex_; mutable std::mutex mutex_;
ThreadSafeQueue<HostRequest> req_; ThreadSafeQueue<HostRequest> req_;
ThreadSafeQueue<HostResponse> resp_; ThreadSafeQueue<HostResponse> resp_;
...@@ -116,6 +127,31 @@ class MemDevice { ...@@ -116,6 +127,31 @@ class MemDevice {
std::mutex mutex_; std::mutex mutex_;
}; };
void SimDevice::Wait() {
std::unique_lock<std::mutex> lock(mutex_);
wait_ = true;
}
void SimDevice::Resume() {
std::unique_lock<std::mutex> lock(mutex_);
wait_ = false;
}
void SimDevice::Exit() {
std::unique_lock<std::mutex> lock(mutex_);
exit_ = true;
}
bool SimDevice::GetWaitStatus() {
std::unique_lock<std::mutex> lock(mutex_);
return wait_;
}
bool SimDevice::GetExitStatus() {
std::unique_lock<std::mutex> lock(mutex_);
return exit_;
}
void HostDevice::PushRequest(uint8_t opcode, uint8_t addr, uint32_t value) { void HostDevice::PushRequest(uint8_t opcode, uint8_t addr, uint32_t value) {
HostRequest r; HostRequest r;
r.opcode = opcode; r.opcode = opcode;
...@@ -141,16 +177,6 @@ void HostDevice::WaitPopResponse(HostResponse* r) { ...@@ -141,16 +177,6 @@ void HostDevice::WaitPopResponse(HostResponse* r) {
resp_.WaitPop(r); resp_.WaitPop(r);
} }
void HostDevice::Exit() {
std::unique_lock<std::mutex> lock(mutex_);
exit_ = 1;
}
uint8_t HostDevice::GetExitStatus() {
std::unique_lock<std::mutex> lock(mutex_);
return exit_;
}
void MemDevice::SetRequest(uint8_t opcode, uint64_t addr, uint32_t len) { void MemDevice::SetRequest(uint8_t opcode, uint64_t addr, uint32_t len) {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
if (opcode == 1) { if (opcode == 1) {
...@@ -212,16 +238,29 @@ class DPIModule final : public DPIModuleNode { ...@@ -212,16 +238,29 @@ class DPIModule final : public DPIModuleNode {
VTADPIInitFunc finit = reinterpret_cast<VTADPIInitFunc>( VTADPIInitFunc finit = reinterpret_cast<VTADPIInitFunc>(
GetSymbol("VTADPIInit")); GetSymbol("VTADPIInit"));
CHECK(finit != nullptr); CHECK(finit != nullptr);
finit(this, VTAHostDPI, VTAMemDPI); finit(this, VTASimDPI, VTAHostDPI, VTAMemDPI);
fvsim_ = reinterpret_cast<VTADPISimFunc>(GetSymbol("VTADPISim")); ftsim_ = reinterpret_cast<VTADPISimFunc>(GetSymbol("VTADPISim"));
CHECK(fvsim_ != nullptr); CHECK(ftsim_ != nullptr);
} }
void Launch(uint64_t max_cycles) { void SimLaunch() {
auto frun = [this, max_cycles]() { auto frun = [this]() {
(*fvsim_)(max_cycles); (*ftsim_)();
}; };
vsim_thread_ = std::thread(frun); tsim_thread_ = std::thread(frun);
}
void SimWait() {
sim_device_.Wait();
}
void SimResume() {
sim_device_.Resume();
}
void SimFinish() {
sim_device_.Exit();
tsim_thread_.join();
} }
void WriteReg(int addr, uint32_t value) { void WriteReg(int addr, uint32_t value) {
...@@ -238,19 +277,20 @@ class DPIModule final : public DPIModuleNode { ...@@ -238,19 +277,20 @@ class DPIModule final : public DPIModuleNode {
return value; return value;
} }
void Finish() {
host_device_.Exit();
vsim_thread_.join();
}
protected: protected:
VTADPISimFunc fvsim_; VTADPISimFunc ftsim_;
SimDevice sim_device_;
HostDevice host_device_; HostDevice host_device_;
MemDevice mem_device_; MemDevice mem_device_;
std::thread vsim_thread_; std::thread tsim_thread_;
void HostDPI(dpi8_t* exit, void SimDPI(dpi8_t* wait,
dpi8_t* req_valid, dpi8_t* exit) {
*wait = sim_device_.GetWaitStatus();
*exit = sim_device_.GetExitStatus();
}
void HostDPI(dpi8_t* req_valid,
dpi8_t* req_opcode, dpi8_t* req_opcode,
dpi8_t* req_addr, dpi8_t* req_addr,
dpi32_t* req_value, dpi32_t* req_value,
...@@ -258,7 +298,6 @@ class DPIModule final : public DPIModuleNode { ...@@ -258,7 +298,6 @@ class DPIModule final : public DPIModuleNode {
dpi8_t resp_valid, dpi8_t resp_valid,
dpi32_t resp_value) { dpi32_t resp_value) {
HostRequest* r = new HostRequest; HostRequest* r = new HostRequest;
*exit = host_device_.GetExitStatus();
*req_valid = host_device_.TryPopRequest(r, req_deq); *req_valid = host_device_.TryPopRequest(r, req_deq);
*req_opcode = r->opcode; *req_opcode = r->opcode;
*req_addr = r->addr; *req_addr = r->addr;
...@@ -290,9 +329,16 @@ class DPIModule final : public DPIModuleNode { ...@@ -290,9 +329,16 @@ class DPIModule final : public DPIModuleNode {
} }
} }
static void VTASimDPI(
VTAContextHandle self,
dpi8_t* wait,
dpi8_t* exit) {
static_cast<DPIModule*>(self)->SimDPI(
wait, exit);
}
static void VTAHostDPI( static void VTAHostDPI(
VTAContextHandle self, VTAContextHandle self,
dpi8_t* exit,
dpi8_t* req_valid, dpi8_t* req_valid,
dpi8_t* req_opcode, dpi8_t* req_opcode,
dpi8_t* req_addr, dpi8_t* req_addr,
...@@ -301,7 +347,7 @@ class DPIModule final : public DPIModuleNode { ...@@ -301,7 +347,7 @@ class DPIModule final : public DPIModuleNode {
dpi8_t resp_valid, dpi8_t resp_valid,
dpi32_t resp_value) { dpi32_t resp_value) {
static_cast<DPIModule*>(self)->HostDPI( static_cast<DPIModule*>(self)->HostDPI(
exit, req_valid, req_opcode, req_addr, req_valid, req_opcode, req_addr,
req_value, req_deq, resp_valid, resp_value); req_value, req_deq, resp_valid, resp_value);
} }
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......
...@@ -17,32 +17,78 @@ ...@@ -17,32 +17,78 @@
* under the License. * under the License.
*/ */
#include <vta/driver.h>
#include <tvm/runtime/module.h> #include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <vta/driver.h>
#include <vta/dpi/module.h> #include <vta/dpi/module.h>
namespace vta { namespace vta {
namespace tsim { namespace tsim {
using vta::dpi::DPIModuleNode;
using tvm::runtime::Module; using tvm::runtime::Module;
using vta::dpi::DPIModuleNode;
class Profiler { class Profiler {
public: public:
/*! \brief cycle counter */ Profiler() {
uint64_t cycle_count{0}; counters_ = new int[num_counters_];
this->ClearAll();
}
~Profiler() {
delete [] counters_;
}
/*! \brief update one event counter */
void Update(uint32_t idx, uint32_t value) {
counters_[idx] += value;
}
/*! \brief clear one event counter*/
void Clear(uint32_t idx) {
counters_[idx] = 0;
}
/*! \brief clear all event counters */
void ClearAll() {
for (uint32_t i = 0; i < num_counters_; i++) {
counters_[i] = 0;
}
}
/*! \brief return counters as json */
std::string AsJSON() {
std::ostringstream os;
os << "{\n"
<< " \"cycle_count\":" << counters_[0] << "\n"
<<"}\n";
return os.str();
}
static Profiler* Global() { static Profiler* Global() {
static Profiler inst; static Profiler inst;
return &inst; return &inst;
} }
private:
/*! \brief total number of event counters */
uint32_t num_counters_{1};
/*! \brief event counters */
int* counters_{nullptr};
}; };
class DPILoader { class DPILoader {
public: public:
~DPILoader() {
dpi_->SimResume();
dpi_->SimFinish();
}
void Init(Module module) { void Init(Module module) {
mod_ = module; mod_ = module;
dpi_ = this->Get();
dpi_->SimLaunch();
dpi_->SimWait();
} }
DPIModuleNode* Get() { DPIModuleNode* Get() {
...@@ -54,13 +100,16 @@ class DPILoader { ...@@ -54,13 +100,16 @@ class DPILoader {
return &inst; return &inst;
} }
// TVM module
Module mod_; Module mod_;
// DPI Module
DPIModuleNode* dpi_{nullptr};
}; };
class Device { class Device {
public: public:
Device() { Device() {
dpi_ = DPILoader::Global(); loader_ = DPILoader::Global();
prof_ = Profiler::Global(); prof_ = Profiler::Global();
} }
...@@ -82,13 +131,13 @@ class Device { ...@@ -82,13 +131,13 @@ class Device {
insn_count, insn_count,
wait_cycles); wait_cycles);
this->WaitForCompletion(wait_cycles); this->WaitForCompletion(wait_cycles);
dev_->Finish();
return 0; return 0;
} }
private: private:
void Init() { void Init() {
dev_ = dpi_->Get(); dpi_ = loader_->Get();
dpi_->SimResume();
} }
void Launch(vta_phy_addr_t insn_phy_addr, void Launch(vta_phy_addr_t insn_phy_addr,
...@@ -99,57 +148,60 @@ class Device { ...@@ -99,57 +148,60 @@ class Device {
vta_phy_addr_t out_phy_addr, vta_phy_addr_t out_phy_addr,
uint32_t insn_count, uint32_t insn_count,
uint32_t wait_cycles) { uint32_t wait_cycles) {
// launch simulation thread dpi_->WriteReg(0x04, 0);
dev_->Launch(wait_cycles); dpi_->WriteReg(0x08, insn_count);
// set counter to zero dpi_->WriteReg(0x0c, insn_phy_addr);
dev_->WriteReg(0x04, 0); dpi_->WriteReg(0x10, insn_phy_addr >> 32);
dev_->WriteReg(0x08, insn_count); dpi_->WriteReg(0x14, 0);
dev_->WriteReg(0x0c, insn_phy_addr); dpi_->WriteReg(0x18, uop_phy_addr >> 32);
dev_->WriteReg(0x10, insn_phy_addr >> 32); dpi_->WriteReg(0x1c, 0);
dev_->WriteReg(0x14, 0); dpi_->WriteReg(0x20, inp_phy_addr >> 32);
dev_->WriteReg(0x18, uop_phy_addr >> 32); dpi_->WriteReg(0x24, 0);
dev_->WriteReg(0x1c, 0); dpi_->WriteReg(0x28, wgt_phy_addr >> 32);
dev_->WriteReg(0x20, inp_phy_addr >> 32); dpi_->WriteReg(0x2c, 0);
dev_->WriteReg(0x24, 0); dpi_->WriteReg(0x30, acc_phy_addr >> 32);
dev_->WriteReg(0x28, wgt_phy_addr >> 32); dpi_->WriteReg(0x34, 0);
dev_->WriteReg(0x2c, 0); dpi_->WriteReg(0x38, out_phy_addr >> 32);
dev_->WriteReg(0x30, acc_phy_addr >> 32);
dev_->WriteReg(0x34, 0);
dev_->WriteReg(0x38, out_phy_addr >> 32);
// start // start
dev_->WriteReg(0x00, 0x1); dpi_->WriteReg(0x00, 0x1);
} }
void WaitForCompletion(uint32_t wait_cycles) { void WaitForCompletion(uint32_t wait_cycles) {
uint32_t i, val; uint32_t i, val;
for (i = 0; i < wait_cycles; i++) { for (i = 0; i < wait_cycles; i++) {
val = dev_->ReadReg(0x00); val = dpi_->ReadReg(0x00);
val &= 0x2; val &= 0x2;
if (val == 0x2) break; // finish if (val == 0x2) break; // finish
} }
prof_->cycle_count = dev_->ReadReg(0x04); prof_->Update(0, dpi_->ReadReg(0x04));
dpi_->SimWait();
} }
// Profiler // Profiler
Profiler* prof_; Profiler* prof_;
// DPI loader // DPI loader
DPILoader* dpi_; DPILoader* loader_;
// DPI Module // DPI Module
DPIModuleNode* dev_; DPIModuleNode* dpi_;
}; };
using tvm::runtime::TVMRetValue; using tvm::runtime::TVMRetValue;
using tvm::runtime::TVMArgs; using tvm::runtime::TVMArgs;
TVM_REGISTER_GLOBAL("tvm.vta.tsim.init") TVM_REGISTER_GLOBAL("vta.tsim.init")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
Module m = args[0]; Module m = args[0];
DPILoader::Global()->Init(m); DPILoader::Global()->Init(m);
}); });
TVM_REGISTER_GLOBAL("tvm.vta.tsim.cycles") TVM_REGISTER_GLOBAL("vta.tsim.profiler_clear")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Profiler::Global()->ClearAll();
});
TVM_REGISTER_GLOBAL("vta.tsim.profiler_status")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = static_cast<int>(Profiler::Global()->cycle_count); *rv = Profiler::Global()->AsJSON();
}); });
} // namespace tsim } // namespace tsim
......
...@@ -18,6 +18,7 @@ import tvm ...@@ -18,6 +18,7 @@ import tvm
import numpy as np import numpy as np
from tvm.contrib import util from tvm.contrib import util
import vta.testing import vta.testing
from vta.testing import simulator
def test_gemm(): def test_gemm():
...@@ -104,7 +105,14 @@ def test_gemm(): ...@@ -104,7 +105,14 @@ def test_gemm():
res_ref = np.right_shift(res_ref, 8) res_ref = np.right_shift(res_ref, 8)
res_ref = np.clip(res_ref, 0, (1<<(env.INP_WIDTH-1))-1).astype(res.dtype) res_ref = np.clip(res_ref, 0, (1<<(env.INP_WIDTH-1))-1).astype(res.dtype)
time_f = f.time_evaluator("gemm", ctx, number=20) time_f = f.time_evaluator("gemm", ctx, number=20)
if env.TARGET in ["sim", "tsim"]:
simulator.clear_stats()
cost = time_f(data_arr, weight_arr, res_arr) cost = time_f(data_arr, weight_arr, res_arr)
if env.TARGET in ["sim", "tsim"]:
stats = simulator.stats()
print("Execution statistics:")
for k, v in stats.items():
print("\t{:<16}: {:>16}".format(k, v))
res_unpack = res_arr.asnumpy().reshape(batch_size // env.BATCH, res_unpack = res_arr.asnumpy().reshape(batch_size // env.BATCH,
channel // env.BLOCK_OUT, channel // env.BLOCK_OUT,
env.BATCH, env.BATCH,
......
...@@ -173,14 +173,20 @@ def run_conv2d(env, remote, wl, target, ...@@ -173,14 +173,20 @@ def run_conv2d(env, remote, wl, target,
# In vta sim mode, collect simulator runtime statistics # In vta sim mode, collect simulator runtime statistics
stats = {} stats = {}
cost = None cost = None
if env.TARGET == "sim": if env.TARGET in ["sim", "tsim"]:
# Check if we're in local RPC mode (allows us to rebuild the # Check if we're in local RPC mode (allows us to rebuild the
# runtime on the fly when varying the VTA designs) # runtime on the fly when varying the VTA designs)
local_rpc = int(os.environ.get("VTA_LOCAL_SIM_RPC", "0")) local_rpc = int(os.environ.get("VTA_LOCAL_SIM_RPC", "0"))
if local_rpc: if local_rpc:
remote.get_function("vta.simulator.profiler_clear")() if env.TARGET == "sim":
remote.get_function("vta.simulator.profiler_clear")()
else:
remote.get_function("vta.tsim.profiler_clear")()
cost = time_f(data_arr, kernel_arr, bias_arr, res_arr) cost = time_f(data_arr, kernel_arr, bias_arr, res_arr)
stats = json.loads(remote.get_function("vta.simulator.profiler_status")()) if env.TARGET == "sim":
stats = json.loads(remote.get_function("vta.simulator.profiler_status")())
else:
stats = json.loads(remote.get_function("vta.tsim.profiler_status")())
else: else:
simulator.clear_stats() simulator.clear_stats()
cost = time_f(data_arr, kernel_arr, bias_arr, res_arr) cost = time_f(data_arr, kernel_arr, bias_arr, res_arr)
...@@ -215,7 +221,7 @@ def test_conv2d(device="vta"): ...@@ -215,7 +221,7 @@ def test_conv2d(device="vta"):
def _run(env, remote): def _run(env, remote):
if device == "vta": if device == "vta":
target = env.target target = env.target
if env.TARGET != "sim": if env.TARGET not in ["sim", "tsim"]:
assert tvm.module.enabled("rpc") assert tvm.module.enabled("rpc")
program_fpga(remote, bitstream=None) program_fpga(remote, bitstream=None)
reconfig_runtime(remote) reconfig_runtime(remote)
......
...@@ -131,14 +131,20 @@ def run_gemm(env, remote, target, ...@@ -131,14 +131,20 @@ def run_gemm(env, remote, target,
# In vta sim mode, collect simulator runtime statistics # In vta sim mode, collect simulator runtime statistics
stats = {} stats = {}
cost = None cost = None
if env.TARGET == "sim": if env.TARGET in ["sim", "tsim"]:
# Check if we're in local RPC mode (allows us to rebuild the # Check if we're in local RPC mode (allows us to rebuild the
# runtime on the fly when varying the VTA designs) # runtime on the fly when varying the VTA designs)
local_rpc = int(os.environ.get("VTA_LOCAL_SIM_RPC", "0")) local_rpc = int(os.environ.get("VTA_LOCAL_SIM_RPC", "0"))
if local_rpc: if local_rpc:
remote.get_function("vta.simulator.profiler_clear")() if env.TARGET == "sim":
remote.get_function("vta.simulator.profiler_clear")()
else:
remote.get_function("vta.tsim.profiler_clear")()
cost = time_f(data_arr, kernel_arr, res_arr) cost = time_f(data_arr, kernel_arr, res_arr)
stats = json.loads(remote.get_function("vta.simulator.profiler_status")()) if env.TARGET == "sim":
stats = json.loads(remote.get_function("vta.simulator.profiler_status")())
else:
stats = json.loads(remote.get_function("vta.tsim.profiler_status")())
else: else:
simulator.clear_stats() simulator.clear_stats()
cost = time_f(data_arr, kernel_arr, res_arr) cost = time_f(data_arr, kernel_arr, res_arr)
...@@ -171,7 +177,7 @@ def test_gemm(device="vta", batch=128, in_feat=128, out_feat=128): ...@@ -171,7 +177,7 @@ def test_gemm(device="vta", batch=128, in_feat=128, out_feat=128):
def _run(env, remote): def _run(env, remote):
if device == "vta": if device == "vta":
target = env.target target = env.target
if env.TARGET != "sim": if env.TARGET not in ["sim", "tsim"]:
assert tvm.module.enabled("rpc") assert tvm.module.enabled("rpc")
program_fpga(remote, bitstream=None) program_fpga(remote, bitstream=None)
reconfig_runtime(remote) reconfig_runtime(remote)
......
...@@ -69,15 +69,18 @@ def test_save_load_out(): ...@@ -69,15 +69,18 @@ def test_save_load_out():
x_nd = tvm.nd.array(x_np, ctx) x_nd = tvm.nd.array(x_np, ctx)
y_nd = tvm.nd.empty(y_np.shape, ctx=ctx, dtype=y_np.dtype) y_nd = tvm.nd.empty(y_np.shape, ctx=ctx, dtype=y_np.dtype)
if env.TARGET == "tsim": if env.TARGET in ["sim", "tsim"]:
simulator.tsim_init("libvta_hw") simulator.clear_stats()
f(x_nd, y_nd) f(x_nd, y_nd)
np.testing.assert_equal(y_np, y_nd.asnumpy()) np.testing.assert_equal(y_np, y_nd.asnumpy())
if env.TARGET == "tsim": if env.TARGET in ["sim", "tsim"]:
print("Load/store test took {} clock cycles".format(simulator.tsim_cycles())) sim_stats = simulator.stats()
print("Save load execution statistics:")
for k, v in sim_stats.items():
print("\t{:<16}: {:>16}".format(k, v))
vta.testing.run(_run) vta.testing.run(_run)
...@@ -135,15 +138,18 @@ def test_padded_load(): ...@@ -135,15 +138,18 @@ def test_padded_load():
x_nd = tvm.nd.array(x_np, ctx) x_nd = tvm.nd.array(x_np, ctx)
y_nd = tvm.nd.empty(y_np.shape, ctx=ctx, dtype=y_np.dtype) y_nd = tvm.nd.empty(y_np.shape, ctx=ctx, dtype=y_np.dtype)
if env.TARGET == "tsim": if env.TARGET in ["sim", "tsim"]:
simulator.tsim_init("libvta_hw") simulator.clear_stats()
f(x_nd, y_nd) f(x_nd, y_nd)
np.testing.assert_equal(y_np, y_nd.asnumpy()) np.testing.assert_equal(y_np, y_nd.asnumpy())
if env.TARGET == "tsim": if env.TARGET in ["sim", "tsim"]:
print("Padded load test took {} clock cycles".format(simulator.tsim_cycles())) sim_stats = simulator.stats()
print("Padded load execution statistics:")
for k, v in sim_stats.items():
print("\t{:<16}: {:>16}".format(k, v))
vta.testing.run(_run) vta.testing.run(_run)
...@@ -213,20 +219,18 @@ def test_gemm(): ...@@ -213,20 +219,18 @@ def test_gemm():
y_np = np.right_shift(y_np, 8) y_np = np.right_shift(y_np, 8)
y_np = np.clip(y_np, 0, (1<<(env.INP_WIDTH-1))-1).astype(y.dtype) y_np = np.clip(y_np, 0, (1<<(env.INP_WIDTH-1))-1).astype(y.dtype)
if env.TARGET == "tsim": if env.TARGET in ["sim", "tsim"]:
simulator.tsim_init("libvta_hw")
if env.TARGET == "sim":
simulator.clear_stats() simulator.clear_stats()
f(x_nd, w_nd, y_nd)
print(simulator.stats()) f(x_nd, w_nd, y_nd)
else:
f(x_nd, w_nd, y_nd)
np.testing.assert_equal(y_np, y_nd.asnumpy()) np.testing.assert_equal(y_np, y_nd.asnumpy())
if env.TARGET == "tsim": if env.TARGET in ["sim", "tsim"]:
print("GEMM schedule:{} test took {} clock cycles".format(name, simulator.tsim_cycles())) sim_stats = simulator.stats()
print("GEMM schedule:{} execution statistics:".format(name))
for k, v in sim_stats.items():
print("\t{:<16}: {:>16}".format(k, v))
def test_schedule1(): def test_schedule1():
# default schedule with no smt # default schedule with no smt
...@@ -374,8 +378,8 @@ def test_alu(): ...@@ -374,8 +378,8 @@ def test_alu():
res_nd = tvm.nd.array( res_nd = tvm.nd.array(
np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx) np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx)
if env.TARGET == "tsim": if env.TARGET in ["sim", "tsim"]:
simulator.tsim_init("libvta_hw") simulator.clear_stats()
if use_imm: if use_imm:
f(a_nd, res_nd) f(a_nd, res_nd)
...@@ -385,8 +389,11 @@ def test_alu(): ...@@ -385,8 +389,11 @@ def test_alu():
np.testing.assert_equal(res_np, res_nd.asnumpy()) np.testing.assert_equal(res_np, res_nd.asnumpy())
if env.TARGET == "tsim": if env.TARGET in ["sim", "tsim"]:
print("ALU {} imm:{} test took {} clock cycles".format(test_name, use_imm, simulator.tsim_cycles())) sim_stats = simulator.stats()
print("ALU {} execution statistics:".format(test_name))
for k, v in sim_stats.items():
print("\t{:<16}: {:>16}".format(k, v))
check_alu(lambda x, y: x << y, np.left_shift, use_imm=True, test_name="SHL") check_alu(lambda x, y: x << y, np.left_shift, use_imm=True, test_name="SHL")
check_alu(tvm.max, np.maximum, use_imm=True, test_name="MAX") check_alu(tvm.max, np.maximum, use_imm=True, test_name="MAX")
...@@ -451,15 +458,18 @@ def test_relu(): ...@@ -451,15 +458,18 @@ def test_relu():
res_nd = tvm.nd.array( res_nd = tvm.nd.array(
np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx) np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx)
if env.TARGET == "tsim": if env.TARGET in ["sim", "tsim"]:
simulator.tsim_init("libvta_hw") simulator.clear_stats()
f(a_nd, res_nd) f(a_nd, res_nd)
np.testing.assert_equal(res_np, res_nd.asnumpy()) np.testing.assert_equal(res_np, res_nd.asnumpy())
if env.TARGET == "tsim": if env.TARGET in ["sim", "tsim"]:
print("Relu test took {} clock cycles".format(simulator.tsim_cycles())) sim_stats = simulator.stats()
print("Relu execution statistics:")
for k, v in sim_stats.items():
print("\t{:<16}: {:>16}".format(k, v))
vta.testing.run(_run) vta.testing.run(_run)
...@@ -518,15 +528,18 @@ def test_shift_and_scale(): ...@@ -518,15 +528,18 @@ def test_shift_and_scale():
res_nd = tvm.nd.array( res_nd = tvm.nd.array(
np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx) np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx)
if env.TARGET == "tsim": if env.TARGET in ["sim", "tsim"]:
simulator.tsim_init("libvta_hw") simulator.clear_stats()
f(a_nd, res_nd) f(a_nd, res_nd)
np.testing.assert_equal(res_np, res_nd.asnumpy()) np.testing.assert_equal(res_np, res_nd.asnumpy())
if env.TARGET == "tsim": if env.TARGET in ["sim", "tsim"]:
print("Shift/scale test took {} clock cycles".format(simulator.tsim_cycles())) sim_stats = simulator.stats()
print("Shift and scale execution statistics:")
for k, v in sim_stats.items():
print("\t{:<16}: {:>16}".format(k, v))
vta.testing.run(_run) vta.testing.run(_run)
......
...@@ -89,7 +89,7 @@ stop_pack="nn.global_avg_pool2d" ...@@ -89,7 +89,7 @@ stop_pack="nn.global_avg_pool2d"
# When target is 'pynq', reconfigure FPGA and runtime. # When target is 'pynq', reconfigure FPGA and runtime.
# Otherwise, if target is 'sim', execute locally. # Otherwise, if target is 'sim', execute locally.
if env.TARGET != "sim": if env.TARGET not in ["sim", "tsim"]:
# Get remote from tracker node if environment variable is set. # Get remote from tracker node if environment variable is set.
# To set up the tracker, you'll need to follow the "Auto-tuning # To set up the tracker, you'll need to follow the "Auto-tuning
...@@ -235,7 +235,7 @@ num = 4 # number of times we run module for a single measurement ...@@ -235,7 +235,7 @@ num = 4 # number of times we run module for a single measurement
rep = 3 # number of measurements (we derive std dev from this) rep = 3 # number of measurements (we derive std dev from this)
timer = m.module.time_evaluator("run", ctx, number=num, repeat=rep) timer = m.module.time_evaluator("run", ctx, number=num, repeat=rep)
if env.TARGET == "sim": if env.TARGET in ["sim", "tsim"]:
simulator.clear_stats() simulator.clear_stats()
timer() timer()
sim_stats = simulator.stats() sim_stats = simulator.stats()
......
...@@ -66,7 +66,7 @@ if env.TARGET == "pynq": ...@@ -66,7 +66,7 @@ if env.TARGET == "pynq":
vta.program_fpga(remote, bitstream=None) vta.program_fpga(remote, bitstream=None)
# In simulation mode, host the RPC server locally. # In simulation mode, host the RPC server locally.
elif env.TARGET == "sim": elif env.TARGET in ["sim", "tsim"]:
remote = rpc.LocalSession() remote = rpc.LocalSession()
###################################################################### ######################################################################
...@@ -437,6 +437,10 @@ A_nd = tvm.nd.array(A_packed, ctx) ...@@ -437,6 +437,10 @@ A_nd = tvm.nd.array(A_packed, ctx)
B_nd = tvm.nd.array(B_packed, ctx) B_nd = tvm.nd.array(B_packed, ctx)
C_nd = tvm.nd.array(np.zeros((o, m, env.BATCH, env.BLOCK_OUT)).astype(C.dtype), ctx) C_nd = tvm.nd.array(np.zeros((o, m, env.BATCH, env.BLOCK_OUT)).astype(C.dtype), ctx)
# Clear stats
if env.TARGET in ["sim", "tsim"]:
simulator.clear_stats()
# Invoke the module to perform the computation # Invoke the module to perform the computation
f(A_nd, B_nd, C_nd) f(A_nd, B_nd, C_nd)
...@@ -452,8 +456,15 @@ C_ref = np.dot(A_orig.astype(env.acc_dtype), ...@@ -452,8 +456,15 @@ C_ref = np.dot(A_orig.astype(env.acc_dtype),
C_ref = C_ref.reshape( C_ref = C_ref.reshape(
o, env.BATCH, m, env.BLOCK_OUT).transpose((0, 2, 1, 3)) o, env.BATCH, m, env.BLOCK_OUT).transpose((0, 2, 1, 3))
np.testing.assert_equal(C_ref, C_nd.asnumpy()) np.testing.assert_equal(C_ref, C_nd.asnumpy())
print("Successful matrix multiply test!")
# Print stats
if env.TARGET in ["sim", "tsim"]:
sim_stats = simulator.stats()
print("Execution statistics:")
for k, v in sim_stats.items():
print("\t{:<16}: {:>16}".format(k, v))
print("Successful matrix multiply test!")
###################################################################### ######################################################################
# Summary # Summary
......
...@@ -70,7 +70,7 @@ if env.TARGET == "pynq": ...@@ -70,7 +70,7 @@ if env.TARGET == "pynq":
vta.program_fpga(remote, bitstream=None) vta.program_fpga(remote, bitstream=None)
# In simulation mode, host the RPC server locally. # In simulation mode, host the RPC server locally.
elif env.TARGET == "sim": elif env.TARGET in ["sim", "tsim"]:
remote = rpc.LocalSession() remote = rpc.LocalSession()
###################################################################### ######################################################################
...@@ -412,6 +412,10 @@ data_nd = tvm.nd.array(data_packed, ctx) ...@@ -412,6 +412,10 @@ data_nd = tvm.nd.array(data_packed, ctx)
kernel_nd = tvm.nd.array(kernel_packed, ctx) kernel_nd = tvm.nd.array(kernel_packed, ctx)
res_nd = tvm.nd.array(np.zeros(output_shape).astype(res.dtype), ctx) res_nd = tvm.nd.array(np.zeros(output_shape).astype(res.dtype), ctx)
# Clear stats
if env.TARGET in ["sim", "tsim"]:
simulator.clear_stats()
# Invoke the module to perform the computation # Invoke the module to perform the computation
f(data_nd, kernel_nd, res_nd) f(data_nd, kernel_nd, res_nd)
...@@ -430,6 +434,14 @@ res_ref = res_ref.reshape((batch_size // env.BATCH, ...@@ -430,6 +434,14 @@ res_ref = res_ref.reshape((batch_size // env.BATCH,
fout_height, fout_height,
fout_width)).transpose((0, 2, 4, 5, 1, 3)) fout_width)).transpose((0, 2, 4, 5, 1, 3))
tvm.testing.assert_allclose(res_ref, res_nd.asnumpy()) tvm.testing.assert_allclose(res_ref, res_nd.asnumpy())
# Print stats
if env.TARGET in ["sim", "tsim"]:
sim_stats = simulator.stats()
print("Execution statistics:")
for k, v in sim_stats.items():
print("\t{:<16}: {:>16}".format(k, v))
print("Successful 2D convolution test!") print("Successful 2D convolution test!")
###################################################################### ######################################################################
......
...@@ -69,7 +69,7 @@ if env.TARGET == "pynq": ...@@ -69,7 +69,7 @@ if env.TARGET == "pynq":
vta.program_fpga(remote, bitstream=None) vta.program_fpga(remote, bitstream=None)
# In simulation mode, host the RPC server locally. # In simulation mode, host the RPC server locally.
elif env.TARGET == "sim": elif env.TARGET in ["sim", "tsim"]:
remote = rpc.LocalSession() remote = rpc.LocalSession()
###################################################################### ######################################################################
...@@ -352,6 +352,10 @@ data_nd = tvm.nd.array(data_packed, ctx) ...@@ -352,6 +352,10 @@ data_nd = tvm.nd.array(data_packed, ctx)
weight_nd = tvm.nd.array(weight_packed, ctx) weight_nd = tvm.nd.array(weight_packed, ctx)
res_nd = tvm.nd.array(np.zeros(output_shape).astype(res.dtype), ctx) res_nd = tvm.nd.array(np.zeros(output_shape).astype(res.dtype), ctx)
# Clear stats
if env.TARGET in ["sim", "tsim"]:
simulator.clear_stats()
# Invoke the module to perform the computation # Invoke the module to perform the computation
f(data_nd, weight_nd, res_nd) f(data_nd, weight_nd, res_nd)
...@@ -366,6 +370,14 @@ res_ref = res_ref.reshape(batch_size // env.BATCH, ...@@ -366,6 +370,14 @@ res_ref = res_ref.reshape(batch_size // env.BATCH,
out_channels // env.BLOCK_OUT, out_channels // env.BLOCK_OUT,
env.BLOCK_OUT).transpose((0, 2, 1, 3)) env.BLOCK_OUT).transpose((0, 2, 1, 3))
np.testing.assert_equal(res_ref, res_nd.asnumpy()) np.testing.assert_equal(res_ref, res_nd.asnumpy())
# Print stats
if env.TARGET in ["sim", "tsim"]:
sim_stats = simulator.stats()
print("Execution statistics:")
for k, v in sim_stats.items():
print("\t{:<16}: {:>16}".format(k, v))
print("Successful blocked matrix multiply test!") print("Successful blocked matrix multiply test!")
###################################################################### ######################################################################
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment