Commit 33a309b2 by Tianqi Chen

[RPC][RUNTIME] Support dynamic reload of runtime API according to config (#19)

parent dc5bdb6b
...@@ -40,36 +40,31 @@ ifneq ($(ADD_LDFLAGS), NONE) ...@@ -40,36 +40,31 @@ ifneq ($(ADD_LDFLAGS), NONE)
LDFLAGS += $(ADD_LDFLAGS) LDFLAGS += $(ADD_LDFLAGS)
endif endif
ifeq ($(UNAME_S), Darwin)
SHARED_LIBRARY_SUFFIX := dylib
WHOLE_ARCH= -all_load
NO_WHOLE_ARCH= -noall_load
LDFLAGS += -undefined dynamic_lookup
else
SHARED_LIBRARY_SUFFIX := so
WHOLE_ARCH= --whole-archive
NO_WHOLE_ARCH= --no-whole-archive
endif
all: lib/libvta.so lib/libvta_runtime.so
all: lib/libvta.$(SHARED_LIBRARY_SUFFIX)
VTA_LIB_SRC = $(wildcard src/*.cc src/tvm/*.cc) VTA_LIB_SRC = $(wildcard src/*.cc src/tvm/*.cc)
ifeq ($(TARGET), VTA_PYNQ_TARGET) ifeq ($(TARGET), VTA_PYNQ_TARGET)
VTA_LIB_SRC += $(wildcard src/pynq/*.cc) VTA_LIB_SRC += $(wildcard src/pynq/*.cc)
LDFLAGS += -L/usr/lib -lsds_lib LDFLAGS += -L/usr/lib -lsds_lib
LDFLAGS += -L/opt/python3.6/lib/python3.6/site-packages/pynq/drivers/ -l:libdma.so LDFLAGS += -L/opt/python3.6/lib/python3.6/site-packages/pynq/drivers/
LDFLAGS += -L/opt/python3.6/lib/python3.6/site-packages/pynq/lib/
LDFLAGS += -l:libdma.so
endif endif
VTA_LIB_OBJ = $(patsubst %.cc, build/%.o, $(VTA_LIB_SRC))
test: $(TEST) VTA_LIB_OBJ = $(patsubst src/%.cc, build/%.o, $(VTA_LIB_SRC))
build/src/%.o: src/%.cc build/%.o: src/%.cc
@mkdir -p $(@D) @mkdir -p $(@D)
$(CXX) $(CFLAGS) -MM -MT build/src/$*.o $< >build/src/$*.d $(CXX) $(CFLAGS) -MM -MT build/src/$*.o $< >build/$*.d
$(CXX) -c $(CFLAGS) -c $< -o $@ $(CXX) -c $(CFLAGS) -c $< -o $@
lib/libvta.$(SHARED_LIBRARY_SUFFIX): $(VTA_LIB_OBJ) lib/libvta.so: $(filter-out build/runtime.o, $(VTA_LIB_OBJ))
@mkdir -p $(@D)
$(CXX) $(CFLAGS) -shared -o $@ $(filter %.o, $^) $(LDFLAGS)
lib/libvta_runtime.so: build/runtime.o
@mkdir -p $(@D) @mkdir -p $(@D)
$(CXX) $(CFLAGS) -shared -o $@ $(filter %.o, $^) $(LDFLAGS) $(CXX) $(CFLAGS) -shared -o $@ $(filter %.o, $^) $(LDFLAGS)
...@@ -79,7 +74,7 @@ cpplint: ...@@ -79,7 +74,7 @@ cpplint:
python nnvm/dmlc-core/scripts/lint.py vta cpp include src hardware tests python nnvm/dmlc-core/scripts/lint.py vta cpp include src hardware tests
pylint: pylint:
pylint python/tvm_vta --rcfile=$(ROOTDIR)/tests/lint/pylintrc pylint python/vta --rcfile=$(ROOTDIR)/tests/lint/pylintrc
doc: doc:
doxygen docs/Doxyfile doxygen docs/Doxyfile
......
#!/bin/bash #!/bin/bash
export PYTHONPATH=${PYTHONPATH}:/home/xilinx/tvm/python export PYTHONPATH=${PYTHONPATH}:/home/xilinx/tvm/python:/home/xilinx/vta/python
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/opt/python3.6/lib/python3.6/site-packages/pynq/drivers/ export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/opt/python3.6/lib/python3.6/site-packages/pynq/drivers/
python -m tvm.exec.rpc_server --load-library /home/xilinx/vta/lib/libvta.so python -m vta.exec.rpc_server
...@@ -34,9 +34,7 @@ for file in [TEST_FILE, CATEG_FILE, RESNET_GRAPH_FILE, RESNET_PARAMS_FILE, BITST ...@@ -34,9 +34,7 @@ for file in [TEST_FILE, CATEG_FILE, RESNET_GRAPH_FILE, RESNET_PARAMS_FILE, BITST
# Program the FPGA remotely # Program the FPGA remotely
assert tvm.module.enabled("rpc") assert tvm.module.enabled("rpc")
remote = rpc.connect(host, port) remote = rpc.connect(host, port)
remote.upload(BITSTREAM_FILE, BITSTREAM_FILE) vta.program_fpga(remote, BITSTREAM_FILE)
fprogram = remote.get_function("tvm.contrib.vta.init")
fprogram(BITSTREAM_FILE)
if verbose: if verbose:
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
......
...@@ -23,40 +23,20 @@ extern "C" { ...@@ -23,40 +23,20 @@ extern "C" {
#define VTA_DEBUG_SKIP_WRITE_BARRIER (1 << 4) #define VTA_DEBUG_SKIP_WRITE_BARRIER (1 << 4)
#define VTA_DEBUG_FORCE_SERIAL (1 << 5) #define VTA_DEBUG_FORCE_SERIAL (1 << 5)
/*! \brief VTA command handle */
typedef void * VTACommandHandle;
/*! \brief Shutdown hook of VTA to cleanup resources */
void VTARuntimeShutdown();
/*!
* \brief Get thread local command handle.
* \return A thread local command handle.
*/
VTACommandHandle VTATLSCommandHandle();
/*! /*!
* \brief Allocate data buffer. * \brief Allocate data buffer.
* \param cmd The VTA command handle. * \param cmd The VTA command handle.
* \param size Buffer size. * \param size Buffer size.
* \return A pointer to the allocated buffer. * \return A pointer to the allocated buffer.
*/ */
void* VTABufferAlloc(VTACommandHandle cmd, size_t size); void* VTABufferAlloc(size_t size);
/*! /*!
* \brief Free data buffer. * \brief Free data buffer.
* \param cmd The VTA command handle. * \param cmd The VTA command handle.
* \param buffer The data buffer to be freed. * \param buffer The data buffer to be freed.
*/ */
void VTABufferFree(VTACommandHandle cmd, void* buffer); void VTABufferFree(void* buffer);
/*!
* \brief Get the buffer access pointer on CPU.
* \param cmd The VTA command handle.
* \param buffer The data buffer.
* \return The pointer that can be accessed by the CPU.
*/
void* VTABufferCPUPtr(VTACommandHandle cmd, void* buffer);
/*! /*!
* \brief Copy data buffer from one location to another. * \brief Copy data buffer from one location to another.
...@@ -68,20 +48,32 @@ void* VTABufferCPUPtr(VTACommandHandle cmd, void* buffer); ...@@ -68,20 +48,32 @@ void* VTABufferCPUPtr(VTACommandHandle cmd, void* buffer);
* \param size Size of copy. * \param size Size of copy.
* \param kind_mask The memory copy kind. * \param kind_mask The memory copy kind.
*/ */
void VTABufferCopy(VTACommandHandle cmd, void VTABufferCopy(const void* from,
const void* from,
size_t from_offset, size_t from_offset,
void* to, void* to,
size_t to_offset, size_t to_offset,
size_t size, size_t size,
int kind_mask); int kind_mask);
/*! \brief VTA command handle */
typedef void* VTACommandHandle;
/*! \brief Shutdown hook of VTA to cleanup resources */
void VTARuntimeShutdown();
/*! /*!
* \brief Set debug mode on the command handle. * \brief Get thread local command handle.
* \return A thread local command handle.
*/
VTACommandHandle VTATLSCommandHandle();
/*!
* \brief Get the buffer access pointer on CPU.
* \param cmd The VTA command handle. * \param cmd The VTA command handle.
* \param debug_flag The debug flag. * \param buffer The data buffer.
* \return The pointer that can be accessed by the CPU.
*/ */
void VTASetDebugMode(VTACommandHandle cmd, int debug_flag); void* VTABufferCPUPtr(VTACommandHandle cmd, void* buffer);
/*! /*!
* \brief Perform a write barrier to make a memory region visible to the CPU. * \brief Perform a write barrier to make a memory region visible to the CPU.
...@@ -92,9 +84,10 @@ void VTASetDebugMode(VTACommandHandle cmd, int debug_flag); ...@@ -92,9 +84,10 @@ void VTASetDebugMode(VTACommandHandle cmd, int debug_flag);
* \param extent The end of the region (in elements). * \param extent The end of the region (in elements).
*/ */
void VTAWriteBarrier(VTACommandHandle cmd, void VTAWriteBarrier(VTACommandHandle cmd,
void* buffer, uint32_t elem_bits, void* buffer,
uint32_t start, uint32_t extent); uint32_t elem_bits,
uint32_t start,
uint32_t extent);
/*! /*!
* \brief Perform a read barrier to a memory region visible to VTA. * \brief Perform a read barrier to a memory region visible to VTA.
* \param cmd The VTA command handle. * \param cmd The VTA command handle.
...@@ -104,8 +97,17 @@ void VTAWriteBarrier(VTACommandHandle cmd, ...@@ -104,8 +97,17 @@ void VTAWriteBarrier(VTACommandHandle cmd,
* \param extent The end of the region (in elements). * \param extent The end of the region (in elements).
*/ */
void VTAReadBarrier(VTACommandHandle cmd, void VTAReadBarrier(VTACommandHandle cmd,
void* buffer, uint32_t elem_bits, void* buffer,
uint32_t start, uint32_t extent); uint32_t elem_bits,
uint32_t start,
uint32_t extent);
/*!
* \brief Set debug mode on the command handle.
* \param cmd The VTA command handle.
* \param debug_flag The debug flag.
*/
void VTASetDebugMode(VTACommandHandle cmd, int debug_flag);
/*! /*!
* \brief Perform a 2D data load from DRAM. * \brief Perform a 2D data load from DRAM.
......
...@@ -54,6 +54,7 @@ VTA_LOG_WGT_BUFF_SIZE = 15 ...@@ -54,6 +54,7 @@ VTA_LOG_WGT_BUFF_SIZE = 15
# Log of acc buffer size in Bytes # Log of acc buffer size in Bytes
VTA_LOG_ACC_BUFF_SIZE = 17 VTA_LOG_ACC_BUFF_SIZE = 17
#--------------------- #---------------------
# Derived VTA hardware parameters # Derived VTA hardware parameters
#-------------------- #--------------------
......
"""TVM VTA runtime""" """TVM-based VTA Compiler Toolchain"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from .hw_spec import * from .hw_spec import *
from .runtime import SCOPE_INP, SCOPE_OUT, SCOPE_WGT, DMA_COPY, ALU try:
from .intrin import GEVM, GEMM from .runtime import SCOPE_INP, SCOPE_OUT, SCOPE_WGT, DMA_COPY, ALU
from .build import debug_mode from .intrin import GEVM, GEMM
from .build import debug_mode
from . import mock, ir_pass
from . import arm_conv2d, vta_conv2d
except AttributeError:
pass
from . import mock, ir_pass from .rpc_client import reconfig_runtime, program_fpga
from . import arm_conv2d, vta_conv2d
from . import graph try:
from . import graph
except ImportError:
pass
"""VTA Command line utils."""
"""VTA customized TVM RPC Server
Provides additional runtime function and library loading.
"""
from __future__ import absolute_import
import logging
import argparse
import os
import ctypes
import tvm
from tvm.contrib import rpc, util, cc
@tvm.register_func("tvm.contrib.rpc.server.start", override=True)
def server_start():
curr_path = os.path.dirname(
os.path.abspath(os.path.expanduser(__file__)))
dll_path = os.path.abspath(
os.path.join(curr_path, "../../../lib/libvta_runtime.so"))
runtime_dll = []
_load_module = tvm.get_global_func("tvm.contrib.rpc.server.load_module")
@tvm.register_func("tvm.contrib.rpc.server.load_module", override=True)
def load_module(file_name):
if not runtime_dll:
runtime_dll.append(ctypes.CDLL(dll_path, ctypes.RTLD_GLOBAL))
return _load_module(file_name)
@tvm.register_func("tvm.contrib.rpc.server.shutdown", override=True)
def server_shutdown():
if runtime_dll:
runtime_dll[0].VTARuntimeShutdown()
runtime_dll.pop()
@tvm.register_func("tvm.contrib.vta.reconfig_runtime", override=True)
def reconfig_runtime(cflags):
"""Rebuild and reload runtime with new configuration.
Parameters
----------
cfg_json : str
JSON string used for configurations.
"""
if runtime_dll:
raise RuntimeError("Can only reconfig in the beginning of session...")
cflags = cflags.split()
cflags += ["-O2", "-std=c++11"]
lib_name = dll_path
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
proj_root = os.path.abspath(os.path.join(curr_path, "../../../"))
runtime_source = os.path.join(proj_root, "src/runtime.cc")
cflags += ["-I%s/include" % proj_root]
cflags += ["-I%s/nnvm/tvm/include" % proj_root]
cflags += ["-I%s/nnvm/tvm/dlpack/include" % proj_root]
cflags += ["-I%s/nnvm/dmlc-core/include" % proj_root]
logging.info("Rebuild runtime dll with %s", str(cflags))
cc.create_shared(lib_name, [runtime_source], cflags)
def main():
"""Main funciton"""
parser = argparse.ArgumentParser()
parser.add_argument('--host', type=str, default="0.0.0.0",
help='the hostname of the server')
parser.add_argument('--port', type=int, default=9090,
help='The port of the PRC')
parser.add_argument('--port-end', type=int, default=9199,
help='The end search port of the PRC')
parser.add_argument('--key', type=str, default="",
help="RPC key used to identify the connection type.")
parser.add_argument('--tracker', type=str, default="",
help="Report to RPC tracker")
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
proj_root = os.path.abspath(os.path.join(curr_path, "../../../"))
lib_path = os.path.abspath(os.path.join(proj_root, "lib/libvta.so"))
libs = []
for file_name in [lib_path]:
libs.append(ctypes.CDLL(file_name, ctypes.RTLD_GLOBAL))
logging.info("Load additional library %s", file_name)
if args.tracker:
url, port = args.tracker.split(":")
port = int(port)
tracker_addr = (url, port)
if not args.key:
raise RuntimeError(
"Need key to present type of resource when tracker is available")
else:
tracker_addr = None
server = rpc.Server(args.host,
args.port,
args.port_end,
key=args.key,
tracker_addr=tracker_addr)
server.libs += libs
server.proc.join()
if __name__ == "__main__":
main()
"""VTA configuration constants (should match hw_spec.h""" """VTA configuration constants (should match hw_spec.h"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
# Log of input/activation width in bits (default 3 -> 8 bits)
VTA_LOG_INP_WIDTH = 3
# Log of kernel weight width in bits (default 3 -> 8 bits)
VTA_LOG_WGT_WIDTH = 3
# Log of accum width in bits (default 5 -> 32 bits)
VTA_LOG_ACC_WIDTH = 5
# Log of tensor batch size (A in (A,B)x(B,C) matrix multiplication)
VTA_LOG_BATCH = 0
# Log of tensor inner block size (B in (A,B)x(B,C) matrix multiplication)
VTA_LOG_BLOCK_IN = 4
# Log of tensor outer block size (C in (A,B)x(B,C) matrix multiplication)
VTA_LOG_BLOCK_OUT = 4
VTA_LOG_OUT_WIDTH = VTA_LOG_INP_WIDTH
# Log of uop buffer size in Bytes
VTA_LOG_UOP_BUFF_SIZE = 15
# Log of acc buffer size in Bytes
VTA_LOG_ACC_BUFF_SIZE = 17
# The Constants # The Constants
VTA_WGT_WIDTH = 8 VTA_WGT_WIDTH = 8
VTA_INP_WIDTH = VTA_WGT_WIDTH VTA_INP_WIDTH = VTA_WGT_WIDTH
VTA_OUT_WIDTH = 32 VTA_OUT_WIDTH = 32
VTA_TARGET = "VTA_PYNQ_TARGET"
# Dimensions of the GEMM unit # Dimensions of the GEMM unit
# (BATCH,BLOCK_IN) x (BLOCK_IN,BLOCK_OUT) # (BATCH,BLOCK_IN) x (BLOCK_IN,BLOCK_OUT)
VTA_BATCH = 1 VTA_BATCH = 1
...@@ -67,4 +87,4 @@ VTA_QID_STORE_INP = 3 ...@@ -67,4 +87,4 @@ VTA_QID_STORE_INP = 3
DEBUG_DUMP_INSN = (1 << 1) DEBUG_DUMP_INSN = (1 << 1)
DEBUG_DUMP_UOP = (1 << 2) DEBUG_DUMP_UOP = (1 << 2)
DEBUG_SKIP_READ_BARRIER = (1 << 3) DEBUG_SKIP_READ_BARRIER = (1 << 3)
DEBUG_SKIP_WRITE_BARRIER = (1 << 4) DEBUG_SKIP_WRITE_BARRIER = (1 << 4)
\ No newline at end of file
"""VTA RPC client function"""
import os
from . import hw_spec as spec
def reconfig_runtime(remote):
"""Reconfigure remote runtime based on current hardware spec.
Parameters
----------
remote : RPCSession
The TVM RPC session
"""
keys = ["VTA_LOG_WGT_WIDTH",
"VTA_LOG_INP_WIDTH",
"VTA_LOG_ACC_WIDTH",
"VTA_LOG_OUT_WIDTH",
"VTA_LOG_BATCH",
"VTA_LOG_BLOCK_IN",
"VTA_LOG_BLOCK_OUT",
"VTA_LOG_UOP_BUFF_SIZE",
"VTA_LOG_INP_BUFF_SIZE",
"VTA_LOG_WGT_BUFF_SIZE",
"VTA_LOG_ACC_BUFF_SIZE",
"VTA_LOG_OUT_BUFF_SIZE"]
cflags = ["-D%s" % spec.VTA_TARGET]
for k in keys:
cflags += ["-D%s=%s" % (k, str(getattr(spec, k)))]
freconfig = remote.get_function("tvm.contrib.vta.reconfig_runtime")
freconfig(" ".join(cflags))
def program_fpga(remote, bitstream):
"""Upload and program bistream
Parameters
----------
remote : RPCSession
The TVM RPC session
bitstream : str
Path to a local bistream file.
"""
fprogram = remote.get_function("tvm.contrib.vta.init")
remote.upload(bitstream)
fprogram(os.path.basename(bitstream))
...@@ -25,7 +25,6 @@ def get_task_qid(qid): ...@@ -25,7 +25,6 @@ def get_task_qid(qid):
"""Get transformed queue index.""" """Get transformed queue index."""
return 1 if DEBUG_NO_SYNC else qid return 1 if DEBUG_NO_SYNC else qid
@tvm.register_func("tvm.intrin.rule.default.vta.coproc_sync") @tvm.register_func("tvm.intrin.rule.default.vta.coproc_sync")
def coproc_sync(op): def coproc_sync(op):
return tvm.call_extern( return tvm.call_extern(
......
/*!
* Copyright (c) 2018 by Contributors
* \file data_buffer.cc
* \brief Buffer related API for VTA.
* \note Buffer API remains stable across VTA designes.
*/
#include "./data_buffer.h"
void* VTABufferAlloc(size_t size) {
return vta::DataBuffer::Alloc(size);
}
void VTABufferFree(void* buffer) {
vta::DataBuffer::Free(vta::DataBuffer::FromHandle(buffer));
}
void VTABufferCopy(const void* from,
size_t from_offset,
void* to,
size_t to_offset,
size_t size,
int kind_mask) {
vta::DataBuffer* from_buffer = nullptr;
vta::DataBuffer* to_buffer = nullptr;
if (kind_mask & 2) {
from_buffer = vta::DataBuffer::FromHandle(from);
from = from_buffer->virt_addr();
}
if (kind_mask & 1) {
to_buffer = vta::DataBuffer::FromHandle(to);
to = to_buffer->virt_addr();
}
if (from_buffer) {
from_buffer->InvalidateCache(from_offset, size);
}
memcpy(static_cast<char*>(to) + to_offset,
static_cast<const char*>(from) + from_offset,
size);
if (to_buffer) {
to_buffer->FlushCache(to_offset, size);
}
}
/*!
* Copyright (c) 2018 by Contributors
* \file data_buffer.h
* \brief VTA runtime internal data buffer structure.
*/
#ifndef VTA_DATA_BUFFER_H_
#define VTA_DATA_BUFFER_H_
#include <vta/driver.h>
#include <vta/runtime.h>
#include <cassert>
#include <cstring>
namespace vta {
/*! \brief Enable coherent access between VTA and CPU. */
static const bool kBufferCoherent = true;
/*!
* \brief Data buffer represents data on CMA.
*/
struct DataBuffer {
/*! \return Virtual address of the data. */
void* virt_addr() const {
return data_;
}
/*! \return Physical address of the data. */
uint32_t phy_addr() const {
return phy_addr_;
}
/*!
* \brief Invalidate the cache of given location in data buffer.
* \param offset The offset to the data.
* \param size The size of the data.
*/
void InvalidateCache(size_t offset, size_t size) {
if (!kBufferCoherent) {
VTAInvalidateCache(reinterpret_cast<void*>(phy_addr_ + offset), size);
}
}
/*!
* \brief Invalidate the cache of certain location in data buffer.
* \param offset The offset to the data.
* \param size The size of the data.
*/
void FlushCache(size_t offset, size_t size) {
if (!kBufferCoherent) {
VTAFlushCache(reinterpret_cast<void*>(phy_addr_ + offset), size);
}
}
/*!
* \brief Allocate a buffer of a given size.
* \param size The size of the buffer.
*/
static DataBuffer* Alloc(size_t size) {
void* data = VTAMemAlloc(size, 1);
assert(data != nullptr);
DataBuffer* buffer = new DataBuffer();
buffer->data_ = data;
buffer->phy_addr_ = VTAGetMemPhysAddr(data);
return buffer;
}
/*!
* \brief Free the data buffer.
* \param buffer The buffer to be freed.
*/
static void Free(DataBuffer* buffer) {
VTAMemFree(buffer->data_);
delete buffer;
}
/*!
* \brief Create data buffer header from buffer ptr.
* \param buffer The buffer pointer.
* \return The corresponding data buffer header.
*/
static DataBuffer* FromHandle(const void* buffer) {
return const_cast<DataBuffer*>(
reinterpret_cast<const DataBuffer*>(buffer));
}
private:
/*! \brief The internal data. */
void* data_;
/*! \brief The physical address of the buffer, excluding header. */
uint32_t phy_addr_;
};
} // namespace vta
#endif // VTA_DATA_BUFFER_H_
/*! /*!
* Copyright (c) 2018 by Contributors * Copyright (c) 2018 by Contributors
* \file vta_runtime.cc * \file runtime.cc
* \brief VTA runtime for PYNQ in C++11 * \brief VTA runtime for PYNQ in C++11
*/ */
...@@ -13,85 +13,14 @@ ...@@ -13,85 +13,14 @@
#include <vta/runtime.h> #include <vta/runtime.h>
#include <cassert> #include <cassert>
#include <cstring>
#include <vector> #include <vector>
#include <thread> #include <thread>
#include <memory> #include <memory>
#include <atomic> #include <atomic>
namespace vta { #include "./data_buffer.h"
/*! \brief Enable coherent access between VTA and CPU. */
static const bool kBufferCoherent = true;
/*! namespace vta {
* \brief Data buffer represents data on CMA.
*/
struct DataBuffer {
/*! \return Virtual address of the data. */
void* virt_addr() const {
return data_;
}
/*! \return Physical address of the data. */
uint32_t phy_addr() const {
return phy_addr_;
}
/*!
* \brief Invalidate the cache of given location in data buffer.
* \param offset The offset to the data.
* \param size The size of the data.
*/
void InvalidateCache(size_t offset, size_t size) {
if (!kBufferCoherent) {
VTAInvalidateCache(reinterpret_cast<void*>(phy_addr_ + offset), size);
}
}
/*!
* \brief Invalidate the cache of certain location in data buffer.
* \param offset The offset to the data.
* \param size The size of the data.
*/
void FlushCache(size_t offset, size_t size) {
if (!kBufferCoherent) {
VTAFlushCache(reinterpret_cast<void*>(phy_addr_ + offset), size);
}
}
/*!
* \brief Allocate a buffer of a given size.
* \param size The size of the buffer.
*/
static DataBuffer* Alloc(size_t size) {
void* data = VTAMemAlloc(size, 1);
assert(data != nullptr);
DataBuffer* buffer = new DataBuffer();
buffer->data_ = data;
buffer->phy_addr_ = VTAGetMemPhysAddr(data);
return buffer;
}
/*!
* \brief Free the data buffer.
* \param buffer The buffer to be freed.
*/
static void Free(DataBuffer* buffer) {
VTAMemFree(buffer->data_);
delete buffer;
}
/*!
* \brief Create data buffer header from buffer ptr.
* \param buffer The buffer pointer.
* \return The corresponding data buffer header.
*/
static DataBuffer* FromHandle(const void* buffer) {
return const_cast<DataBuffer*>(
reinterpret_cast<const DataBuffer*>(buffer));
}
private:
/*! \brief The internal data. */
void* data_;
/*! \brief The physical address of the buffer, excluding header. */
uint32_t phy_addr_;
};
/*! /*!
* \brief Micro op kernel. * \brief Micro op kernel.
...@@ -1130,6 +1059,9 @@ class CommandQueue { ...@@ -1130,6 +1059,9 @@ class CommandQueue {
static std::shared_ptr<CommandQueue>& ThreadLocal() { static std::shared_ptr<CommandQueue>& ThreadLocal() {
static std::shared_ptr<CommandQueue> inst = static std::shared_ptr<CommandQueue> inst =
std::make_shared<CommandQueue>(); std::make_shared<CommandQueue>();
if (inst == nullptr) {
inst = std::make_shared<CommandQueue>();
}
return inst; return inst;
} }
...@@ -1254,63 +1186,29 @@ void VTARuntimeShutdown() { ...@@ -1254,63 +1186,29 @@ void VTARuntimeShutdown() {
vta::CommandQueue::Shutdown(); vta::CommandQueue::Shutdown();
} }
void* VTABufferAlloc(VTACommandHandle cmd, size_t size) { void VTASetDebugMode(VTACommandHandle cmd, int debug_flag) {
return vta::DataBuffer::Alloc(size); static_cast<vta::CommandQueue*>(cmd)->
} SetDebugFlag(debug_flag);
void VTABufferFree(VTACommandHandle cmd, void* buffer) {
vta::DataBuffer::Free(vta::DataBuffer::FromHandle(buffer));
} }
void* VTABufferCPUPtr(VTACommandHandle cmd, void* buffer) { void* VTABufferCPUPtr(VTACommandHandle cmd, void* buffer) {
return vta::DataBuffer::FromHandle(buffer)->virt_addr(); return vta::DataBuffer::FromHandle(buffer)->virt_addr();
} }
void VTABufferCopy(VTACommandHandle cmd,
const void* from,
size_t from_offset,
void* to,
size_t to_offset,
size_t size,
int kind_mask) {
vta::DataBuffer* from_buffer = nullptr;
vta::DataBuffer* to_buffer = nullptr;
if (kind_mask & 2) {
from_buffer = vta::DataBuffer::FromHandle(from);
from = from_buffer->virt_addr();
}
if (kind_mask & 1) {
to_buffer = vta::DataBuffer::FromHandle(to);
to = to_buffer->virt_addr();
}
if (from_buffer) {
from_buffer->InvalidateCache(from_offset, size);
}
memcpy(static_cast<char*>(to) + to_offset,
static_cast<const char*>(from) + from_offset,
size);
if (to_buffer) {
to_buffer->FlushCache(to_offset, size);
}
}
void VTASetDebugMode(VTACommandHandle cmd, int debug_flag) {
static_cast<vta::CommandQueue*>(cmd)->
SetDebugFlag(debug_flag);
}
void VTAWriteBarrier(VTACommandHandle cmd, void VTAWriteBarrier(VTACommandHandle cmd,
void* buffer, uint32_t elem_bits, void* buffer,
uint32_t start, uint32_t extent) { uint32_t elem_bits,
uint32_t start,
uint32_t extent) {
static_cast<vta::CommandQueue*>(cmd)-> static_cast<vta::CommandQueue*>(cmd)->
WriteBarrier(buffer, elem_bits, start, extent); WriteBarrier(buffer, elem_bits, start, extent);
} }
void VTAReadBarrier(VTACommandHandle cmd, void VTAReadBarrier(VTACommandHandle cmd,
void* buffer, uint32_t elem_bits, void* buffer,
uint32_t start, uint32_t extent) { uint32_t elem_bits,
uint32_t start,
uint32_t extent) {
static_cast<vta::CommandQueue*>(cmd)-> static_cast<vta::CommandQueue*>(cmd)->
ReadBarrier(buffer, elem_bits, start, extent); ReadBarrier(buffer, elem_bits, start, extent);
} }
...@@ -1409,3 +1307,11 @@ void VTASynchronize(VTACommandHandle cmd, uint32_t wait_cycles) { ...@@ -1409,3 +1307,11 @@ void VTASynchronize(VTACommandHandle cmd, uint32_t wait_cycles) {
static_cast<vta::CommandQueue*>(cmd)-> static_cast<vta::CommandQueue*>(cmd)->
Synchronize(wait_cycles); Synchronize(wait_cycles);
} }
extern "C" int VTARuntimeDynamicMagic() {
#ifdef VTA_DYNAMIC_MAGIC
return VTA_DYNAMIC_MAGIC;
#else
return 0;
#endif
}
...@@ -7,31 +7,17 @@ ...@@ -7,31 +7,17 @@
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <dmlc/thread_local.h> #include <dmlc/thread_local.h>
#include <vta/runtime.h> #include <vta/runtime.h>
#include <dlfcn.h>
#include "../../nnvm/tvm/src/runtime/workspace_pool.h" #include "../../nnvm/tvm/src/runtime/workspace_pool.h"
namespace tvm { extern "C" {
namespace runtime { typedef void (*FShutdown)();
typedef int (*FDynamicMagic)();
std::string VTARPCGetPath(const std::string& name) {
static const PackedFunc* f =
runtime::Registry::Get("tvm.contrib.rpc.server.workpath");
CHECK(f != nullptr) << "require tvm.contrib.rpc.server.workpath";
return (*f)(name);
} }
// Global functions that can be called namespace tvm {
TVM_REGISTER_GLOBAL("tvm.contrib.vta.init") namespace runtime {
.set_body([](TVMArgs args, TVMRetValue* rv) {
std::string path = VTARPCGetPath(args[0]);
VTAProgram(path.c_str());
LOG(INFO) << "VTA initialization end with bistream " << path;
});
TVM_REGISTER_GLOBAL("tvm.contrib.rpc.server.shutdown")
.set_body([](TVMArgs args, TVMRetValue* rv) {
VTARuntimeShutdown();
});
class VTADeviceAPI final : public DeviceAPI { class VTADeviceAPI final : public DeviceAPI {
public: public:
...@@ -46,11 +32,11 @@ class VTADeviceAPI final : public DeviceAPI { ...@@ -46,11 +32,11 @@ class VTADeviceAPI final : public DeviceAPI {
void* AllocDataSpace(TVMContext ctx, void* AllocDataSpace(TVMContext ctx,
size_t size, size_t alignment, size_t size, size_t alignment,
TVMType type_hint) final { TVMType type_hint) final {
return VTABufferAlloc(VTATLSCommandHandle(), size); return VTABufferAlloc(size);
} }
void FreeDataSpace(TVMContext ctx, void* ptr) final { void FreeDataSpace(TVMContext ctx, void* ptr) final {
VTABufferFree(VTATLSCommandHandle(), ptr); VTABufferFree(ptr);
} }
void CopyDataFromTo(const void* from, void CopyDataFromTo(const void* from,
...@@ -68,8 +54,7 @@ class VTADeviceAPI final : public DeviceAPI { ...@@ -68,8 +54,7 @@ class VTADeviceAPI final : public DeviceAPI {
if (ctx_to.device_type != kDLCPU) { if (ctx_to.device_type != kDLCPU) {
kind_mask |= 1; kind_mask |= 1;
} }
VTABufferCopy(VTATLSCommandHandle(), VTABufferCopy(from, from_offset,
from, from_offset,
to, to_offset, to, to_offset,
size, kind_mask); size, kind_mask);
} }
...@@ -86,6 +71,9 @@ class VTADeviceAPI final : public DeviceAPI { ...@@ -86,6 +71,9 @@ class VTADeviceAPI final : public DeviceAPI {
std::make_shared<VTADeviceAPI>(); std::make_shared<VTADeviceAPI>();
return inst; return inst;
} }
private:
void* runtime_dll_{nullptr};
}; };
struct VTAWorkspacePool : public WorkspacePool { struct VTAWorkspacePool : public WorkspacePool {
...@@ -103,6 +91,21 @@ void VTADeviceAPI::FreeWorkspace(TVMContext ctx, void* data) { ...@@ -103,6 +91,21 @@ void VTADeviceAPI::FreeWorkspace(TVMContext ctx, void* data) {
dmlc::ThreadLocalStore<VTAWorkspacePool>::Get()->FreeWorkspace(ctx, data); dmlc::ThreadLocalStore<VTAWorkspacePool>::Get()->FreeWorkspace(ctx, data);
} }
std::string VTARPCGetPath(const std::string& name) {
static const PackedFunc* f =
runtime::Registry::Get("tvm.contrib.rpc.server.workpath");
CHECK(f != nullptr) << "require tvm.contrib.rpc.server.workpath";
return (*f)(name);
}
// Global functions that can be called
TVM_REGISTER_GLOBAL("tvm.contrib.vta.init")
.set_body([](TVMArgs args, TVMRetValue* rv) {
std::string path = VTARPCGetPath(args[0]);
VTAProgram(path.c_str());
LOG(INFO) << "VTA initialization end with bistream " << path;
});
TVM_REGISTER_GLOBAL("device_api.ext_dev") TVM_REGISTER_GLOBAL("device_api.ext_dev")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = VTADeviceAPI::Global().get(); DeviceAPI* ptr = VTADeviceAPI::Global().get();
......
...@@ -27,6 +27,7 @@ inp_dtype = "int%d" % vta.VTA_INP_WIDTH ...@@ -27,6 +27,7 @@ inp_dtype = "int%d" % vta.VTA_INP_WIDTH
target = "llvm -target=armv7-none-linux-gnueabihf -mattr=+neon" target = "llvm -target=armv7-none-linux-gnueabihf -mattr=+neon"
print_ir = False print_ir = False
def test_vta_conv2d(key, batch_size, wl, profile=True): def test_vta_conv2d(key, batch_size, wl, profile=True):
data_shape = (batch_size, wl.in_filter//vta.VTA_BLOCK_IN, data_shape = (batch_size, wl.in_filter//vta.VTA_BLOCK_IN,
wl.height, wl.width, vta.VTA_BLOCK_IN) wl.height, wl.width, vta.VTA_BLOCK_IN)
...@@ -54,6 +55,7 @@ def test_vta_conv2d(key, batch_size, wl, profile=True): ...@@ -54,6 +55,7 @@ def test_vta_conv2d(key, batch_size, wl, profile=True):
mod = tvm.build(s, [data, kernel, bias, res], "ext_dev", target, name="conv2d") mod = tvm.build(s, [data, kernel, bias, res], "ext_dev", target, name="conv2d")
temp = util.tempdir() temp = util.tempdir()
remote = rpc.connect(host, port) remote = rpc.connect(host, port)
mod.save(temp.relpath("conv2d.o")) mod.save(temp.relpath("conv2d.o"))
remote.upload(temp.relpath("conv2d.o")) remote.upload(temp.relpath("conv2d.o"))
f = remote.load_module("conv2d.o") f = remote.load_module("conv2d.o")
......
...@@ -14,8 +14,12 @@ bitstream = os.path.join(curr_path, "./", bit) ...@@ -14,8 +14,12 @@ bitstream = os.path.join(curr_path, "./", bit)
def test_program_rpc(): def test_program_rpc():
assert tvm.module.enabled("rpc") assert tvm.module.enabled("rpc")
remote = rpc.connect(host, port) remote = rpc.connect(host, port)
remote.upload(bitstream, bit) vta.program_fpga(remote, bit)
fprogram = remote.get_function("tvm.contrib.vta.init")
fprogram(bit) def test_reconfig_runtime():
assert tvm.module.enabled("rpc")
remote = rpc.connect(host, port)
vta.reconfig_runtime(remote)
test_program_rpc() test_program_rpc()
test_reconfig_runtime()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment