Commit 666f32d6 by Tianqi Chen

[RUNTIME] Simplify dynamic library and code path. (#27)

* [RUNTIME] Simplify dynamic library and code path.

* reword the readme
parent 9c44e4b4
...@@ -53,12 +53,9 @@ else ...@@ -53,12 +53,9 @@ else
NO_WHOLE_ARCH= --no-whole-archive NO_WHOLE_ARCH= --no-whole-archive
endif endif
all: lib/libvta.so lib/libvta_runtime.so
VTA_LIB_SRC = $(wildcard src/*.cc src/tvm/*.cc) VTA_LIB_SRC = $(wildcard src/*.cc src/tvm/*.cc)
ifeq ($(TARGET), VTA_PYNQ_TARGET) ifeq ($(VTA_TARGET), pynq)
VTA_LIB_SRC += $(wildcard src/pynq/*.cc) VTA_LIB_SRC += $(wildcard src/pynq/*.cc)
LDFLAGS += -L/usr/lib -lsds_lib LDFLAGS += -L/usr/lib -lsds_lib
LDFLAGS += -L/opt/python3.6/lib/python3.6/site-packages/pynq/drivers/ LDFLAGS += -L/opt/python3.6/lib/python3.6/site-packages/pynq/drivers/
...@@ -66,24 +63,23 @@ ifeq ($(TARGET), VTA_PYNQ_TARGET) ...@@ -66,24 +63,23 @@ ifeq ($(TARGET), VTA_PYNQ_TARGET)
LDFLAGS += -l:libdma.so LDFLAGS += -l:libdma.so
endif endif
ifeq ($(TARGET), sim) ifeq ($(VTA_TARGET), sim)
VTA_LIB_SRC += $(wildcard src/sim/*.cc) VTA_LIB_SRC += $(wildcard src/sim/*.cc)
endif endif
VTA_LIB_OBJ = $(patsubst src/%.cc, build/%.o, $(VTA_LIB_SRC)) VTA_LIB_OBJ = $(patsubst src/%.cc, build/%.o, $(VTA_LIB_SRC))
all: lib/libvta.so
build/%.o: src/%.cc build/%.o: src/%.cc
@mkdir -p $(@D) @mkdir -p $(@D)
$(CXX) $(CFLAGS) -MM -MT build/src/$*.o $< >build/$*.d $(CXX) $(CFLAGS) -MM -MT build/$*.o $< >build/$*.d
$(CXX) -c $(CFLAGS) -c $< -o $@ $(CXX) -c $(CFLAGS) -c $< -o $@
lib/libvta.so: $(filter-out build/runtime.o, $(VTA_LIB_OBJ)) lib/libvta.so: $(VTA_LIB_OBJ)
@mkdir -p $(@D) @mkdir -p $(@D)
$(CXX) $(CFLAGS) -shared -o $@ $(filter %.o, $^) $(LDFLAGS) $(CXX) $(CFLAGS) -shared -o $@ $(filter %.o, $^) $(LDFLAGS)
lib/libvta_runtime.so: build/runtime.o
@mkdir -p $(@D)
$(CXX) $(CFLAGS) -shared -o $@ $(filter %.o, $^) $(LDFLAGS)
lint: pylint cpplint lint: pylint cpplint
......
Open Hardware/Software Stack for Vertical Deep Learning System Optimization VTA: Open, Modular, Deep Learning Accelerator Stack
============================================== ===================================================
[![GitHub license](http://dmlc.github.io/img/apache2.svg)](./LICENSE) [![GitHub license](http://dmlc.github.io/img/apache2.svg)](./LICENSE)
VTA is an open hardware/software co-design stack for deep learning systems systems. VTA(versatile tensor accelerator) is an open-source deep learning accelerator stack.
It provides a customizable hardware accelerator template for deep learning inference workloads, It is not just an open-source hardware, but is an end to end solution that includes
combined with a fully functional compiler stack built with TVM. the entire software stack on top of VTA open-source hardware.
The key features include:
- Generic, modular open-source hardware
- Streamlined workflow to deploy to FPGAs.
- Simulator support
- Driver and JIT runtime for both simulated backend and FPGA.
- End to end TVM stack integration
- Direct optimization and deploy models from deep learning frameworks via TVM stack.
- Customized and extendible TVM compiler backend
- Flexible RPC support to ease the deployment, you can program it with python :)
VTA is part of our effort on [TVM Stack](http://www.tvmlang.org/).
License License
------- -------
......
...@@ -26,8 +26,8 @@ ADD_LDFLAGS= ...@@ -26,8 +26,8 @@ ADD_LDFLAGS=
# the additional compile flags you want to add # the additional compile flags you want to add
ADD_CFLAGS= ADD_CFLAGS=
# the hardware target # the hardware target, can be [sim, pynq]
TARGET = pynq VTA_TARGET = pynq
#--------------------- #---------------------
# VTA hardware parameters # VTA hardware parameters
...@@ -88,7 +88,8 @@ $(shell echo "$$(( $(VTA_LOG_ACC_BUFF_SIZE) + $(VTA_LOG_OUT_WIDTH) - $(VTA_LOG_A ...@@ -88,7 +88,8 @@ $(shell echo "$$(( $(VTA_LOG_ACC_BUFF_SIZE) + $(VTA_LOG_OUT_WIDTH) - $(VTA_LOG_A
VTA_OUT_BUFF_SIZE = $(shell echo "$$(( 1 << $(VTA_LOG_OUT_BUFF_SIZE) ))" ) VTA_OUT_BUFF_SIZE = $(shell echo "$$(( 1 << $(VTA_LOG_OUT_BUFF_SIZE) ))" )
# Update ADD_CFLAGS # Update ADD_CFLAGS
ADD_CFLAGS += \ ADD_CFLAGS +=
-DVTA_TARGET=$(VTA_TARGET)\
-DVTA_LOG_WGT_WIDTH=$(VTA_LOG_WGT_WIDTH) -DVTA_LOG_INP_WIDTH=$(VTA_LOG_INP_WIDTH) \ -DVTA_LOG_WGT_WIDTH=$(VTA_LOG_WGT_WIDTH) -DVTA_LOG_INP_WIDTH=$(VTA_LOG_INP_WIDTH) \
-DVTA_LOG_ACC_WIDTH=$(VTA_LOG_ACC_WIDTH) -DVTA_LOG_OUT_WIDTH=$(VTA_LOG_OUT_WIDTH) \ -DVTA_LOG_ACC_WIDTH=$(VTA_LOG_ACC_WIDTH) -DVTA_LOG_OUT_WIDTH=$(VTA_LOG_OUT_WIDTH) \
-DVTA_LOG_BATCH=$(VTA_LOG_BATCH) \ -DVTA_LOG_BATCH=$(VTA_LOG_BATCH) \
......
"""TVM-based VTA Compiler Toolchain""" """TVM-based VTA Compiler Toolchain"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import sys
from .environment import get_env, Environment from .environment import get_env, Environment
...@@ -10,5 +11,5 @@ try: ...@@ -10,5 +11,5 @@ try:
from .rpc_client import reconfig_runtime, program_fpga from .rpc_client import reconfig_runtime, program_fpga
from . import graph from . import graph
except ImportError: except (ImportError, RuntimeError):
pass pass
...@@ -3,14 +3,10 @@ ...@@ -3,14 +3,10 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import os import os
import glob
import copy import copy
import tvm
try: from . import intrin
# Allow missing import in config mode.
import tvm
from . import intrin
except ImportError:
pass
class DevContext(object): class DevContext(object):
...@@ -65,6 +61,45 @@ class DevContext(object): ...@@ -65,6 +61,45 @@ class DevContext(object):
return 1 if self.DEBUG_NO_SYNC else qid return 1 if self.DEBUG_NO_SYNC else qid
class PkgConfig(object):
"""Simple package config tool for VTA.
This is used to provide runtime specific configurations.
"""
def __init__(self, env):
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
proj_root = os.path.abspath(os.path.join(curr_path, "../../"))
# include path
self.include_path = [
"-I%s/include" % proj_root,
"-I%s/nnvm/tvm/include" % proj_root,
"-I%s/nnvm/tvm/dlpack/include" % proj_root,
"-I%s/nnvm/dmlc-core/include" % proj_root
]
# List of source files that can be used to build standalone library.
self.lib_source = []
self.lib_source += glob.glob("%s/src/*.cc" % proj_root)
self.lib_source += glob.glob("%s/src/%s/*.cc" % (proj_root, env.TARGET))
# macro keys
self.macro_defs = []
for key in env.cfg_keys:
self.macro_defs.append("-DVTA_%s=%s" % (key, str(getattr(env, key))))
if env.TARGET == "pynq":
self.ldflags = [
"-L/usr/lib",
"-lsds_lib",
"-L/opt/python3.6/lib/python3.6/site-packages/pynq/drivers/",
"-L/opt/python3.6/lib/python3.6/site-packages/pynq/lib/",
"-l:libdma.so"]
else:
self.ldflags = []
@property
def cflags(self):
return self.include_path + self.macro_defs
class Environment(object): class Environment(object):
"""Hareware configuration object. """Hareware configuration object.
...@@ -160,6 +195,7 @@ class Environment(object): ...@@ -160,6 +195,7 @@ class Environment(object):
self.mock_mode = False self.mock_mode = False
self._mock_env = None self._mock_env = None
self._dev_ctx = None self._dev_ctx = None
self._pkg_config = None
@property @property
def dev(self): def dev(self):
...@@ -169,6 +205,13 @@ class Environment(object): ...@@ -169,6 +205,13 @@ class Environment(object):
return self._dev_ctx return self._dev_ctx
@property @property
def pkg_config(self):
"""PkgConfig instance"""
if self._pkg_config is None:
self._pkg_config = PkgConfig(self)
return self._pkg_config
@property
def mock(self): def mock(self):
"""A mock version of the Environment """A mock version of the Environment
...@@ -249,7 +292,7 @@ def mem_info_wgt_buffer(): ...@@ -249,7 +292,7 @@ def mem_info_wgt_buffer():
head_address=None) head_address=None)
@tvm.register_func("tvm.info.mem.%s" % Environment.acc_scope) @tvm.register_func("tvm.info.mem.%s" % Environment.acc_scope)
def mem_info_out_buffer(): def mem_info_acc_buffer():
spec = get_env() spec = get_env()
return tvm.make.node("MemoryInfo", return tvm.make.node("MemoryInfo",
unit_bits=spec.ACC_ELEM_BITS, unit_bits=spec.ACC_ELEM_BITS,
...@@ -265,6 +308,7 @@ def coproc_sync(op): ...@@ -265,6 +308,7 @@ def coproc_sync(op):
"int32", "VTASynchronize", "int32", "VTASynchronize",
get_env().dev.command_handle, 1<<31) get_env().dev.command_handle, 1<<31)
@tvm.register_func("tvm.intrin.rule.default.vta.coproc_dep_push") @tvm.register_func("tvm.intrin.rule.default.vta.coproc_dep_push")
def coproc_dep_push(op): def coproc_dep_push(op):
return tvm.call_extern( return tvm.call_extern(
...@@ -272,6 +316,7 @@ def coproc_dep_push(op): ...@@ -272,6 +316,7 @@ def coproc_dep_push(op):
get_env().dev.command_handle, get_env().dev.command_handle,
op.args[0], op.args[1]) op.args[0], op.args[1])
@tvm.register_func("tvm.intrin.rule.default.vta.coproc_dep_pop") @tvm.register_func("tvm.intrin.rule.default.vta.coproc_dep_pop")
def coproc_dep_pop(op): def coproc_dep_pop(op):
return tvm.call_extern( return tvm.call_extern(
...@@ -288,7 +333,6 @@ def _init_env(): ...@@ -288,7 +333,6 @@ def _init_env():
for k in Environment.cfg_keys: for k in Environment.cfg_keys:
keys.add("VTA_" + k) keys.add("VTA_" + k)
keys.add("TARGET")
if not os.path.isfile(filename): if not os.path.isfile(filename):
raise RuntimeError( raise RuntimeError(
...@@ -303,8 +347,9 @@ def _init_env(): ...@@ -303,8 +347,9 @@ def _init_env():
val = line.split("=")[1].strip() val = line.split("=")[1].strip()
if k.startswith("VTA_"): if k.startswith("VTA_"):
k = k[4:] k = k[4:]
try:
cfg[k] = int(val) cfg[k] = int(val)
else: except ValueError:
cfg[k] = val cfg[k] = val
return Environment(cfg) return Environment(cfg)
......
...@@ -9,26 +9,46 @@ import argparse ...@@ -9,26 +9,46 @@ import argparse
import os import os
import ctypes import ctypes
import tvm import tvm
from tvm._ffi.base import c_str
from tvm.contrib import rpc, cc from tvm.contrib import rpc, cc
from ..environment import get_env
@tvm.register_func("tvm.contrib.rpc.server.start", override=True) @tvm.register_func("tvm.contrib.rpc.server.start", override=True)
def server_start(): def server_start():
"""callback when server starts.""" """VTA RPC server extension."""
# pylint: disable=unused-variable # pylint: disable=unused-variable
curr_path = os.path.dirname( curr_path = os.path.dirname(
os.path.abspath(os.path.expanduser(__file__))) os.path.abspath(os.path.expanduser(__file__)))
dll_path = os.path.abspath( dll_path = os.path.abspath(
os.path.join(curr_path, "../../../lib/libvta_runtime.so")) os.path.join(curr_path, "../../../lib/libvta.so"))
runtime_dll = [] runtime_dll = []
_load_module = tvm.get_global_func("tvm.contrib.rpc.server.load_module") _load_module = tvm.get_global_func("tvm.contrib.rpc.server.load_module")
@tvm.register_func("tvm.contrib.rpc.server.load_module", override=True) def load_vta_dll():
def load_module(file_name): """Try to load vta dll"""
if not runtime_dll: if not runtime_dll:
runtime_dll.append(ctypes.CDLL(dll_path, ctypes.RTLD_GLOBAL)) runtime_dll.append(ctypes.CDLL(dll_path, ctypes.RTLD_GLOBAL))
logging.info("Loading VTA library: %s", dll_path)
return runtime_dll[0]
@tvm.register_func("tvm.contrib.rpc.server.load_module", override=True)
def load_module(file_name):
load_vta_dll()
return _load_module(file_name) return _load_module(file_name)
@tvm.register_func("device_api.ext_dev")
def ext_dev_callback():
load_vta_dll()
return tvm.get_global_func("device_api.ext_dev")()
@tvm.register_func("tvm.contrib.vta.init", override=True)
def program_fpga(file_name):
path = tvm.get_global_func("tvm.contrib.rpc.server.workpath")(file_name)
load_vta_dll().VTAProgram(c_str(path))
logging.info("Program FPGA with %s", file_name)
@tvm.register_func("tvm.contrib.rpc.server.shutdown", override=True) @tvm.register_func("tvm.contrib.rpc.server.shutdown", override=True)
def server_shutdown(): def server_shutdown():
if runtime_dll: if runtime_dll:
...@@ -47,17 +67,15 @@ def server_start(): ...@@ -47,17 +67,15 @@ def server_start():
if runtime_dll: if runtime_dll:
raise RuntimeError("Can only reconfig in the beginning of session...") raise RuntimeError("Can only reconfig in the beginning of session...")
cflags = cflags.split() cflags = cflags.split()
env = get_env()
cflags += ["-O2", "-std=c++11"] cflags += ["-O2", "-std=c++11"]
cflags += env.pkg_config.include_path
ldflags = env.pkg_config.ldflags
lib_name = dll_path lib_name = dll_path
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) source = env.pkg_config.lib_source
proj_root = os.path.abspath(os.path.join(curr_path, "../../../")) logging.info("Rebuild runtime: output=%s, cflags=%s, source=%s, ldflags=%s",
runtime_source = os.path.join(proj_root, "src/runtime.cc") dll_path, str(cflags), str(source), str(ldflags))
cflags += ["-I%s/include" % proj_root] cc.create_shared(lib_name, source, cflags + ldflags)
cflags += ["-I%s/nnvm/tvm/include" % proj_root]
cflags += ["-I%s/nnvm/tvm/dlpack/include" % proj_root]
cflags += ["-I%s/nnvm/dmlc-core/include" % proj_root]
logging.info("Rebuild runtime dll with %s", str(cflags))
cc.create_shared(lib_name, [runtime_source], cflags)
def main(): def main():
...@@ -75,14 +93,6 @@ def main(): ...@@ -75,14 +93,6 @@ def main():
help="Report to RPC tracker") help="Report to RPC tracker")
args = parser.parse_args() args = parser.parse_args()
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
proj_root = os.path.abspath(os.path.join(curr_path, "../../../"))
lib_path = os.path.abspath(os.path.join(proj_root, "lib/libvta.so"))
libs = []
for file_name in [lib_path]:
libs.append(ctypes.CDLL(file_name, ctypes.RTLD_GLOBAL))
logging.info("Load additional library %s", file_name)
if args.tracker: if args.tracker:
url, port = args.tracker.split(":") url, port = args.tracker.split(":")
...@@ -99,7 +109,6 @@ def main(): ...@@ -99,7 +109,6 @@ def main():
args.port_end, args.port_end,
key=args.key, key=args.key,
tracker_addr=tracker_addr) tracker_addr=tracker_addr)
server.libs += libs
server.proc.join() server.proc.join()
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -10,7 +10,6 @@ def _load_lib(): ...@@ -10,7 +10,6 @@ def _load_lib():
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
dll_path = [ dll_path = [
os.path.abspath(os.path.join(curr_path, "../../../lib/libvta.so")), os.path.abspath(os.path.join(curr_path, "../../../lib/libvta.so")),
os.path.abspath(os.path.join(curr_path, "../../../lib/libvta_runtime.so"))
] ]
runtime_dll = [] runtime_dll = []
if not all(os.path.exists(f) for f in dll_path): if not all(os.path.exists(f) for f in dll_path):
......
...@@ -15,6 +15,14 @@ def run(run_func): ...@@ -15,6 +15,14 @@ def run(run_func):
run_func : function(env, remote) run_func : function(env, remote)
""" """
env = get_env() env = get_env()
# Run on local sim rpc if necessary
local_rpc = int(os.environ.get("VTA_LOCAL_SIM_RPC", "0"))
if local_rpc:
env.TARGET = "sim"
remote = rpc.connect("localhost", local_rpc)
run_func(env, remote)
else:
# run on simulator # run on simulator
if simulator.enabled(): if simulator.enabled():
env.TARGET = "sim" env.TARGET = "sim"
......
/*!
* Copyright (c) 2018 by Contributors
* \file data_buffer.cc
* \brief Buffer related API for VTA.
* \note Buffer API remains stable across VTA designes.
*/
#include "./data_buffer.h"
void* VTABufferAlloc(size_t size) {
return vta::DataBuffer::Alloc(size);
}
void VTABufferFree(void* buffer) {
vta::DataBuffer::Free(vta::DataBuffer::FromHandle(buffer));
}
void VTABufferCopy(const void* from,
size_t from_offset,
void* to,
size_t to_offset,
size_t size,
int kind_mask) {
vta::DataBuffer* from_buffer = nullptr;
vta::DataBuffer* to_buffer = nullptr;
if (kind_mask & 2) {
from_buffer = vta::DataBuffer::FromHandle(from);
from = from_buffer->virt_addr();
}
if (kind_mask & 1) {
to_buffer = vta::DataBuffer::FromHandle(to);
to = to_buffer->virt_addr();
}
if (from_buffer) {
from_buffer->InvalidateCache(from_offset, size);
}
memcpy(static_cast<char*>(to) + to_offset,
static_cast<const char*>(from) + from_offset,
size);
if (to_buffer) {
to_buffer->FlushCache(to_offset, size);
}
}
/*!
* Copyright (c) 2018 by Contributors
* \file data_buffer.h
* \brief VTA runtime internal data buffer structure.
*/
#ifndef VTA_DATA_BUFFER_H_
#define VTA_DATA_BUFFER_H_
#include <vta/driver.h>
#include <vta/runtime.h>
#include <cassert>
#include <cstring>
namespace vta {
/*! \brief Enable coherent access between VTA and CPU. */
static const bool kBufferCoherent = true;
/*!
* \brief Data buffer represents data on CMA.
*/
struct DataBuffer {
/*! \return Virtual address of the data. */
void* virt_addr() const {
return data_;
}
/*! \return Physical address of the data. */
uint32_t phy_addr() const {
return phy_addr_;
}
/*!
* \brief Invalidate the cache of given location in data buffer.
* \param offset The offset to the data.
* \param size The size of the data.
*/
void InvalidateCache(size_t offset, size_t size) {
if (!kBufferCoherent) {
VTAInvalidateCache(phy_addr_ + offset, size);
}
}
/*!
* \brief Invalidate the cache of certain location in data buffer.
* \param offset The offset to the data.
* \param size The size of the data.
*/
void FlushCache(size_t offset, size_t size) {
if (!kBufferCoherent) {
VTAFlushCache(phy_addr_ + offset, size);
}
}
/*!
* \brief Allocate a buffer of a given size.
* \param size The size of the buffer.
*/
static DataBuffer* Alloc(size_t size) {
void* data = VTAMemAlloc(size, 1);
assert(data != nullptr);
DataBuffer* buffer = new DataBuffer();
buffer->data_ = data;
buffer->phy_addr_ = VTAMemGetPhyAddr(data);
return buffer;
}
/*!
* \brief Free the data buffer.
* \param buffer The buffer to be freed.
*/
static void Free(DataBuffer* buffer) {
VTAMemFree(buffer->data_);
delete buffer;
}
/*!
* \brief Create data buffer header from buffer ptr.
* \param buffer The buffer pointer.
* \return The corresponding data buffer header.
*/
static DataBuffer* FromHandle(const void* buffer) {
return const_cast<DataBuffer*>(
reinterpret_cast<const DataBuffer*>(buffer));
}
private:
/*! \brief The internal data. */
void* data_;
/*! \brief The physical address of the buffer, excluding header. */
uint32_t phy_addr_;
};
} // namespace vta
#endif // VTA_DATA_BUFFER_H_
/*! /*!
* Copyright (c) 2018 by Contributors * Copyright (c) 2018 by Contributors
* \file vta_device_api.cc * \file device_api.cc
* \brief VTA device API for TVM * \brief TVM device API for VTA
*/ */
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <dmlc/thread_local.h> #include <dmlc/thread_local.h>
#include <vta/runtime.h> #include <vta/runtime.h>
#include <dlfcn.h>
#include "../../nnvm/tvm/src/runtime/workspace_pool.h" #include "../nnvm/tvm/src/runtime/workspace_pool.h"
namespace tvm { namespace tvm {
...@@ -26,7 +25,8 @@ class VTADeviceAPI final : public DeviceAPI { ...@@ -26,7 +25,8 @@ class VTADeviceAPI final : public DeviceAPI {
} }
void* AllocDataSpace(TVMContext ctx, void* AllocDataSpace(TVMContext ctx,
size_t size, size_t alignment, size_t size,
size_t alignment,
TVMType type_hint) final { TVMType type_hint) final {
return VTABufferAlloc(size); return VTABufferAlloc(size);
} }
...@@ -84,22 +84,9 @@ void VTADeviceAPI::FreeWorkspace(TVMContext ctx, void* data) { ...@@ -84,22 +84,9 @@ void VTADeviceAPI::FreeWorkspace(TVMContext ctx, void* data) {
dmlc::ThreadLocalStore<VTAWorkspacePool>::Get()->FreeWorkspace(ctx, data); dmlc::ThreadLocalStore<VTAWorkspacePool>::Get()->FreeWorkspace(ctx, data);
} }
std::string VTARPCGetPath(const std::string& name) { // Register device api with override.
static const PackedFunc* f = static TVM_ATTRIBUTE_UNUSED auto& __register_dev__ =
runtime::Registry::Get("tvm.contrib.rpc.server.workpath"); ::tvm::runtime::Registry::Register("device_api.ext_dev", true)
CHECK(f != nullptr) << "require tvm.contrib.rpc.server.workpath";
return (*f)(name);
}
// Global functions that can be called
TVM_REGISTER_GLOBAL("tvm.contrib.vta.init")
.set_body([](TVMArgs args, TVMRetValue* rv) {
std::string path = VTARPCGetPath(args[0]);
VTAProgram(path.c_str());
LOG(INFO) << "VTA initialization end with bistream " << path;
});
TVM_REGISTER_GLOBAL("device_api.ext_dev")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = VTADeviceAPI::Global().get(); DeviceAPI* ptr = VTADeviceAPI::Global().get();
*rv = static_cast<void*>(ptr); *rv = static_cast<void*>(ptr);
......
...@@ -5,8 +5,6 @@ ...@@ -5,8 +5,6 @@
* *
* The runtime depends on specific instruction * The runtime depends on specific instruction
* stream spec as specified in hw_spec.h * stream spec as specified in hw_spec.h
* It is intended to be used as a dynamic library
* to enable hot swapping of hardware configurations.
*/ */
#include <vta/driver.h> #include <vta/driver.h>
#include <vta/hw_spec.h> #include <vta/hw_spec.h>
...@@ -14,15 +12,87 @@ ...@@ -14,15 +12,87 @@
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <cassert> #include <cassert>
#include <cstring>
#include <vector> #include <vector>
#include <thread> #include <thread>
#include <memory> #include <memory>
#include <atomic> #include <atomic>
#include "./data_buffer.h"
namespace vta { namespace vta {
/*! \brief Enable coherent access between VTA and CPU. */
static const bool kBufferCoherent = true;
/*!
* \brief Data buffer represents data on CMA.
*/
struct DataBuffer {
/*! \return Virtual address of the data. */
void* virt_addr() const {
return data_;
}
/*! \return Physical address of the data. */
uint32_t phy_addr() const {
return phy_addr_;
}
/*!
* \brief Invalidate the cache of given location in data buffer.
* \param offset The offset to the data.
* \param size The size of the data.
*/
void InvalidateCache(size_t offset, size_t size) {
if (!kBufferCoherent) {
VTAInvalidateCache(phy_addr_ + offset, size);
}
}
/*!
* \brief Invalidate the cache of certain location in data buffer.
* \param offset The offset to the data.
* \param size The size of the data.
*/
void FlushCache(size_t offset, size_t size) {
if (!kBufferCoherent) {
VTAFlushCache(phy_addr_ + offset, size);
}
}
/*!
* \brief Allocate a buffer of a given size.
* \param size The size of the buffer.
*/
static DataBuffer* Alloc(size_t size) {
void* data = VTAMemAlloc(size, 1);
CHECK(data != nullptr);
DataBuffer* buffer = new DataBuffer();
buffer->data_ = data;
buffer->phy_addr_ = VTAMemGetPhyAddr(data);
return buffer;
}
/*!
* \brief Free the data buffer.
* \param buffer The buffer to be freed.
*/
static void Free(DataBuffer* buffer) {
VTAMemFree(buffer->data_);
delete buffer;
}
/*!
* \brief Create data buffer header from buffer ptr.
* \param buffer The buffer pointer.
* \return The corresponding data buffer header.
*/
static DataBuffer* FromHandle(const void* buffer) {
return const_cast<DataBuffer*>(
reinterpret_cast<const DataBuffer*>(buffer));
}
private:
/*! \brief The internal data. */
void* data_;
/*! \brief The physical address of the buffer, excluding header. */
uint32_t phy_addr_;
};
/*! /*!
* \brief Micro op kernel. * \brief Micro op kernel.
* Contains functions to construct the kernel with prefix Push. * Contains functions to construct the kernel with prefix Push.
...@@ -1130,6 +1200,42 @@ class CommandQueue { ...@@ -1130,6 +1200,42 @@ class CommandQueue {
} // namespace vta } // namespace vta
void* VTABufferAlloc(size_t size) {
return vta::DataBuffer::Alloc(size);
}
void VTABufferFree(void* buffer) {
vta::DataBuffer::Free(vta::DataBuffer::FromHandle(buffer));
}
void VTABufferCopy(const void* from,
size_t from_offset,
void* to,
size_t to_offset,
size_t size,
int kind_mask) {
vta::DataBuffer* from_buffer = nullptr;
vta::DataBuffer* to_buffer = nullptr;
if (kind_mask & 2) {
from_buffer = vta::DataBuffer::FromHandle(from);
from = from_buffer->virt_addr();
}
if (kind_mask & 1) {
to_buffer = vta::DataBuffer::FromHandle(to);
to = to_buffer->virt_addr();
}
if (from_buffer) {
from_buffer->InvalidateCache(from_offset, size);
}
memcpy(static_cast<char*>(to) + to_offset,
static_cast<const char*>(from) + from_offset,
size);
if (to_buffer) {
to_buffer->FlushCache(to_offset, size);
}
}
VTACommandHandle VTATLSCommandHandle() { VTACommandHandle VTATLSCommandHandle() {
return vta::CommandQueue::ThreadLocal().get(); return vta::CommandQueue::ThreadLocal().get();
......
...@@ -468,7 +468,22 @@ def test_shift_and_scale(): ...@@ -468,7 +468,22 @@ def test_shift_and_scale():
vta.testing.run(_run) vta.testing.run(_run)
def test_runtime_array():
def _run(env, remote):
n = 100
ctx = remote.ext_dev(0)
x_np = np.random.randint(
1, 10, size=(n, n, env.BATCH, env.BLOCK_OUT)).astype("int8")
x_nd = tvm.nd.array(x_np, ctx)
np.testing.assert_equal(x_np, x_nd.asnumpy())
vta.testing.run(_run)
if __name__ == "__main__": if __name__ == "__main__":
print("Array test")
test_runtime_array()
print("Load/store test") print("Load/store test")
test_save_load_out() test_save_load_out()
print("Padded load test") print("Padded load test")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment