Commit 9c44e4b4 by Tianqi Chen

[DRIVER] Add simulator, unify testcase to unittest (#25)

parent 0faefafc
......@@ -40,6 +40,19 @@ ifneq ($(ADD_LDFLAGS), NONE)
LDFLAGS += $(ADD_LDFLAGS)
endif
UNAME_S := $(shell uname -s)
ifeq ($(UNAME_S), Darwin)
SHARED_LIBRARY_SUFFIX := dylib
WHOLE_ARCH= -all_load
NO_WHOLE_ARCH= -noall_load
LDFLAGS += -undefined dynamic_lookup
else
SHARED_LIBRARY_SUFFIX := so
WHOLE_ARCH= --whole-archive
NO_WHOLE_ARCH= --no-whole-archive
endif
all: lib/libvta.so lib/libvta_runtime.so
......@@ -53,6 +66,10 @@ ifeq ($(TARGET), VTA_PYNQ_TARGET)
LDFLAGS += -l:libdma.so
endif
ifeq ($(TARGET), sim)
VTA_LIB_SRC += $(wildcard src/sim/*.cc)
endif
VTA_LIB_OBJ = $(patsubst src/%.cc, build/%.o, $(VTA_LIB_SRC))
build/%.o: src/%.cc
......@@ -71,7 +88,7 @@ lib/libvta_runtime.so: build/runtime.o
lint: pylint cpplint
cpplint:
python nnvm/dmlc-core/scripts/lint.py vta cpp include src hardware tests
python nnvm/dmlc-core/scripts/lint.py vta cpp include src
pylint:
pylint python/vta --rcfile=$(ROOTDIR)/tests/lint/pylintrc
......@@ -86,3 +103,4 @@ clean:
-include build/*.d
-include build/*/*.d
-include build/*/*/*.d
-include build/*/*/*/*.d
......@@ -77,7 +77,7 @@ void VTAMemFree(void* buf);
* \param buf Pointer to memory region allocated with VTAMemAlloc.
* \return The physical address of the memory region.
*/
vta_phy_addr_t VTAGetMemPhysAddr(void* buf);
vta_phy_addr_t VTAMemGetPhyAddr(void* buf);
/*!
* \brief Flushes the region of memory out of the CPU cache to DRAM.
......
......@@ -519,8 +519,8 @@ typedef struct {
uint64_t alu_opcode : VTA_ALU_OPCODE_BIT_WIDTH;
/*! \brief Use immediate is true */
uint64_t use_imm : 1;
/*! \brief Immediate value */
uint64_t imm : VTA_ALUOP_IMM_BIT_WIDTH;
/*! \brief Immediate value: allow negative value */
int64_t imm : VTA_ALUOP_IMM_BIT_WIDTH;
} VTAAluInsn;
/*! \brief VTA ALU instruction converter */
......
......@@ -196,7 +196,7 @@ void VTAUopPush(uint32_t mode,
uint32_t wgt_index,
uint32_t opcode,
uint32_t use_imm,
uint32_t imm_val);
int32_t imm_val);
/*!
* \brief Mark start of a micro op loop.
......
......@@ -27,7 +27,7 @@ ADD_LDFLAGS=
ADD_CFLAGS=
# the hardware target
TARGET = VTA_PYNQ_TARGET
TARGET = pynq
#---------------------
# VTA hardware parameters
......@@ -89,7 +89,6 @@ VTA_OUT_BUFF_SIZE = $(shell echo "$$(( 1 << $(VTA_LOG_OUT_BUFF_SIZE) ))" )
# Update ADD_CFLAGS
ADD_CFLAGS += \
-D$(TARGET) \
-DVTA_LOG_WGT_WIDTH=$(VTA_LOG_WGT_WIDTH) -DVTA_LOG_INP_WIDTH=$(VTA_LOG_INP_WIDTH) \
-DVTA_LOG_ACC_WIDTH=$(VTA_LOG_ACC_WIDTH) -DVTA_LOG_OUT_WIDTH=$(VTA_LOG_OUT_WIDTH) \
-DVTA_LOG_BATCH=$(VTA_LOG_BATCH) \
......
......@@ -8,6 +8,7 @@ try:
from . import arm_conv2d, vta_conv2d
from .build_module import build_config, lower, build
from .rpc_client import reconfig_runtime, program_fpga
from . import graph
except ImportError:
pass
......@@ -89,7 +89,7 @@ class Environment(object):
"""
current = None
cfg_keys = [
"target",
"TARGET",
"LOG_INP_WIDTH",
"LOG_WGT_WIDTH",
"LOG_ACC_WIDTH",
......@@ -204,9 +204,19 @@ class Environment(object):
@property
def gevm(self):
"""GEMM intrinsic"""
"""GEVM intrinsic"""
return self.dev.gevm
@property
def target_host(self):
"""The target host"""
if self.TARGET == "pynq":
return "llvm -target=armv7-none-linux-gnueabihf"
elif self.TARGET == "sim":
return "llvm"
else:
raise ValueError("Unknown target %s" % self.TARGET)
def get_env():
"""Get the current VTA Environment.
......@@ -278,6 +288,7 @@ def _init_env():
for k in Environment.cfg_keys:
keys.add("VTA_" + k)
keys.add("TARGET")
if not os.path.isfile(filename):
raise RuntimeError(
......@@ -290,8 +301,11 @@ def _init_env():
for k in keys:
if k +" =" in line:
val = line.split("=")[1].strip()
cfg[k[4:]] = int(val)
cfg["target"] = "pynq"
if k.startswith("VTA_"):
k = k[4:]
cfg[k] = int(val)
else:
cfg[k] = val
return Environment(cfg)
Environment.current = _init_env()
......@@ -78,8 +78,7 @@ def fold_uop_loop(stmt_in):
if not fail[0]:
begin = tvm.call_extern(
"int32", "VTAUopLoopBegin", stmt.extent, *gemm_offsets)
end = tvm.call_extern(
"int32", "VTAUopLoopEnd", stmt.extent, *gemm_offsets)
end = tvm.call_extern("int32", "VTAUopLoopEnd")
return [begin, ret, end]
raise ValueError("Failed to fold the GEMM instructions..")
......@@ -683,8 +682,14 @@ def inject_alu_intrin(stmt_in):
else:
raise RuntimeError(
"Function call not recognized %s" % (loop_body.value.name))
elif isinstance(loop_body.value, tvm.expr.Load):
alu_opcode = env.dev.ALU_OPCODE_SHR
lhs = loop_body.value
rhs = tvm.const(0)
else:
raise RuntimeError("Expression not recognized %s" % (type(loop_body.value)))
raise RuntimeError(
"Expression not recognized %s, %s, %s" % (
type(loop_body.value), str(loop_body.value), str(stmt)))
# Derive array index coefficients
dst_coeff = tvm.arith.DetectLinearEquation(dst_idx, indices)
......@@ -772,7 +777,9 @@ def inject_alu_intrin(stmt_in):
irb = tvm.ir_builder.create()
for idx, extent in enumerate(extents):
irb.emit(tvm.call_extern(
"int32", "VTAUopLoopBegin", extent, dst_coeff[idx], src_coeff[idx]))
"int32", "VTAUopLoopBegin",
extent, dst_coeff[idx], src_coeff[idx], 0))
use_imm = int(use_imm)
irb.emit(tvm.call_extern(
"int32", "VTAUopPush",
1, 0,
......@@ -804,5 +811,6 @@ def debug_print(stmt):
stmt : Stmt
The
"""
# pylint: disable=superfluous-parens
print(stmt)
return stmt
......@@ -24,8 +24,7 @@ def reconfig_runtime(remote):
"VTA_LOG_WGT_BUFF_SIZE",
"VTA_LOG_ACC_BUFF_SIZE",
"VTA_LOG_OUT_BUFF_SIZE"]
cflags = ["-DVTA_%s_TARGET" % env.target.upper()]
cflags = []
for k in keys:
cflags += ["-D%s=%s" % (k, str(getattr(env, k[4:])))]
freconfig = remote.get_function("tvm.contrib.vta.reconfig_runtime")
......
"""Testing utilities, this namespace is not imported by default."""
from . util import run
"""Utilities to start simulator."""
import os
import ctypes
import json
import tvm
def _load_lib():
"""Load local library, assuming they are simulator."""
# pylint: disable=unused-variable
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
dll_path = [
os.path.abspath(os.path.join(curr_path, "../../../lib/libvta.so")),
os.path.abspath(os.path.join(curr_path, "../../../lib/libvta_runtime.so"))
]
runtime_dll = []
if not all(os.path.exists(f) for f in dll_path):
return []
try:
for fname in dll_path:
runtime_dll.append(ctypes.CDLL(fname, ctypes.RTLD_GLOBAL))
return runtime_dll
except OSError:
return []
def enabled():
"""Check if simulator is enabled."""
f = tvm.get_global_func("vta.simulator.profiler_clear", True)
return f is not None
def clear_stats():
"""Clear profiler statistics"""
f = tvm.get_global_func("vta.simulator.profiler_clear", True)
if f:
f()
def stats():
"""Clear profiler statistics
Returns
-------
stats : dict
Current profiler statistics
"""
x = tvm.get_global_func("vta.simulator.profiler_status")()
return json.loads(x)
LIBS = _load_lib()
"""Test Utilities"""
from __future__ import absolute_import as _abs
import os
from tvm.contrib import rpc
from ..environment import get_env
from . import simulator
def run(run_func):
"""Run test function on all available env.
Parameters
----------
run_func : function(env, remote)
"""
env = get_env()
# run on simulator
if simulator.enabled():
env.TARGET = "sim"
run_func(env, rpc.LocalSession())
# Run on PYNQ if env variable exists
pynq_host = os.environ.get("VTA_PYNQ_RPC_HOST", None)
if pynq_host:
env.TARGET = "pynq"
port = os.environ.get("VTA_PYNQ_RPC_PORT", "9091")
port = int(port)
remote = rpc.connect(pynq_host, port)
run_func(env, remote)
......@@ -57,7 +57,7 @@ struct DataBuffer {
assert(data != nullptr);
DataBuffer* buffer = new DataBuffer();
buffer->data_ = data;
buffer->phy_addr_ = VTAGetMemPhysAddr(data);
buffer->phy_addr_ = VTAMemGetPhyAddr(data);
return buffer;
}
/*!
......
/*!
* Copyright (c) 2018 by Contributors
* \file vta_pynq_driver.c
* \file pynq_driver.c
* \brief VTA driver for Pynq board.
*/
......@@ -17,7 +17,7 @@ void VTAMemFree(void* buf) {
cma_free(buf);
}
vta_phy_addr_t VTAGetMemPhysAddr(void* buf) {
vta_phy_addr_t VTAMemGetPhyAddr(void* buf) {
return cma_get_phy_addr(buf);
}
......
......@@ -11,6 +11,7 @@
#include <vta/driver.h>
#include <vta/hw_spec.h>
#include <vta/runtime.h>
#include <dmlc/logging.h>
#include <cassert>
#include <vector>
......@@ -113,7 +114,7 @@ class UopKernel {
uint32_t wgt_index,
uint32_t opcode,
uint32_t use_imm,
uint32_t imm_val) {
int32_t imm_val) {
// The loop nest structure
VerifyDep(dst_index);
VTAUop op;
......@@ -166,7 +167,7 @@ class UopKernel {
uint32_t opcode_{0xFFFFFFFF};
uint32_t reset_out_{0xFFFFFFFF};
bool use_imm_{false};
uint16_t imm_val_{0};
int16_t imm_val_{0};
private:
// Verify that we don't write to the same acc_mem index two cycles in a row
......@@ -195,10 +196,6 @@ class UopKernel {
/*!
* \brief Base class of all queues to send and recv serial data.
* \param kElemBytes Element unit bytes.
* \param kMaxBytes Maximum number of bytes.
* \param kCoherent Whether we have coherent access to the buffer.
* \param kAlwaysCache Wether we should use cached memory.
*/
class BaseQueue {
public:
......@@ -227,7 +224,7 @@ class BaseQueue {
dram_buffer_ = static_cast<char*>(VTAMemAlloc(
max_bytes, coherent || always_cache_));
assert(dram_buffer_ != nullptr);
dram_phy_addr_ = VTAGetMemPhysAddr(dram_buffer_);
dram_phy_addr_ = VTAMemGetPhyAddr(dram_buffer_);
}
/*!
* \brief Reset the pointer of the buffer.
......@@ -597,14 +594,14 @@ class InsnQueue : public BaseQueue {
}
// Print instruction field information
if (c.mem.opcode == VTA_OPCODE_LOAD) {
printf("LOAD ");
if (c.mem.memory_type == VTA_MEM_ID_UOP) printf("UOP\n");
if (c.mem.memory_type == VTA_MEM_ID_WGT) printf("WGT\n");
if (c.mem.memory_type == VTA_MEM_ID_INP) printf("INP\n");
if (c.mem.memory_type == VTA_MEM_ID_ACC) printf("ACC\n");
printf("LOAD ");
if (c.mem.memory_type == VTA_MEM_ID_UOP) printf("UOP\n");
if (c.mem.memory_type == VTA_MEM_ID_WGT) printf("WGT\n");
if (c.mem.memory_type == VTA_MEM_ID_INP) printf("INP\n");
if (c.mem.memory_type == VTA_MEM_ID_ACC) printf("ACC\n");
}
if (c.mem.opcode == VTA_OPCODE_STORE) {
printf("STORE\n");
printf("STORE:\n");
}
printf("\tdep - pop prev: %d, pop next: %d, push prev: %d, push next: %d\n",
static_cast<int>(c.mem.pop_prev_dep),
......@@ -1210,7 +1207,7 @@ void VTAUopPush(uint32_t mode,
uint32_t wgt_index,
uint32_t opcode,
uint32_t use_imm,
uint32_t imm_val) {
int32_t imm_val) {
vta::CommandQueue::ThreadLocal()->record_kernel()
->Push(mode, reset_out, dst_index, src_index,
wgt_index, opcode, use_imm, imm_val);
......
/*!
* Copyright (c) 2018 by Contributors
* \file sim_driver.cc
* \brief VTA driver for simulated backend.
*/
#include <vta/driver.h>
#include <vta/hw_spec.h>
#include <tvm/runtime/registry.h>
#include <type_traits>
#include <mutex>
#include <map>
#include <unordered_map>
#include <cstring>
#include <sstream>
namespace vta {
namespace sim {
/*!
* \brief Helper class to pack and unpack bits
* Applies truncation when pack to low level bits.
*
* \tparam bits The number of bits in integer.
* \note This implementation relies on little endian.
*/
template<uint32_t bits>
class BitPacker {
public:
explicit BitPacker(void* data) {
data_ = static_cast<uint32_t*>(data);
}
uint32_t GetUnsigned(uint32_t index) const {
if (bits == 32) {
return data_[index];
} else if (bits == 16) {
return reinterpret_cast<uint16_t*>(data_)[index];
} else if (bits == 8) {
return reinterpret_cast<uint8_t*>(data_)[index];
} else {
uint32_t offset = index / kNumPackElem;
uint32_t shift = index % kNumPackElem;
return (data_[offset] >> shift) & kMask;
}
}
int32_t GetSigned(uint32_t index) const {
if (bits == 32) {
return reinterpret_cast<int32_t*>(data_)[index];
} else if (bits == 16) {
return reinterpret_cast<int16_t*>(data_)[index];
} else if (bits == 8) {
return reinterpret_cast<int8_t*>(data_)[index];
} else {
uint32_t offset = index / kNumPackElem;
uint32_t shift = (index % kNumPackElem) * bits;
int32_t uvalue = static_cast<int32_t>(
(data_[offset] >> shift) & kMask);
int kleft = 32 - bits;
return (uvalue << kleft) >> kleft;
}
}
void SetUnsigned(uint32_t index, uint32_t value) {
if (bits == 32) {
data_[index] = value;
} else if (bits == 16) {
reinterpret_cast<uint16_t*>(data_)[index] = value;
} else if (bits == 8) {
reinterpret_cast<uint8_t*>(data_)[index] = value;
} else {
uint32_t offset = index / kNumPackElem;
uint32_t shift = (index % kNumPackElem) * bits;
data_[offset] &= (~(kMask << shift));
data_[offset] |= (value & kMask) << shift;
}
}
void SetSigned(uint32_t index, int32_t value) {
if (bits == 32) {
reinterpret_cast<int32_t*>(data_)[index] = value;
} else if (bits == 16) {
reinterpret_cast<int16_t*>(data_)[index] = value;
} else if (bits == 8) {
reinterpret_cast<int8_t*>(data_)[index] = value;
} else {
uint32_t offset = index / kNumPackElem;
uint32_t shift = (index % kNumPackElem) * bits;
data_[offset] &= (~(kMask << shift));
data_[offset] |= static_cast<uint32_t>(value & kMask) << shift;
}
}
private:
uint32_t* data_;
static constexpr uint32_t kNumPackElem = 32 / bits;
static constexpr uint32_t kMask = (1U << (bits >= 32U ? 31U : bits)) - 1U;
};
/*!
* \brief DRAM memory manager
* Implements simple paging to allow physical address translation.
*/
class DRAM {
public:
/*!
* \brief Get virtual address given physical address.
* \param phy_addr The simulator phyiscal address.
* \return The true virtual address;
*/
void* GetAddr(uint64_t phy_addr) {
std::lock_guard<std::mutex> lock(mutex_);
uint64_t loc = (phy_addr >> kPageBits) - 1;
CHECK_LT(loc, ptable_.size());
Page* p = ptable_[loc];
CHECK(p != nullptr);
size_t offset = (loc - p->ptable_begin) << kPageBits;
offset += phy_addr & (kPageSize - 1);
return reinterpret_cast<char*>(p->data) + offset;
}
/*!
* \brief Get physical address
* \param buf The virtual address.
* \return The true physical address;
*/
vta_phy_addr_t GetPhyAddr(void* buf) {
std::lock_guard<std::mutex> lock(mutex_);
auto it = pmap_.find(buf);
CHECK(it != pmap_.end());
Page* p = it->second.get();
return (p->ptable_begin + 1) << kPageBits;
}
/*!
* \brief Allocate memory from manager
* \param size The size of memory
* \return The virtual address
*/
void* Alloc(size_t size) {
std::lock_guard<std::mutex> lock(mutex_);
size_t npage = (size + kPageSize - 1) / kPageSize;
auto it = free_map_.lower_bound(npage);
if (it != free_map_.end()) {
Page* p = it->second;
free_map_.erase(it);
return p->data;
}
size_t start = ptable_.size();
std::unique_ptr<Page> p(new Page(start, npage));
// insert page entry
ptable_.resize(start + npage, p.get());
void* data = p->data;
pmap_[data] = std::move(p);
return data;
}
/*!
* \brief Free the memory.
* \param size The size of memory
* \return The virtual address
*/
void Free(void* data) {
std::lock_guard<std::mutex> lock(mutex_);
auto it = pmap_.find(data);
CHECK(it != pmap_.end());
Page* p = it->second.get();
free_map_.insert(std::make_pair(p->num_pages, p));
}
static DRAM* Global() {
static DRAM inst;
return &inst;
}
private:
// The bits in page table
static constexpr vta_phy_addr_t kPageBits = 16;
// page size, also the maximum allocable size 16 K
static constexpr vta_phy_addr_t kPageSize = 1 << kPageBits;
/*! \brief A page in the DRAM */
struct Page {
/*! \brief Data Type */
using DType = typename std::aligned_storage<kPageSize, 256>::type;
/*! \brief Start location in page table */
size_t ptable_begin;
/*! \brief The total number of pages */
size_t num_pages;
/*! \brief Data */
DType* data{nullptr};
// construct a new page
explicit Page(size_t ptable_begin, size_t num_pages)
: ptable_begin(ptable_begin), num_pages(num_pages) {
data = new DType[num_pages];
}
~Page() {
delete [] data;
}
};
// Internal lock
std::mutex mutex_;
// Physical address -> page
std::vector<Page*> ptable_;
// virtual addres -> page
std::unordered_map<void*, std::unique_ptr<Page> > pmap_;
// Free map
std::multimap<size_t, Page*> free_map_;
};
/*!
* \brief Register file.
* \tparam kBits Number of bits of one value.
* \tparam kLane Number of lanes in one element.
* \tparam kMaxNumElem Maximum number of element.
*/
template<int kBits, int kLane, int kMaxNumElem>
class SRAM {
public:
/*! \brief Bytes of single vector element */
static const int kElemBytes = (kBits * kLane + 7) / 8;
/*! \brief content data type */
using DType = typename std::aligned_storage<kElemBytes, kElemBytes>::type;
SRAM() {
data_ = new DType[kMaxNumElem];
}
~SRAM() {
delete [] data_;
}
// Get the i-th index
void* BeginPtr(uint32_t index) {
CHECK_LT(index, kMaxNumElem);
return &(data_[index]);
}
// Execute the load instruction on this SRAM
void Load(const VTAMemInsn* op, DRAM* dram, uint64_t* load_counter) {
load_counter[0] += (op->x_size * op->y_size) * kElemBytes;
DType* sram_ptr = data_ + op->sram_base;
uint8_t* dram_ptr = static_cast<uint8_t*>(dram->GetAddr(
op->dram_base * kElemBytes));
uint64_t xtotal = op->x_size + op->x_pad_0 + op->x_pad_1;
uint32_t ytotal = op->y_size + op->y_pad_0 + op->y_pad_1;
uint64_t sram_end = op->sram_base + xtotal * ytotal;
CHECK_LE(sram_end, kMaxNumElem);
memset(sram_ptr, 0, kElemBytes * xtotal * op->y_pad_0);
sram_ptr += xtotal * op->y_pad_0;
for (uint32_t y = 0; y < op->y_size; ++y) {
memset(sram_ptr, 0, kElemBytes * op->x_pad_0);
sram_ptr += op->x_pad_0;
memcpy(sram_ptr, dram_ptr, kElemBytes * op->x_size);
sram_ptr += op->x_size;
BitPacker<kBits> src(sram_ptr);
memset(sram_ptr, 0, kElemBytes * op->x_pad_1);
sram_ptr += op->x_pad_1;
dram_ptr += kElemBytes * op->x_stride;
}
memset(sram_ptr, 0, kElemBytes * xtotal * op->y_pad_1);
}
// Execute the store instruction on this SRAM apply trucation.
// This relies on the elements is 32 bits
template<int target_bits>
void TruncStore(const VTAMemInsn* op, DRAM* dram) {
CHECK_EQ(op->x_pad_0, 0);
CHECK_EQ(op->x_pad_1, 0);
CHECK_EQ(op->y_pad_0, 0);
CHECK_EQ(op->y_pad_1, 0);
int target_width = (target_bits * kLane + 7) / 8;
BitPacker<kBits> src(data_ + op->sram_base);
BitPacker<target_bits> dst(dram->GetAddr(op->dram_base * target_width));
for (uint32_t y = 0; y < op->y_size; ++y) {
for (uint32_t x = 0; x < op->x_size; ++x) {
uint32_t sram_base = y * op->x_size + x;
uint32_t dram_base = y * op->x_stride + x;
for (int i = 0; i < kLane; ++i) {
dst.SetSigned(dram_base * kLane + i,
src.GetSigned(sram_base * kLane +i));
}
}
}
}
private:
/*! \brief internal data content */
DType* data_;
};
/*!
* \brief Memory information of special memory region.
* Use MemoryInfo as its container type
*/
class Profiler {
public:
/*! \brief The memory load statistics */
uint64_t inp_load_nbytes{0};
/*! \brief The memory load statistics */
uint64_t wgt_load_nbytes{0};
/*! \brief The ACC memory load statistics */
uint64_t acc_load_nbytes{0};
/*! \brief The ACC memory load statistics */
uint64_t uop_load_nbytes{0};
/*! \brief The ACC memory load statistics */
uint64_t out_store_nbytes{0};
/*! \brief instr counter for gemm */
uint64_t gemm_counter{0};
/*! \brief instr counter for ALU ops */
uint64_t alu_counter{0};
/*! \brief clear the profiler */
void Clear() {
inp_load_nbytes = 0;
wgt_load_nbytes = 0;
acc_load_nbytes = 0;
uop_load_nbytes = 0;
out_store_nbytes = 0;
gemm_counter = 0;
alu_counter = 0;
}
std::string AsJSON() {
std::ostringstream os;
os << "{\n"
<< " \"inp_load_nbytes\":" << inp_load_nbytes << ",\n"
<< " \"wgt_load_nbytes\":" << wgt_load_nbytes << ",\n"
<< " \"acc_load_nbytes\":" << acc_load_nbytes << ",\n"
<< " \"uop_load_nbytes\":" << uop_load_nbytes << ",\n"
<< " \"out_store_nbytes\":" << out_store_nbytes << ",\n"
<< " \"gemm_counter\":" << gemm_counter << ",\n"
<< " \"alu_counter\":" << alu_counter << "\n"
<<"}\n";
return os.str();
}
static Profiler* ThreadLocal() {
static thread_local Profiler inst;
return &inst;
}
};
// Simulate device
// TODO(tqchen,thierry): queue based event driven simulation.
class Device {
public:
Device() {
prof_ = Profiler::ThreadLocal();
dram_ = DRAM::Global();
}
int Run(vta_phy_addr_t insn_phy_addr,
uint32_t insn_count,
uint32_t wait_cycles) {
VTAGenericInsn* insn = static_cast<VTAGenericInsn*>(
dram_->GetAddr(insn_phy_addr));
finish_counter_ = 0;
for (uint32_t i = 0; i < insn_count; ++i) {
this->Run(insn + i);
}
return 0;
}
private:
void Run(const VTAGenericInsn* insn) {
const VTAMemInsn* mem = reinterpret_cast<const VTAMemInsn*>(insn);
const VTAGemInsn* gem = reinterpret_cast<const VTAGemInsn*>(insn);
const VTAAluInsn* alu = reinterpret_cast<const VTAAluInsn*>(insn);
switch (mem->opcode) {
case VTA_OPCODE_LOAD: RunLoad(mem); break;
case VTA_OPCODE_STORE: RunStore(mem); break;
case VTA_OPCODE_GEMM: RunGEMM(gem); break;
case VTA_OPCODE_ALU: RunALU(alu); break;
case VTA_OPCODE_FINISH: ++finish_counter_; break;
default: {
LOG(FATAL) << "Unknown op_code" << mem->opcode;
}
}
}
void RunLoad(const VTAMemInsn* op) {
if (op->x_size == 0) return;
if (op->memory_type == VTA_MEM_ID_INP) {
inp_.Load(op, dram_, &(prof_->inp_load_nbytes));
} else if (op->memory_type == VTA_MEM_ID_WGT) {
wgt_.Load(op, dram_, &(prof_->wgt_load_nbytes));
} else if (op->memory_type == VTA_MEM_ID_ACC) {
acc_.Load(op, dram_, &(prof_->acc_load_nbytes));
} else if (op->memory_type == VTA_MEM_ID_UOP) {
uop_.Load(op, dram_, &(prof_->uop_load_nbytes));
} else {
LOG(FATAL) << "Unknown memory_type=" << op->memory_type;
}
}
void RunStore(const VTAMemInsn* op) {
if (op->memory_type == VTA_MEM_ID_ACC ||
op->memory_type == VTA_MEM_ID_UOP) {
prof_->out_store_nbytes += (
op->x_size * op->y_size * VTA_BATCH * VTA_BLOCK_OUT * VTA_OUT_WIDTH / 8);
acc_.TruncStore<VTA_OUT_WIDTH>(op, dram_);
} else {
LOG(FATAL) << "Store do not support memory_type="
<< op->memory_type;
}
}
void RunGEMM(const VTAGemInsn* op) {
if (!op->reset_reg) {
prof_->gemm_counter += op->iter_out * op->iter_in;
for (uint32_t y = 0; y < op->iter_out; ++y) {
for (uint32_t x = 0; x < op->iter_in; ++x) {
for (uint32_t uindex = op->uop_bgn; uindex < op->uop_end; ++uindex) {
VTAUop* uop_ptr = static_cast<VTAUop*>(uop_.BeginPtr(uindex));
// Read in memory indices
uint32_t acc_idx = uop_ptr->dst_idx;
uint32_t inp_idx = uop_ptr->src_idx;
uint32_t wgt_idx = uop_ptr->wgt_idx;
acc_idx += y * op->dst_factor_out + x * op->dst_factor_in;
inp_idx += y * op->src_factor_out + x * op->src_factor_in;
wgt_idx += y * op->wgt_factor_out + x * op->wgt_factor_in;
BitPacker<VTA_ACC_WIDTH> acc(acc_.BeginPtr(acc_idx));
BitPacker<VTA_INP_WIDTH> inp(inp_.BeginPtr(inp_idx));
BitPacker<VTA_WGT_WIDTH> wgt(wgt_.BeginPtr(wgt_idx));
// gemm loop
for (uint32_t i = 0; i < VTA_BATCH; ++i) {
for (uint32_t j = 0; j < VTA_BLOCK_OUT; ++j) {
uint32_t acc_offset = i * VTA_BLOCK_OUT + j;
int32_t sum = acc.GetSigned(acc_offset);
for (uint32_t k = 0; k < VTA_BLOCK_IN; ++k) {
sum +=
inp.GetSigned(i * VTA_BLOCK_IN + k) *
wgt.GetSigned(j * VTA_BLOCK_IN + k);
}
acc.SetSigned(acc_offset, sum);
}
}
}
}
}
} else {
// reset
for (uint32_t y = 0; y < op->iter_out; ++y) {
for (uint32_t x = 0; x < op->iter_in; ++x) {
for (uint32_t uindex = op->uop_bgn; uindex < op->uop_end; ++uindex) {
VTAUop* uop_ptr = static_cast<VTAUop*>(uop_.BeginPtr(uindex));
uint32_t acc_idx = uop_ptr->dst_idx;
acc_idx += y * op->dst_factor_out + x * op->dst_factor_in;
BitPacker<VTA_ACC_WIDTH> acc(acc_.BeginPtr(acc_idx));
for (uint32_t i = 0; i < VTA_BATCH * VTA_BLOCK_OUT; ++i) {
acc.SetSigned(i, 0);
}
}
}
}
}
}
void RunALU(const VTAAluInsn* op) {
prof_->alu_counter += op->iter_out * op->iter_in;
if (op->use_imm) {
RunALU_<true>(op);
} else {
RunALU_<false>(op);
}
}
template<bool use_imm>
void RunALU_(const VTAAluInsn* op) {
switch (op->alu_opcode) {
case VTA_ALU_OPCODE_ADD: {
return RunALULoop<use_imm>(op, [](int32_t x, int32_t y) {
return x + y;
});
}
case VTA_ALU_OPCODE_MAX: {
return RunALULoop<use_imm>(op, [](int32_t x, int32_t y) {
return std::max(x, y);
});
}
case VTA_ALU_OPCODE_MIN: {
return RunALULoop<use_imm>(op, [](int32_t x, int32_t y) {
return std::min(x, y);
});
}
case VTA_ALU_OPCODE_SHR: {
return RunALULoop<use_imm>(op, [](int32_t x, int32_t y) {
if (y >= 0) {
return x >> y;
} else {
return x << (-y);
}
});
}
default: {
LOG(FATAL) << "Unknown ALU code " << op->alu_opcode;
}
}
}
template<bool use_imm, typename F>
void RunALULoop(const VTAAluInsn* op, F func) {
for (int y = 0; y < op->iter_out; ++y) {
for (int x = 0; x < op->iter_in; ++x) {
for (int k = op->uop_bgn; k < op->uop_end; ++k) {
// Read micro op
VTAUop* uop_ptr = static_cast<VTAUop*>(uop_.BeginPtr(k));
uint32_t dst_index = uop_ptr->dst_idx;
uint32_t src_index = uop_ptr->src_idx;
dst_index += y * op->dst_factor_out + x * op->dst_factor_in;
src_index += y * op->src_factor_out + x * op->src_factor_in;
BitPacker<VTA_ACC_WIDTH> dst(acc_.BeginPtr(dst_index));
BitPacker<VTA_ACC_WIDTH> src(acc_.BeginPtr(src_index));
for (int k = 0; k < VTA_BLOCK_OUT; ++k) {
if (use_imm) {
dst.SetSigned(k, func(dst.GetSigned(k), op->imm));
} else {
dst.SetSigned(k, func(dst.GetSigned(k), src.GetSigned(k)));
}
}
}
}
}
}
// the finish counter
int finish_counter_{0};
// Prof_
Profiler* prof_;
// The DRAM interface
DRAM* dram_;
// The SRAM
SRAM<VTA_INP_WIDTH, VTA_BATCH * VTA_BLOCK_IN, VTA_INP_BUFF_DEPTH> inp_;
SRAM<VTA_WGT_WIDTH, VTA_BLOCK_IN * VTA_BLOCK_OUT, VTA_WGT_BUFF_DEPTH> wgt_;
SRAM<VTA_ACC_WIDTH, VTA_BATCH * VTA_BLOCK_OUT, VTA_ACC_BUFF_DEPTH> acc_;
SRAM<VTA_UOP_WIDTH, 1, VTA_UOP_BUFF_DEPTH> uop_;
};
using tvm::runtime::TVMRetValue;
using tvm::runtime::TVMArgs;
TVM_REGISTER_GLOBAL("vta.simulator.profiler_clear")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Profiler::ThreadLocal()->Clear();
});
TVM_REGISTER_GLOBAL("vta.simulator.profiler_status")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = Profiler::ThreadLocal()->AsJSON();
});
} // namespace sim
} // namespace vta
void* VTAMemAlloc(size_t size, int cached) {
return vta::sim::DRAM::Global()->Alloc(size);
}
void VTAMemFree(void* buf) {
vta::sim::DRAM::Global()->Free(buf);
}
vta_phy_addr_t VTAMemGetPhyAddr(void* buf) {
return vta::sim::DRAM::Global()->GetPhyAddr(buf);
}
void VTAFlushCache(vta_phy_addr_t buf, int size) {
}
void VTAInvalidateCache(vta_phy_addr_t buf, int size) {
}
VTADeviceHandle VTADeviceAlloc() {
return new vta::sim::Device();
}
void VTADeviceFree(VTADeviceHandle handle) {
delete static_cast<vta::sim::Device*>(handle);
}
int VTADeviceRun(VTADeviceHandle handle,
vta_phy_addr_t insn_phy_addr,
uint32_t insn_count,
uint32_t wait_cycles) {
return static_cast<vta::sim::Device*>(handle)->Run(
insn_phy_addr, insn_count, wait_cycles);
}
void VTAProgram(const char* bitstream) {
}
......@@ -67,9 +67,6 @@ class VTADeviceAPI final : public DeviceAPI {
std::make_shared<VTADeviceAPI>();
return inst;
}
private:
void* runtime_dll_{nullptr};
};
struct VTAWorkspacePool : public WorkspacePool {
......
"""Unit test VTA's instructions """
import tvm
import vta
import mxnet as mx
import numpy as np
import topi
from tvm.contrib import rpc, util
host = "pynq"
port = 9091
target = "llvm -target=armv7-none-linux-gnueabihf"
do_verify = True
print_ir = False
import vta
import vta.testing
from vta.testing import simulator
def test_save_load_out():
env = vta.get_env()
"""Test save/store output command"""
n = 4
x = tvm.placeholder(
(n, n, env.BATCH, env.BLOCK_OUT),
name="x",
dtype=env.acc_dtype)
x_buf = tvm.compute(
(n, n, env.BATCH, env.BLOCK_OUT),
lambda *i: x(*i),
"x_buf")
# insert no-op that won't be optimized away
y_buf = tvm.compute(
(n, n, env.BATCH, env.BLOCK_OUT),
lambda *i: x_buf(*i)>>0,
"y_buf")
y = tvm.compute(
(n, n, env.BATCH, env.BLOCK_OUT),
lambda *i: y_buf(*i).astype(env.inp_dtype),
"y")
# schedule
s = tvm.create_schedule(y.op)
s[x_buf].set_scope(env.acc_scope)
s[x_buf].pragma(x_buf.op.axis[0], env.dma_copy)
s[y_buf].set_scope(env.acc_scope)
s[y_buf].pragma(y_buf.op.axis[0], env.alu)
s[y].pragma(y.op.axis[0], env.dma_copy)
def verify():
# build
with vta.build_config(env.DEBUG_DUMP_INSN):
m = vta.build(s, [x, y], "ext_dev", target)
def _run(env, remote):
n = 6
x = tvm.placeholder(
(n, n, env.BATCH, env.BLOCK_OUT),
name="x",
dtype=env.acc_dtype)
x_buf = tvm.compute(
(n, n, env.BATCH, env.BLOCK_OUT),
lambda *i: x(*i), "x_buf")
# insert no-op that won't be optimized away
y_buf = tvm.compute(
(n, n, env.BATCH, env.BLOCK_OUT),
lambda *i: x_buf(*i)>>0, "y_buf")
y = tvm.compute(
(n, n, env.BATCH, env.BLOCK_OUT),
lambda *i: y_buf(*i).astype(env.inp_dtype), "y")
# schedule
s = tvm.create_schedule(y.op)
s[x_buf].set_scope(env.acc_scope)
s[x_buf].pragma(x_buf.op.axis[0], env.dma_copy)
s[y_buf].set_scope(env.acc_scope)
s[y_buf].pragma(y_buf.op.axis[0], env.alu)
s[y].pragma(y.op.axis[0], env.dma_copy)
# verification
with vta.build_config():
m = vta.build(s, [x, y], "ext_dev", env.target_host)
if not remote:
return
temp = util.tempdir()
remote = rpc.connect(host, port)
m.save(temp.relpath("load_act.o"))
remote.upload(temp.relpath("load_act.o"))
f = remote.load_module("load_act.o")
......@@ -59,47 +54,46 @@ def test_save_load_out():
y_nd = tvm.nd.empty(y_np.shape, ctx=ctx, dtype=y_np.dtype)
f(x_nd, y_nd)
np.testing.assert_equal(y_np, y_nd.asnumpy())
print("\tFinished verification...")
if do_verify:
verify()
vta.testing.run(_run)
def test_padded_load():
"""Test padded load."""
env = vta.get_env()
# declare
n = 21
m = 20
pad_before = [0, 1, 0, 0]
pad_after = [1, 3, 0, 0]
x = tvm.placeholder(
(n, m, env.BATCH, env.BLOCK_OUT),
name="x",
dtype=env.acc_dtype)
x_buf = topi.nn.pad(x, pad_before, pad_after, name="y")
# insert no-op that won't be optimized away
y_buf = tvm.compute((n + pad_before[0] + pad_after[0],
def _run(env, remote):
# declare
n = 21
m = 20
pad_before = [0, 1, 0, 0]
pad_after = [1, 3, 0, 0]
x = tvm.placeholder(
(n, m, env.BATCH, env.BLOCK_OUT),
name="x",
dtype=env.acc_dtype)
x_buf = topi.nn.pad(x, pad_before, pad_after, name="y")
# insert no-op that won't be optimized away
y_buf = tvm.compute((n + pad_before[0] + pad_after[0],
m + pad_before[1] + pad_after[1],
env.BATCH,
env.BLOCK_OUT), lambda *i: x_buf(*i)>>0, "y_buf")
y = tvm.compute((n + pad_before[0] + pad_after[0],
m + pad_before[1] + pad_after[1],
env.BATCH,
env.BLOCK_OUT), lambda *i: x_buf(*i)>>0, "y_buf")
y = tvm.compute((n + pad_before[0] + pad_after[0],
m + pad_before[1] + pad_after[1],
env.BATCH,
env.BLOCK_OUT), lambda *i: y_buf(*i).astype(env.inp_dtype), "y")
# schedule
s = tvm.create_schedule(y.op)
s[x_buf].set_scope(env.acc_scope)
s[x_buf].pragma(x_buf.op.axis[0], env.dma_copy)
s[y_buf].set_scope(env.acc_scope)
s[y_buf].pragma(y_buf.op.axis[0], env.alu)
s[y].pragma(y.op.axis[0], env.dma_copy)
def verify():
env.BLOCK_OUT), lambda *i: y_buf(*i).astype(env.inp_dtype), "y")
# schedule
s = tvm.create_schedule(y.op)
s[x_buf].set_scope(env.acc_scope)
s[x_buf].pragma(x_buf.op.axis[0], env.dma_copy)
s[y_buf].set_scope(env.acc_scope)
s[y_buf].pragma(y_buf.op.axis[0], env.alu)
s[y].pragma(y.op.axis[0], env.dma_copy)
# build
with vta.build_config(env.DEBUG_DUMP_INSN):
mod = vta.build(s, [x, y], "ext_dev", target)
with vta.build_config():
mod = vta.build(s, [x, y], "ext_dev", env.target_host)
if not remote:
return
temp = util.tempdir()
remote = rpc.connect(host, port)
mod.save(temp.relpath("padded_load.o"))
remote.upload(temp.relpath("padded_load.o"))
f = remote.load_module("padded_load.o")
......@@ -118,285 +112,287 @@ def test_padded_load():
y_nd = tvm.nd.empty(y_np.shape, ctx=ctx, dtype=y_np.dtype)
f(x_nd, y_nd)
np.testing.assert_equal(y_np, y_nd.asnumpy())
print("\tFinished verification...")
if print_ir:
print(vta.lower(s, [y, x], simple_mode=True))
vta.testing.run(_run)
def test_gemm():
"""Test GEMM."""
env = vta.get_env()
# declare
o = 4
n = 4
m = 4
x = tvm.placeholder((o, n, env.BATCH, env.BLOCK_IN), name="x", dtype=env.inp_dtype)
w = tvm.placeholder((m, n, env.BLOCK_OUT, env.BLOCK_IN), name="w", dtype=env.wgt_dtype)
x_buf = tvm.compute((o, n, env.BATCH, env.BLOCK_IN), lambda *i: x(*i), "x_buf")
w_buf = tvm.compute((m, n, env.BLOCK_OUT, env.BLOCK_IN), lambda *i: w(*i), "w_buf")
ko = tvm.reduce_axis((0, n), name="ko")
ki = tvm.reduce_axis((0, env.BLOCK_IN), name="ki")
y_gem = tvm.compute(
(o, m, env.BATCH, env.BLOCK_OUT),
lambda bo, co, bi, ci:
def _run(env, remote):
# declare
o = 4
n = 1
m = 4
x = tvm.placeholder((o, n, env.BATCH, env.BLOCK_IN), name="x", dtype=env.inp_dtype)
w = tvm.placeholder((m, n, env.BLOCK_OUT, env.BLOCK_IN), name="w", dtype=env.wgt_dtype)
x_buf = tvm.compute((o, n, env.BATCH, env.BLOCK_IN), lambda *i: x(*i), "x_buf")
w_buf = tvm.compute((m, n, env.BLOCK_OUT, env.BLOCK_IN), lambda *i: w(*i), "w_buf")
ko = tvm.reduce_axis((0, n), name="ko")
ki = tvm.reduce_axis((0, env.BLOCK_IN), name="ki")
y_gem = tvm.compute(
(o, m, env.BATCH, env.BLOCK_OUT),
lambda bo, co, bi, ci:
tvm.sum(x_buf[bo, ko, bi, ki].astype(env.acc_dtype) *
w_buf[co, ko, ci, ki].astype(env.acc_dtype),
axis=[ko, ki]),
name="y_gem")
y_shf = tvm.compute(
(o, m, env.BATCH, env.BLOCK_OUT),
lambda *i: y_gem(*i)>>8,
name="y_shf")
y_max = tvm.compute(
(o, m, env.BATCH, env.BLOCK_OUT),
lambda *i: tvm.max(y_shf(*i), 0),
"y_max") #relu
y_min = tvm.compute(
(o, m, env.BATCH, env.BLOCK_OUT),
lambda *i: tvm.min(y_max(*i), (1<<(env.INP_WIDTH-1))-1),
"y_min") #relu
y = tvm.compute(
(o, m, env.BATCH, env.BLOCK_OUT),
lambda *i: y_min(*i).astype(env.inp_dtype),
name="y")
def verify(s):
mod = vta.build(s, [x, w, y], "ext_dev", target)
temp = util.tempdir()
remote = rpc.connect(host, port)
mod.save(temp.relpath("gemm.o"))
remote.upload(temp.relpath("gemm.o"))
f = remote.load_module("gemm.o")
# verify
ctx = remote.ext_dev(0)
x_np = np.random.randint(
-128, 128, size=(o, n, env.BATCH, env.BLOCK_IN)).astype(x.dtype)
w_np = np.random.randint(
-128, 128, size=(m, n, env.BLOCK_OUT, env.BLOCK_IN)).astype(w.dtype)
y_np = np.zeros((o, m, env.BATCH, env.BLOCK_OUT)).astype(y.dtype)
x_nd = tvm.nd.array(x_np, ctx)
w_nd = tvm.nd.array(w_np, ctx)
y_nd = tvm.nd.array(y_np, ctx)
y_np = y_np.astype(env.acc_dtype)
for b in range(o):
for i in range(m):
for j in range(n):
y_np[b,i,:] += np.dot(x_np[b,j,:].astype(env.acc_dtype),
w_np[i,j].T.astype(env.acc_dtype))
y_np = np.right_shift(y_np, 8)
y_np = np.clip(y_np, 0, (1<<(env.INP_WIDTH-1))-1).astype(y.dtype)
f(x_nd, w_nd, y_nd)
np.testing.assert_equal(y_np, y_nd.asnumpy())
print("\tFinished verification...")
name="y_gem")
y_shf = tvm.compute(
(o, m, env.BATCH, env.BLOCK_OUT),
lambda *i: y_gem(*i)>>8,
name="y_shf")
y_max = tvm.compute(
(o, m, env.BATCH, env.BLOCK_OUT),
lambda *i: tvm.max(y_shf(*i), 0),
"y_max") #relu
y_min = tvm.compute(
(o, m, env.BATCH, env.BLOCK_OUT),
lambda *i: tvm.min(y_max(*i), (1<<(env.INP_WIDTH-1))-1),
"y_min") #relu
y = tvm.compute(
(o, m, env.BATCH, env.BLOCK_OUT),
lambda *i: y_min(*i).astype(env.inp_dtype),
name="y")
def test_schedule1():
# default schedule with no smt
s = tvm.create_schedule(y.op)
# set the scope of the SRAM buffers
s[x_buf].set_scope(env.SCOPE_INP)
s[w_buf].set_scope(env.SCOPE_WGT)
s[y_gem].set_scope(env.acc_scope)
s[y_shf].set_scope(env.acc_scope)
s[y_max].set_scope(env.acc_scope)
s[y_min].set_scope(env.acc_scope)
# set pragmas for DMA transfer and ALU ops
s[x_buf].pragma(s[x_buf].op.axis[0], env.dma_copy)
s[w_buf].pragma(s[w_buf].op.axis[0], env.dma_copy)
s[y_shf].pragma(s[y_shf].op.axis[0], env.alu)
s[y_max].pragma(s[y_max].op.axis[0], env.alu)
s[y_min].pragma(s[y_min].op.axis[0], env.alu)
s[y].pragma(s[y].op.axis[0], env.dma_copy)
# tensorization
s[y_gem].reorder(
ko,
s[y_gem].op.axis[0],
s[y_gem].op.axis[1],
s[y_gem].op.axis[2],
s[y_gem].op.axis[3],
ki)
s[y_gem].tensorize(s[y_gem].op.axis[2], env.GEMM)
if print_ir:
print(vta.lower(s, [x, w, y], simple_mode=True))
if do_verify:
with vta.build_config(env.DEBUG_DUMP_INSN):
verify(s)
def test_smt():
# test smt schedule
s = tvm.create_schedule(y.op)
s[x_buf].set_scope(env.SCOPE_INP)
s[w_buf].set_scope(env.SCOPE_WGT)
s[y_gem].set_scope(env.acc_scope)
s[y_shf].set_scope(env.acc_scope)
s[y_max].set_scope(env.acc_scope)
s[y_min].set_scope(env.acc_scope)
abo, aco, abi, aci = s[y].op.axis
abo1, abo2 = s[y].split(abo, nparts=2)
s[y].bind(abo1, tvm.thread_axis("cthread"))
s[y_gem].compute_at(s[y], abo1)
s[y_shf].compute_at(s[y], abo1)
s[y_max].compute_at(s[y], abo1)
s[y_min].compute_at(s[y], abo1)
s[y_gem].reorder(
ko,
s[y_gem].op.axis[0],
s[y_gem].op.axis[1],
s[y_gem].op.axis[2],
s[y_gem].op.axis[3],
ki)
s[y_gem].tensorize(s[y_gem].op.axis[2], env.GEMM)
s[y_shf].pragma(s[y_shf].op.axis[0], env.alu)
s[y_max].pragma(s[y_max].op.axis[0], env.alu)
s[y_min].pragma(s[y_min].op.axis[0], env.alu)
s[x_buf].compute_at(s[y_gem], ko)
s[x_buf].pragma(s[x_buf].op.axis[0], env.dma_copy)
s[w_buf].compute_at(s[y_gem], ko)
s[w_buf].pragma(s[w_buf].op.axis[0], env.dma_copy)
s[y].pragma(abo2, env.dma_copy)
if print_ir:
print(vta.lower(s, [x, y, w], simple_mode=True))
if do_verify:
with vta.build_config(env.DEBUG_DUMP_INSN):
verify(s)
test_schedule1()
test_smt()
def test_alu(tvm_op, np_op=None, use_imm=False):
"""Test ALU"""
env = vta.get_env()
m = 8
n = 8
imm = np.random.randint(1,5)
# compute
a = tvm.placeholder(
(m, n, env.BATCH, env.BLOCK_OUT),
name="a",
dtype=env.acc_dtype)
a_buf = tvm.compute(
(m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: a(*i),
"a_buf") #DRAM->SRAM
if use_imm:
res_buf = tvm.compute(
(m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: tvm_op(a_buf(*i), imm),
"res_buf") #compute
else:
b = tvm.placeholder(
(m, n, env.BATCH, env.BLOCK_OUT),
name="b",
dtype=env.acc_dtype)
b_buf = tvm.compute(
(m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: b(*i),
"b_buf") #DRAM->SRAM
res_buf = tvm.compute(
(m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: tvm_op(a_buf(*i), b_buf(*i)),
"res_buf") #compute
res = tvm.compute(
(m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: res_buf(*i).astype(env.inp_dtype),
"res") #SRAM->DRAM
# schedule
s = tvm.create_schedule(res.op)
s[a_buf].set_scope(env.acc_scope) # SRAM
s[a_buf].pragma(a_buf.op.axis[0], env.dma_copy) # DRAM->SRAM
s[res_buf].set_scope(env.acc_scope) # SRAM
s[res_buf].pragma(res_buf.op.axis[0], env.alu) # compute
s[res].pragma(res.op.axis[0], env.dma_copy) # SRAM->DRAM
if use_imm:
if print_ir:
print(vta.lower(s, [a, res], simple_mode=True))
else:
s[b_buf].set_scope(env.acc_scope) # SRAM
s[b_buf].pragma(b_buf.op.axis[0], env.dma_copy) # DRAM->SRAM
if print_ir:
print(vta.lower(s, [a, b, res], simple_mode=True))
def verify():
# build
with vta.build_config():
if not remote:
return
def verify(s):
mod = vta.build(s, [x, w, y], "ext_dev", env.target_host)
temp = util.tempdir()
mod.save(temp.relpath("gemm.o"))
remote.upload(temp.relpath("gemm.o"))
f = remote.load_module("gemm.o")
# verify
ctx = remote.ext_dev(0)
x_np = np.random.randint(
-128, 128, size=(o, n, env.BATCH, env.BLOCK_IN)).astype(x.dtype)
w_np = np.random.randint(
-128, 128, size=(m, n, env.BLOCK_OUT, env.BLOCK_IN)).astype(w.dtype)
y_np = np.zeros((o, m, env.BATCH, env.BLOCK_OUT)).astype(y.dtype)
x_nd = tvm.nd.array(x_np, ctx)
w_nd = tvm.nd.array(w_np, ctx)
y_nd = tvm.nd.array(y_np, ctx)
y_np = y_np.astype(env.acc_dtype)
for b in range(o):
for i in range(m):
for j in range(n):
y_np[b,i,:] += np.dot(x_np[b,j,:].astype(env.acc_dtype),
w_np[i,j].T.astype(env.acc_dtype))
y_np = np.right_shift(y_np, 8)
y_np = np.clip(y_np, 0, (1<<(env.INP_WIDTH-1))-1).astype(y.dtype)
if env.TARGET == "sim":
simulator.clear_stats()
f(x_nd, w_nd, y_nd)
print(simulator.stats())
else:
f(x_nd, w_nd, y_nd)
np.testing.assert_equal(y_np, y_nd.asnumpy())
def test_schedule1():
# default schedule with no smt
s = tvm.create_schedule(y.op)
# set the scope of the SRAM buffers
s[x_buf].set_scope(env.inp_scope)
s[w_buf].set_scope(env.wgt_scope)
s[y_gem].set_scope(env.acc_scope)
s[y_shf].set_scope(env.acc_scope)
s[y_max].set_scope(env.acc_scope)
s[y_min].set_scope(env.acc_scope)
# set pragmas for DMA transfer and ALU ops
s[x_buf].compute_at(s[y_gem], ko)
s[x_buf].pragma(s[x_buf].op.axis[0], env.dma_copy)
s[w_buf].compute_at(s[y_gem], ko)
s[w_buf].pragma(s[w_buf].op.axis[0], env.dma_copy)
s[y_shf].pragma(s[y_shf].op.axis[0], env.alu)
s[y_max].pragma(s[y_max].op.axis[0], env.alu)
s[y_min].pragma(s[y_min].op.axis[0], env.alu)
s[y].pragma(s[y].op.axis[0], env.dma_copy)
# tensorization
s[y_gem].reorder(
ko,
s[y_gem].op.axis[0],
s[y_gem].op.axis[1],
s[y_gem].op.axis[2],
s[y_gem].op.axis[3],
ki)
s[y_gem].tensorize(s[y_gem].op.axis[2], env.gemm)
verify(s)
def test_smt():
# test smt schedule
s = tvm.create_schedule(y.op)
s[x_buf].set_scope(env.inp_scope)
s[w_buf].set_scope(env.wgt_scope)
s[y_gem].set_scope(env.acc_scope)
s[y_shf].set_scope(env.acc_scope)
s[y_max].set_scope(env.acc_scope)
s[y_min].set_scope(env.acc_scope)
abo, aco, abi, aci = s[y].op.axis
abo1, abo2 = s[y].split(abo, nparts=2)
s[y].bind(abo1, tvm.thread_axis("cthread"))
s[y_gem].compute_at(s[y], abo1)
s[y_shf].compute_at(s[y], abo1)
s[y_max].compute_at(s[y], abo1)
s[y_min].compute_at(s[y], abo1)
s[y_gem].reorder(
ko,
s[y_gem].op.axis[0],
s[y_gem].op.axis[1],
s[y_gem].op.axis[2],
s[y_gem].op.axis[3],
ki)
s[y_gem].tensorize(s[y_gem].op.axis[2], env.gemm)
s[y_shf].pragma(s[y_shf].op.axis[0], env.alu)
s[y_max].pragma(s[y_max].op.axis[0], env.alu)
s[y_min].pragma(s[y_min].op.axis[0], env.alu)
s[x_buf].compute_at(s[y_gem], ko)
s[x_buf].pragma(s[x_buf].op.axis[0], env.dma_copy)
s[w_buf].compute_at(s[y_gem], ko)
s[w_buf].pragma(s[w_buf].op.axis[0], env.dma_copy)
s[y].pragma(abo2, env.dma_copy)
verify(s)
test_schedule1()
test_smt()
vta.testing.run(_run)
def test_alu():
def _run(env, remote):
def check_alu(tvm_op, np_op=None, use_imm=False):
"""Test ALU"""
m = 8
n = 8
imm = np.random.randint(1,5)
# compute
a = tvm.placeholder(
(m, n, env.BATCH, env.BLOCK_OUT),
name="a",
dtype=env.acc_dtype)
a_buf = tvm.compute(
(m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: a(*i),
"a_buf") #DRAM->SRAM
if use_imm:
mod = vta.build(s, [a, res], "ext_dev", target)
res_buf = tvm.compute(
(m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: tvm_op(a_buf(*i), imm),
"res_buf") #compute
else:
mod = vta.build(s, [a, b, res], "ext_dev", target)
temp = util.tempdir()
remote = rpc.connect(host, port)
mod.save(temp.relpath("load_act.o"))
remote.upload(temp.relpath("load_act.o"))
f = remote.load_module("load_act.o")
# verify
ctx = remote.ext_dev(0)
a_np = np.random.randint(
-16, 16, size=(m, n, env.BATCH, env.BLOCK_OUT)).astype(a.dtype)
if use_imm:
res_np = np_op(a_np, imm) if np_op else tvm_op(a_np, imm)
else:
b_np = np.random.randint(
-16, 16, size=(m, n, env.BATCH, env.BLOCK_OUT)).astype(b.dtype)
res_np = np_op(a_np, b_np) if np_op else tvm_op(a_np, b_np)
res_np = res_np.astype(res.dtype)
a_nd = tvm.nd.array(a_np, ctx)
res_nd = tvm.nd.array(
np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx)
if use_imm:
f(a_nd, res_nd)
else:
b_nd = tvm.nd.array(b_np, ctx)
f(a_nd, b_nd, res_nd)
np.testing.assert_equal(res_np, res_nd.asnumpy())
print("\tFinished verification...")
b = tvm.placeholder(
(m, n, env.BATCH, env.BLOCK_OUT),
name="b",
dtype=env.acc_dtype)
b_buf = tvm.compute(
(m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: b(*i),
"b_buf") #DRAM->SRAM
res_buf = tvm.compute(
(m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: tvm_op(a_buf(*i), b_buf(*i)),
"res_buf") #compute5B
res = tvm.compute(
(m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: res_buf(*i).astype(env.inp_dtype),
"res") #SRAM->DRAM
# schedule
s = tvm.create_schedule(res.op)
s[a_buf].set_scope(env.acc_scope) # SRAM
s[a_buf].pragma(a_buf.op.axis[0], env.dma_copy) # DRAM->SRAM
s[res_buf].set_scope(env.acc_scope) # SRAM
s[res_buf].pragma(res_buf.op.axis[0], env.alu) # compute
s[res].pragma(res.op.axis[0], env.dma_copy) # SRAM->DRAM
if not use_imm:
s[b_buf].set_scope(env.acc_scope) # SRAM
s[b_buf].pragma(b_buf.op.axis[0], env.dma_copy) # DRAM->SRAM
if not remote:
return
# build
with vta.build_config():
if use_imm:
mod = vta.build(s, [a, res], "ext_dev", env.target_host)
else:
mod = vta.build(s, [a, b, res], "ext_dev", env.target_host)
temp = util.tempdir()
mod.save(temp.relpath("load_act.o"))
remote.upload(temp.relpath("load_act.o"))
f = remote.load_module("load_act.o")
# verify
ctx = remote.ext_dev(0)
a_np = np.random.randint(
-16, 16, size=(m, n, env.BATCH, env.BLOCK_OUT)).astype(a.dtype)
if use_imm:
res_np = np_op(a_np, imm) if np_op else tvm_op(a_np, imm)
else:
b_np = np.random.randint(
-16, 16, size=(m, n, env.BATCH, env.BLOCK_OUT)).astype(b.dtype)
res_np = np_op(a_np, b_np) if np_op else tvm_op(a_np, b_np)
res_np = res_np.astype(res.dtype)
a_nd = tvm.nd.array(a_np, ctx)
res_nd = tvm.nd.array(
np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx)
if use_imm:
f(a_nd, res_nd)
else:
b_nd = tvm.nd.array(b_np, ctx)
f(a_nd, b_nd, res_nd)
np.testing.assert_equal(res_np, res_nd.asnumpy())
check_alu(lambda x, y: x << y, np.left_shift, use_imm=True)
check_alu(tvm.max, np.maximum, use_imm=True)
check_alu(tvm.max, np.maximum)
check_alu(lambda x, y: x + y, use_imm=True)
check_alu(lambda x, y: x + y)
check_alu(lambda x, y: x >> y, np.right_shift, use_imm=True)
vta.testing.run(_run)
if do_verify:
verify()
def test_relu():
"""Test RELU on ALU"""
env = vta.get_env()
m = 8
n = 8
# compute
a = tvm.placeholder(
(m, n, env.BATCH, env.BLOCK_OUT),
name="a",
dtype=env.acc_dtype)
a_buf = tvm.compute(
(m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: a(*i),
"a_buf") # DRAM->SRAM
max_buf = tvm.compute(
(m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: tvm.max(a_buf(*i), 0),
"res_buf") # relu
min_buf = tvm.compute(
(m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: tvm.min(max_buf(*i), (1<<(env.INP_WIDTH-1))-1),
"max_buf") # relu
res = tvm.compute(
(m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: min_buf(*i).astype(env.inp_dtype),
"min_buf") # SRAM->DRAM
# schedule
s = tvm.create_schedule(res.op)
s[a_buf].set_scope(env.acc_scope) # SRAM
s[a_buf].pragma(a_buf.op.axis[0], env.dma_copy) # DRAM->SRAM
s[max_buf].set_scope(env.acc_scope) # SRAM
s[min_buf].set_scope(env.acc_scope) # SRAM
s[max_buf].pragma(max_buf.op.axis[0], env.alu) # compute
s[min_buf].pragma(min_buf.op.axis[0], env.alu) # compute
s[res].pragma(res.op.axis[0], env.dma_copy) # SRAM->DRAM
if print_ir:
print(vta.lower(s, [a, res], simple_mode=True))
def verify():
def _run(env, remote):
m = 8
n = 10
# compute
a = tvm.placeholder(
(m, n, env.BATCH, env.BLOCK_OUT),
name="a",
dtype=env.acc_dtype)
a_buf = tvm.compute(
(m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: a(*i),
"a_buf") # DRAM->SRAM
max_buf = tvm.compute(
(m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: tvm.max(a_buf(*i), 0),
"res_buf") # relu
min_buf = tvm.compute(
(m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: tvm.min(max_buf(*i), (1<<(env.INP_WIDTH-1))-1),
"max_buf") # relu
res = tvm.compute(
(m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: min_buf(*i).astype(env.inp_dtype),
"min_buf") # SRAM->DRAM
# schedule
s = tvm.create_schedule(res.op)
s[a_buf].set_scope(env.acc_scope) # SRAM
s[a_buf].pragma(a_buf.op.axis[0], env.dma_copy) # DRAM->SRAM
s[max_buf].set_scope(env.acc_scope) # SRAM
s[min_buf].set_scope(env.acc_scope) # SRAM
s[max_buf].pragma(max_buf.op.axis[0], env.alu) # compute
s[min_buf].pragma(min_buf.op.axis[0], env.alu) # compute
s[res].pragma(res.op.axis[0], env.dma_copy) # SRAM->DRAM
# build
with vta.build_config(env.DEBUG_DUMP_INSN):
mod = vta.build(s, [a, res], "ext_dev", target)
with vta.build_config():
mod = vta.build(s, [a, res], "ext_dev", env.target_host)
if not remote:
return
temp = util.tempdir()
remote = rpc.connect(host, port)
mod.save(temp.relpath("load_act.o"))
remote.upload(temp.relpath("load_act.o"))
f = remote.load_module("load_act.o")
......@@ -410,56 +406,51 @@ def test_relu():
np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx)
f(a_nd, res_nd)
np.testing.assert_equal(res_np, res_nd.asnumpy())
print("\tFinished verification...")
if do_verify:
verify()
vta.testing.run(_run)
def test_shift_and_scale():
"""Test shift and scale on ALU"""
env = vta.get_env()
m = 8
n = 8
imm_shift = np.random.randint(-10,10)
imm_scale = np.random.randint(1,5)
# compute
a = tvm.placeholder(
(m, n, env.BATCH, env.BLOCK_OUT),
name="a", dtype=env.acc_dtype)
a_buf = tvm.compute(
(m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: a(*i),
"a_buf") # DRAM->SRAM
res_shift = tvm.compute(
(m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: a_buf(*i)+imm_shift,
"res_shift") # compute
res_scale = tvm.compute(
(m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: res_shift(*i)>>imm_scale,
"res_scale") # compute
res = tvm.compute(
(m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: res_scale(*i).astype(env.inp_dtype),
"res") # SRAM->DRAM
# schedule
s = tvm.create_schedule(res.op)
s[a_buf].set_scope(env.acc_scope) # SRAM
s[res_shift].set_scope(env.acc_scope) # SRAM
s[res_scale].set_scope(env.acc_scope) # SRAM
s[a_buf].pragma(a_buf.op.axis[0], env.dma_copy) # DRAM->SRAM
s[res_shift].pragma(res_shift.op.axis[0], env.alu) # compute
s[res_scale].pragma(res_scale.op.axis[0], env.alu) # compute
s[res].pragma(res.op.axis[0], env.dma_copy) # SRAM->DRAM
if print_ir:
print(vta.lower(s, [a, res], simple_mode=True))
def verify():
def _run(env, remote):
m = 2
n = 8
imm_shift = np.random.randint(0,8)
imm_scale = np.random.randint(1,5)
# compute
a = tvm.placeholder(
(m, n, env.BATCH, env.BLOCK_OUT),
name="a", dtype=env.acc_dtype)
a_buf = tvm.compute(
(m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: a(*i),
"a_buf") # DRAM->SRAM
res_shift = tvm.compute(
(m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: a_buf(*i)+imm_shift,
"res_shift") # compute
res_scale = tvm.compute(
(m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: res_shift(*i)>>imm_scale,
"res_scale") # compute
res = tvm.compute(
(m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: res_scale(*i).astype(env.inp_dtype),
"res") # SRAM->DRAM
# schedule
s = tvm.create_schedule(res.op)
s[a_buf].set_scope(env.acc_scope) # SRAM
s[res_shift].set_scope(env.acc_scope) # SRAM
s[res_scale].set_scope(env.acc_scope) # SRAM
s[a_buf].pragma(a_buf.op.axis[0], env.dma_copy) # DRAM->SRAM
s[res_shift].pragma(res_shift.op.axis[0], env.alu) # compute
s[res_scale].pragma(res_scale.op.axis[0], env.alu) # compute
s[res].pragma(res.op.axis[0], env.dma_copy) # SRAM->DRAM
# build
mod = vta.build(s, [a, res], "ext_dev", target)
mod = vta.build(s, [a, res], "ext_dev", env.target_host)
if not remote:
return
temp = util.tempdir()
remote = rpc.connect(host, port)
mod.save(temp.relpath("load_act.o"))
remote.upload(temp.relpath("load_act.o"))
f = remote.load_module("load_act.o")
......@@ -474,31 +465,18 @@ def test_shift_and_scale():
np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), ctx)
f(a_nd, res_nd)
np.testing.assert_equal(res_np, res_nd.asnumpy())
print("\tFinished verification...")
if do_verify:
verify()
vta.testing.run(_run)
if __name__ == "__main__":
print("Load/store test")
test_save_load_out()
print("Padded load test")
test_padded_load()
# print("GEMM test")
# test_gemm()
print("Max immediate")
test_alu(tvm.max, np.maximum, use_imm=True)
print("Max")
test_alu(tvm.max, np.maximum)
print("Add immediate")
test_alu(lambda x, y: x + y, use_imm=True)
print("Add")
test_alu(lambda x, y: x + y)
print("Shift right immediate")
test_alu(lambda x, y: x >> y, np.right_shift, use_imm=True)
print("Shift left immediate")
test_alu(lambda x, y: x << y, np.left_shift, use_imm=True)
print("Relu")
#test_padded_load()
print("GEMM test")
test_gemm()
test_alu()
print("ALU test")
test_relu()
# print("Shift and scale")
# test_shift_and_scale()
print("Shift and scale")
test_shift_and_scale()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment