Commit 9c44e4b4 by Tianqi Chen

[DRIVER] Add simulator, unify testcase to unittest (#25)

parent 0faefafc
...@@ -40,6 +40,19 @@ ifneq ($(ADD_LDFLAGS), NONE) ...@@ -40,6 +40,19 @@ ifneq ($(ADD_LDFLAGS), NONE)
LDFLAGS += $(ADD_LDFLAGS) LDFLAGS += $(ADD_LDFLAGS)
endif endif
UNAME_S := $(shell uname -s)
ifeq ($(UNAME_S), Darwin)
SHARED_LIBRARY_SUFFIX := dylib
WHOLE_ARCH= -all_load
NO_WHOLE_ARCH= -noall_load
LDFLAGS += -undefined dynamic_lookup
else
SHARED_LIBRARY_SUFFIX := so
WHOLE_ARCH= --whole-archive
NO_WHOLE_ARCH= --no-whole-archive
endif
all: lib/libvta.so lib/libvta_runtime.so all: lib/libvta.so lib/libvta_runtime.so
...@@ -53,6 +66,10 @@ ifeq ($(TARGET), VTA_PYNQ_TARGET) ...@@ -53,6 +66,10 @@ ifeq ($(TARGET), VTA_PYNQ_TARGET)
LDFLAGS += -l:libdma.so LDFLAGS += -l:libdma.so
endif endif
ifeq ($(TARGET), sim)
VTA_LIB_SRC += $(wildcard src/sim/*.cc)
endif
VTA_LIB_OBJ = $(patsubst src/%.cc, build/%.o, $(VTA_LIB_SRC)) VTA_LIB_OBJ = $(patsubst src/%.cc, build/%.o, $(VTA_LIB_SRC))
build/%.o: src/%.cc build/%.o: src/%.cc
...@@ -71,7 +88,7 @@ lib/libvta_runtime.so: build/runtime.o ...@@ -71,7 +88,7 @@ lib/libvta_runtime.so: build/runtime.o
lint: pylint cpplint lint: pylint cpplint
cpplint: cpplint:
python nnvm/dmlc-core/scripts/lint.py vta cpp include src hardware tests python nnvm/dmlc-core/scripts/lint.py vta cpp include src
pylint: pylint:
pylint python/vta --rcfile=$(ROOTDIR)/tests/lint/pylintrc pylint python/vta --rcfile=$(ROOTDIR)/tests/lint/pylintrc
...@@ -86,3 +103,4 @@ clean: ...@@ -86,3 +103,4 @@ clean:
-include build/*.d -include build/*.d
-include build/*/*.d -include build/*/*.d
-include build/*/*/*.d -include build/*/*/*.d
-include build/*/*/*/*.d
...@@ -77,7 +77,7 @@ void VTAMemFree(void* buf); ...@@ -77,7 +77,7 @@ void VTAMemFree(void* buf);
* \param buf Pointer to memory region allocated with VTAMemAlloc. * \param buf Pointer to memory region allocated with VTAMemAlloc.
* \return The physical address of the memory region. * \return The physical address of the memory region.
*/ */
vta_phy_addr_t VTAGetMemPhysAddr(void* buf); vta_phy_addr_t VTAMemGetPhyAddr(void* buf);
/*! /*!
* \brief Flushes the region of memory out of the CPU cache to DRAM. * \brief Flushes the region of memory out of the CPU cache to DRAM.
......
...@@ -519,8 +519,8 @@ typedef struct { ...@@ -519,8 +519,8 @@ typedef struct {
uint64_t alu_opcode : VTA_ALU_OPCODE_BIT_WIDTH; uint64_t alu_opcode : VTA_ALU_OPCODE_BIT_WIDTH;
/*! \brief Use immediate is true */ /*! \brief Use immediate is true */
uint64_t use_imm : 1; uint64_t use_imm : 1;
/*! \brief Immediate value */ /*! \brief Immediate value: allow negative value */
uint64_t imm : VTA_ALUOP_IMM_BIT_WIDTH; int64_t imm : VTA_ALUOP_IMM_BIT_WIDTH;
} VTAAluInsn; } VTAAluInsn;
/*! \brief VTA ALU instruction converter */ /*! \brief VTA ALU instruction converter */
......
...@@ -196,7 +196,7 @@ void VTAUopPush(uint32_t mode, ...@@ -196,7 +196,7 @@ void VTAUopPush(uint32_t mode,
uint32_t wgt_index, uint32_t wgt_index,
uint32_t opcode, uint32_t opcode,
uint32_t use_imm, uint32_t use_imm,
uint32_t imm_val); int32_t imm_val);
/*! /*!
* \brief Mark start of a micro op loop. * \brief Mark start of a micro op loop.
......
...@@ -27,7 +27,7 @@ ADD_LDFLAGS= ...@@ -27,7 +27,7 @@ ADD_LDFLAGS=
ADD_CFLAGS= ADD_CFLAGS=
# the hardware target # the hardware target
TARGET = VTA_PYNQ_TARGET TARGET = pynq
#--------------------- #---------------------
# VTA hardware parameters # VTA hardware parameters
...@@ -89,7 +89,6 @@ VTA_OUT_BUFF_SIZE = $(shell echo "$$(( 1 << $(VTA_LOG_OUT_BUFF_SIZE) ))" ) ...@@ -89,7 +89,6 @@ VTA_OUT_BUFF_SIZE = $(shell echo "$$(( 1 << $(VTA_LOG_OUT_BUFF_SIZE) ))" )
# Update ADD_CFLAGS # Update ADD_CFLAGS
ADD_CFLAGS += \ ADD_CFLAGS += \
-D$(TARGET) \
-DVTA_LOG_WGT_WIDTH=$(VTA_LOG_WGT_WIDTH) -DVTA_LOG_INP_WIDTH=$(VTA_LOG_INP_WIDTH) \ -DVTA_LOG_WGT_WIDTH=$(VTA_LOG_WGT_WIDTH) -DVTA_LOG_INP_WIDTH=$(VTA_LOG_INP_WIDTH) \
-DVTA_LOG_ACC_WIDTH=$(VTA_LOG_ACC_WIDTH) -DVTA_LOG_OUT_WIDTH=$(VTA_LOG_OUT_WIDTH) \ -DVTA_LOG_ACC_WIDTH=$(VTA_LOG_ACC_WIDTH) -DVTA_LOG_OUT_WIDTH=$(VTA_LOG_OUT_WIDTH) \
-DVTA_LOG_BATCH=$(VTA_LOG_BATCH) \ -DVTA_LOG_BATCH=$(VTA_LOG_BATCH) \
......
...@@ -8,6 +8,7 @@ try: ...@@ -8,6 +8,7 @@ try:
from . import arm_conv2d, vta_conv2d from . import arm_conv2d, vta_conv2d
from .build_module import build_config, lower, build from .build_module import build_config, lower, build
from .rpc_client import reconfig_runtime, program_fpga from .rpc_client import reconfig_runtime, program_fpga
from . import graph from . import graph
except ImportError: except ImportError:
pass pass
...@@ -89,7 +89,7 @@ class Environment(object): ...@@ -89,7 +89,7 @@ class Environment(object):
""" """
current = None current = None
cfg_keys = [ cfg_keys = [
"target", "TARGET",
"LOG_INP_WIDTH", "LOG_INP_WIDTH",
"LOG_WGT_WIDTH", "LOG_WGT_WIDTH",
"LOG_ACC_WIDTH", "LOG_ACC_WIDTH",
...@@ -204,9 +204,19 @@ class Environment(object): ...@@ -204,9 +204,19 @@ class Environment(object):
@property @property
def gevm(self): def gevm(self):
"""GEMM intrinsic""" """GEVM intrinsic"""
return self.dev.gevm return self.dev.gevm
@property
def target_host(self):
"""The target host"""
if self.TARGET == "pynq":
return "llvm -target=armv7-none-linux-gnueabihf"
elif self.TARGET == "sim":
return "llvm"
else:
raise ValueError("Unknown target %s" % self.TARGET)
def get_env(): def get_env():
"""Get the current VTA Environment. """Get the current VTA Environment.
...@@ -278,6 +288,7 @@ def _init_env(): ...@@ -278,6 +288,7 @@ def _init_env():
for k in Environment.cfg_keys: for k in Environment.cfg_keys:
keys.add("VTA_" + k) keys.add("VTA_" + k)
keys.add("TARGET")
if not os.path.isfile(filename): if not os.path.isfile(filename):
raise RuntimeError( raise RuntimeError(
...@@ -290,8 +301,11 @@ def _init_env(): ...@@ -290,8 +301,11 @@ def _init_env():
for k in keys: for k in keys:
if k +" =" in line: if k +" =" in line:
val = line.split("=")[1].strip() val = line.split("=")[1].strip()
cfg[k[4:]] = int(val) if k.startswith("VTA_"):
cfg["target"] = "pynq" k = k[4:]
cfg[k] = int(val)
else:
cfg[k] = val
return Environment(cfg) return Environment(cfg)
Environment.current = _init_env() Environment.current = _init_env()
...@@ -78,8 +78,7 @@ def fold_uop_loop(stmt_in): ...@@ -78,8 +78,7 @@ def fold_uop_loop(stmt_in):
if not fail[0]: if not fail[0]:
begin = tvm.call_extern( begin = tvm.call_extern(
"int32", "VTAUopLoopBegin", stmt.extent, *gemm_offsets) "int32", "VTAUopLoopBegin", stmt.extent, *gemm_offsets)
end = tvm.call_extern( end = tvm.call_extern("int32", "VTAUopLoopEnd")
"int32", "VTAUopLoopEnd", stmt.extent, *gemm_offsets)
return [begin, ret, end] return [begin, ret, end]
raise ValueError("Failed to fold the GEMM instructions..") raise ValueError("Failed to fold the GEMM instructions..")
...@@ -683,8 +682,14 @@ def inject_alu_intrin(stmt_in): ...@@ -683,8 +682,14 @@ def inject_alu_intrin(stmt_in):
else: else:
raise RuntimeError( raise RuntimeError(
"Function call not recognized %s" % (loop_body.value.name)) "Function call not recognized %s" % (loop_body.value.name))
elif isinstance(loop_body.value, tvm.expr.Load):
alu_opcode = env.dev.ALU_OPCODE_SHR
lhs = loop_body.value
rhs = tvm.const(0)
else: else:
raise RuntimeError("Expression not recognized %s" % (type(loop_body.value))) raise RuntimeError(
"Expression not recognized %s, %s, %s" % (
type(loop_body.value), str(loop_body.value), str(stmt)))
# Derive array index coefficients # Derive array index coefficients
dst_coeff = tvm.arith.DetectLinearEquation(dst_idx, indices) dst_coeff = tvm.arith.DetectLinearEquation(dst_idx, indices)
...@@ -772,7 +777,9 @@ def inject_alu_intrin(stmt_in): ...@@ -772,7 +777,9 @@ def inject_alu_intrin(stmt_in):
irb = tvm.ir_builder.create() irb = tvm.ir_builder.create()
for idx, extent in enumerate(extents): for idx, extent in enumerate(extents):
irb.emit(tvm.call_extern( irb.emit(tvm.call_extern(
"int32", "VTAUopLoopBegin", extent, dst_coeff[idx], src_coeff[idx])) "int32", "VTAUopLoopBegin",
extent, dst_coeff[idx], src_coeff[idx], 0))
use_imm = int(use_imm)
irb.emit(tvm.call_extern( irb.emit(tvm.call_extern(
"int32", "VTAUopPush", "int32", "VTAUopPush",
1, 0, 1, 0,
...@@ -804,5 +811,6 @@ def debug_print(stmt): ...@@ -804,5 +811,6 @@ def debug_print(stmt):
stmt : Stmt stmt : Stmt
The The
""" """
# pylint: disable=superfluous-parens
print(stmt) print(stmt)
return stmt return stmt
...@@ -24,8 +24,7 @@ def reconfig_runtime(remote): ...@@ -24,8 +24,7 @@ def reconfig_runtime(remote):
"VTA_LOG_WGT_BUFF_SIZE", "VTA_LOG_WGT_BUFF_SIZE",
"VTA_LOG_ACC_BUFF_SIZE", "VTA_LOG_ACC_BUFF_SIZE",
"VTA_LOG_OUT_BUFF_SIZE"] "VTA_LOG_OUT_BUFF_SIZE"]
cflags = []
cflags = ["-DVTA_%s_TARGET" % env.target.upper()]
for k in keys: for k in keys:
cflags += ["-D%s=%s" % (k, str(getattr(env, k[4:])))] cflags += ["-D%s=%s" % (k, str(getattr(env, k[4:])))]
freconfig = remote.get_function("tvm.contrib.vta.reconfig_runtime") freconfig = remote.get_function("tvm.contrib.vta.reconfig_runtime")
......
"""Testing utilities, this namespace is not imported by default."""
from . util import run
"""Utilities to start simulator."""
import os
import ctypes
import json
import tvm
def _load_lib():
"""Load local library, assuming they are simulator."""
# pylint: disable=unused-variable
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
dll_path = [
os.path.abspath(os.path.join(curr_path, "../../../lib/libvta.so")),
os.path.abspath(os.path.join(curr_path, "../../../lib/libvta_runtime.so"))
]
runtime_dll = []
if not all(os.path.exists(f) for f in dll_path):
return []
try:
for fname in dll_path:
runtime_dll.append(ctypes.CDLL(fname, ctypes.RTLD_GLOBAL))
return runtime_dll
except OSError:
return []
def enabled():
"""Check if simulator is enabled."""
f = tvm.get_global_func("vta.simulator.profiler_clear", True)
return f is not None
def clear_stats():
"""Clear profiler statistics"""
f = tvm.get_global_func("vta.simulator.profiler_clear", True)
if f:
f()
def stats():
"""Clear profiler statistics
Returns
-------
stats : dict
Current profiler statistics
"""
x = tvm.get_global_func("vta.simulator.profiler_status")()
return json.loads(x)
LIBS = _load_lib()
"""Test Utilities"""
from __future__ import absolute_import as _abs
import os
from tvm.contrib import rpc
from ..environment import get_env
from . import simulator
def run(run_func):
"""Run test function on all available env.
Parameters
----------
run_func : function(env, remote)
"""
env = get_env()
# run on simulator
if simulator.enabled():
env.TARGET = "sim"
run_func(env, rpc.LocalSession())
# Run on PYNQ if env variable exists
pynq_host = os.environ.get("VTA_PYNQ_RPC_HOST", None)
if pynq_host:
env.TARGET = "pynq"
port = os.environ.get("VTA_PYNQ_RPC_PORT", "9091")
port = int(port)
remote = rpc.connect(pynq_host, port)
run_func(env, remote)
...@@ -57,7 +57,7 @@ struct DataBuffer { ...@@ -57,7 +57,7 @@ struct DataBuffer {
assert(data != nullptr); assert(data != nullptr);
DataBuffer* buffer = new DataBuffer(); DataBuffer* buffer = new DataBuffer();
buffer->data_ = data; buffer->data_ = data;
buffer->phy_addr_ = VTAGetMemPhysAddr(data); buffer->phy_addr_ = VTAMemGetPhyAddr(data);
return buffer; return buffer;
} }
/*! /*!
......
/*! /*!
* Copyright (c) 2018 by Contributors * Copyright (c) 2018 by Contributors
* \file vta_pynq_driver.c * \file pynq_driver.c
* \brief VTA driver for Pynq board. * \brief VTA driver for Pynq board.
*/ */
...@@ -17,7 +17,7 @@ void VTAMemFree(void* buf) { ...@@ -17,7 +17,7 @@ void VTAMemFree(void* buf) {
cma_free(buf); cma_free(buf);
} }
vta_phy_addr_t VTAGetMemPhysAddr(void* buf) { vta_phy_addr_t VTAMemGetPhyAddr(void* buf) {
return cma_get_phy_addr(buf); return cma_get_phy_addr(buf);
} }
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include <vta/driver.h> #include <vta/driver.h>
#include <vta/hw_spec.h> #include <vta/hw_spec.h>
#include <vta/runtime.h> #include <vta/runtime.h>
#include <dmlc/logging.h>
#include <cassert> #include <cassert>
#include <vector> #include <vector>
...@@ -113,7 +114,7 @@ class UopKernel { ...@@ -113,7 +114,7 @@ class UopKernel {
uint32_t wgt_index, uint32_t wgt_index,
uint32_t opcode, uint32_t opcode,
uint32_t use_imm, uint32_t use_imm,
uint32_t imm_val) { int32_t imm_val) {
// The loop nest structure // The loop nest structure
VerifyDep(dst_index); VerifyDep(dst_index);
VTAUop op; VTAUop op;
...@@ -166,7 +167,7 @@ class UopKernel { ...@@ -166,7 +167,7 @@ class UopKernel {
uint32_t opcode_{0xFFFFFFFF}; uint32_t opcode_{0xFFFFFFFF};
uint32_t reset_out_{0xFFFFFFFF}; uint32_t reset_out_{0xFFFFFFFF};
bool use_imm_{false}; bool use_imm_{false};
uint16_t imm_val_{0}; int16_t imm_val_{0};
private: private:
// Verify that we don't write to the same acc_mem index two cycles in a row // Verify that we don't write to the same acc_mem index two cycles in a row
...@@ -195,10 +196,6 @@ class UopKernel { ...@@ -195,10 +196,6 @@ class UopKernel {
/*! /*!
* \brief Base class of all queues to send and recv serial data. * \brief Base class of all queues to send and recv serial data.
* \param kElemBytes Element unit bytes.
* \param kMaxBytes Maximum number of bytes.
* \param kCoherent Whether we have coherent access to the buffer.
* \param kAlwaysCache Wether we should use cached memory.
*/ */
class BaseQueue { class BaseQueue {
public: public:
...@@ -227,7 +224,7 @@ class BaseQueue { ...@@ -227,7 +224,7 @@ class BaseQueue {
dram_buffer_ = static_cast<char*>(VTAMemAlloc( dram_buffer_ = static_cast<char*>(VTAMemAlloc(
max_bytes, coherent || always_cache_)); max_bytes, coherent || always_cache_));
assert(dram_buffer_ != nullptr); assert(dram_buffer_ != nullptr);
dram_phy_addr_ = VTAGetMemPhysAddr(dram_buffer_); dram_phy_addr_ = VTAMemGetPhyAddr(dram_buffer_);
} }
/*! /*!
* \brief Reset the pointer of the buffer. * \brief Reset the pointer of the buffer.
...@@ -597,14 +594,14 @@ class InsnQueue : public BaseQueue { ...@@ -597,14 +594,14 @@ class InsnQueue : public BaseQueue {
} }
// Print instruction field information // Print instruction field information
if (c.mem.opcode == VTA_OPCODE_LOAD) { if (c.mem.opcode == VTA_OPCODE_LOAD) {
printf("LOAD "); printf("LOAD ");
if (c.mem.memory_type == VTA_MEM_ID_UOP) printf("UOP\n"); if (c.mem.memory_type == VTA_MEM_ID_UOP) printf("UOP\n");
if (c.mem.memory_type == VTA_MEM_ID_WGT) printf("WGT\n"); if (c.mem.memory_type == VTA_MEM_ID_WGT) printf("WGT\n");
if (c.mem.memory_type == VTA_MEM_ID_INP) printf("INP\n"); if (c.mem.memory_type == VTA_MEM_ID_INP) printf("INP\n");
if (c.mem.memory_type == VTA_MEM_ID_ACC) printf("ACC\n"); if (c.mem.memory_type == VTA_MEM_ID_ACC) printf("ACC\n");
} }
if (c.mem.opcode == VTA_OPCODE_STORE) { if (c.mem.opcode == VTA_OPCODE_STORE) {
printf("STORE\n"); printf("STORE:\n");
} }
printf("\tdep - pop prev: %d, pop next: %d, push prev: %d, push next: %d\n", printf("\tdep - pop prev: %d, pop next: %d, push prev: %d, push next: %d\n",
static_cast<int>(c.mem.pop_prev_dep), static_cast<int>(c.mem.pop_prev_dep),
...@@ -1210,7 +1207,7 @@ void VTAUopPush(uint32_t mode, ...@@ -1210,7 +1207,7 @@ void VTAUopPush(uint32_t mode,
uint32_t wgt_index, uint32_t wgt_index,
uint32_t opcode, uint32_t opcode,
uint32_t use_imm, uint32_t use_imm,
uint32_t imm_val) { int32_t imm_val) {
vta::CommandQueue::ThreadLocal()->record_kernel() vta::CommandQueue::ThreadLocal()->record_kernel()
->Push(mode, reset_out, dst_index, src_index, ->Push(mode, reset_out, dst_index, src_index,
wgt_index, opcode, use_imm, imm_val); wgt_index, opcode, use_imm, imm_val);
......
/*!
* Copyright (c) 2018 by Contributors
* \file sim_driver.cc
* \brief VTA driver for simulated backend.
*/
#include <vta/driver.h>
#include <vta/hw_spec.h>
#include <tvm/runtime/registry.h>
#include <type_traits>
#include <mutex>
#include <map>
#include <unordered_map>
#include <cstring>
#include <sstream>
namespace vta {
namespace sim {
/*!
* \brief Helper class to pack and unpack bits
* Applies truncation when pack to low level bits.
*
* \tparam bits The number of bits in integer.
* \note This implementation relies on little endian.
*/
template<uint32_t bits>
class BitPacker {
public:
explicit BitPacker(void* data) {
data_ = static_cast<uint32_t*>(data);
}
uint32_t GetUnsigned(uint32_t index) const {
if (bits == 32) {
return data_[index];
} else if (bits == 16) {
return reinterpret_cast<uint16_t*>(data_)[index];
} else if (bits == 8) {
return reinterpret_cast<uint8_t*>(data_)[index];
} else {
uint32_t offset = index / kNumPackElem;
uint32_t shift = index % kNumPackElem;
return (data_[offset] >> shift) & kMask;
}
}
int32_t GetSigned(uint32_t index) const {
if (bits == 32) {
return reinterpret_cast<int32_t*>(data_)[index];
} else if (bits == 16) {
return reinterpret_cast<int16_t*>(data_)[index];
} else if (bits == 8) {
return reinterpret_cast<int8_t*>(data_)[index];
} else {
uint32_t offset = index / kNumPackElem;
uint32_t shift = (index % kNumPackElem) * bits;
int32_t uvalue = static_cast<int32_t>(
(data_[offset] >> shift) & kMask);
int kleft = 32 - bits;
return (uvalue << kleft) >> kleft;
}
}
void SetUnsigned(uint32_t index, uint32_t value) {
if (bits == 32) {
data_[index] = value;
} else if (bits == 16) {
reinterpret_cast<uint16_t*>(data_)[index] = value;
} else if (bits == 8) {
reinterpret_cast<uint8_t*>(data_)[index] = value;
} else {
uint32_t offset = index / kNumPackElem;
uint32_t shift = (index % kNumPackElem) * bits;
data_[offset] &= (~(kMask << shift));
data_[offset] |= (value & kMask) << shift;
}
}
void SetSigned(uint32_t index, int32_t value) {
if (bits == 32) {
reinterpret_cast<int32_t*>(data_)[index] = value;
} else if (bits == 16) {
reinterpret_cast<int16_t*>(data_)[index] = value;
} else if (bits == 8) {
reinterpret_cast<int8_t*>(data_)[index] = value;
} else {
uint32_t offset = index / kNumPackElem;
uint32_t shift = (index % kNumPackElem) * bits;
data_[offset] &= (~(kMask << shift));
data_[offset] |= static_cast<uint32_t>(value & kMask) << shift;
}
}
private:
uint32_t* data_;
static constexpr uint32_t kNumPackElem = 32 / bits;
static constexpr uint32_t kMask = (1U << (bits >= 32U ? 31U : bits)) - 1U;
};
/*!
* \brief DRAM memory manager
* Implements simple paging to allow physical address translation.
*/
class DRAM {
public:
/*!
* \brief Get virtual address given physical address.
* \param phy_addr The simulator phyiscal address.
* \return The true virtual address;
*/
void* GetAddr(uint64_t phy_addr) {
std::lock_guard<std::mutex> lock(mutex_);
uint64_t loc = (phy_addr >> kPageBits) - 1;
CHECK_LT(loc, ptable_.size());
Page* p = ptable_[loc];
CHECK(p != nullptr);
size_t offset = (loc - p->ptable_begin) << kPageBits;
offset += phy_addr & (kPageSize - 1);
return reinterpret_cast<char*>(p->data) + offset;
}
/*!
* \brief Get physical address
* \param buf The virtual address.
* \return The true physical address;
*/
vta_phy_addr_t GetPhyAddr(void* buf) {
std::lock_guard<std::mutex> lock(mutex_);
auto it = pmap_.find(buf);
CHECK(it != pmap_.end());
Page* p = it->second.get();
return (p->ptable_begin + 1) << kPageBits;
}
/*!
* \brief Allocate memory from manager
* \param size The size of memory
* \return The virtual address
*/
void* Alloc(size_t size) {
std::lock_guard<std::mutex> lock(mutex_);
size_t npage = (size + kPageSize - 1) / kPageSize;
auto it = free_map_.lower_bound(npage);
if (it != free_map_.end()) {
Page* p = it->second;
free_map_.erase(it);
return p->data;
}
size_t start = ptable_.size();
std::unique_ptr<Page> p(new Page(start, npage));
// insert page entry
ptable_.resize(start + npage, p.get());
void* data = p->data;
pmap_[data] = std::move(p);
return data;
}
/*!
* \brief Free the memory.
* \param size The size of memory
* \return The virtual address
*/
void Free(void* data) {
std::lock_guard<std::mutex> lock(mutex_);
auto it = pmap_.find(data);
CHECK(it != pmap_.end());
Page* p = it->second.get();
free_map_.insert(std::make_pair(p->num_pages, p));
}
static DRAM* Global() {
static DRAM inst;
return &inst;
}
private:
// The bits in page table
static constexpr vta_phy_addr_t kPageBits = 16;
// page size, also the maximum allocable size 16 K
static constexpr vta_phy_addr_t kPageSize = 1 << kPageBits;
/*! \brief A page in the DRAM */
struct Page {
/*! \brief Data Type */
using DType = typename std::aligned_storage<kPageSize, 256>::type;
/*! \brief Start location in page table */
size_t ptable_begin;
/*! \brief The total number of pages */
size_t num_pages;
/*! \brief Data */
DType* data{nullptr};
// construct a new page
explicit Page(size_t ptable_begin, size_t num_pages)
: ptable_begin(ptable_begin), num_pages(num_pages) {
data = new DType[num_pages];
}
~Page() {
delete [] data;
}
};
// Internal lock
std::mutex mutex_;
// Physical address -> page
std::vector<Page*> ptable_;
// virtual addres -> page
std::unordered_map<void*, std::unique_ptr<Page> > pmap_;
// Free map
std::multimap<size_t, Page*> free_map_;
};
/*!
* \brief Register file.
* \tparam kBits Number of bits of one value.
* \tparam kLane Number of lanes in one element.
* \tparam kMaxNumElem Maximum number of element.
*/
template<int kBits, int kLane, int kMaxNumElem>
class SRAM {
public:
/*! \brief Bytes of single vector element */
static const int kElemBytes = (kBits * kLane + 7) / 8;
/*! \brief content data type */
using DType = typename std::aligned_storage<kElemBytes, kElemBytes>::type;
SRAM() {
data_ = new DType[kMaxNumElem];
}
~SRAM() {
delete [] data_;
}
// Get the i-th index
void* BeginPtr(uint32_t index) {
CHECK_LT(index, kMaxNumElem);
return &(data_[index]);
}
// Execute the load instruction on this SRAM
void Load(const VTAMemInsn* op, DRAM* dram, uint64_t* load_counter) {
load_counter[0] += (op->x_size * op->y_size) * kElemBytes;
DType* sram_ptr = data_ + op->sram_base;
uint8_t* dram_ptr = static_cast<uint8_t*>(dram->GetAddr(
op->dram_base * kElemBytes));
uint64_t xtotal = op->x_size + op->x_pad_0 + op->x_pad_1;
uint32_t ytotal = op->y_size + op->y_pad_0 + op->y_pad_1;
uint64_t sram_end = op->sram_base + xtotal * ytotal;
CHECK_LE(sram_end, kMaxNumElem);
memset(sram_ptr, 0, kElemBytes * xtotal * op->y_pad_0);
sram_ptr += xtotal * op->y_pad_0;
for (uint32_t y = 0; y < op->y_size; ++y) {
memset(sram_ptr, 0, kElemBytes * op->x_pad_0);
sram_ptr += op->x_pad_0;
memcpy(sram_ptr, dram_ptr, kElemBytes * op->x_size);
sram_ptr += op->x_size;
BitPacker<kBits> src(sram_ptr);
memset(sram_ptr, 0, kElemBytes * op->x_pad_1);
sram_ptr += op->x_pad_1;
dram_ptr += kElemBytes * op->x_stride;
}
memset(sram_ptr, 0, kElemBytes * xtotal * op->y_pad_1);
}
// Execute the store instruction on this SRAM apply trucation.
// This relies on the elements is 32 bits
template<int target_bits>
void TruncStore(const VTAMemInsn* op, DRAM* dram) {
CHECK_EQ(op->x_pad_0, 0);
CHECK_EQ(op->x_pad_1, 0);
CHECK_EQ(op->y_pad_0, 0);
CHECK_EQ(op->y_pad_1, 0);
int target_width = (target_bits * kLane + 7) / 8;
BitPacker<kBits> src(data_ + op->sram_base);
BitPacker<target_bits> dst(dram->GetAddr(op->dram_base * target_width));
for (uint32_t y = 0; y < op->y_size; ++y) {
for (uint32_t x = 0; x < op->x_size; ++x) {
uint32_t sram_base = y * op->x_size + x;
uint32_t dram_base = y * op->x_stride + x;
for (int i = 0; i < kLane; ++i) {
dst.SetSigned(dram_base * kLane + i,
src.GetSigned(sram_base * kLane +i));
}
}
}
}
private:
/*! \brief internal data content */
DType* data_;
};
/*!
* \brief Memory information of special memory region.
* Use MemoryInfo as its container type
*/
class Profiler {
public:
/*! \brief The memory load statistics */
uint64_t inp_load_nbytes{0};
/*! \brief The memory load statistics */
uint64_t wgt_load_nbytes{0};
/*! \brief The ACC memory load statistics */
uint64_t acc_load_nbytes{0};
/*! \brief The ACC memory load statistics */
uint64_t uop_load_nbytes{0};
/*! \brief The ACC memory load statistics */
uint64_t out_store_nbytes{0};
/*! \brief instr counter for gemm */
uint64_t gemm_counter{0};
/*! \brief instr counter for ALU ops */
uint64_t alu_counter{0};
/*! \brief clear the profiler */
void Clear() {
inp_load_nbytes = 0;
wgt_load_nbytes = 0;
acc_load_nbytes = 0;
uop_load_nbytes = 0;
out_store_nbytes = 0;
gemm_counter = 0;
alu_counter = 0;
}
std::string AsJSON() {
std::ostringstream os;
os << "{\n"
<< " \"inp_load_nbytes\":" << inp_load_nbytes << ",\n"
<< " \"wgt_load_nbytes\":" << wgt_load_nbytes << ",\n"
<< " \"acc_load_nbytes\":" << acc_load_nbytes << ",\n"
<< " \"uop_load_nbytes\":" << uop_load_nbytes << ",\n"
<< " \"out_store_nbytes\":" << out_store_nbytes << ",\n"
<< " \"gemm_counter\":" << gemm_counter << ",\n"
<< " \"alu_counter\":" << alu_counter << "\n"
<<"}\n";
return os.str();
}
static Profiler* ThreadLocal() {
static thread_local Profiler inst;
return &inst;
}
};
// Simulate device
// TODO(tqchen,thierry): queue based event driven simulation.
class Device {
public:
Device() {
prof_ = Profiler::ThreadLocal();
dram_ = DRAM::Global();
}
int Run(vta_phy_addr_t insn_phy_addr,
uint32_t insn_count,
uint32_t wait_cycles) {
VTAGenericInsn* insn = static_cast<VTAGenericInsn*>(
dram_->GetAddr(insn_phy_addr));
finish_counter_ = 0;
for (uint32_t i = 0; i < insn_count; ++i) {
this->Run(insn + i);
}
return 0;
}
private:
void Run(const VTAGenericInsn* insn) {
const VTAMemInsn* mem = reinterpret_cast<const VTAMemInsn*>(insn);
const VTAGemInsn* gem = reinterpret_cast<const VTAGemInsn*>(insn);
const VTAAluInsn* alu = reinterpret_cast<const VTAAluInsn*>(insn);
switch (mem->opcode) {
case VTA_OPCODE_LOAD: RunLoad(mem); break;
case VTA_OPCODE_STORE: RunStore(mem); break;
case VTA_OPCODE_GEMM: RunGEMM(gem); break;
case VTA_OPCODE_ALU: RunALU(alu); break;
case VTA_OPCODE_FINISH: ++finish_counter_; break;
default: {
LOG(FATAL) << "Unknown op_code" << mem->opcode;
}
}
}
void RunLoad(const VTAMemInsn* op) {
if (op->x_size == 0) return;
if (op->memory_type == VTA_MEM_ID_INP) {
inp_.Load(op, dram_, &(prof_->inp_load_nbytes));
} else if (op->memory_type == VTA_MEM_ID_WGT) {
wgt_.Load(op, dram_, &(prof_->wgt_load_nbytes));
} else if (op->memory_type == VTA_MEM_ID_ACC) {
acc_.Load(op, dram_, &(prof_->acc_load_nbytes));
} else if (op->memory_type == VTA_MEM_ID_UOP) {
uop_.Load(op, dram_, &(prof_->uop_load_nbytes));
} else {
LOG(FATAL) << "Unknown memory_type=" << op->memory_type;
}
}
void RunStore(const VTAMemInsn* op) {
if (op->memory_type == VTA_MEM_ID_ACC ||
op->memory_type == VTA_MEM_ID_UOP) {
prof_->out_store_nbytes += (
op->x_size * op->y_size * VTA_BATCH * VTA_BLOCK_OUT * VTA_OUT_WIDTH / 8);
acc_.TruncStore<VTA_OUT_WIDTH>(op, dram_);
} else {
LOG(FATAL) << "Store do not support memory_type="
<< op->memory_type;
}
}
void RunGEMM(const VTAGemInsn* op) {
if (!op->reset_reg) {
prof_->gemm_counter += op->iter_out * op->iter_in;
for (uint32_t y = 0; y < op->iter_out; ++y) {
for (uint32_t x = 0; x < op->iter_in; ++x) {
for (uint32_t uindex = op->uop_bgn; uindex < op->uop_end; ++uindex) {
VTAUop* uop_ptr = static_cast<VTAUop*>(uop_.BeginPtr(uindex));
// Read in memory indices
uint32_t acc_idx = uop_ptr->dst_idx;
uint32_t inp_idx = uop_ptr->src_idx;
uint32_t wgt_idx = uop_ptr->wgt_idx;
acc_idx += y * op->dst_factor_out + x * op->dst_factor_in;
inp_idx += y * op->src_factor_out + x * op->src_factor_in;
wgt_idx += y * op->wgt_factor_out + x * op->wgt_factor_in;
BitPacker<VTA_ACC_WIDTH> acc(acc_.BeginPtr(acc_idx));
BitPacker<VTA_INP_WIDTH> inp(inp_.BeginPtr(inp_idx));
BitPacker<VTA_WGT_WIDTH> wgt(wgt_.BeginPtr(wgt_idx));
// gemm loop
for (uint32_t i = 0; i < VTA_BATCH; ++i) {
for (uint32_t j = 0; j < VTA_BLOCK_OUT; ++j) {
uint32_t acc_offset = i * VTA_BLOCK_OUT + j;
int32_t sum = acc.GetSigned(acc_offset);
for (uint32_t k = 0; k < VTA_BLOCK_IN; ++k) {
sum +=
inp.GetSigned(i * VTA_BLOCK_IN + k) *
wgt.GetSigned(j * VTA_BLOCK_IN + k);
}
acc.SetSigned(acc_offset, sum);
}
}
}
}
}
} else {
// reset
for (uint32_t y = 0; y < op->iter_out; ++y) {
for (uint32_t x = 0; x < op->iter_in; ++x) {
for (uint32_t uindex = op->uop_bgn; uindex < op->uop_end; ++uindex) {
VTAUop* uop_ptr = static_cast<VTAUop*>(uop_.BeginPtr(uindex));
uint32_t acc_idx = uop_ptr->dst_idx;
acc_idx += y * op->dst_factor_out + x * op->dst_factor_in;
BitPacker<VTA_ACC_WIDTH> acc(acc_.BeginPtr(acc_idx));
for (uint32_t i = 0; i < VTA_BATCH * VTA_BLOCK_OUT; ++i) {
acc.SetSigned(i, 0);
}
}
}
}
}
}
void RunALU(const VTAAluInsn* op) {
prof_->alu_counter += op->iter_out * op->iter_in;
if (op->use_imm) {
RunALU_<true>(op);
} else {
RunALU_<false>(op);
}
}
template<bool use_imm>
void RunALU_(const VTAAluInsn* op) {
switch (op->alu_opcode) {
case VTA_ALU_OPCODE_ADD: {
return RunALULoop<use_imm>(op, [](int32_t x, int32_t y) {
return x + y;
});
}
case VTA_ALU_OPCODE_MAX: {
return RunALULoop<use_imm>(op, [](int32_t x, int32_t y) {
return std::max(x, y);
});
}
case VTA_ALU_OPCODE_MIN: {
return RunALULoop<use_imm>(op, [](int32_t x, int32_t y) {
return std::min(x, y);
});
}
case VTA_ALU_OPCODE_SHR: {
return RunALULoop<use_imm>(op, [](int32_t x, int32_t y) {
if (y >= 0) {
return x >> y;
} else {
return x << (-y);
}
});
}
default: {
LOG(FATAL) << "Unknown ALU code " << op->alu_opcode;
}
}
}
template<bool use_imm, typename F>
void RunALULoop(const VTAAluInsn* op, F func) {
for (int y = 0; y < op->iter_out; ++y) {
for (int x = 0; x < op->iter_in; ++x) {
for (int k = op->uop_bgn; k < op->uop_end; ++k) {
// Read micro op
VTAUop* uop_ptr = static_cast<VTAUop*>(uop_.BeginPtr(k));
uint32_t dst_index = uop_ptr->dst_idx;
uint32_t src_index = uop_ptr->src_idx;
dst_index += y * op->dst_factor_out + x * op->dst_factor_in;
src_index += y * op->src_factor_out + x * op->src_factor_in;
BitPacker<VTA_ACC_WIDTH> dst(acc_.BeginPtr(dst_index));
BitPacker<VTA_ACC_WIDTH> src(acc_.BeginPtr(src_index));
for (int k = 0; k < VTA_BLOCK_OUT; ++k) {
if (use_imm) {
dst.SetSigned(k, func(dst.GetSigned(k), op->imm));
} else {
dst.SetSigned(k, func(dst.GetSigned(k), src.GetSigned(k)));
}
}
}
}
}
}
// the finish counter
int finish_counter_{0};
// Prof_
Profiler* prof_;
// The DRAM interface
DRAM* dram_;
// The SRAM
SRAM<VTA_INP_WIDTH, VTA_BATCH * VTA_BLOCK_IN, VTA_INP_BUFF_DEPTH> inp_;
SRAM<VTA_WGT_WIDTH, VTA_BLOCK_IN * VTA_BLOCK_OUT, VTA_WGT_BUFF_DEPTH> wgt_;
SRAM<VTA_ACC_WIDTH, VTA_BATCH * VTA_BLOCK_OUT, VTA_ACC_BUFF_DEPTH> acc_;
SRAM<VTA_UOP_WIDTH, 1, VTA_UOP_BUFF_DEPTH> uop_;
};
using tvm::runtime::TVMRetValue;
using tvm::runtime::TVMArgs;
TVM_REGISTER_GLOBAL("vta.simulator.profiler_clear")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Profiler::ThreadLocal()->Clear();
});
TVM_REGISTER_GLOBAL("vta.simulator.profiler_status")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = Profiler::ThreadLocal()->AsJSON();
});
} // namespace sim
} // namespace vta
void* VTAMemAlloc(size_t size, int cached) {
return vta::sim::DRAM::Global()->Alloc(size);
}
void VTAMemFree(void* buf) {
vta::sim::DRAM::Global()->Free(buf);
}
vta_phy_addr_t VTAMemGetPhyAddr(void* buf) {
return vta::sim::DRAM::Global()->GetPhyAddr(buf);
}
void VTAFlushCache(vta_phy_addr_t buf, int size) {
}
void VTAInvalidateCache(vta_phy_addr_t buf, int size) {
}
VTADeviceHandle VTADeviceAlloc() {
return new vta::sim::Device();
}
void VTADeviceFree(VTADeviceHandle handle) {
delete static_cast<vta::sim::Device*>(handle);
}
int VTADeviceRun(VTADeviceHandle handle,
vta_phy_addr_t insn_phy_addr,
uint32_t insn_count,
uint32_t wait_cycles) {
return static_cast<vta::sim::Device*>(handle)->Run(
insn_phy_addr, insn_count, wait_cycles);
}
void VTAProgram(const char* bitstream) {
}
...@@ -67,9 +67,6 @@ class VTADeviceAPI final : public DeviceAPI { ...@@ -67,9 +67,6 @@ class VTADeviceAPI final : public DeviceAPI {
std::make_shared<VTADeviceAPI>(); std::make_shared<VTADeviceAPI>();
return inst; return inst;
} }
private:
void* runtime_dll_{nullptr};
}; };
struct VTAWorkspacePool : public WorkspacePool { struct VTAWorkspacePool : public WorkspacePool {
......
"""Unit test VTA's instructions """ """Unit test VTA's instructions """
import tvm import tvm
import vta
import mxnet as mx
import numpy as np import numpy as np
import topi import topi
from tvm.contrib import rpc, util from tvm.contrib import rpc, util
host = "pynq" import vta
port = 9091 import vta.testing
target = "llvm -target=armv7-none-linux-gnueabihf" from vta.testing import simulator
do_verify = True
print_ir = False
def test_save_load_out(): def test_save_load_out():
env = vta.get_env()
"""Test save/store output command""" """Test save/store output command"""
n = 4 def _run(env, remote):
x = tvm.placeholder( n = 6
(n, n, env.BATCH, env.BLOCK_OUT), x = tvm.placeholder(
name="x", (n, n, env.BATCH, env.BLOCK_OUT),
dtype=env.acc_dtype) name="x",
x_buf = tvm.compute( dtype=env.acc_dtype)
(n, n, env.BATCH, env.BLOCK_OUT), x_buf = tvm.compute(
lambda *i: x(*i), (n, n, env.BATCH, env.BLOCK_OUT),
"x_buf") lambda *i: x(*i), "x_buf")
# insert no-op that won't be optimized away # insert no-op that won't be optimized away
y_buf = tvm.compute( y_buf = tvm.compute(
(n, n, env.BATCH, env.BLOCK_OUT), (n, n, env.BATCH, env.BLOCK_OUT),
lambda *i: x_buf(*i)>>0, lambda *i: x_buf(*i)>>0, "y_buf")
"y_buf") y = tvm.compute(
y = tvm.compute( (n, n, env.BATCH, env.BLOCK_OUT),
(n, n, env.BATCH, env.BLOCK_OUT), lambda *i: y_buf(*i).astype(env.inp_dtype), "y")
lambda *i: y_buf(*i).astype(env.inp_dtype), # schedule
"y") s = tvm.create_schedule(y.op)
# schedule s[x_buf].set_scope(env.acc_scope)
s = tvm.create_schedule(y.op) s[x_buf].pragma(x_buf.op.axis[0], env.dma_copy)
s[x_buf].set_scope(env.acc_scope) s[y_buf].set_scope(env.acc_scope)
s[x_buf].pragma(x_buf.op.axis[0], env.dma_copy) s[y_buf].pragma(y_buf.op.axis[0], env.alu)
s[y_buf].set_scope(env.acc_scope) s[y].pragma(y.op.axis[0], env.dma_copy)
s[y_buf].pragma(y_buf.op.axis[0], env.alu)
s[y].pragma(y.op.axis[0], env.dma_copy) # verification
with vta.build_config():
def verify(): m = vta.build(s, [x, y], "ext_dev", env.target_host)
# build
with vta.build_config(env.DEBUG_DUMP_INSN): if not remote:
m = vta.build(s, [x, y], "ext_dev", target) return
temp = util.tempdir() temp = util.tempdir()
remote = rpc.connect(host, port)
m.save(temp.relpath("load_act.o")) m.save(temp.relpath("load_act.o"))
remote.upload(temp.relpath("load_act.o")) remote.upload(temp.relpath("load_act.o"))
f = remote.load_module("load_act.o") f = remote.load_module("load_act.o")
...@@ -59,47 +54,46 @@ def test_save_load_out(): ...@@ -59,47 +54,46 @@ def test_save_load_out():
y_nd = tvm.nd.empty(y_np.shape, ctx=ctx, dtype=y_np.dtype) y_nd = tvm.nd.empty(y_np.shape, ctx=ctx, dtype=y_np.dtype)
f(x_nd, y_nd) f(x_nd, y_nd)
np.testing.assert_equal(y_np, y_nd.asnumpy()) np.testing.assert_equal(y_np, y_nd.asnumpy())
print("\tFinished verification...")
if do_verify: vta.testing.run(_run)
verify()
def test_padded_load(): def test_padded_load():
"""Test padded load.""" """Test padded load."""
env = vta.get_env() def _run(env, remote):
# declare # declare
n = 21 n = 21
m = 20 m = 20
pad_before = [0, 1, 0, 0] pad_before = [0, 1, 0, 0]
pad_after = [1, 3, 0, 0] pad_after = [1, 3, 0, 0]
x = tvm.placeholder( x = tvm.placeholder(
(n, m, env.BATCH, env.BLOCK_OUT), (n, m, env.BATCH, env.BLOCK_OUT),
name="x", name="x",
dtype=env.acc_dtype) dtype=env.acc_dtype)
x_buf = topi.nn.pad(x, pad_before, pad_after, name="y") x_buf = topi.nn.pad(x, pad_before, pad_after, name="y")
# insert no-op that won't be optimized away # insert no-op that won't be optimized away
y_buf = tvm.compute((n + pad_before[0] + pad_after[0], y_buf = tvm.compute((n + pad_before[0] + pad_after[0],
m + pad_before[1] + pad_after[1],
env.BATCH,
env.BLOCK_OUT), lambda *i: x_buf(*i)>>0, "y_buf")
y = tvm.compute((n + pad_before[0] + pad_after[0],
m + pad_before[1] + pad_after[1], m + pad_before[1] + pad_after[1],
env.BATCH, env.BATCH,
env.BLOCK_OUT), lambda *i: x_buf(*i)>>0, "y_buf") env.BLOCK_OUT), lambda *i: y_buf(*i).astype(env.inp_dtype), "y")
y = tvm.compute((n + pad_before[0] + pad_after[0], # schedule
m + pad_before[1] + pad_after[1], s = tvm.create_schedule(y.op)
env.BATCH, s[x_buf].set_scope(env.acc_scope)
env.BLOCK_OUT), lambda *i: y_buf(*i).astype(env.inp_dtype), "y") s[x_buf].pragma(x_buf.op.axis[0], env.dma_copy)
# schedule s[y_buf].set_scope(env.acc_scope)
s = tvm.create_schedule(y.op) s[y_buf].pragma(y_buf.op.axis[0], env.alu)
s[x_buf].set_scope(env.acc_scope) s[y].pragma(y.op.axis[0], env.dma_copy)
s[x_buf].pragma(x_buf.op.axis[0], env.dma_copy)
s[y_buf].set_scope(env.acc_scope)
s[y_buf].pragma(y_buf.op.axis[0], env.alu)
s[y].pragma(y.op.axis[0], env.dma_copy)
def verify():
# build # build
with vta.build_config(env.DEBUG_DUMP_INSN): with vta.build_config():
mod = vta.build(s, [x, y], "ext_dev", target) mod = vta.build(s, [x, y], "ext_dev", env.target_host)
if not remote:
return
temp = util.tempdir() temp = util.tempdir()
remote = rpc.connect(host, port)
mod.save(temp.relpath("padded_load.o")) mod.save(temp.relpath("padded_load.o"))
remote.upload(temp.relpath("padded_load.o")) remote.upload(temp.relpath("padded_load.o"))
f = remote.load_module("padded_load.o") f = remote.load_module("padded_load.o")
...@@ -118,285 +112,287 @@ def test_padded_load(): ...@@ -118,285 +112,287 @@ def test_padded_load():
y_nd = tvm.nd.empty(y_np.shape, ctx=ctx, dtype=y_np.dtype) y_nd = tvm.nd.empty(y_np.shape, ctx=ctx, dtype=y_np.dtype)
f(x_nd, y_nd) f(x_nd, y_nd)
np.testing.assert_equal(y_np, y_nd.asnumpy()) np.testing.assert_equal(y_np, y_nd.asnumpy())
print("\tFinished verification...")
if print_ir: vta.testing.run(_run)
print(vta.lower(s, [y, x], simple_mode=True))
def test_gemm(): def test_gemm():
"""Test GEMM.""" """Test GEMM."""
env = vta.get_env() def _run(env, remote):
# declare # declare
o = 4 o = 4
n = 4 n = 1
m = 4 m = 4
x = tvm.placeholder((o, n, env.BATCH, env.BLOCK_IN), name="x", dtype=env.inp_dtype) x = tvm.placeholder((o, n, env.BATCH, env.BLOCK_IN), name="x", dtype=env.inp_dtype)
w = tvm.placeholder((m, n, env.BLOCK_OUT, env.BLOCK_IN), name="w", dtype=env.wgt_dtype) w = tvm.placeholder((m, n, env.BLOCK_OUT, env.BLOCK_IN), name="w", dtype=env.wgt_dtype)
x_buf = tvm.compute((o, n, env.BATCH, env.BLOCK_IN), lambda *i: x(*i), "x_buf") x_buf = tvm.compute((o, n, env.BATCH, env.BLOCK_IN), lambda *i: x(*i), "x_buf")
w_buf = tvm.compute((m, n, env.BLOCK_OUT, env.BLOCK_IN), lambda *i: w(*i), "w_buf") w_buf = tvm.compute((m, n, env.BLOCK_OUT, env.BLOCK_IN), lambda *i: w(*i), "w_buf")
ko = tvm.reduce_axis((0, n), name="ko") ko = tvm.reduce_axis((0, n), name="ko")
ki = tvm.reduce_axis((0, env.BLOCK_IN), name="ki") ki = tvm.reduce_axis((0, env.BLOCK_IN), name="ki")
y_gem = tvm.compute( y_gem = tvm.compute(
(o, m, env.BATCH, env.BLOCK_OUT), (o, m, env.BATCH, env.BLOCK_OUT),
lambda bo, co, bi, ci: lambda bo, co, bi, ci:
tvm.sum(x_buf[bo, ko, bi, ki].astype(env.acc_dtype) * tvm.sum(x_buf[bo, ko, bi, ki].astype(env.acc_dtype) *
w_buf[co, ko, ci, ki].astype(env.acc_dtype), w_buf[co, ko, ci, ki].astype(env.acc_dtype),
axis=[ko, ki]), axis=[ko, ki]),
name="y_gem") name="y_gem")
y_shf = tvm.compute( y_shf = tvm.compute(
(o, m, env.BATCH, env.BLOCK_OUT), (o, m, env.BATCH, env.BLOCK_OUT),
lambda *i: y_gem(*i)>>8, lambda *i: y_gem(*i)>>8,
name="y_shf") name="y_shf")
y_max = tvm.compute( y_max = tvm.compute(
(o, m, env.BATCH, env.BLOCK_OUT), (o, m, env.BATCH, env.BLOCK_OUT),
lambda *i: tvm.max(y_shf(*i), 0), lambda *i: tvm.max(y_shf(*i), 0),
"y_max") #relu "y_max") #relu
y_min = tvm.compute( y_min = tvm.compute(
(o, m, env.BATCH, env.BLOCK_OUT), (o, m, env.BATCH, env.BLOCK_OUT),
lambda *i: tvm.min(y_max(*i), (1<<(env.INP_WIDTH-1))-1), lambda *i: tvm.min(y_max(*i), (1<<(env.INP_WIDTH-1))-1),
"y_min") #relu "y_min") #relu
y = tvm.compute( y = tvm.compute(
(o, m, env.BATCH, env.BLOCK_OUT), (o, m, env.BATCH, env.BLOCK_OUT),
lambda *i: y_min(*i).astype(env.inp_dtype), lambda *i: y_min(*i).astype(env.inp_dtype),
name="y") name="y")
def verify(s):
mod = vta.build(s, [x, w, y], "ext_dev", target)
temp = util.tempdir()
remote = rpc.connect(host, port)
mod.save(temp.relpath("gemm.o"))
remote.upload(temp.relpath("gemm.o"))
f = remote.load_module("gemm.o")
# verify
ctx = remote.ext_dev(0)
x_np = np.random.randint(
-128, 128, size=(o, n, env.BATCH, env.BLOCK_IN)).astype(x.dtype)
w_np = np.random.randint(
-128, 128, size=(m, n, env.BLOCK_OUT, env.BLOCK_IN)).astype(w.dtype)
y_np = np.zeros((o, m, env.BATCH, env.BLOCK_OUT)).astype(y.dtype)
x_nd = tvm.nd.array(x_np, ctx)
w_nd = tvm.nd.array(w_np, ctx)
y_nd = tvm.nd.array(y_np, ctx)
y_np = y_np.astype(env.acc_dtype)
for b in range(o):
for i in range(m):
for j in range(n):
y_np[b,i,:] += np.dot(x_np[b,j,:].astype(env.acc_dtype),
w_np[i,j].T.astype(env.acc_dtype))
y_np = np.right_shift(y_np, 8)
y_np = np.clip(y_np, 0, (1<<(env.INP_WIDTH-1))-1).astype(y.dtype)
f(x_nd, w_nd, y_nd)
np.testing.assert_equal(y_np, y_nd.asnumpy())
print("\tFinished verification...")
def test_schedule1(): if not remote:
# default schedule with no smt return
s = tvm.create_schedule(y.op)
# set the scope of the SRAM buffers def verify(s):
s[x_buf].set_scope(env.SCOPE_INP) mod = vta.build(s, [x, w, y], "ext_dev", env.target_host)
s[w_buf].set_scope(env.SCOPE_WGT) temp = util.tempdir()
s[y_gem].set_scope(env.acc_scope) mod.save(temp.relpath("gemm.o"))
s[y_shf].set_scope(env.acc_scope) remote.upload(temp.relpath("gemm.o"))
s[y_max].set_scope(env.acc_scope) f = remote.load_module("gemm.o")
s[y_min].set_scope(env.acc_scope) # verify
# set pragmas for DMA transfer and ALU ops ctx = remote.ext_dev(0)
s[x_buf].pragma(s[x_buf].op.axis[0], env.dma_copy) x_np = np.random.randint(
s[w_buf].pragma(s[w_buf].op.axis[0], env.dma_copy) -128, 128, size=(o, n, env.BATCH, env.BLOCK_IN)).astype(x.dtype)
s[y_shf].pragma(s[y_shf].op.axis[0], env.alu) w_np = np.random.randint(
s[y_max].pragma(s[y_max].op.axis[0], env.alu) -128, 128, size=(m, n, env.BLOCK_OUT, env.BLOCK_IN)).astype(w.dtype)
s[y_min].pragma(s[y_min].op.axis[0], env.alu) y_np = np.zeros((o, m, env.BATCH, env.BLOCK_OUT)).astype(y.dtype)
s[y].pragma(s[y].op.axis[0], env.dma_copy) x_nd = tvm.nd.array(x_np, ctx)
# tensorization w_nd = tvm.nd.array(w_np, ctx)
s[y_gem].reorder( y_nd = tvm.nd.array(y_np, ctx)
ko, y_np = y_np.astype(env.acc_dtype)
s[y_gem].op.axis[0], for b in range(o):
s[y_gem].op.axis[1], for i in range(m):
s[y_gem].op.axis[2], for j in range(n):
s[y_gem].op.axis[3], y_np[b,i,:] += np.dot(x_np[b,j,:].astype(env.acc_dtype),
ki) w_np[i,j].T.astype(env.acc_dtype))
s[y_gem].tensorize(s[y_gem].op.axis[2], env.GEMM) y_np = np.right_shift(y_np, 8)
if print_ir: y_np = np.clip(y_np, 0, (1<<(env.INP_WIDTH-1))-1).astype(y.dtype)
print(vta.lower(s, [x, w, y], simple_mode=True))
if do_verify: if env.TARGET == "sim":
with vta.build_config(env.DEBUG_DUMP_INSN): simulator.clear_stats()
verify(s) f(x_nd, w_nd, y_nd)
print(simulator.stats())
def test_smt(): else:
# test smt schedule f(x_nd, w_nd, y_nd)
s = tvm.create_schedule(y.op)
s[x_buf].set_scope(env.SCOPE_INP) np.testing.assert_equal(y_np, y_nd.asnumpy())
s[w_buf].set_scope(env.SCOPE_WGT)
s[y_gem].set_scope(env.acc_scope) def test_schedule1():
s[y_shf].set_scope(env.acc_scope) # default schedule with no smt
s[y_max].set_scope(env.acc_scope) s = tvm.create_schedule(y.op)
s[y_min].set_scope(env.acc_scope) # set the scope of the SRAM buffers
abo, aco, abi, aci = s[y].op.axis s[x_buf].set_scope(env.inp_scope)
abo1, abo2 = s[y].split(abo, nparts=2) s[w_buf].set_scope(env.wgt_scope)
s[y].bind(abo1, tvm.thread_axis("cthread")) s[y_gem].set_scope(env.acc_scope)
s[y_gem].compute_at(s[y], abo1) s[y_shf].set_scope(env.acc_scope)
s[y_shf].compute_at(s[y], abo1) s[y_max].set_scope(env.acc_scope)
s[y_max].compute_at(s[y], abo1) s[y_min].set_scope(env.acc_scope)
s[y_min].compute_at(s[y], abo1) # set pragmas for DMA transfer and ALU ops
s[y_gem].reorder( s[x_buf].compute_at(s[y_gem], ko)
ko, s[x_buf].pragma(s[x_buf].op.axis[0], env.dma_copy)
s[y_gem].op.axis[0], s[w_buf].compute_at(s[y_gem], ko)
s[y_gem].op.axis[1], s[w_buf].pragma(s[w_buf].op.axis[0], env.dma_copy)
s[y_gem].op.axis[2], s[y_shf].pragma(s[y_shf].op.axis[0], env.alu)
s[y_gem].op.axis[3], s[y_max].pragma(s[y_max].op.axis[0], env.alu)
ki) s[y_min].pragma(s[y_min].op.axis[0], env.alu)
s[y_gem].tensorize(s[y_gem].op.axis[2], env.GEMM) s[y].pragma(s[y].op.axis[0], env.dma_copy)
s[y_shf].pragma(s[y_shf].op.axis[0], env.alu) # tensorization
s[y_max].pragma(s[y_max].op.axis[0], env.alu) s[y_gem].reorder(
s[y_min].pragma(s[y_min].op.axis[0], env.alu) ko,
s[x_buf].compute_at(s[y_gem], ko) s[y_gem].op.axis[0],
s[x_buf].pragma(s[x_buf].op.axis[0], env.dma_copy) s[y_gem].op.axis[1],
s[w_buf].compute_at(s[y_gem], ko) s[y_gem].op.axis[2],
s[w_buf].pragma(s[w_buf].op.axis[0], env.dma_copy) s[y_gem].op.axis[3],
s[y].pragma(abo2, env.dma_copy) ki)
if print_ir: s[y_gem].tensorize(s[y_gem].op.axis[2], env.gemm)
print(vta.lower(s, [x, y, w], simple_mode=True)) verify(s)
if do_verify:
with vta.build_config(env.DEBUG_DUMP_INSN): def test_smt():
verify(s) # test smt schedule
s = tvm.create_schedule(y.op)
test_schedule1() s[x_buf].set_scope(env.inp_scope)
test_smt() s[w_buf].set_scope(env.wgt_scope)
s[y_gem].set_scope(env.acc_scope)
def test_alu(tvm_op, np_op=None, use_imm=False): s[y_shf].set_scope(env.acc_scope)
"""Test ALU""" s[y_max].set_scope(env.acc_scope)
env = vta.get_env() s[y_min].set_scope(env.acc_scope)
m = 8 abo, aco, abi, aci = s[y].op.axis
n = 8 abo1, abo2 = s[y].split(abo, nparts=2)
imm = np.random.randint(1,5) s[y].bind(abo1, tvm.thread_axis("cthread"))
# compute s[y_gem].compute_at(s[y], abo1)
a = tvm.placeholder( s[y_shf].compute_at(s[y], abo1)
(m, n, env.BATCH, env.BLOCK_OUT), s[y_max].compute_at(s[y], abo1)
name="a", s[y_min].compute_at(s[y], abo1)
dtype=env.acc_dtype) s[y_gem].reorder(
a_buf = tvm.compute( ko,
(m, n, env.BATCH, env.BLOCK_OUT), s[y_gem].op.axis[0],
lambda *i: a(*i), s[y_gem].op.axis[1],
"a_buf") #DRAM->SRAM s[y_gem].op.axis[2],
if use_imm: s[y_gem].op.axis[3],
res_buf = tvm.compute( ki)
(m, n, env.BATCH, env.BLOCK_OUT), s[y_gem].tensorize(s[y_gem].op.axis[2], env.gemm)
lambda *i: tvm_op(a_buf(*i), imm), s[y_shf].pragma(s[y_shf].op.axis[0], env.alu)
"res_buf") #compute s[y_max].pragma(s[y_max].op.axis[0], env.alu)
else: s[y_min].pragma(s[y_min].op.axis[0], env.alu)
b = tvm.placeholder( s[x_buf].compute_at(s[y_gem], ko)
(m, n, env.BATCH, env.BLOCK_OUT), s[x_buf].pragma(s[x_buf].op.axis[0], env.dma_copy)
name="b", s[w_buf].compute_at(s[y_gem], ko)
dtype=env.acc_dtype) s[w_buf].pragma(s[w_buf].op.axis[0], env.dma_copy)
b_buf = tvm.compute( s[y].pragma(abo2, env.dma_copy)
(m, n, env.BATCH, env.BLOCK_OUT), verify(s)
lambda *i: b(*i),
"b_buf") #DRAM->SRAM test_schedule1()
res_buf = tvm.compute( test_smt()
(m, n, env.BATCH, env.BLOCK_OUT), vta.testing.run(_run)
lambda *i: tvm_op(a_buf(*i), b_buf(*i)),
"res_buf") #compute
res = tvm.compute( def test_alu():
(m, n, env.BATCH, env.BLOCK_OUT), def _run(env, remote):
lambda *i: res_buf(*i).astype(env.inp_dtype), def check_alu(tvm_op, np_op=None, use_imm=False):
"res") #SRAM->DRAM """Test ALU"""
# schedule m = 8
s = tvm.create_schedule(res.op) n = 8
s[a_buf].set_scope(env.acc_scope) # SRAM imm = np.random.randint(1,5)
s[a_buf].pragma(a_buf.op.axis[0], env.dma_copy) # DRAM->SRAM # compute
s[res_buf].set_scope(env.acc_scope) # SRAM a = tvm.placeholder(
s[res_buf].pragma(res_buf.op.axis[0], env.alu) # compute (m, n, env.BATCH, env.BLOCK_OUT),
s[res].pragma(res.op.axis[0], env.dma_copy) # SRAM->DRAM name="a",
if use_imm: dtype=env.acc_dtype)
if print_ir: a_buf = tvm.compute(
print(vta.lower(s, [a, res], simple_mode=True)) (m, n, env.BATCH, env.BLOCK_OUT),
else: lambda *i: a(*i),
s[b_buf].set_scope(env.acc_scope) # SRAM "a_buf") #DRAM->SRAM
s[b_buf].pragma(b_buf.op.axis[0], env.dma_copy) # DRAM->SRAM
if print_ir:
print(vta.lower(s, [a, b, res], simple_mode=True))
def verify():
# build
with vta.build_config():
if use_imm: if use_imm:
mod = vta.build(s, [a, res], "ext_dev", target) res_buf = tvm.compute(
(m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: tvm_op(a_buf(*i), imm),
"res_buf") #compute
else: else:
mod = vta.build(s, [a, b, res], "ext_dev", target) b = tvm.placeholder(
temp = util.tempdir() (m, n, env.BATCH, env.BLOCK_OUT),
remote = rpc.connect(host, port) name="b",
mod.save(temp.relpath("load_act.o")) dtype=env.acc_dtype)
remote.upload(temp.relpath("load_act.o")) b_buf = tvm.compute(
f = remote.load_module("load_act.o") (m, n, env.BATCH, env.BLOCK_OUT),
# verify lambda *i: b(*i),
ctx = remote.ext_dev(0) "b_buf") #DRAM->SRAM
a_np = np.random.randint( res_buf = tvm.compute(
-16, 16, size=(m, n, env.BATCH, env.BLOCK_OUT)).astype(a.dtype) (m, n, env.BATCH, env.BLOCK_OUT),
if use_imm: lambda *i: tvm_op(a_buf(*i), b_buf(*i)),
res_np = np_op(a_np, imm) if np_op else tvm_op(a_np, imm) "res_buf") #compute5B
else: res = tvm.compute(
b_np = np.random.randint( (m, n, env.BATCH, env.BLOCK_OUT),
-16, 16, size=(m, n, env.BATCH, env.BLOCK_OUT)).astype(b.dtype) lambda *i: res_buf(*i).astype(env.inp_dtype),
res_np = np_op(a_np, b_np) if np_op else tvm_op(a_np, b_np) "res") #SRAM->DRAM
res_np = res_np.astype(res.dtype) # schedule
a_nd = tvm.nd.array(a_np, ctx) s = tvm.create_schedule(res.op)
res_nd = tvm.nd.array( s[a_buf].set_scope(env.acc_scope) # SRAM
np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx) s[a_buf].pragma(a_buf.op.axis[0], env.dma_copy) # DRAM->SRAM
if use_imm: s[res_buf].set_scope(env.acc_scope) # SRAM
f(a_nd, res_nd) s[res_buf].pragma(res_buf.op.axis[0], env.alu) # compute
else: s[res].pragma(res.op.axis[0], env.dma_copy) # SRAM->DRAM
b_nd = tvm.nd.array(b_np, ctx) if not use_imm:
f(a_nd, b_nd, res_nd) s[b_buf].set_scope(env.acc_scope) # SRAM
np.testing.assert_equal(res_np, res_nd.asnumpy()) s[b_buf].pragma(b_buf.op.axis[0], env.dma_copy) # DRAM->SRAM
print("\tFinished verification...")
if not remote:
return
# build
with vta.build_config():
if use_imm:
mod = vta.build(s, [a, res], "ext_dev", env.target_host)
else:
mod = vta.build(s, [a, b, res], "ext_dev", env.target_host)
temp = util.tempdir()
mod.save(temp.relpath("load_act.o"))
remote.upload(temp.relpath("load_act.o"))
f = remote.load_module("load_act.o")
# verify
ctx = remote.ext_dev(0)
a_np = np.random.randint(
-16, 16, size=(m, n, env.BATCH, env.BLOCK_OUT)).astype(a.dtype)
if use_imm:
res_np = np_op(a_np, imm) if np_op else tvm_op(a_np, imm)
else:
b_np = np.random.randint(
-16, 16, size=(m, n, env.BATCH, env.BLOCK_OUT)).astype(b.dtype)
res_np = np_op(a_np, b_np) if np_op else tvm_op(a_np, b_np)
res_np = res_np.astype(res.dtype)
a_nd = tvm.nd.array(a_np, ctx)
res_nd = tvm.nd.array(
np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx)
if use_imm:
f(a_nd, res_nd)
else:
b_nd = tvm.nd.array(b_np, ctx)
f(a_nd, b_nd, res_nd)
np.testing.assert_equal(res_np, res_nd.asnumpy())
check_alu(lambda x, y: x << y, np.left_shift, use_imm=True)
check_alu(tvm.max, np.maximum, use_imm=True)
check_alu(tvm.max, np.maximum)
check_alu(lambda x, y: x + y, use_imm=True)
check_alu(lambda x, y: x + y)
check_alu(lambda x, y: x >> y, np.right_shift, use_imm=True)
vta.testing.run(_run)
if do_verify:
verify()
def test_relu(): def test_relu():
"""Test RELU on ALU""" """Test RELU on ALU"""
env = vta.get_env() def _run(env, remote):
m = 8 m = 8
n = 8 n = 10
# compute # compute
a = tvm.placeholder( a = tvm.placeholder(
(m, n, env.BATCH, env.BLOCK_OUT), (m, n, env.BATCH, env.BLOCK_OUT),
name="a", name="a",
dtype=env.acc_dtype) dtype=env.acc_dtype)
a_buf = tvm.compute( a_buf = tvm.compute(
(m, n, env.BATCH, env.BLOCK_OUT), (m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: a(*i), lambda *i: a(*i),
"a_buf") # DRAM->SRAM "a_buf") # DRAM->SRAM
max_buf = tvm.compute( max_buf = tvm.compute(
(m, n, env.BATCH, env.BLOCK_OUT), (m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: tvm.max(a_buf(*i), 0), lambda *i: tvm.max(a_buf(*i), 0),
"res_buf") # relu "res_buf") # relu
min_buf = tvm.compute( min_buf = tvm.compute(
(m, n, env.BATCH, env.BLOCK_OUT), (m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: tvm.min(max_buf(*i), (1<<(env.INP_WIDTH-1))-1), lambda *i: tvm.min(max_buf(*i), (1<<(env.INP_WIDTH-1))-1),
"max_buf") # relu "max_buf") # relu
res = tvm.compute( res = tvm.compute(
(m, n, env.BATCH, env.BLOCK_OUT), (m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: min_buf(*i).astype(env.inp_dtype), lambda *i: min_buf(*i).astype(env.inp_dtype),
"min_buf") # SRAM->DRAM "min_buf") # SRAM->DRAM
# schedule # schedule
s = tvm.create_schedule(res.op) s = tvm.create_schedule(res.op)
s[a_buf].set_scope(env.acc_scope) # SRAM s[a_buf].set_scope(env.acc_scope) # SRAM
s[a_buf].pragma(a_buf.op.axis[0], env.dma_copy) # DRAM->SRAM s[a_buf].pragma(a_buf.op.axis[0], env.dma_copy) # DRAM->SRAM
s[max_buf].set_scope(env.acc_scope) # SRAM s[max_buf].set_scope(env.acc_scope) # SRAM
s[min_buf].set_scope(env.acc_scope) # SRAM s[min_buf].set_scope(env.acc_scope) # SRAM
s[max_buf].pragma(max_buf.op.axis[0], env.alu) # compute s[max_buf].pragma(max_buf.op.axis[0], env.alu) # compute
s[min_buf].pragma(min_buf.op.axis[0], env.alu) # compute s[min_buf].pragma(min_buf.op.axis[0], env.alu) # compute
s[res].pragma(res.op.axis[0], env.dma_copy) # SRAM->DRAM s[res].pragma(res.op.axis[0], env.dma_copy) # SRAM->DRAM
if print_ir:
print(vta.lower(s, [a, res], simple_mode=True))
def verify():
# build # build
with vta.build_config(env.DEBUG_DUMP_INSN): with vta.build_config():
mod = vta.build(s, [a, res], "ext_dev", target) mod = vta.build(s, [a, res], "ext_dev", env.target_host)
if not remote:
return
temp = util.tempdir() temp = util.tempdir()
remote = rpc.connect(host, port)
mod.save(temp.relpath("load_act.o")) mod.save(temp.relpath("load_act.o"))
remote.upload(temp.relpath("load_act.o")) remote.upload(temp.relpath("load_act.o"))
f = remote.load_module("load_act.o") f = remote.load_module("load_act.o")
...@@ -410,56 +406,51 @@ def test_relu(): ...@@ -410,56 +406,51 @@ def test_relu():
np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx) np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx)
f(a_nd, res_nd) f(a_nd, res_nd)
np.testing.assert_equal(res_np, res_nd.asnumpy()) np.testing.assert_equal(res_np, res_nd.asnumpy())
print("\tFinished verification...")
if do_verify: vta.testing.run(_run)
verify()
def test_shift_and_scale(): def test_shift_and_scale():
"""Test shift and scale on ALU""" """Test shift and scale on ALU"""
env = vta.get_env() def _run(env, remote):
m = 8 m = 2
n = 8 n = 8
imm_shift = np.random.randint(-10,10) imm_shift = np.random.randint(0,8)
imm_scale = np.random.randint(1,5) imm_scale = np.random.randint(1,5)
# compute # compute
a = tvm.placeholder( a = tvm.placeholder(
(m, n, env.BATCH, env.BLOCK_OUT), (m, n, env.BATCH, env.BLOCK_OUT),
name="a", dtype=env.acc_dtype) name="a", dtype=env.acc_dtype)
a_buf = tvm.compute( a_buf = tvm.compute(
(m, n, env.BATCH, env.BLOCK_OUT), (m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: a(*i), lambda *i: a(*i),
"a_buf") # DRAM->SRAM "a_buf") # DRAM->SRAM
res_shift = tvm.compute( res_shift = tvm.compute(
(m, n, env.BATCH, env.BLOCK_OUT), (m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: a_buf(*i)+imm_shift, lambda *i: a_buf(*i)+imm_shift,
"res_shift") # compute "res_shift") # compute
res_scale = tvm.compute( res_scale = tvm.compute(
(m, n, env.BATCH, env.BLOCK_OUT), (m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: res_shift(*i)>>imm_scale, lambda *i: res_shift(*i)>>imm_scale,
"res_scale") # compute "res_scale") # compute
res = tvm.compute( res = tvm.compute(
(m, n, env.BATCH, env.BLOCK_OUT), (m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: res_scale(*i).astype(env.inp_dtype), lambda *i: res_scale(*i).astype(env.inp_dtype),
"res") # SRAM->DRAM "res") # SRAM->DRAM
# schedule # schedule
s = tvm.create_schedule(res.op) s = tvm.create_schedule(res.op)
s[a_buf].set_scope(env.acc_scope) # SRAM s[a_buf].set_scope(env.acc_scope) # SRAM
s[res_shift].set_scope(env.acc_scope) # SRAM s[res_shift].set_scope(env.acc_scope) # SRAM
s[res_scale].set_scope(env.acc_scope) # SRAM s[res_scale].set_scope(env.acc_scope) # SRAM
s[a_buf].pragma(a_buf.op.axis[0], env.dma_copy) # DRAM->SRAM s[a_buf].pragma(a_buf.op.axis[0], env.dma_copy) # DRAM->SRAM
s[res_shift].pragma(res_shift.op.axis[0], env.alu) # compute s[res_shift].pragma(res_shift.op.axis[0], env.alu) # compute
s[res_scale].pragma(res_scale.op.axis[0], env.alu) # compute s[res_scale].pragma(res_scale.op.axis[0], env.alu) # compute
s[res].pragma(res.op.axis[0], env.dma_copy) # SRAM->DRAM s[res].pragma(res.op.axis[0], env.dma_copy) # SRAM->DRAM
if print_ir:
print(vta.lower(s, [a, res], simple_mode=True))
def verify():
# build # build
mod = vta.build(s, [a, res], "ext_dev", target) mod = vta.build(s, [a, res], "ext_dev", env.target_host)
if not remote:
return
temp = util.tempdir() temp = util.tempdir()
remote = rpc.connect(host, port)
mod.save(temp.relpath("load_act.o")) mod.save(temp.relpath("load_act.o"))
remote.upload(temp.relpath("load_act.o")) remote.upload(temp.relpath("load_act.o"))
f = remote.load_module("load_act.o") f = remote.load_module("load_act.o")
...@@ -474,31 +465,18 @@ def test_shift_and_scale(): ...@@ -474,31 +465,18 @@ def test_shift_and_scale():
np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx) np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx)
f(a_nd, res_nd) f(a_nd, res_nd)
np.testing.assert_equal(res_np, res_nd.asnumpy()) np.testing.assert_equal(res_np, res_nd.asnumpy())
print("\tFinished verification...")
if do_verify: vta.testing.run(_run)
verify()
if __name__ == "__main__": if __name__ == "__main__":
print("Load/store test") print("Load/store test")
test_save_load_out() test_save_load_out()
print("Padded load test") print("Padded load test")
test_padded_load() #test_padded_load()
# print("GEMM test") print("GEMM test")
# test_gemm() test_gemm()
print("Max immediate") test_alu()
test_alu(tvm.max, np.maximum, use_imm=True) print("ALU test")
print("Max")
test_alu(tvm.max, np.maximum)
print("Add immediate")
test_alu(lambda x, y: x + y, use_imm=True)
print("Add")
test_alu(lambda x, y: x + y)
print("Shift right immediate")
test_alu(lambda x, y: x >> y, np.right_shift, use_imm=True)
print("Shift left immediate")
test_alu(lambda x, y: x << y, np.left_shift, use_imm=True)
print("Relu")
test_relu() test_relu()
# print("Shift and scale") print("Shift and scale")
# test_shift_and_scale() test_shift_and_scale()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment