Commit a31dd162 by Luis Vega Committed by Thierry Moreau

[VTA] TSIM improvements and fixes (#3505)

* add tsim init function

* add sim device

* test wait and resume

* launch simulation thread from DPILoader

* add VTASimDPI module to handle all simulation related stuff

* test tsim init

* move exit to simdpi module

* update vta driver

* add chisel DPI module

* get back simshell

* update vta to support dpi sim

* update unittests

* add tsim to integration-conv2d test

* run resnet on tsim

* remove max-cycles

* match tsim counters with sim counters

* use env in simulator to switch between sim and tsim

* update unittest

* rollback conv2d test

* update resnet

* add stats to matrix multiply

* add stats

* print stats after assert

* update other tests

* add stats to gemm

* add return and remove unused libs

* add missing arg

* return lib

* update comments for linter

* add more comments to VTASimDPI module

* remove trailing spaces

* remove trailing spaces
parent 2a7aebe5
......@@ -20,6 +20,7 @@
package test
import chisel3._
import chisel3.experimental.MultiIOModule
import vta.dpi._
import accel._
......@@ -28,19 +29,23 @@ import accel._
* Instantiate Host and Memory DPI modules.
*
*/
class VTASimShell extends Module {
val io = IO(new Bundle {
val host = new VTAHostDPIMaster
val mem = new VTAMemDPIClient
})
val host = Module(new VTAHostDPI)
val mem = Module(new VTAMemDPI)
mem.io.dpi <> io.mem
mem.io.reset := reset
mem.io.clock := clock
io.host <> host.io.dpi
host.io.reset := reset
host.io.clock := clock
class VTASimShell extends MultiIOModule {
val host = IO(new VTAHostDPIMaster)
val mem = IO(new VTAMemDPIClient)
val sim_clock = IO(Input(Clock()))
val sim_wait = IO(Output(Bool()))
val mod_sim = Module(new VTASimDPI)
val mod_host = Module(new VTAHostDPI)
val mod_mem = Module(new VTAMemDPI)
mod_mem.io.clock := clock
mod_mem.io.reset := reset
mod_mem.io.dpi <> mem
mod_host.io.clock := clock
mod_host.io.reset := reset
host <> mod_host.io.dpi
mod_sim.io.clock := sim_clock
mod_sim.io.reset := reset
sim_wait := mod_sim.io.dpi_wait
}
/** Test accelerator.
......@@ -48,12 +53,15 @@ class VTASimShell extends Module {
* Instantiate and connect the simulation-shell and the accelerator.
*
*/
class TestAccel extends Module {
val io = IO(new Bundle {})
class TestAccel extends MultiIOModule {
val sim_clock = IO(Input(Clock()))
val sim_wait = IO(Output(Bool()))
val sim_shell = Module(new VTASimShell)
val vta_accel = Module(new Accel)
vta_accel.io.host <> sim_shell.io.host
sim_shell.io.mem <> vta_accel.io.mem
sim_shell.sim_clock := sim_clock
sim_wait := sim_shell.sim_wait
sim_shell.mem <> vta_accel.io.mem
vta_accel.io.host <> sim_shell.host
}
/** Generate TestAccel as top module */
......
......@@ -25,7 +25,9 @@
module TestAccel
(
input clock,
input reset
input reset,
input sim_clock,
output sim_wait
);
localparam HOST_ADDR_BITS = 8;
......@@ -53,6 +55,14 @@ module TestAccel
logic [MEM_DATA_BITS-1:0] mem_rd_bits;
logic mem_rd_ready;
VTASimDPI sim
(
.clock (sim_clock),
.reset (reset),
.dpi_wait (sim_wait)
);
VTAHostDPI host
(
.clock (clock),
......@@ -114,4 +124,5 @@ module TestAccel
.mem_rd_bits (mem_rd_bits),
.mem_rd_ready (mem_rd_ready)
);
endmodule
......@@ -20,7 +20,22 @@ import ctypes
import os.path as osp
from sys import platform
def driver(hw_backend):
def get_ext():
return ".dylib" if platform == "darwin" else ".so"
def load_dll(dll):
try:
return [ctypes.CDLL(dll, ctypes.RTLD_GLOBAL)]
except OSError:
return []
def load_sw():
cur_path = osp.dirname(osp.abspath(osp.expanduser(__file__)))
sw_libname = "libsw" + get_ext()
sw_lib = osp.join(cur_path, "..", "build", sw_libname)
load_dll(sw_lib)
def init(hw_backend):
"""Init hardware and software shared library for accelerator
Parameters
......@@ -29,23 +44,15 @@ def driver(hw_backend):
Hardware backend can be verilog or chisel
"""
_ext = ".dylib" if platform == "darwin" else ".so"
_hw_libname = "libhw" + _ext
_sw_libname = "libsw" + _ext
_cur_path = osp.dirname(osp.abspath(osp.expanduser(__file__)))
cur_path = osp.dirname(osp.abspath(osp.expanduser(__file__)))
hw_libname = "libhw" + get_ext()
if hw_backend in ("verilog", "chisel"):
_hw_lib = osp.join(_cur_path, "..", "hardware", hw_backend, "build", _hw_libname)
_sw_lib = osp.join(_cur_path, "..", "build", _sw_libname)
def load_dll(dll):
try:
return [ctypes.CDLL(dll, ctypes.RTLD_GLOBAL)]
except OSError:
return []
hw_lib = osp.join(cur_path, "..", "hardware", hw_backend, "build", hw_libname)
m = tvm.module.load(hw_lib, "vta-tsim")
load_sw()
f = tvm.get_global_func("tvm.vta.tsim.init")
f(m)
def run(a, b, c):
load_dll(_sw_lib)
f = tvm.get_global_func("tvm.vta.driver")
m = tvm.module.load(_hw_lib, "vta-tsim")
return f(m, a, b, c)
return run
def load_module():
load_sw()
return tvm.get_global_func("tvm.vta.driver")
......@@ -35,25 +35,56 @@ uint32_t get_half_addr(void *p, bool upper) {
using vta::dpi::DPIModuleNode;
using tvm::runtime::Module;
class DPILoader {
public:
~DPILoader() {
dpi_->SimResume();
dpi_->SimFinish();
}
void Init(Module module) {
mod_ = module;
dpi_ = this->Get();
dpi_->SimLaunch();
dpi_->SimWait();
}
DPIModuleNode* Get() {
return static_cast<DPIModuleNode*>(mod_.operator->());
}
static DPILoader* Global() {
static DPILoader inst;
return &inst;
}
// TVM module
Module mod_;
// DPI Module
DPIModuleNode* dpi_{nullptr};
};
class Device {
public:
Device(Module module)
: module_(module) {
dpi_ = static_cast<DPIModuleNode*>(
module.operator->());
Device() {
loader_ = DPILoader::Global();
}
uint32_t Run(uint32_t c, uint32_t length, void* inp, void* out) {
uint32_t cycles;
this->Init();
this->Launch(c, length, inp, out);
cycles = this->WaitForCompletion();
dpi_->Finish();
return cycles;
}
private:
void Init() {
dpi_ = loader_->Get();
dpi_->SimResume();
}
void Launch(uint32_t c, uint32_t length, void* inp, void* out) {
dpi_->Launch(wait_cycles_);
dpi_->WriteReg(0x08, c);
dpi_->WriteReg(0x0c, length);
dpi_->WriteReg(0x10, get_half_addr(inp, false));
......@@ -70,24 +101,33 @@ class Device {
if (val == 2) break; // finish
}
val = dpi_->ReadReg(0x04);
dpi_->SimWait();
return val;
}
// wait cycles
uint32_t wait_cycles_{100000000};
DPIModuleNode* dpi_;
Module module_;
// DPI loader
DPILoader* loader_{nullptr};
// DPI Module
DPIModuleNode* dpi_{nullptr};
};
using tvm::runtime::TVMRetValue;
using tvm::runtime::TVMArgs;
TVM_REGISTER_GLOBAL("tvm.vta.tsim.init")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Module m = args[0];
DPILoader::Global()->Init(m);
});
TVM_REGISTER_GLOBAL("tvm.vta.driver")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Module dev_mod = args[0];
DLTensor* A = args[1];
DLTensor* B = args[2];
Device dev_(dev_mod);
uint32_t cycles = dev_.Run(static_cast<int>(args[3]), A->shape[0], A->data, B->data);
DLTensor* A = args[0];
DLTensor* B = args[1];
Device dev_;
uint32_t cycles = dev_.Run(static_cast<int>(args[2]), A->shape[0], A->data, B->data);
*rv = static_cast<int>(cycles);
});
......
......@@ -26,12 +26,13 @@ def test_accel():
ctx = tvm.cpu(0)
a = tvm.nd.array(np.random.randint(rmax, size=n).astype("uint64"), ctx)
b = tvm.nd.array(np.zeros(n).astype("uint64"), ctx)
f = tsim.driver("chisel")
f = tsim.load_module()
cycles = f(a, b, c)
msg = "cycles:{0:4} n:{1:2} c:{2:2}".format(cycles, n, c)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + c, err_msg = "[FAIL] " + msg)
print("[PASS] " + msg)
if __name__ == "__main__":
tsim.init("chisel")
for i in range(10):
test_accel()
......@@ -26,12 +26,13 @@ def test_accel():
ctx = tvm.cpu(0)
a = tvm.nd.array(np.random.randint(rmax, size=n).astype("uint64"), ctx)
b = tvm.nd.array(np.zeros(n).astype("uint64"), ctx)
f = tsim.driver("verilog")
f = tsim.load_module()
cycles = f(a, b, c)
msg = "cycles:{0:4} n:{1:2} c:{2:2}".format(cycles, n, c)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + c, err_msg = "[FAIL] " + msg)
print("[PASS] " + msg)
if __name__ == "__main__":
tsim.init("verilog")
for i in range(10):
test_accel()
......@@ -35,7 +35,6 @@ module VTAHostDPI #
import "DPI-C" function void VTAHostDPI
(
output byte unsigned exit,
output byte unsigned req_valid,
output byte unsigned req_opcode,
output byte unsigned req_addr,
......@@ -50,7 +49,6 @@ module VTAHostDPI #
typedef logic [31:0] dpi32_t;
dpi1_t __reset;
dpi8_t __exit;
dpi8_t __req_valid;
dpi8_t __req_opcode;
dpi8_t __req_addr;
......@@ -80,7 +78,6 @@ module VTAHostDPI #
// evaluate DPI function
always_ff @(posedge clock) begin
if (reset | __reset) begin
__exit = 0;
__req_valid = 0;
__req_opcode = 0;
__req_addr = 0;
......@@ -88,7 +85,6 @@ module VTAHostDPI #
end
else begin
VTAHostDPI(
__exit,
__req_valid,
__req_opcode,
__req_addr,
......@@ -99,21 +95,4 @@ module VTAHostDPI #
end
end
logic [63:0] cycles;
always_ff @(posedge clock) begin
if (reset | __reset) begin
cycles <= 'd0;
end
else begin
cycles <= cycles + 1'b1;
end
end
always_ff @(posedge clock) begin
if (__exit == 'd1) begin
$finish;
end
end
endmodule
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module VTASimDPI
(
input clock,
input reset,
output logic dpi_wait
);
import "DPI-C" function void VTASimDPI
(
output byte unsigned sim_wait,
output byte unsigned sim_exit
);
typedef logic dpi1_t;
typedef logic [7:0] dpi8_t;
dpi1_t __reset;
dpi8_t __wait;
dpi8_t __exit;
// reset
always_ff @(posedge clock) begin
__reset <= reset;
end
// evaluate DPI function
always_ff @(posedge clock) begin
if (reset | __reset) begin
__wait = 0;
__exit = 0;
end
else begin
VTASimDPI(
__wait,
__exit);
end
end
logic wait_reg;
always_ff @(posedge clock) begin
if (reset | __reset) begin
wait_reg <= 1'b0;
end else if (__wait == 1) begin
wait_reg <= 1'b1;
end else begin
wait_reg <= 1'b0;
end
end
assign dpi_wait = wait_reg;
always_ff @(posedge clock) begin
if (__exit == 1) begin
$finish;
end
end
endmodule
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package vta.dpi
import chisel3._
import chisel3.util._
import vta.util.config._
import vta.interface.axi._
import vta.shell._
/** Sim DPI module.
*
* Wrapper for Sim Verilog DPI module.
*/
class VTASimDPI extends BlackBox with HasBlackBoxResource {
val io = IO(new Bundle {
val clock = Input(Clock())
val reset = Input(Bool())
val dpi_wait = Output(Bool())
})
setResource("/verilog/VTASimDPI.v")
}
......@@ -20,6 +20,7 @@
package vta.shell
import chisel3._
import chisel3.experimental.MultiIOModule
import vta.util.config._
import vta.interface.axi._
import vta.shell._
......@@ -61,18 +62,37 @@ class VTAMem(implicit p: Parameters) extends Module {
mem_axi.io.axi <> io.axi
}
/** VTASim.
*
* This module is used to handle hardware simulation thread, such as halting
* or terminating the simulation thread. The sim_wait port is used to halt
* the simulation thread when it is asserted and resume it when it is
* de-asserted.
*/
class VTASim(implicit p: Parameters) extends MultiIOModule {
val sim_wait = IO(Output(Bool()))
val sim = Module(new VTASimDPI)
sim.io.reset := reset
sim.io.clock := clock
sim_wait := sim.io.dpi_wait
}
/** SimShell.
*
* The simulation shell instantiate a host and memory simulation modules and it is
* intended to be connected to the VTAShell.
* The simulation shell instantiate the sim, host and memory DPI modules that
* are connected to the VTAShell. An extra clock, sim_clock, is used to eval
* the VTASim DPI function when the main simulation clock is on halt state.
*/
class SimShell(implicit p: Parameters) extends Module {
val io = IO(new Bundle {
val mem = new AXIClient(p(ShellKey).memParams)
val host = new AXILiteMaster(p(ShellKey).hostParams)
})
val host = Module(new VTAHost)
val mem = Module(new VTAMem)
io.mem <> mem.io.axi
io.host <> host.io.axi
class SimShell(implicit p: Parameters) extends MultiIOModule {
val mem = IO(new AXIClient(p(ShellKey).memParams))
val host = IO(new AXILiteMaster(p(ShellKey).hostParams))
val sim_clock = IO(Input(Clock()))
val sim_wait = IO(Output(Bool()))
val mod_sim = Module(new VTASim)
val mod_host = Module(new VTAHost)
val mod_mem = Module(new VTAMem)
mem <> mod_mem.io.axi
host <> mod_host.io.axi
mod_sim.reset := reset
mod_sim.clock := sim_clock
sim_wait := mod_sim.sim_wait
}
......@@ -20,14 +20,18 @@
package vta.test
import chisel3._
import chisel3.experimental.MultiIOModule
import vta.util.config._
import vta.shell._
/** Test. This generates a testbench file for simulation */
class Test(implicit p: Parameters) extends Module {
val io = IO(new Bundle {})
class Test(implicit p: Parameters) extends MultiIOModule {
val sim_clock = IO(Input(Clock()))
val sim_wait = IO(Output(Bool()))
val sim_shell = Module(new SimShell)
val vta_shell = Module(new VTAShell)
vta_shell.io.host <> sim_shell.io.host
sim_shell.io.mem <> vta_shell.io.mem
sim_shell.sim_clock := sim_clock
sim_wait := sim_shell.sim_wait
sim_shell.mem <> vta_shell.io.mem
vta_shell.io.host <> sim_shell.host
}
......@@ -17,6 +17,8 @@
* under the License.
*/
#include <chrono>
#include <thread>
#include <vta/dpi/tsim.h>
#if VM_TRACE
......@@ -29,11 +31,17 @@
#endif
static VTAContextHandle _ctx = nullptr;
static VTAMemDPIFunc _mem_dpi = nullptr;
static VTASimDPIFunc _sim_dpi = nullptr;
static VTAHostDPIFunc _host_dpi = nullptr;
static VTAMemDPIFunc _mem_dpi = nullptr;
void VTASimDPI(dpi8_t* wait,
dpi8_t* exit) {
assert(_sim_dpi != nullptr);
(*_sim_dpi)(_ctx, wait, exit);
}
void VTAHostDPI(dpi8_t* exit,
dpi8_t* req_valid,
void VTAHostDPI(dpi8_t* req_valid,
dpi8_t* req_opcode,
dpi8_t* req_addr,
dpi32_t* req_value,
......@@ -41,7 +49,7 @@ void VTAHostDPI(dpi8_t* exit,
dpi8_t resp_valid,
dpi32_t resp_value) {
assert(_host_dpi != nullptr);
(*_host_dpi)(_ctx, exit, req_valid, req_opcode,
(*_host_dpi)(_ctx, req_valid, req_opcode,
req_addr, req_value, req_deq,
resp_valid, resp_value);
}
......@@ -63,9 +71,11 @@ void VTAMemDPI(dpi8_t req_valid,
}
void VTADPIInit(VTAContextHandle handle,
VTASimDPIFunc sim_dpi,
VTAHostDPIFunc host_dpi,
VTAMemDPIFunc mem_dpi) {
_ctx = handle;
_sim_dpi = sim_dpi;
_host_dpi = host_dpi;
_mem_dpi = mem_dpi;
}
......@@ -77,7 +87,7 @@ void vl_finish(const char* filename, int linenum, const char* hier) {
Verilated::gotFinish(true);
}
int VTADPISim(uint64_t max_cycles) {
int VTADPISim() {
uint64_t trace_count = 0;
Verilated::flushCall();
Verilated::gotFinish(false);
......@@ -115,13 +125,15 @@ int VTADPISim(uint64_t max_cycles) {
top->reset = 0;
// start simulation
while (!Verilated::gotFinish() && trace_count < max_cycles) {
while (!Verilated::gotFinish()) {
top->sim_clock = 0;
top->clock = 0;
top->eval();
#if VM_TRACE
if (trace_count >= start)
tfp->dump(static_cast<vluint64_t>(trace_count * 2));
#endif
top->sim_clock = 1;
top->clock = 1;
top->eval();
#if VM_TRACE
......@@ -129,6 +141,14 @@ int VTADPISim(uint64_t max_cycles) {
tfp->dump(static_cast<vluint64_t>(trace_count * 2 + 1));
#endif
trace_count++;
while (top->sim_wait) {
top->clock = 0;
std::this_thread::sleep_for(std::chrono::milliseconds(100));
top->sim_clock = 0;
top->eval();
top->sim_clock = 1;
top->eval();
}
}
#if VM_TRACE
......
......@@ -34,11 +34,17 @@ namespace dpi {
*/
class DPIModuleNode : public tvm::runtime::ModuleNode {
public:
/*!
* \brief Launch hardware simulation until accelerator finishes or reach max_cycles
* \param max_cycles The maximum of cycles to wait
*/
virtual void Launch(uint64_t max_cycles) = 0;
/*! \brief Launch hardware simulation */
virtual void SimLaunch() = 0;
/*! \brief Halt hardware simulation */
virtual void SimWait() = 0;
/*! \brief Resume hardware simulation */
virtual void SimResume() = 0;
/*! \brief Finish hardware simulation */
virtual void SimFinish() = 0;
/*!
* \brief Write an accelerator register
......@@ -53,13 +59,9 @@ class DPIModuleNode : public tvm::runtime::ModuleNode {
*/
virtual uint32_t ReadReg(int addr) = 0;
/*! \brief Finish hardware simulation */
virtual void Finish() = 0;
static tvm::runtime::Module Load(std::string dll_name);
};
} // namespace dpi
} // namespace vta
#endif // VTA_DPI_MODULE_H_
......@@ -36,9 +36,13 @@ typedef unsigned long long dpi64_t; // NOLINT(*)
/*! \brief the context handle */
typedef void* VTAContextHandle;
typedef void (*VTASimDPIFunc)(
VTAContextHandle self,
dpi8_t* wait,
dpi8_t* exit);
/*!
* \brief Host DPI callback function that is invoked in VTAHostDPI.v every clock cycle
* \param exit Host kill simulation
* \param req_valid Host has a valid request for read or write a register in Accel
* \param req_opcode Host request type, opcode=0 for read and opcode=1 for write
* \param req_addr Host request register address
......@@ -50,7 +54,6 @@ typedef void* VTAContextHandle;
*/
typedef void (*VTAHostDPIFunc)(
VTAContextHandle self,
dpi8_t* exit,
dpi8_t* req_valid,
dpi8_t* req_opcode,
dpi8_t* req_addr,
......@@ -84,28 +87,28 @@ typedef void (*VTAMemDPIFunc)(
/*! \brief The type of VTADPIInit function pointer */
typedef void (*VTADPIInitFunc)(VTAContextHandle handle,
VTASimDPIFunc sim_dpi,
VTAHostDPIFunc host_dpi,
VTAMemDPIFunc mem_dpi);
/*! \brief The type of VTADPISim function pointer */
typedef int (*VTADPISimFunc)(uint64_t max_cycles);
typedef int (*VTADPISimFunc)();
/*!
* \brief Set Host and Memory DPI functions
* \param handle DPI Context handle
* \param sim_dpi Sim DPI function
* \param host_dpi Host DPI function
* \param mem_dpi Memory DPI function
*/
TVM_DLL void VTADPIInit(VTAContextHandle handle,
VTASimDPIFunc sim_dpi,
VTAHostDPIFunc host_dpi,
VTAMemDPIFunc mem_dpi);
/*!
* \brief Instantiate VTA design and generate clock/reset
* \param max_cycles The maximum number of simulation cycles
*/
TVM_DLL int VTADPISim(uint64_t max_cycles);
/*! \brief VTA hardware simulation thread */
TVM_DLL int VTADPISim();
#ifdef __cplusplus
}
......
......@@ -42,7 +42,7 @@ def server_start():
curr_path = os.path.dirname(
os.path.abspath(os.path.expanduser(__file__)))
proj_root = os.path.abspath(os.path.join(curr_path, "../../../../"))
dll_path = find_libvta()[0]
dll_path = find_libvta("libvta")[0]
cfg_path = os.path.abspath(os.path.join(proj_root, "build/vta_config.json"))
runtime_dll = []
_load_module = tvm.get_global_func("tvm.rpc.server.load_module")
......
......@@ -19,21 +19,48 @@ from __future__ import absolute_import
import sys
import os
def _get_lib_name():
def _get_lib_name(lib_name):
"""Get lib name with extension
Returns
-------
lib_name_ext : str
Name of VTA shared library with extension
Parameters
------------
lib_name : str
Name of VTA shared library
"""
if sys.platform.startswith('win32'):
return "vta.dll"
return lib_name + ".dll"
if sys.platform.startswith('darwin'):
return "libvta.dylib"
return "libvta.so"
return lib_name + ".dylib"
return lib_name + ".so"
def find_libvta(lib_vta, optional=False):
"""Find VTA library
Returns
-------
lib_found : str
Library path
Parameters
------------
lib_vta : str
Name of VTA shared library
def find_libvta(optional=False):
"""Find VTA library"""
optional : bool
Enable error check
"""
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
lib_search = [curr_path]
lib_search += [os.path.join(curr_path, "..", "..", "build",)]
lib_search += [os.path.join(curr_path, "..", "..", "..", "build",)]
lib_search += [os.path.join(curr_path, "..", "..", "..", "build", "Release")]
lib_name = _get_lib_name()
lib_name = _get_lib_name(lib_vta)
lib_path = [os.path.join(x, lib_name) for x in lib_search]
lib_found = [x for x in lib_path if os.path.exists(x)]
if not lib_found and not optional:
......
......@@ -17,22 +17,34 @@
"""Utilities to start simulator."""
import ctypes
import json
import sys
import os
import tvm
from ..environment import get_env
from ..libinfo import find_libvta
def _load_lib():
"""Load local library, assuming they are simulator."""
lib_path = find_libvta(optional=True)
if not lib_path:
def _load_sw():
"""Load software library, assuming they are simulator."""
lib_sw = find_libvta("libvta", optional=True)
if not lib_sw:
return []
try:
return [ctypes.CDLL(lib_path[0], ctypes.RTLD_GLOBAL)]
return [ctypes.CDLL(lib_sw[0], ctypes.RTLD_GLOBAL)]
except OSError:
return []
def _load_all():
"""Load hardware library for tsim."""
lib = _load_sw()
env = get_env()
if env.TARGET == "tsim":
lib = find_libvta("libvta_hw", optional=True)
f = tvm.get_global_func("vta.tsim.init")
m = tvm.module.load(lib[0], "vta-tsim")
f(m)
return lib
def enabled():
"""Check if simulator is enabled."""
f = tvm.get_global_func("vta.simulator.profiler_clear", True)
......@@ -40,49 +52,31 @@ def enabled():
def clear_stats():
"""Clear profiler statistics"""
"""Clear profiler statistics."""
env = get_env()
if env.TARGET == "sim":
f = tvm.get_global_func("vta.simulator.profiler_clear", True)
else:
f = tvm.get_global_func("vta.tsim.profiler_clear", True)
if f:
f()
def stats():
"""Clear profiler statistics
"""Get profiler statistics
Returns
-------
stats : dict
Current profiler statistics
"""
env = get_env()
if env.TARGET == "sim":
x = tvm.get_global_func("vta.simulator.profiler_status")()
else:
x = tvm.get_global_func("vta.tsim.profiler_status")()
return json.loads(x)
def tsim_init(hw_lib):
"""Init hardware shared library for TSIM
Parameters
------------
hw_lib : str
Name of hardware shared library
"""
cur_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
vta_build_path = os.path.join(cur_path, "..", "..", "..", "build")
if not hw_lib.endswith(("dylib", "so")):
hw_lib += ".dylib" if sys.platform == "darwin" else ".so"
lib = os.path.join(vta_build_path, hw_lib)
f = tvm.get_global_func("tvm.vta.tsim.init")
m = tvm.module.load(lib, "vta-tsim")
f(m)
def tsim_cycles():
"""Get tsim clock cycles
Returns
-------
stats : int
tsim clock cycles
"""
return tvm.get_global_func("tvm.vta.tsim.cycles")()
# debug flag to skip execution.
DEBUG_SKIP_EXEC = 1
......@@ -97,4 +91,4 @@ def debug_mode(flag):
tvm.get_global_func("vta.simulator.profiler_debug_mode")(flag)
LIBS = _load_lib()
LIBS = _load_all()
......@@ -22,6 +22,7 @@ from tvm import rpc, autotvm
from ..environment import get_env
from . import simulator
def run(run_func):
"""Run test function on all available env.
......
......@@ -86,17 +86,28 @@ class ThreadSafeQueue {
std::condition_variable cond_;
};
class SimDevice {
public:
void Wait();
void Resume();
void Exit();
bool GetWaitStatus();
bool GetExitStatus();
private:
bool wait_{false};
bool exit_{false};
mutable std::mutex mutex_;
};
class HostDevice {
public:
void PushRequest(uint8_t opcode, uint8_t addr, uint32_t value);
bool TryPopRequest(HostRequest* r, bool pop);
void PushResponse(uint32_t value);
void WaitPopResponse(HostResponse* r);
void Exit();
uint8_t GetExitStatus();
private:
uint8_t exit_{0};
mutable std::mutex mutex_;
ThreadSafeQueue<HostRequest> req_;
ThreadSafeQueue<HostResponse> resp_;
......@@ -116,6 +127,31 @@ class MemDevice {
std::mutex mutex_;
};
void SimDevice::Wait() {
std::unique_lock<std::mutex> lock(mutex_);
wait_ = true;
}
void SimDevice::Resume() {
std::unique_lock<std::mutex> lock(mutex_);
wait_ = false;
}
void SimDevice::Exit() {
std::unique_lock<std::mutex> lock(mutex_);
exit_ = true;
}
bool SimDevice::GetWaitStatus() {
std::unique_lock<std::mutex> lock(mutex_);
return wait_;
}
bool SimDevice::GetExitStatus() {
std::unique_lock<std::mutex> lock(mutex_);
return exit_;
}
void HostDevice::PushRequest(uint8_t opcode, uint8_t addr, uint32_t value) {
HostRequest r;
r.opcode = opcode;
......@@ -141,16 +177,6 @@ void HostDevice::WaitPopResponse(HostResponse* r) {
resp_.WaitPop(r);
}
void HostDevice::Exit() {
std::unique_lock<std::mutex> lock(mutex_);
exit_ = 1;
}
uint8_t HostDevice::GetExitStatus() {
std::unique_lock<std::mutex> lock(mutex_);
return exit_;
}
void MemDevice::SetRequest(uint8_t opcode, uint64_t addr, uint32_t len) {
std::lock_guard<std::mutex> lock(mutex_);
if (opcode == 1) {
......@@ -212,16 +238,29 @@ class DPIModule final : public DPIModuleNode {
VTADPIInitFunc finit = reinterpret_cast<VTADPIInitFunc>(
GetSymbol("VTADPIInit"));
CHECK(finit != nullptr);
finit(this, VTAHostDPI, VTAMemDPI);
fvsim_ = reinterpret_cast<VTADPISimFunc>(GetSymbol("VTADPISim"));
CHECK(fvsim_ != nullptr);
finit(this, VTASimDPI, VTAHostDPI, VTAMemDPI);
ftsim_ = reinterpret_cast<VTADPISimFunc>(GetSymbol("VTADPISim"));
CHECK(ftsim_ != nullptr);
}
void Launch(uint64_t max_cycles) {
auto frun = [this, max_cycles]() {
(*fvsim_)(max_cycles);
void SimLaunch() {
auto frun = [this]() {
(*ftsim_)();
};
vsim_thread_ = std::thread(frun);
tsim_thread_ = std::thread(frun);
}
void SimWait() {
sim_device_.Wait();
}
void SimResume() {
sim_device_.Resume();
}
void SimFinish() {
sim_device_.Exit();
tsim_thread_.join();
}
void WriteReg(int addr, uint32_t value) {
......@@ -238,19 +277,20 @@ class DPIModule final : public DPIModuleNode {
return value;
}
void Finish() {
host_device_.Exit();
vsim_thread_.join();
}
protected:
VTADPISimFunc fvsim_;
VTADPISimFunc ftsim_;
SimDevice sim_device_;
HostDevice host_device_;
MemDevice mem_device_;
std::thread vsim_thread_;
std::thread tsim_thread_;
void HostDPI(dpi8_t* exit,
dpi8_t* req_valid,
void SimDPI(dpi8_t* wait,
dpi8_t* exit) {
*wait = sim_device_.GetWaitStatus();
*exit = sim_device_.GetExitStatus();
}
void HostDPI(dpi8_t* req_valid,
dpi8_t* req_opcode,
dpi8_t* req_addr,
dpi32_t* req_value,
......@@ -258,7 +298,6 @@ class DPIModule final : public DPIModuleNode {
dpi8_t resp_valid,
dpi32_t resp_value) {
HostRequest* r = new HostRequest;
*exit = host_device_.GetExitStatus();
*req_valid = host_device_.TryPopRequest(r, req_deq);
*req_opcode = r->opcode;
*req_addr = r->addr;
......@@ -290,9 +329,16 @@ class DPIModule final : public DPIModuleNode {
}
}
static void VTASimDPI(
VTAContextHandle self,
dpi8_t* wait,
dpi8_t* exit) {
static_cast<DPIModule*>(self)->SimDPI(
wait, exit);
}
static void VTAHostDPI(
VTAContextHandle self,
dpi8_t* exit,
dpi8_t* req_valid,
dpi8_t* req_opcode,
dpi8_t* req_addr,
......@@ -301,7 +347,7 @@ class DPIModule final : public DPIModuleNode {
dpi8_t resp_valid,
dpi32_t resp_value) {
static_cast<DPIModule*>(self)->HostDPI(
exit, req_valid, req_opcode, req_addr,
req_valid, req_opcode, req_addr,
req_value, req_deq, resp_valid, resp_value);
}
......
......@@ -17,32 +17,78 @@
* under the License.
*/
#include <vta/driver.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>
#include <vta/driver.h>
#include <vta/dpi/module.h>
namespace vta {
namespace tsim {
using vta::dpi::DPIModuleNode;
using tvm::runtime::Module;
using vta::dpi::DPIModuleNode;
class Profiler {
public:
/*! \brief cycle counter */
uint64_t cycle_count{0};
Profiler() {
counters_ = new int[num_counters_];
this->ClearAll();
}
~Profiler() {
delete [] counters_;
}
/*! \brief update one event counter */
void Update(uint32_t idx, uint32_t value) {
counters_[idx] += value;
}
/*! \brief clear one event counter*/
void Clear(uint32_t idx) {
counters_[idx] = 0;
}
/*! \brief clear all event counters */
void ClearAll() {
for (uint32_t i = 0; i < num_counters_; i++) {
counters_[i] = 0;
}
}
/*! \brief return counters as json */
std::string AsJSON() {
std::ostringstream os;
os << "{\n"
<< " \"cycle_count\":" << counters_[0] << "\n"
<<"}\n";
return os.str();
}
static Profiler* Global() {
static Profiler inst;
return &inst;
}
private:
/*! \brief total number of event counters */
uint32_t num_counters_{1};
/*! \brief event counters */
int* counters_{nullptr};
};
class DPILoader {
public:
~DPILoader() {
dpi_->SimResume();
dpi_->SimFinish();
}
void Init(Module module) {
mod_ = module;
dpi_ = this->Get();
dpi_->SimLaunch();
dpi_->SimWait();
}
DPIModuleNode* Get() {
......@@ -54,13 +100,16 @@ class DPILoader {
return &inst;
}
// TVM module
Module mod_;
// DPI Module
DPIModuleNode* dpi_{nullptr};
};
class Device {
public:
Device() {
dpi_ = DPILoader::Global();
loader_ = DPILoader::Global();
prof_ = Profiler::Global();
}
......@@ -82,13 +131,13 @@ class Device {
insn_count,
wait_cycles);
this->WaitForCompletion(wait_cycles);
dev_->Finish();
return 0;
}
private:
void Init() {
dev_ = dpi_->Get();
dpi_ = loader_->Get();
dpi_->SimResume();
}
void Launch(vta_phy_addr_t insn_phy_addr,
......@@ -99,57 +148,60 @@ class Device {
vta_phy_addr_t out_phy_addr,
uint32_t insn_count,
uint32_t wait_cycles) {
// launch simulation thread
dev_->Launch(wait_cycles);
// set counter to zero
dev_->WriteReg(0x04, 0);
dev_->WriteReg(0x08, insn_count);
dev_->WriteReg(0x0c, insn_phy_addr);
dev_->WriteReg(0x10, insn_phy_addr >> 32);
dev_->WriteReg(0x14, 0);
dev_->WriteReg(0x18, uop_phy_addr >> 32);
dev_->WriteReg(0x1c, 0);
dev_->WriteReg(0x20, inp_phy_addr >> 32);
dev_->WriteReg(0x24, 0);
dev_->WriteReg(0x28, wgt_phy_addr >> 32);
dev_->WriteReg(0x2c, 0);
dev_->WriteReg(0x30, acc_phy_addr >> 32);
dev_->WriteReg(0x34, 0);
dev_->WriteReg(0x38, out_phy_addr >> 32);
dpi_->WriteReg(0x04, 0);
dpi_->WriteReg(0x08, insn_count);
dpi_->WriteReg(0x0c, insn_phy_addr);
dpi_->WriteReg(0x10, insn_phy_addr >> 32);
dpi_->WriteReg(0x14, 0);
dpi_->WriteReg(0x18, uop_phy_addr >> 32);
dpi_->WriteReg(0x1c, 0);
dpi_->WriteReg(0x20, inp_phy_addr >> 32);
dpi_->WriteReg(0x24, 0);
dpi_->WriteReg(0x28, wgt_phy_addr >> 32);
dpi_->WriteReg(0x2c, 0);
dpi_->WriteReg(0x30, acc_phy_addr >> 32);
dpi_->WriteReg(0x34, 0);
dpi_->WriteReg(0x38, out_phy_addr >> 32);
// start
dev_->WriteReg(0x00, 0x1);
dpi_->WriteReg(0x00, 0x1);
}
void WaitForCompletion(uint32_t wait_cycles) {
uint32_t i, val;
for (i = 0; i < wait_cycles; i++) {
val = dev_->ReadReg(0x00);
val = dpi_->ReadReg(0x00);
val &= 0x2;
if (val == 0x2) break; // finish
}
prof_->cycle_count = dev_->ReadReg(0x04);
prof_->Update(0, dpi_->ReadReg(0x04));
dpi_->SimWait();
}
// Profiler
Profiler* prof_;
// DPI loader
DPILoader* dpi_;
DPILoader* loader_;
// DPI Module
DPIModuleNode* dev_;
DPIModuleNode* dpi_;
};
using tvm::runtime::TVMRetValue;
using tvm::runtime::TVMArgs;
TVM_REGISTER_GLOBAL("tvm.vta.tsim.init")
TVM_REGISTER_GLOBAL("vta.tsim.init")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Module m = args[0];
DPILoader::Global()->Init(m);
});
TVM_REGISTER_GLOBAL("tvm.vta.tsim.cycles")
TVM_REGISTER_GLOBAL("vta.tsim.profiler_clear")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Profiler::Global()->ClearAll();
});
TVM_REGISTER_GLOBAL("vta.tsim.profiler_status")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = static_cast<int>(Profiler::Global()->cycle_count);
*rv = Profiler::Global()->AsJSON();
});
} // namespace tsim
......
......@@ -18,6 +18,7 @@ import tvm
import numpy as np
from tvm.contrib import util
import vta.testing
from vta.testing import simulator
def test_gemm():
......@@ -104,7 +105,14 @@ def test_gemm():
res_ref = np.right_shift(res_ref, 8)
res_ref = np.clip(res_ref, 0, (1<<(env.INP_WIDTH-1))-1).astype(res.dtype)
time_f = f.time_evaluator("gemm", ctx, number=20)
if env.TARGET in ["sim", "tsim"]:
simulator.clear_stats()
cost = time_f(data_arr, weight_arr, res_arr)
if env.TARGET in ["sim", "tsim"]:
stats = simulator.stats()
print("Execution statistics:")
for k, v in stats.items():
print("\t{:<16}: {:>16}".format(k, v))
res_unpack = res_arr.asnumpy().reshape(batch_size // env.BATCH,
channel // env.BLOCK_OUT,
env.BATCH,
......
......@@ -173,15 +173,21 @@ def run_conv2d(env, remote, wl, target,
# In vta sim mode, collect simulator runtime statistics
stats = {}
cost = None
if env.TARGET == "sim":
if env.TARGET in ["sim", "tsim"]:
# Check if we're in local RPC mode (allows us to rebuild the
# runtime on the fly when varying the VTA designs)
local_rpc = int(os.environ.get("VTA_LOCAL_SIM_RPC", "0"))
if local_rpc:
if env.TARGET == "sim":
remote.get_function("vta.simulator.profiler_clear")()
else:
remote.get_function("vta.tsim.profiler_clear")()
cost = time_f(data_arr, kernel_arr, bias_arr, res_arr)
if env.TARGET == "sim":
stats = json.loads(remote.get_function("vta.simulator.profiler_status")())
else:
stats = json.loads(remote.get_function("vta.tsim.profiler_status")())
else:
simulator.clear_stats()
cost = time_f(data_arr, kernel_arr, bias_arr, res_arr)
stats = simulator.stats()
......@@ -215,7 +221,7 @@ def test_conv2d(device="vta"):
def _run(env, remote):
if device == "vta":
target = env.target
if env.TARGET != "sim":
if env.TARGET not in ["sim", "tsim"]:
assert tvm.module.enabled("rpc")
program_fpga(remote, bitstream=None)
reconfig_runtime(remote)
......
......@@ -131,15 +131,21 @@ def run_gemm(env, remote, target,
# In vta sim mode, collect simulator runtime statistics
stats = {}
cost = None
if env.TARGET == "sim":
if env.TARGET in ["sim", "tsim"]:
# Check if we're in local RPC mode (allows us to rebuild the
# runtime on the fly when varying the VTA designs)
local_rpc = int(os.environ.get("VTA_LOCAL_SIM_RPC", "0"))
if local_rpc:
if env.TARGET == "sim":
remote.get_function("vta.simulator.profiler_clear")()
else:
remote.get_function("vta.tsim.profiler_clear")()
cost = time_f(data_arr, kernel_arr, res_arr)
if env.TARGET == "sim":
stats = json.loads(remote.get_function("vta.simulator.profiler_status")())
else:
stats = json.loads(remote.get_function("vta.tsim.profiler_status")())
else:
simulator.clear_stats()
cost = time_f(data_arr, kernel_arr, res_arr)
stats = simulator.stats()
......@@ -171,7 +177,7 @@ def test_gemm(device="vta", batch=128, in_feat=128, out_feat=128):
def _run(env, remote):
if device == "vta":
target = env.target
if env.TARGET != "sim":
if env.TARGET not in ["sim", "tsim"]:
assert tvm.module.enabled("rpc")
program_fpga(remote, bitstream=None)
reconfig_runtime(remote)
......
......@@ -69,15 +69,18 @@ def test_save_load_out():
x_nd = tvm.nd.array(x_np, ctx)
y_nd = tvm.nd.empty(y_np.shape, ctx=ctx, dtype=y_np.dtype)
if env.TARGET == "tsim":
simulator.tsim_init("libvta_hw")
if env.TARGET in ["sim", "tsim"]:
simulator.clear_stats()
f(x_nd, y_nd)
np.testing.assert_equal(y_np, y_nd.asnumpy())
if env.TARGET == "tsim":
print("Load/store test took {} clock cycles".format(simulator.tsim_cycles()))
if env.TARGET in ["sim", "tsim"]:
sim_stats = simulator.stats()
print("Save load execution statistics:")
for k, v in sim_stats.items():
print("\t{:<16}: {:>16}".format(k, v))
vta.testing.run(_run)
......@@ -135,15 +138,18 @@ def test_padded_load():
x_nd = tvm.nd.array(x_np, ctx)
y_nd = tvm.nd.empty(y_np.shape, ctx=ctx, dtype=y_np.dtype)
if env.TARGET == "tsim":
simulator.tsim_init("libvta_hw")
if env.TARGET in ["sim", "tsim"]:
simulator.clear_stats()
f(x_nd, y_nd)
np.testing.assert_equal(y_np, y_nd.asnumpy())
if env.TARGET == "tsim":
print("Padded load test took {} clock cycles".format(simulator.tsim_cycles()))
if env.TARGET in ["sim", "tsim"]:
sim_stats = simulator.stats()
print("Padded load execution statistics:")
for k, v in sim_stats.items():
print("\t{:<16}: {:>16}".format(k, v))
vta.testing.run(_run)
......@@ -213,20 +219,18 @@ def test_gemm():
y_np = np.right_shift(y_np, 8)
y_np = np.clip(y_np, 0, (1<<(env.INP_WIDTH-1))-1).astype(y.dtype)
if env.TARGET == "tsim":
simulator.tsim_init("libvta_hw")
if env.TARGET == "sim":
if env.TARGET in ["sim", "tsim"]:
simulator.clear_stats()
f(x_nd, w_nd, y_nd)
print(simulator.stats())
else:
f(x_nd, w_nd, y_nd)
np.testing.assert_equal(y_np, y_nd.asnumpy())
if env.TARGET == "tsim":
print("GEMM schedule:{} test took {} clock cycles".format(name, simulator.tsim_cycles()))
if env.TARGET in ["sim", "tsim"]:
sim_stats = simulator.stats()
print("GEMM schedule:{} execution statistics:".format(name))
for k, v in sim_stats.items():
print("\t{:<16}: {:>16}".format(k, v))
def test_schedule1():
# default schedule with no smt
......@@ -374,8 +378,8 @@ def test_alu():
res_nd = tvm.nd.array(
np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx)
if env.TARGET == "tsim":
simulator.tsim_init("libvta_hw")
if env.TARGET in ["sim", "tsim"]:
simulator.clear_stats()
if use_imm:
f(a_nd, res_nd)
......@@ -385,8 +389,11 @@ def test_alu():
np.testing.assert_equal(res_np, res_nd.asnumpy())
if env.TARGET == "tsim":
print("ALU {} imm:{} test took {} clock cycles".format(test_name, use_imm, simulator.tsim_cycles()))
if env.TARGET in ["sim", "tsim"]:
sim_stats = simulator.stats()
print("ALU {} execution statistics:".format(test_name))
for k, v in sim_stats.items():
print("\t{:<16}: {:>16}".format(k, v))
check_alu(lambda x, y: x << y, np.left_shift, use_imm=True, test_name="SHL")
check_alu(tvm.max, np.maximum, use_imm=True, test_name="MAX")
......@@ -451,15 +458,18 @@ def test_relu():
res_nd = tvm.nd.array(
np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx)
if env.TARGET == "tsim":
simulator.tsim_init("libvta_hw")
if env.TARGET in ["sim", "tsim"]:
simulator.clear_stats()
f(a_nd, res_nd)
np.testing.assert_equal(res_np, res_nd.asnumpy())
if env.TARGET == "tsim":
print("Relu test took {} clock cycles".format(simulator.tsim_cycles()))
if env.TARGET in ["sim", "tsim"]:
sim_stats = simulator.stats()
print("Relu execution statistics:")
for k, v in sim_stats.items():
print("\t{:<16}: {:>16}".format(k, v))
vta.testing.run(_run)
......@@ -518,15 +528,18 @@ def test_shift_and_scale():
res_nd = tvm.nd.array(
np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx)
if env.TARGET == "tsim":
simulator.tsim_init("libvta_hw")
if env.TARGET in ["sim", "tsim"]:
simulator.clear_stats()
f(a_nd, res_nd)
np.testing.assert_equal(res_np, res_nd.asnumpy())
if env.TARGET == "tsim":
print("Shift/scale test took {} clock cycles".format(simulator.tsim_cycles()))
if env.TARGET in ["sim", "tsim"]:
sim_stats = simulator.stats()
print("Shift and scale execution statistics:")
for k, v in sim_stats.items():
print("\t{:<16}: {:>16}".format(k, v))
vta.testing.run(_run)
......
......@@ -89,7 +89,7 @@ stop_pack="nn.global_avg_pool2d"
# When target is 'pynq', reconfigure FPGA and runtime.
# Otherwise, if target is 'sim', execute locally.
if env.TARGET != "sim":
if env.TARGET not in ["sim", "tsim"]:
# Get remote from tracker node if environment variable is set.
# To set up the tracker, you'll need to follow the "Auto-tuning
......@@ -235,7 +235,7 @@ num = 4 # number of times we run module for a single measurement
rep = 3 # number of measurements (we derive std dev from this)
timer = m.module.time_evaluator("run", ctx, number=num, repeat=rep)
if env.TARGET == "sim":
if env.TARGET in ["sim", "tsim"]:
simulator.clear_stats()
timer()
sim_stats = simulator.stats()
......
......@@ -66,7 +66,7 @@ if env.TARGET == "pynq":
vta.program_fpga(remote, bitstream=None)
# In simulation mode, host the RPC server locally.
elif env.TARGET == "sim":
elif env.TARGET in ["sim", "tsim"]:
remote = rpc.LocalSession()
######################################################################
......@@ -437,6 +437,10 @@ A_nd = tvm.nd.array(A_packed, ctx)
B_nd = tvm.nd.array(B_packed, ctx)
C_nd = tvm.nd.array(np.zeros((o, m, env.BATCH, env.BLOCK_OUT)).astype(C.dtype), ctx)
# Clear stats
if env.TARGET in ["sim", "tsim"]:
simulator.clear_stats()
# Invoke the module to perform the computation
f(A_nd, B_nd, C_nd)
......@@ -452,8 +456,15 @@ C_ref = np.dot(A_orig.astype(env.acc_dtype),
C_ref = C_ref.reshape(
o, env.BATCH, m, env.BLOCK_OUT).transpose((0, 2, 1, 3))
np.testing.assert_equal(C_ref, C_nd.asnumpy())
print("Successful matrix multiply test!")
# Print stats
if env.TARGET in ["sim", "tsim"]:
sim_stats = simulator.stats()
print("Execution statistics:")
for k, v in sim_stats.items():
print("\t{:<16}: {:>16}".format(k, v))
print("Successful matrix multiply test!")
######################################################################
# Summary
......
......@@ -70,7 +70,7 @@ if env.TARGET == "pynq":
vta.program_fpga(remote, bitstream=None)
# In simulation mode, host the RPC server locally.
elif env.TARGET == "sim":
elif env.TARGET in ["sim", "tsim"]:
remote = rpc.LocalSession()
######################################################################
......@@ -412,6 +412,10 @@ data_nd = tvm.nd.array(data_packed, ctx)
kernel_nd = tvm.nd.array(kernel_packed, ctx)
res_nd = tvm.nd.array(np.zeros(output_shape).astype(res.dtype), ctx)
# Clear stats
if env.TARGET in ["sim", "tsim"]:
simulator.clear_stats()
# Invoke the module to perform the computation
f(data_nd, kernel_nd, res_nd)
......@@ -430,6 +434,14 @@ res_ref = res_ref.reshape((batch_size // env.BATCH,
fout_height,
fout_width)).transpose((0, 2, 4, 5, 1, 3))
tvm.testing.assert_allclose(res_ref, res_nd.asnumpy())
# Print stats
if env.TARGET in ["sim", "tsim"]:
sim_stats = simulator.stats()
print("Execution statistics:")
for k, v in sim_stats.items():
print("\t{:<16}: {:>16}".format(k, v))
print("Successful 2D convolution test!")
######################################################################
......
......@@ -69,7 +69,7 @@ if env.TARGET == "pynq":
vta.program_fpga(remote, bitstream=None)
# In simulation mode, host the RPC server locally.
elif env.TARGET == "sim":
elif env.TARGET in ["sim", "tsim"]:
remote = rpc.LocalSession()
######################################################################
......@@ -352,6 +352,10 @@ data_nd = tvm.nd.array(data_packed, ctx)
weight_nd = tvm.nd.array(weight_packed, ctx)
res_nd = tvm.nd.array(np.zeros(output_shape).astype(res.dtype), ctx)
# Clear stats
if env.TARGET in ["sim", "tsim"]:
simulator.clear_stats()
# Invoke the module to perform the computation
f(data_nd, weight_nd, res_nd)
......@@ -366,6 +370,14 @@ res_ref = res_ref.reshape(batch_size // env.BATCH,
out_channels // env.BLOCK_OUT,
env.BLOCK_OUT).transpose((0, 2, 1, 3))
np.testing.assert_equal(res_ref, res_nd.asnumpy())
# Print stats
if env.TARGET in ["sim", "tsim"]:
sim_stats = simulator.stats()
print("Execution statistics:")
for k, v in sim_stats.items():
print("\t{:<16}: {:>16}".format(k, v))
print("Successful blocked matrix multiply test!")
######################################################################
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment