Commit 9c44e4b4 by Tianqi Chen

[DRIVER] Add simulator, unify testcase to unittest (#25)

parent 0faefafc
...@@ -40,6 +40,19 @@ ifneq ($(ADD_LDFLAGS), NONE) ...@@ -40,6 +40,19 @@ ifneq ($(ADD_LDFLAGS), NONE)
LDFLAGS += $(ADD_LDFLAGS) LDFLAGS += $(ADD_LDFLAGS)
endif endif
UNAME_S := $(shell uname -s)
ifeq ($(UNAME_S), Darwin)
SHARED_LIBRARY_SUFFIX := dylib
WHOLE_ARCH= -all_load
NO_WHOLE_ARCH= -noall_load
LDFLAGS += -undefined dynamic_lookup
else
SHARED_LIBRARY_SUFFIX := so
WHOLE_ARCH= --whole-archive
NO_WHOLE_ARCH= --no-whole-archive
endif
all: lib/libvta.so lib/libvta_runtime.so all: lib/libvta.so lib/libvta_runtime.so
...@@ -53,6 +66,10 @@ ifeq ($(TARGET), VTA_PYNQ_TARGET) ...@@ -53,6 +66,10 @@ ifeq ($(TARGET), VTA_PYNQ_TARGET)
LDFLAGS += -l:libdma.so LDFLAGS += -l:libdma.so
endif endif
ifeq ($(TARGET), sim)
VTA_LIB_SRC += $(wildcard src/sim/*.cc)
endif
VTA_LIB_OBJ = $(patsubst src/%.cc, build/%.o, $(VTA_LIB_SRC)) VTA_LIB_OBJ = $(patsubst src/%.cc, build/%.o, $(VTA_LIB_SRC))
build/%.o: src/%.cc build/%.o: src/%.cc
...@@ -71,7 +88,7 @@ lib/libvta_runtime.so: build/runtime.o ...@@ -71,7 +88,7 @@ lib/libvta_runtime.so: build/runtime.o
lint: pylint cpplint lint: pylint cpplint
cpplint: cpplint:
python nnvm/dmlc-core/scripts/lint.py vta cpp include src hardware tests python nnvm/dmlc-core/scripts/lint.py vta cpp include src
pylint: pylint:
pylint python/vta --rcfile=$(ROOTDIR)/tests/lint/pylintrc pylint python/vta --rcfile=$(ROOTDIR)/tests/lint/pylintrc
...@@ -86,3 +103,4 @@ clean: ...@@ -86,3 +103,4 @@ clean:
-include build/*.d -include build/*.d
-include build/*/*.d -include build/*/*.d
-include build/*/*/*.d -include build/*/*/*.d
-include build/*/*/*/*.d
...@@ -77,7 +77,7 @@ void VTAMemFree(void* buf); ...@@ -77,7 +77,7 @@ void VTAMemFree(void* buf);
* \param buf Pointer to memory region allocated with VTAMemAlloc. * \param buf Pointer to memory region allocated with VTAMemAlloc.
* \return The physical address of the memory region. * \return The physical address of the memory region.
*/ */
vta_phy_addr_t VTAGetMemPhysAddr(void* buf); vta_phy_addr_t VTAMemGetPhyAddr(void* buf);
/*! /*!
* \brief Flushes the region of memory out of the CPU cache to DRAM. * \brief Flushes the region of memory out of the CPU cache to DRAM.
......
...@@ -519,8 +519,8 @@ typedef struct { ...@@ -519,8 +519,8 @@ typedef struct {
uint64_t alu_opcode : VTA_ALU_OPCODE_BIT_WIDTH; uint64_t alu_opcode : VTA_ALU_OPCODE_BIT_WIDTH;
/*! \brief Use immediate is true */ /*! \brief Use immediate is true */
uint64_t use_imm : 1; uint64_t use_imm : 1;
/*! \brief Immediate value */ /*! \brief Immediate value: allow negative value */
uint64_t imm : VTA_ALUOP_IMM_BIT_WIDTH; int64_t imm : VTA_ALUOP_IMM_BIT_WIDTH;
} VTAAluInsn; } VTAAluInsn;
/*! \brief VTA ALU instruction converter */ /*! \brief VTA ALU instruction converter */
......
...@@ -196,7 +196,7 @@ void VTAUopPush(uint32_t mode, ...@@ -196,7 +196,7 @@ void VTAUopPush(uint32_t mode,
uint32_t wgt_index, uint32_t wgt_index,
uint32_t opcode, uint32_t opcode,
uint32_t use_imm, uint32_t use_imm,
uint32_t imm_val); int32_t imm_val);
/*! /*!
* \brief Mark start of a micro op loop. * \brief Mark start of a micro op loop.
......
...@@ -27,7 +27,7 @@ ADD_LDFLAGS= ...@@ -27,7 +27,7 @@ ADD_LDFLAGS=
ADD_CFLAGS= ADD_CFLAGS=
# the hardware target # the hardware target
TARGET = VTA_PYNQ_TARGET TARGET = pynq
#--------------------- #---------------------
# VTA hardware parameters # VTA hardware parameters
...@@ -89,7 +89,6 @@ VTA_OUT_BUFF_SIZE = $(shell echo "$$(( 1 << $(VTA_LOG_OUT_BUFF_SIZE) ))" ) ...@@ -89,7 +89,6 @@ VTA_OUT_BUFF_SIZE = $(shell echo "$$(( 1 << $(VTA_LOG_OUT_BUFF_SIZE) ))" )
# Update ADD_CFLAGS # Update ADD_CFLAGS
ADD_CFLAGS += \ ADD_CFLAGS += \
-D$(TARGET) \
-DVTA_LOG_WGT_WIDTH=$(VTA_LOG_WGT_WIDTH) -DVTA_LOG_INP_WIDTH=$(VTA_LOG_INP_WIDTH) \ -DVTA_LOG_WGT_WIDTH=$(VTA_LOG_WGT_WIDTH) -DVTA_LOG_INP_WIDTH=$(VTA_LOG_INP_WIDTH) \
-DVTA_LOG_ACC_WIDTH=$(VTA_LOG_ACC_WIDTH) -DVTA_LOG_OUT_WIDTH=$(VTA_LOG_OUT_WIDTH) \ -DVTA_LOG_ACC_WIDTH=$(VTA_LOG_ACC_WIDTH) -DVTA_LOG_OUT_WIDTH=$(VTA_LOG_OUT_WIDTH) \
-DVTA_LOG_BATCH=$(VTA_LOG_BATCH) \ -DVTA_LOG_BATCH=$(VTA_LOG_BATCH) \
......
...@@ -8,6 +8,7 @@ try: ...@@ -8,6 +8,7 @@ try:
from . import arm_conv2d, vta_conv2d from . import arm_conv2d, vta_conv2d
from .build_module import build_config, lower, build from .build_module import build_config, lower, build
from .rpc_client import reconfig_runtime, program_fpga from .rpc_client import reconfig_runtime, program_fpga
from . import graph from . import graph
except ImportError: except ImportError:
pass pass
...@@ -89,7 +89,7 @@ class Environment(object): ...@@ -89,7 +89,7 @@ class Environment(object):
""" """
current = None current = None
cfg_keys = [ cfg_keys = [
"target", "TARGET",
"LOG_INP_WIDTH", "LOG_INP_WIDTH",
"LOG_WGT_WIDTH", "LOG_WGT_WIDTH",
"LOG_ACC_WIDTH", "LOG_ACC_WIDTH",
...@@ -204,9 +204,19 @@ class Environment(object): ...@@ -204,9 +204,19 @@ class Environment(object):
@property @property
def gevm(self): def gevm(self):
"""GEMM intrinsic""" """GEVM intrinsic"""
return self.dev.gevm return self.dev.gevm
@property
def target_host(self):
"""The target host"""
if self.TARGET == "pynq":
return "llvm -target=armv7-none-linux-gnueabihf"
elif self.TARGET == "sim":
return "llvm"
else:
raise ValueError("Unknown target %s" % self.TARGET)
def get_env(): def get_env():
"""Get the current VTA Environment. """Get the current VTA Environment.
...@@ -278,6 +288,7 @@ def _init_env(): ...@@ -278,6 +288,7 @@ def _init_env():
for k in Environment.cfg_keys: for k in Environment.cfg_keys:
keys.add("VTA_" + k) keys.add("VTA_" + k)
keys.add("TARGET")
if not os.path.isfile(filename): if not os.path.isfile(filename):
raise RuntimeError( raise RuntimeError(
...@@ -290,8 +301,11 @@ def _init_env(): ...@@ -290,8 +301,11 @@ def _init_env():
for k in keys: for k in keys:
if k +" =" in line: if k +" =" in line:
val = line.split("=")[1].strip() val = line.split("=")[1].strip()
cfg[k[4:]] = int(val) if k.startswith("VTA_"):
cfg["target"] = "pynq" k = k[4:]
cfg[k] = int(val)
else:
cfg[k] = val
return Environment(cfg) return Environment(cfg)
Environment.current = _init_env() Environment.current = _init_env()
...@@ -78,8 +78,7 @@ def fold_uop_loop(stmt_in): ...@@ -78,8 +78,7 @@ def fold_uop_loop(stmt_in):
if not fail[0]: if not fail[0]:
begin = tvm.call_extern( begin = tvm.call_extern(
"int32", "VTAUopLoopBegin", stmt.extent, *gemm_offsets) "int32", "VTAUopLoopBegin", stmt.extent, *gemm_offsets)
end = tvm.call_extern( end = tvm.call_extern("int32", "VTAUopLoopEnd")
"int32", "VTAUopLoopEnd", stmt.extent, *gemm_offsets)
return [begin, ret, end] return [begin, ret, end]
raise ValueError("Failed to fold the GEMM instructions..") raise ValueError("Failed to fold the GEMM instructions..")
...@@ -683,8 +682,14 @@ def inject_alu_intrin(stmt_in): ...@@ -683,8 +682,14 @@ def inject_alu_intrin(stmt_in):
else: else:
raise RuntimeError( raise RuntimeError(
"Function call not recognized %s" % (loop_body.value.name)) "Function call not recognized %s" % (loop_body.value.name))
elif isinstance(loop_body.value, tvm.expr.Load):
alu_opcode = env.dev.ALU_OPCODE_SHR
lhs = loop_body.value
rhs = tvm.const(0)
else: else:
raise RuntimeError("Expression not recognized %s" % (type(loop_body.value))) raise RuntimeError(
"Expression not recognized %s, %s, %s" % (
type(loop_body.value), str(loop_body.value), str(stmt)))
# Derive array index coefficients # Derive array index coefficients
dst_coeff = tvm.arith.DetectLinearEquation(dst_idx, indices) dst_coeff = tvm.arith.DetectLinearEquation(dst_idx, indices)
...@@ -772,7 +777,9 @@ def inject_alu_intrin(stmt_in): ...@@ -772,7 +777,9 @@ def inject_alu_intrin(stmt_in):
irb = tvm.ir_builder.create() irb = tvm.ir_builder.create()
for idx, extent in enumerate(extents): for idx, extent in enumerate(extents):
irb.emit(tvm.call_extern( irb.emit(tvm.call_extern(
"int32", "VTAUopLoopBegin", extent, dst_coeff[idx], src_coeff[idx])) "int32", "VTAUopLoopBegin",
extent, dst_coeff[idx], src_coeff[idx], 0))
use_imm = int(use_imm)
irb.emit(tvm.call_extern( irb.emit(tvm.call_extern(
"int32", "VTAUopPush", "int32", "VTAUopPush",
1, 0, 1, 0,
...@@ -804,5 +811,6 @@ def debug_print(stmt): ...@@ -804,5 +811,6 @@ def debug_print(stmt):
stmt : Stmt stmt : Stmt
The The
""" """
# pylint: disable=superfluous-parens
print(stmt) print(stmt)
return stmt return stmt
...@@ -24,8 +24,7 @@ def reconfig_runtime(remote): ...@@ -24,8 +24,7 @@ def reconfig_runtime(remote):
"VTA_LOG_WGT_BUFF_SIZE", "VTA_LOG_WGT_BUFF_SIZE",
"VTA_LOG_ACC_BUFF_SIZE", "VTA_LOG_ACC_BUFF_SIZE",
"VTA_LOG_OUT_BUFF_SIZE"] "VTA_LOG_OUT_BUFF_SIZE"]
cflags = []
cflags = ["-DVTA_%s_TARGET" % env.target.upper()]
for k in keys: for k in keys:
cflags += ["-D%s=%s" % (k, str(getattr(env, k[4:])))] cflags += ["-D%s=%s" % (k, str(getattr(env, k[4:])))]
freconfig = remote.get_function("tvm.contrib.vta.reconfig_runtime") freconfig = remote.get_function("tvm.contrib.vta.reconfig_runtime")
......
"""Testing utilities, this namespace is not imported by default."""
from . util import run
"""Utilities to start simulator."""
import os
import ctypes
import json
import tvm
def _load_lib():
"""Load local library, assuming they are simulator."""
# pylint: disable=unused-variable
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
dll_path = [
os.path.abspath(os.path.join(curr_path, "../../../lib/libvta.so")),
os.path.abspath(os.path.join(curr_path, "../../../lib/libvta_runtime.so"))
]
runtime_dll = []
if not all(os.path.exists(f) for f in dll_path):
return []
try:
for fname in dll_path:
runtime_dll.append(ctypes.CDLL(fname, ctypes.RTLD_GLOBAL))
return runtime_dll
except OSError:
return []
def enabled():
"""Check if simulator is enabled."""
f = tvm.get_global_func("vta.simulator.profiler_clear", True)
return f is not None
def clear_stats():
"""Clear profiler statistics"""
f = tvm.get_global_func("vta.simulator.profiler_clear", True)
if f:
f()
def stats():
"""Clear profiler statistics
Returns
-------
stats : dict
Current profiler statistics
"""
x = tvm.get_global_func("vta.simulator.profiler_status")()
return json.loads(x)
LIBS = _load_lib()
"""Test Utilities"""
from __future__ import absolute_import as _abs
import os
from tvm.contrib import rpc
from ..environment import get_env
from . import simulator
def run(run_func):
"""Run test function on all available env.
Parameters
----------
run_func : function(env, remote)
"""
env = get_env()
# run on simulator
if simulator.enabled():
env.TARGET = "sim"
run_func(env, rpc.LocalSession())
# Run on PYNQ if env variable exists
pynq_host = os.environ.get("VTA_PYNQ_RPC_HOST", None)
if pynq_host:
env.TARGET = "pynq"
port = os.environ.get("VTA_PYNQ_RPC_PORT", "9091")
port = int(port)
remote = rpc.connect(pynq_host, port)
run_func(env, remote)
...@@ -57,7 +57,7 @@ struct DataBuffer { ...@@ -57,7 +57,7 @@ struct DataBuffer {
assert(data != nullptr); assert(data != nullptr);
DataBuffer* buffer = new DataBuffer(); DataBuffer* buffer = new DataBuffer();
buffer->data_ = data; buffer->data_ = data;
buffer->phy_addr_ = VTAGetMemPhysAddr(data); buffer->phy_addr_ = VTAMemGetPhyAddr(data);
return buffer; return buffer;
} }
/*! /*!
......
/*! /*!
* Copyright (c) 2018 by Contributors * Copyright (c) 2018 by Contributors
* \file vta_pynq_driver.c * \file pynq_driver.c
* \brief VTA driver for Pynq board. * \brief VTA driver for Pynq board.
*/ */
...@@ -17,7 +17,7 @@ void VTAMemFree(void* buf) { ...@@ -17,7 +17,7 @@ void VTAMemFree(void* buf) {
cma_free(buf); cma_free(buf);
} }
vta_phy_addr_t VTAGetMemPhysAddr(void* buf) { vta_phy_addr_t VTAMemGetPhyAddr(void* buf) {
return cma_get_phy_addr(buf); return cma_get_phy_addr(buf);
} }
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include <vta/driver.h> #include <vta/driver.h>
#include <vta/hw_spec.h> #include <vta/hw_spec.h>
#include <vta/runtime.h> #include <vta/runtime.h>
#include <dmlc/logging.h>
#include <cassert> #include <cassert>
#include <vector> #include <vector>
...@@ -113,7 +114,7 @@ class UopKernel { ...@@ -113,7 +114,7 @@ class UopKernel {
uint32_t wgt_index, uint32_t wgt_index,
uint32_t opcode, uint32_t opcode,
uint32_t use_imm, uint32_t use_imm,
uint32_t imm_val) { int32_t imm_val) {
// The loop nest structure // The loop nest structure
VerifyDep(dst_index); VerifyDep(dst_index);
VTAUop op; VTAUop op;
...@@ -166,7 +167,7 @@ class UopKernel { ...@@ -166,7 +167,7 @@ class UopKernel {
uint32_t opcode_{0xFFFFFFFF}; uint32_t opcode_{0xFFFFFFFF};
uint32_t reset_out_{0xFFFFFFFF}; uint32_t reset_out_{0xFFFFFFFF};
bool use_imm_{false}; bool use_imm_{false};
uint16_t imm_val_{0}; int16_t imm_val_{0};
private: private:
// Verify that we don't write to the same acc_mem index two cycles in a row // Verify that we don't write to the same acc_mem index two cycles in a row
...@@ -195,10 +196,6 @@ class UopKernel { ...@@ -195,10 +196,6 @@ class UopKernel {
/*! /*!
* \brief Base class of all queues to send and recv serial data. * \brief Base class of all queues to send and recv serial data.
* \param kElemBytes Element unit bytes.
* \param kMaxBytes Maximum number of bytes.
* \param kCoherent Whether we have coherent access to the buffer.
* \param kAlwaysCache Wether we should use cached memory.
*/ */
class BaseQueue { class BaseQueue {
public: public:
...@@ -227,7 +224,7 @@ class BaseQueue { ...@@ -227,7 +224,7 @@ class BaseQueue {
dram_buffer_ = static_cast<char*>(VTAMemAlloc( dram_buffer_ = static_cast<char*>(VTAMemAlloc(
max_bytes, coherent || always_cache_)); max_bytes, coherent || always_cache_));
assert(dram_buffer_ != nullptr); assert(dram_buffer_ != nullptr);
dram_phy_addr_ = VTAGetMemPhysAddr(dram_buffer_); dram_phy_addr_ = VTAMemGetPhyAddr(dram_buffer_);
} }
/*! /*!
* \brief Reset the pointer of the buffer. * \brief Reset the pointer of the buffer.
...@@ -604,7 +601,7 @@ class InsnQueue : public BaseQueue { ...@@ -604,7 +601,7 @@ class InsnQueue : public BaseQueue {
if (c.mem.memory_type == VTA_MEM_ID_ACC) printf("ACC\n"); if (c.mem.memory_type == VTA_MEM_ID_ACC) printf("ACC\n");
} }
if (c.mem.opcode == VTA_OPCODE_STORE) { if (c.mem.opcode == VTA_OPCODE_STORE) {
printf("STORE\n"); printf("STORE:\n");
} }
printf("\tdep - pop prev: %d, pop next: %d, push prev: %d, push next: %d\n", printf("\tdep - pop prev: %d, pop next: %d, push prev: %d, push next: %d\n",
static_cast<int>(c.mem.pop_prev_dep), static_cast<int>(c.mem.pop_prev_dep),
...@@ -1210,7 +1207,7 @@ void VTAUopPush(uint32_t mode, ...@@ -1210,7 +1207,7 @@ void VTAUopPush(uint32_t mode,
uint32_t wgt_index, uint32_t wgt_index,
uint32_t opcode, uint32_t opcode,
uint32_t use_imm, uint32_t use_imm,
uint32_t imm_val) { int32_t imm_val) {
vta::CommandQueue::ThreadLocal()->record_kernel() vta::CommandQueue::ThreadLocal()->record_kernel()
->Push(mode, reset_out, dst_index, src_index, ->Push(mode, reset_out, dst_index, src_index,
wgt_index, opcode, use_imm, imm_val); wgt_index, opcode, use_imm, imm_val);
......
...@@ -67,9 +67,6 @@ class VTADeviceAPI final : public DeviceAPI { ...@@ -67,9 +67,6 @@ class VTADeviceAPI final : public DeviceAPI {
std::make_shared<VTADeviceAPI>(); std::make_shared<VTADeviceAPI>();
return inst; return inst;
} }
private:
void* runtime_dll_{nullptr};
}; };
struct VTAWorkspacePool : public WorkspacePool { struct VTAWorkspacePool : public WorkspacePool {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment