Commit 96488c11 by Thierry Moreau Committed by Tianqi Chen

[PYTHON, TVM] Python TVM library, unit tests and end to end example

* VTA python library
* Python unit tests
* End to end example with Resnet18
* README instructions
* Bug fixes
parent 56a0dea8
......@@ -55,10 +55,10 @@ endif
all: lib/libvta.$(SHARED_LIBRARY_SUFFIX)
VTA_LIB_SRC = $(wildcard src/*.cc src/tvm/*.cc)
ifeq ($(TARGET), PYNQ_TARGET)
ifeq ($(TARGET), VTA_PYNQ_TARGET)
VTA_LIB_SRC += $(wildcard src/pynq/*.cc)
LDFLAGS += -L/usr/lib -lsds_lib
LDFLAGS += -L/opt/python3.6/lib/python3.6/site-packages/pynq/drivers/ -l:libdma.so
LDFLAGS += -L/opt/python3.6/lib/python3.6/site-packages/pynq/lib/ -l:libdma.so
endif
VTA_LIB_OBJ = $(patsubst %.cc, build/%.o, $(VTA_LIB_SRC))
......@@ -79,7 +79,7 @@ cpplint:
python nnvm/dmlc-core/scripts/lint.py vta cpp include src hardware tests
pylint:
pylint python/vta --rcfile=$(ROOTDIR)/tests/lint/pylintrc
pylint python/tvm_vta --rcfile=$(ROOTDIR)/tests/lint/pylintrc
doc:
doxygen docs/Doxyfile
......
### PYNQ RPC Server for VTA
This guide describes how to setup a Pynq-based RPC server to accelerate deep learning workloads with VTA.
## Pynq Setup
Follow the getting started tutorial for the [Pynq board](http://pynq.readthedocs.io/en/latest/getting_started.html).
* For this RPC setup make sure to go with the *Connect to a Computer* Ethernet setup.
Make sure that you can ssh into your Pynq board successfully:
```bash
ssh xilinx@192.168.2.99
```
When ssh-ing onto the board, the default password for the `xilinx` account is `xilinx`.
For convenience let's go ahead and mount the Pynq board's file system to easily access it and maintain it:
```bash
sshfs xilinx@192.168.2.99:/home/xilinx <mountpoint>
```
## Pynq TVM & VTA installation
On your **host PC**, go to the `<mountpoint>` directory of your Pynq board file system.
```bash
cd <mountpoint>
```
From there, clone the VTA repository:
```bash
git clone git@github.com:uwsaml/vta.git --recursive
```
Next, clone the TVM repository:
```bash
git clone git@github.com:dmlc/tvm.git --recursive
```
TVM is rapidly changing, and to ensure stability, we keep track of working TVM checkpoints.
As of now, the TVM checkpoint `e4c2af9abdcb3c7aabafba8084414d7739c17c4c` is known to work with VTA.
```bash
git checkout e4c2af9abdcb3c7aabafba8084414d7739c17c4c
```
Now, ssh into your **Pynq board** to build the TVM runtime with the following commands:
```bash
ssh xilinx@192.168.2.99 # ssh if you haven't done so
cd ~/tvm
cp make/config.mk .
echo USE_RPC=1 >> config.mk
make runtime -j2
```
## Pynq RPC server setup
We're now ready to build the Pynq RPC server on the Pynq board.
```bash
ssh xilinx@192.168.2.99 # ssh if you haven't done so
cd ~/vta
export TVM_PATH = /home/xilinx/tvm
make
```
The last stage will build the `192.168.2.99:home/xilinx/vta/lib/libvta.so` library file. We are now ready to launch the RPC server on the Pynq. In order to enable the FPGA drivers, we need to run the RPC server with administrator privileges (using `su`, account: `xilinx`, pwd: `xilinx`).
```bash
ssh xilinx@192.168.2.99 # ssh if you haven't done so
cd ~/vta
su
./apps/pynq_rpc/start_rpc_server.sh
```
You should see the following being displayed when starting the RPC server:
```
INFO:root:Load additional library /home/xilinx/vta/lib/libvta.so
INFO:root:RPCServer: bind to 0.0.0.0:9091
```
Note that it should be listening on port `9091`.
To kill the RPC server, just enter the `Ctrl + c` command.
\ No newline at end of file
#!/bin/bash
export PYTHONPATH=${PYTHONPATH}:/home/xilinx/tvm/python
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/opt/python3.6/lib/python3.6/site-packages/pynq/drivers/
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/opt/python3.6/lib/python3.6/site-packages/pynq/lib/
python -m tvm.exec.rpc_server --load-library /home/xilinx/vta/lib/libvta.so
quantize_graph.json
quantize_params.pkl
synset.txt
*.jpg
vta.bit
\ No newline at end of file
# Resnet-18 Example on Pynq-based VTA Design
In order to run this example you'll need to have:
* VTA installed
* TVM installed
* NNVM installed
* A Pynq-based RPC server running
## VTA installation
Clone the VTA repository in the directory of your choosing:
```bash
git clone git@github.com:uwsaml/vta.git --recursive
```
Update your `~/.bashrc` file to include the VTA python libraries in your `PYTHONPATH` (don't forget to source the newly modified `.bashrc` file!):
```bash
export PYTHONPATH=<vta root>/python:${PYTHONPATH}
```
## TVM installation
Clone the TVM repository in the directory of your choosing:
```bash
git clone git@github.com:dmlc/tvm.git --recursive
```
TVM is rapidly changing, and to ensure stability, we keep track of working TVM checkpoints.
As of now, the TVM checkpoint `e4c2af9abdcb3c7aabafba8084414d7739c17c4c` is known to work with VTA.
```bash
git checkout e4c2af9abdcb3c7aabafba8084414d7739c17c4c
```
Before building TVM, copy the `make/config.mk` file into the root TVM directory:
```bash
cd <tvm root>
cp make/config.mk .
```
In the 'config.mk' file sure that:
* `LLVM_CONFIG` points to the llvm-config executable (e.g. `LLVM_CONFIG = /usr/bin/llvm-config-4.0`). You'll need to have llvm4.0 installed or later.
* `USE_RPC` should be set to 1
Launch the compilation, this takes about 5 minutes.
```bash
cd <tvm root>
make -j4
```
Finally update your `~/.bashrc` file to include the TVM python libraries in your `PYTHONPATH` (don't forget to source the newly modified `.bashrc` file!):
```bash
export PYTHONPATH=<tvm root>/python:<tvm root>/topi/python:${PYTHONPATH}
```
## NNVM installation
Clone the NNVM repository from `tqchen` in the directory of your choosing:
```bash
git clone git@github.com:tqchen/nnvm.git --recursive
```
To run this example, we rely on a special branch of NNVM: `qt`:
```bash
cd <nnvm root>
git checkout qt
```
Launch the compilation, this takes less a minute.
```bash
cd <nnvm root>
make -j4
```
Finally update your `~/.bashrc` file to include the NNVM python libraries in your `PYTHONPATH` (don't forget to source the newly modified `.bashrc` file!):
```bash
export PYTHONPATH=<nnvm root>/python:${PYTHONPATH}
```
## Pynq RPC Server Setup
Follow the [Pynq RPC Server Guide](https://github.com/saml/vta/tree/master/apps/pynq_rpc/README.md)
## Running the example
Simply run the following python script:
```bash
python imagenet_predict.py
```
This will run imagenet classification using the ResNet18 architecture on a VTA design that performs 8-bit integer inference, to perform classification on a cat image `cat.jpg`.
The script reports runtime measured on the Pynq board, and the top-1 result category:
```
('x', (1, 3, 224, 224))
Build complete...
('TVM prediction top-1:', 281, 'tabby, tabby cat')
t-cost=0.41906
```
\ No newline at end of file
# some standard imports
import nnvm
import tvm
from nnvm.compiler import graph_attr
import vta
import os
import numpy as np
from PIL import Image
import pickle
import json
import logging
import wget
from tvm.contrib import graph_runtime, rpc, util
factor = 16
host = "pynq"
port = 9091
verbose = False
# only run fpga component, mark non-conv ops as nop
debug_fpga_only = False
# Obtain model and hardware files (they're too large to check-in)
url = "https://homes.cs.washington.edu/~moreau/media/vta/"
TEST_FILE = 'cat.jpg'
CATEG_FILE = 'synset.txt'
RESNET_GRAPH_FILE = 'quantize_graph.json'
RESNET_PARAMS_FILE = 'quantize_params.pkl'
BITSTREAM_FILE = 'vta.bit'
for file in [TEST_FILE, CATEG_FILE, RESNET_GRAPH_FILE, RESNET_PARAMS_FILE, BITSTREAM_FILE]:
if not os.path.isfile(file):
print "Downloading {}".format(file)
wget.download(url+file)
# Program the FPGA remotely
assert tvm.module.enabled("rpc")
remote = rpc.connect(host, port)
remote.upload(BITSTREAM_FILE, BITSTREAM_FILE)
fprogram = remote.get_function("tvm.contrib.vta.init")
fprogram(BITSTREAM_FILE)
if verbose:
logging.basicConfig(level=logging.INFO)
# Change to -device=tcpu to run cpu only inference.
target = "llvm -device=vta"
synset = eval(open(os.path.join(CATEG_FILE)).read())
image = Image.open(os.path.join(TEST_FILE)).resize((224, 224))
def transform_image(image):
image = np.array(image) - np.array([123., 117., 104.])
image /= np.array([58.395, 57.12, 57.375])
image = image.transpose((2, 0, 1))
image = image[np.newaxis, :]
return image
def mark_nop(graph, conv_layer=-1, skip_conv_layer=()):
"""Helper function to mark certain op as nop
Useful to debug performance issues.
"""
jgraph = json.loads(graph.json())
counter = 0
for nid, node in enumerate(jgraph["nodes"]):
op_name = node["op"]
if op_name != "tvm_op":
continue
attrs = node["attrs"]
node_name = node["name"]
func_name = attrs["func_name"]
if func_name.find("quantized_conv2d") != -1:
if conv_layer >= 0:
if counter != conv_layer:
attrs["func_name"] = "__nop"
if counter in skip_conv_layer:
attrs["func_name"] = "__nop"
counter += 1
else:
if conv_layer >= 0:
attrs["func_name"] = "__nop"
attrs["func_name"] = "__nop"
if attrs["func_name"] != "__nop":
print("Run function %s"% func_name)
graph = nnvm.graph.load_json(json.dumps(jgraph))
return graph
x = transform_image(image)
print('x', x.shape)
######################################################################
# now compile the graph
import nnvm.compiler
np.random.seed(0)
sym = nnvm.graph.load_json(
open(os.path.join(RESNET_GRAPH_FILE)).read())
params = pickle.load(
open(os.path.join(RESNET_PARAMS_FILE)))
shape_dict = {"data": x.shape}
dtype_dict = {"data": 'float32'}
shape_dict.update({k: v.shape for k, v in params.items()})
dtype_dict.update({k: str(v.dtype) for k, v in params.items()})
graph = nnvm.graph.create(sym)
graph_attr.set_shape_inputs(sym, shape_dict)
graph_attr.set_dtype_inputs(sym, dtype_dict)
graph = graph.apply("InferShape").apply("InferType")
dtype = "float32"
sym = vta.graph.remove_stochastic(sym)
sym = vta.graph.clean_cast(sym)
sym = vta.graph.clean_conv_fuse(sym)
if "vta" in target:
sym = vta.graph.pack(sym, shape_dict, factor)
graph_attr.set_shape_inputs(sym, shape_dict)
sym = sym.apply("InferShape")
graph_attr.set_dtype_inputs(sym, dtype_dict)
sym = sym.apply("InferType")
with nnvm.compiler.build_config(opt_level=3):
bdict = {}
if "vta" not in target:
bdict = {"add_lower_pass": []}
else:
bdict = {"add_lower_pass": vta.debug_mode(0)}
with tvm.build_config(**bdict):
graph, lib, params = nnvm.compiler.build(
sym, target, shape_dict, dtype_dict,
params=params)
remote = rpc.connect(host, port)
temp = util.tempdir()
lib.save(temp.relpath("graphlib.o"))
remote.upload(temp.relpath("graphlib.o"))
lib = remote.load_module("graphlib.o")
ctx = remote.ext_dev(0) if "vta" in target else remote.cpu(0)
print("Build complete...")
def run_e2e(graph):
"""Running end to end example
"""
if debug_fpga_only:
graph = mark_nop(graph, skip_conv_layer=(0,))
m = graph_runtime.create(graph, lib, ctx)
# set inputs
m.set_input('data', tvm.nd.array(x.astype("float32")))
m.set_input(**params)
# execute
timer = m.module.time_evaluator("run", ctx, number=10)
tcost = timer()
# get outputs
tvm_output = m.get_output(
0,tvm.nd.empty((1000,), dtype, remote.cpu(0)))
top1 = np.argmax(tvm_output.asnumpy())
print('TVM prediction top-1:', top1, synset[top1])
print("t-cost=%g" % tcost.mean)
def run_layer(old_graph):
"""Run a certain layer."""
for layer_id in range(1, 2):
graph = mark_nop(old_graph, layer_id)
m = graph_runtime.create(graph, lib, ctx)
# set inputs
m.set_input('data', tvm.nd.array(x.astype("float32")))
m.set_input(**params)
# execute
timer = m.module.time_evaluator("run", ctx, number=10)
tcost = timer()
print("resnet[%d]: %g\n"% (layer_id, tcost.mean))
run_e2e(graph)
# Directories
ROOTDIR = $(CURDIR)
BUILD_DIR = $(ROOTDIR)/../../build/hardware/vivado
BUILD_DIR = $(ROOTDIR)/../../build/hardware/xilinx
SCRIPT_DIR = $(ROOTDIR)/scripts
SRC_DIR = $(ROOTDIR)/src
SIM_DIR = $(ROOTDIR)/sim
......@@ -64,7 +64,7 @@ bit: ip
cd $(HW_BUILD_PATH) && \
$(VIVADO) -mode tcl -source $(SCRIPT_DIR)/vivado.tcl \
-tclargs $(IP_BUILD_PATH) $(VTA_HW_COMP_THREADS) $(VTA_HW_COMP_CLOCK_FREQ) \
$(VTA_INP_WIDTH) $(VTA_WGT_WIDTH) $(OUT_WIDTH) \
$(VTA_INP_WIDTH) $(VTA_WGT_WIDTH) $(VTA_OUT_WIDTH) \
$(VTA_BATCH) $(VTA_IN_BLOCK) $(VTA_OUT_BLOCK) \
$(VTA_INP_BUFF_SIZE) $(VTA_WGT_BUFF_SIZE) $(VTA_OUT_BUFF_SIZE)
......
# Hardware Compilation Guide
**This hardware compilation guide aims to provide guidance on generating VTA bitstreams with the Xilinx Vivado toolchains.**
As of writing this guide, we recommend using `Vivado 2017.1` since our scripts have been tested to work on this version of the Xilinx toolchains.
# Vivado Toolchains Installation for Pynq Board
## Ubuntu instructions
You’ll need to install Xilinx’ FPGA compilation toolchain, [Vivado HL WebPACK 2017.1](https://www.xilinx.com/products/design-tools/vivado.html), which a license-free version of the Vivado HLx toolchain.
### Obtaining and launching the installation binary
1. Go to the [download webpage](https://www.xilinx.com/support/download.html), and download the Linux Self Extracting Web Installer for Vivado HL 2017.1 WebPACK and Editions.
2. You’ll have to sign in with a Xilinx account. This requires a Xilinx account creation that will take 2 minutes.
3. Complete the Name and Address Verification by clicking “Next”, and you will get the opportunity to download a binary file, called `Xilinx_Vivado_SDK_2017.1_0415_1_Lin64.bin`.
4. Now that the file is downloaded, go to your `Downloads` directory, and change the file permissions so it can be executed:
```bash
chmod u+x Xilinx_Vivado_SDK_2017.1_0415_1_Lin64.bin
```
5. Now you can execute the binary:
```bash
./Xilinx_Vivado_SDK_2017.1_0415_1_Lin64.bin
```
### Installation Steps
At this point you've launched the Vivado 2017.1 Installer GUI program.
1. Click “Next” on the **Welcome** screen.
2. Enter your Xilinx User Credentials under “User Authentication” and select the “Download and Install Now” before clicking “Next” on the **Select Install Type** screen.
3. Accept all terms before clicking on “Next” on the **Accept License Agreements** screen.
4. Select the “Vivado HL WebPACK” before clicking on “Next” on the **Select Edition to Install** screen.
5. Under the **Vivado HL WebPACK** screen, before hitting “Next", check the following options (the rest should be unchecked):
* Design Tools -> Vivado Design Suite -> Vivado
* Design Tools -> Vivado Design Suite -> Vivado High Level Synthesis
* Devices -> Production Services -> SoCs -> Zynq-7000 Series
6. Your total download size should be about 3GB and the amount of Disk Space Required 13GB.
7. Set the installation directory before clicking “Next” on the **Select Destination Directory** screen. It might highlight some paths as red - that’s because the installer doesn’t have the permission to write to that directory. In that case select a path that doesn’t require special write permissions (e.g. in your home directory).
8. Hit “Install” under the **Installation Summary** screen.
9. An **Installation Progress Window** will pop-up to track progress of the download and the installation.
10. This process will take about 20-30 minutes depending on your connection speed.
11. A pop-up window will inform you that the installation completed successfully. Click "OK".
12. Finally the **Vivado License Manager** will launch. Select "Get Free ISE WebPACK, ISE/Vivado IP or PetaLinux License" and click "Connect Now" to complete the license registration process.
### Environment Setup
The last step is to update your `~/.bashrc` with the following line:
```bash
# Xilinx Vivado 2017.1 environment
source <install_path>/Vivado/2017.1/settings64.sh
```
This will include all of the Xilinx binary paths so you can launch compilation scripts from the command line.
Note that this will overwrite the paths to GCC required to build TVM, or NNVM. Therefore, when attempting to build TVM and NNVM, please comment this line from your `~/.bashrc` before re-sourcing it.
# Bitstream compilation
High-level parameters are listed under `<vta root>/make/config.mk` and can be customized by the user.
Bitstream generation is driven by a makefile. All it takes is to enter the following command:
```bash
make
```
The local `Makefile` containts several variables that can be tweaked by the user:
* `VTA_HW_COMP_THREADS`: determines the number of threads used for the Vivado compilation job (default 8 threads).
* `VTA_HW_COMP_CLOCK_FREQ`: determines the target frequency of the VTA design (default 100MHz). It can only be set to 100, 142, 167 or 200MHz.
* `VTA_HW_COMP_TIMING_COMP`: determines how much additional slack must be provided to close timing (default 0ns). Generally when utilization is high for an FPGA design, setting this paramter to 1, 2 or 3 can help close timing.
Once the compilation completes, the generated bitstream can be found under `<vta root>/build/hardware/xilinx/vivado/<design name>/export/vta.bit`.
\ No newline at end of file
......@@ -40,6 +40,8 @@ int main(void) {
status |= alu_test(VTA_ALU_OPCODE_ADD, true, 16, 128, false);
status |= alu_test(VTA_ALU_OPCODE_SHR, true, 16, 128, true);
status |= alu_test(VTA_ALU_OPCODE_SHR, true, 16, 128, false);
status |= alu_test(VTA_ALU_OPCODE_SHL, true, 16, 128, true);
status |= alu_test(VTA_ALU_OPCODE_SHL, true, 16, 128, false);
// Run ALU test (vector-vector operators)
status |= alu_test(VTA_ALU_OPCODE_MIN, false, 16, 128, true);
......
......@@ -13,9 +13,9 @@
void fetch(
uint32_t insn_count,
volatile insn_T *insns,
hls::stream<insn_T> *load_queue,
hls::stream<insn_T> *gemm_queue,
hls::stream<insn_T> *store_queue) {
hls::stream<insn_T> &load_queue,
hls::stream<insn_T> &gemm_queue,
hls::stream<insn_T> &store_queue) {
#pragma HLS INTERFACE s_axilite port = insn_count bundle = CONTROL_BUS
#pragma HLS INTERFACE m_axi port = insns offset = slave bundle = ins_port
#pragma HLS INTERFACE axis port = load_queue
......@@ -32,12 +32,12 @@ void fetch(
memop_id_T memory_type = insn.range(VTA_INSN_MEM_5_1, VTA_INSN_MEM_5_0);
// Push to appropriate instruction queue
if (opcode == VTA_OPCODE_STORE) {
store_queue->write(insn);
store_queue.write(insn);
} else if (opcode == VTA_OPCODE_LOAD &&
(memory_type == VTA_MEM_ID_INP || memory_type == VTA_MEM_ID_WGT)) {
load_queue->write(insn);
load_queue.write(insn);
} else {
gemm_queue->write(insn);
gemm_queue.write(insn);
}
}
}
......@@ -45,9 +45,9 @@ void fetch(
void load(
volatile inp_vec_T *inputs,
volatile wgt_vec_T *weights,
hls::stream<insn_T> *load_queue,
hls::stream<bool> *g2l_dep_queue,
hls::stream<bool> *l2g_dep_queue,
hls::stream<insn_T> &load_queue,
hls::stream<bool> &g2l_dep_queue,
hls::stream<bool> &l2g_dep_queue,
inp_vec_T inp_mem[VTA_INP_BUFF_DEPTH][VTA_BATCH],
wgt_vec_T wgt_mem[VTA_WGT_BUFF_DEPTH][VTA_BLOCK_OUT]
) {
......@@ -61,7 +61,7 @@ void load(
#pragma HLS INTERFACE s_axilite port = return bundle = CONTROL_BUS
// Pop load instruction
insn_T insn = load_queue->read();
insn_T insn = load_queue.read();
// Decode instruction
bool pop_prev_dependence = insn[VTA_INSN_MEM_1];
......@@ -81,7 +81,7 @@ void load(
// Pop dependence token if instructed
if (pop_next_dependence) {
g2l_dep_queue->read();
g2l_dep_queue.read();
}
// Initialize indices
......@@ -170,19 +170,19 @@ void load(
// Push dependence token if instructed
if (push_next_dependence) {
l2g_dep_queue->write(1);
l2g_dep_queue.write(1);
}
}
void compute(
volatile uint32_t *done,
volatile uint32_t &done,
volatile uop_T *uops,
volatile acc_vec_T *biases,
hls::stream<insn_T> *gemm_queue,
hls::stream<bool> *l2g_dep_queue,
hls::stream<bool> *s2g_dep_queue,
hls::stream<bool> *g2l_dep_queue,
hls::stream<bool> *g2s_dep_queue,
hls::stream<insn_T> &gemm_queue,
hls::stream<bool> &l2g_dep_queue,
hls::stream<bool> &s2g_dep_queue,
hls::stream<bool> &g2l_dep_queue,
hls::stream<bool> &g2s_dep_queue,
out_vec_T inp_mem[VTA_INP_BUFF_DEPTH][VTA_BATCH],
wgt_vec_T wgt_mem[VTA_WGT_BUFF_DEPTH][VTA_BLOCK_OUT],
out_vec_T out_mem[VTA_ACC_BUFF_DEPTH][VTA_BATCH]
......@@ -210,7 +210,7 @@ void compute(
#pragma HLS ARRAY_PARTITION variable = acc_mem complete dim = 2
// Pop GEMM instruction
insn_T insn = gemm_queue->read();
insn_T insn = gemm_queue.read();
// Decode
opcode_T opcode = insn.range(VTA_INSN_MEM_0_1, VTA_INSN_MEM_0_0);
......@@ -221,19 +221,19 @@ void compute(
// Pop dependence token if instructed
if (pop_prev_dependence) {
l2g_dep_queue->read();
l2g_dep_queue.read();
}
if (pop_next_dependence) {
s2g_dep_queue->read();
s2g_dep_queue.read();
}
// Perform action based on opcode
if (opcode == VTA_OPCODE_FINISH) {
// Set done flag if we reach a FINISH instruction
*done = 1;
done = 1;
} else if (opcode == VTA_OPCODE_LOAD || opcode == VTA_OPCODE_STORE) {
// Set done value
*done = 0;
done = 0;
// Decode instruction
memop_id_T memory_type = insn.range(VTA_INSN_MEM_5_1, VTA_INSN_MEM_5_0);
......@@ -283,7 +283,7 @@ void compute(
}
} else if (opcode == VTA_OPCODE_GEMM || opcode == VTA_OPCODE_ALU) {
// Set done value
*done = 0;
done = 0;
// Decode
uop_idx_T uop_bgn = insn.range(VTA_INSN_GEM_5_1, VTA_INSN_GEM_5_0);
......@@ -383,6 +383,7 @@ void compute(
} else if (opcode == VTA_OPCODE_ALU) {
// Iterate over micro op
READ_ALU_UOP: for (int upc = uop_bgn; upc < uop_end; upc++) {
#pragma HLS PIPELINE II = 2 rewind
// Read micro-op fields
uop_T uop = uop_mem[upc];
......@@ -405,14 +406,15 @@ void compute(
// Result matrices
acc_vec_T cmp_res[VTA_BATCH];
acc_vec_T add_res[VTA_BATCH];
acc_vec_T shr_res[VTA_BATCH];
acc_vec_T rshr_res[VTA_BATCH];
acc_vec_T lshr_res[VTA_BATCH];
out_vec_T short_cmp_res[VTA_BATCH];
out_vec_T short_add_res[VTA_BATCH];
out_vec_T short_shr_res[VTA_BATCH];
out_vec_T short_rshr_res[VTA_BATCH];
out_vec_T short_lshr_res[VTA_BATCH];
// Perform ALU op over matrix elements
for (int i = 0; i < VTA_BATCH; i++) {
#pragma HLS PIPELINE II = 1 rewind
// Results vector
acc_vec_T res_vec = 0;
for (int b = 0; b < VTA_BLOCK_OUT; b++) {
......@@ -434,12 +436,18 @@ void compute(
add_res[i].range((b + 1) * VTA_ACC_WIDTH - 1, b * VTA_ACC_WIDTH) = add_val;
short_add_res[i].range((b + 1) * VTA_OUT_WIDTH - 1, b * VTA_OUT_WIDTH) =
(inp_T) add_val.range(VTA_OUT_WIDTH - 1, 0);
// Compute Shift
acc_T shr_val =
// Compute Right Shift
acc_T rshr_val =
src_0 >> (aluop_sh_imm_T) src_1.range(VTA_LOG_ACC_WIDTH - 1, 0);
shr_res[i].range((b + 1) * VTA_ACC_WIDTH - 1, b * VTA_ACC_WIDTH) = shr_val;
short_shr_res[i].range((b + 1) * VTA_OUT_WIDTH - 1, b * VTA_OUT_WIDTH) =
(inp_T) shr_val.range(VTA_OUT_WIDTH-1, 0);
rshr_res[i].range((b + 1) * VTA_ACC_WIDTH - 1, b * VTA_ACC_WIDTH) = rshr_val;
short_rshr_res[i].range((b + 1) * VTA_OUT_WIDTH - 1, b * VTA_OUT_WIDTH) =
(inp_T) rshr_val.range(VTA_OUT_WIDTH-1, 0);
// Compute Left Shift
acc_T lshr_val =
src_0 << (aluop_sh_imm_T) src_1.range(VTA_LOG_ACC_WIDTH - 1, 0);
lshr_res[i].range((b + 1) * VTA_ACC_WIDTH - 1, b * VTA_ACC_WIDTH) = lshr_val;
short_lshr_res[i].range((b + 1) * VTA_OUT_WIDTH - 1, b * VTA_OUT_WIDTH) =
(inp_T) lshr_val.range(VTA_OUT_WIDTH-1, 0);
}
// Store to accum memory/store buffer
......@@ -451,8 +459,11 @@ void compute(
acc_mem[dst_idx][i] = add_res[i];
out_mem[dst_idx][i] = short_add_res[i];
} else if (alu_opcode == VTA_ALU_OPCODE_SHR) {
acc_mem[dst_idx][i] = shr_res[i];
out_mem[dst_idx][i] = short_shr_res[i];
acc_mem[dst_idx][i] = rshr_res[i];
out_mem[dst_idx][i] = short_rshr_res[i];
} else if (alu_opcode == VTA_ALU_OPCODE_SHL) {
acc_mem[dst_idx][i] = lshr_res[i];
out_mem[dst_idx][i] = short_lshr_res[i];
}
}
}
......@@ -473,18 +484,18 @@ void compute(
// Push dependence token if instructed
if (push_prev_dependence) {
g2l_dep_queue->write(1);
g2l_dep_queue.write(1);
}
if (push_next_dependence) {
g2s_dep_queue->write(1);
g2s_dep_queue.write(1);
}
}
void store(
volatile out_vec_T *outputs,
hls::stream<insn_T> *store_queue,
hls::stream<bool> *g2s_dep_queue,
hls::stream<bool> *s2g_dep_queue,
hls::stream<insn_T> &store_queue,
hls::stream<bool> &g2s_dep_queue,
hls::stream<bool> &s2g_dep_queue,
out_vec_T out_mem[VTA_ACC_BUFF_DEPTH][VTA_BATCH]
) {
#pragma HLS INTERFACE m_axi port = outputs offset = slave bundle = data_port
......@@ -495,7 +506,7 @@ void store(
#pragma HLS INTERFACE s_axilite port = return bundle = CONTROL_BUS
// Load buffer
insn_T insn = store_queue->read();
insn_T insn = store_queue.read();
// Decode
bool pop_prev_dependence = insn[VTA_INSN_MEM_1];
......@@ -515,7 +526,7 @@ void store(
// Pop dependence token if instructed
if (pop_prev_dependence) {
g2s_dep_queue->read();
g2s_dep_queue.read();
}
// Initialize indices
......@@ -546,7 +557,7 @@ void store(
// Push dependence token if instructed
if (push_prev_dependence) {
s2g_dep_queue->write(1);
s2g_dep_queue.write(1);
}
}
......@@ -589,7 +600,7 @@ void vta(
out_vec_T out_mem[VTA_ACC_BUFF_DEPTH][VTA_BATCH];
// Push all instructions into the queues
fetch(insn_count, insns, &tmp_load_queue, &tmp_gemm_queue, &tmp_store_queue);
fetch(insn_count, insns, tmp_load_queue, tmp_gemm_queue, tmp_store_queue);
// Global done indicator
uint32_t done = 0;
......@@ -621,7 +632,7 @@ void vta(
// Push the instruction in the load queue
load_queue.write(tmp_load);
tmp_load_popped = false;
load(inputs, weights, &load_queue, &g2l_dep_queue, &l2g_dep_queue, inp_mem, wgt_mem);
load(inputs, weights, load_queue, g2l_dep_queue, l2g_dep_queue, inp_mem, wgt_mem);
} else {
// Execution of load stage pending on completion of other stages, so break here...
break;
......@@ -649,8 +660,8 @@ void vta(
// Push the instruction in the load queue
gemm_queue.write(tmp_gemv);
tmp_gemm_popped = false;
compute(&done, uops, biases, &gemm_queue, &l2g_dep_queue, &s2g_dep_queue,
&g2l_dep_queue, &g2s_dep_queue, inp_mem, wgt_mem, out_mem);
compute(done, uops, biases, gemm_queue, l2g_dep_queue, s2g_dep_queue,
g2l_dep_queue, g2s_dep_queue, inp_mem, wgt_mem, out_mem);
} else {
// Execution of load stage pending on completion of other stages,
// so break here...
......@@ -671,7 +682,7 @@ void vta(
// Push the instruction in the load queue
store_queue.write(tmp_store);
tmp_store_popped = false;
store(outputs, &store_queue, &g2s_dep_queue, &s2g_dep_queue, out_mem);
store(outputs, store_queue, g2s_dep_queue, s2g_dep_queue, out_mem);
} else {
// Execution of load stage pending on completion of other stages, so break here...
break;
......
......@@ -107,9 +107,9 @@ typedef ap_uint<VTA_LOG_ACC_WIDTH> aluop_sh_imm_T;
void fetch(
uint32_t insn_count,
volatile insn_T *insns,
hls::stream<insn_T> *load_queue,
hls::stream<insn_T> *gemm_queue,
hls::stream<insn_T> *store_queue);
hls::stream<insn_T> &load_queue,
hls::stream<insn_T> &gemm_queue,
hls::stream<insn_T> &store_queue);
/*!
* \brief Load module.
......@@ -129,9 +129,9 @@ void fetch(
void load(
volatile inp_vec_T *inputs,
volatile wgt_vec_T *weights,
hls::stream<insn_T> *load_queue,
hls::stream<bool> *g2l_dep_queue,
hls::stream<bool> *l2g_dep_queue,
hls::stream<insn_T> &load_queue,
hls::stream<bool> &g2l_dep_queue,
hls::stream<bool> &l2g_dep_queue,
inp_vec_T inp_mem[VTA_INP_BUFF_DEPTH][VTA_BATCH],
wgt_vec_T wgt_mem[VTA_WGT_BUFF_DEPTH][VTA_BLOCK_OUT]);
......@@ -159,14 +159,14 @@ void load(
* \param out_mem Local output SRAM buffer. Write only single port BRAM.
*/
void compute(
volatile uint32_t *done,
volatile uint32_t &done,
volatile uop_T *uops,
volatile acc_vec_T *biases,
hls::stream<insn_T> *gemm_queue,
hls::stream<bool> *l2g_dep_queue,
hls::stream<bool> *s2g_dep_queue,
hls::stream<bool> *g2l_dep_queue,
hls::stream<bool> *g2s_dep_queue,
hls::stream<insn_T> &gemm_queue,
hls::stream<bool> &l2g_dep_queue,
hls::stream<bool> &s2g_dep_queue,
hls::stream<bool> &g2l_dep_queue,
hls::stream<bool> &g2s_dep_queue,
out_vec_T inp_mem[VTA_INP_BUFF_DEPTH][VTA_BATCH],
wgt_vec_T wgt_mem[VTA_WGT_BUFF_DEPTH][VTA_BLOCK_OUT],
out_vec_T out_mem[VTA_ACC_BUFF_DEPTH][VTA_BATCH]);
......@@ -186,9 +186,9 @@ void compute(
*/
void store(
volatile out_vec_T *outputs,
hls::stream<insn_T> *store_queue,
hls::stream<bool> *g2s_dep_queue,
hls::stream<bool> *s2g_dep_queue,
hls::stream<insn_T> &store_queue,
hls::stream<bool> &g2s_dep_queue,
hls::stream<bool> &s2g_dep_queue,
out_vec_T out_mem[VTA_ACC_BUFF_DEPTH][VTA_BATCH]);
/*!
......
......@@ -84,7 +84,7 @@ VTA_ACC_BUFF_SIZE = $(shell echo "$$(( 1 << $(VTA_LOG_ACC_BUFF_SIZE) ))" )
VTA_LOG_OUT_BUFF_SIZE = \
$(shell echo "$$(( $(VTA_LOG_ACC_BUFF_SIZE) + $(VTA_LOG_OUT_WIDTH) - $(VTA_LOG_ACC_WIDTH) ))" )
# Out buffer size in Bytes
VTA_OUT_BUFF_SIZE = $(shell echo "$$(( 1 << $(LOG_OUT_BUFF_SIZE) ))" )
VTA_OUT_BUFF_SIZE = $(shell echo "$$(( 1 << $(VTA_LOG_OUT_BUFF_SIZE) ))" )
# Update ADD_CFLAGS
ADD_CFLAGS += \
......
"""VTA Python package backed by TVM"""
"""TVM VTA runtime"""
from __future__ import absolute_import as _abs
from .hw_spec import *
# version of this package
__version__ = "0.1.0"
from .runtime import SCOPE_INP, SCOPE_OUT, SCOPE_WGT, DMA_COPY, ALU
from .intrin import GEVM, GEMM
from .build import debug_mode
from . import mock, ir_pass
from . import arm_conv2d, vta_conv2d
from . import graph
# pylint: disable=invalid-name,unused-variable,invalid-name
"""Conv2D schedule ported from RASP
Used for CPU conv2d
"""
from __future__ import absolute_import as _abs
import tvm
from topi import tag
from topi.nn.conv2d import conv2d, _get_schedule
from topi.nn.conv2d import SpatialPack, Im2ColPack, Workload
from topi.nn.conv2d import _SCH_TO_DECL_FUNC
from topi.nn.conv2d import _get_workload
from topi.nn.util import infer_pad, infer_stride
from topi import generic
_WORKLOADS = [
Workload('int8', 'int32', 224, 224, 3, 64, 7, 7, 3, 3, 2, 2),
Workload('int8', 'int32', 56, 56, 64, 64, 3, 3, 1, 1, 1, 1),
Workload('int8', 'int32', 56, 56, 64, 64, 1, 1, 0, 0, 1, 1),
Workload('int8', 'int32', 56, 56, 64, 128, 3, 3, 1, 1, 2, 2),
Workload('int8', 'int32', 56, 56, 64, 128, 1, 1, 0, 0, 2, 2),
Workload('int8', 'int32', 28, 28, 128, 128, 3, 3, 1, 1, 1, 1),
Workload('int8', 'int32', 28, 28, 128, 256, 3, 3, 1, 1, 2, 2),
Workload('int8', 'int32', 28, 28, 128, 256, 1, 1, 0, 0, 2, 2),
Workload('int8', 'int32', 14, 14, 256, 256, 3, 3, 1, 1, 1, 1),
Workload('int8', 'int32', 14, 14, 256, 512, 3, 3, 1, 1, 2, 2),
Workload('int8', 'int32', 14, 14, 256, 512, 1, 1, 0, 0, 2, 2),
Workload('int8', 'int32', 7, 7, 512, 512, 3, 3, 1, 1, 1, 1),
]
_SCHEDULES = [
# float32 imagenet
SpatialPack(1, 8, 4, 1, 4, True),
SpatialPack(1, 7, 4, 2, 4, True),
SpatialPack(1, 4, 8, 4, 1, True),
SpatialPack(1, 4, 4, 1, 16, False),
SpatialPack(1, 4, 8, 4, 8, False),
SpatialPack(1, 7, 4, 3, 8, True),
SpatialPack(1, 2, 8, 1, 8, True),
SpatialPack(2, 1, 16, 1, 4, True),
SpatialPack(1, 7, 4, 1, 1, True),
Im2ColPack(7, 4, 1, 16, True),
Im2ColPack(7, 4, 1, 8, False),
Im2ColPack(7, 4, 1, 16, False),
]
@_get_schedule.register(["tcpu", "vta"])
def _schedule_conv2d(wkl):
if wkl not in _WORKLOADS:
raise ValueError("no schedule for such workload: {}".format(wkl))
idx = _WORKLOADS.index(wkl)
sch = _SCHEDULES[idx]
return sch
@conv2d.register(["tcpu", "vta"])
def _declaration_conv2d(data, kernel, stride, padding, layout, out_dtype):
assert layout == 'NCHW', "only support NCHW convolution on tcpu"
assert data.shape[0].value == 1, "only support batch size=1 convolution on tcpu"
wkl = _get_workload(data, kernel, stride, padding, out_dtype)
sch = _get_schedule(wkl)
return _SCH_TO_DECL_FUNC[type(sch)](data, kernel, stride, padding, out_dtype)
def _schedule_spatial_conv2d(s, data, data_pad, data_vec,
kernel, kernel_vec,
conv_out, output, last):
# no stride and padding info here
padding = infer_pad(data, data_pad)
if data_pad is None:
stride = infer_stride(data, kernel, output)
else:
stride = infer_stride(data_pad, kernel, output)
wkl = _get_workload(data, kernel, stride, padding, output.dtype)
sch = _get_schedule(wkl)
H, W = wkl.height, wkl.width
CI, CO = wkl.in_filter, wkl.out_filter
HK, WK = wkl.hkernel, wkl.wkernel
HPAD, WPAD = wkl.hpad, wkl.wpad
HSTR, WSTR = wkl.hstride, wkl.wstride
HCAT, WCAT = HK-1, WK-1
DOPAD = (HPAD != 0 and WPAD != 0)
VH = sch.vh
VW = sch.vw
VC = sch.vc
UNROLL = sch.unroll
A, B, C = data, kernel, last
A0, A1 = data_pad, data_vec
B0 = kernel_vec
C0, C1 = conv_out, output
CC = s.cache_write(C0, "global")
_, co, oh, ow, vh, vw, vc = s[C0].op.axis
if UNROLL:
s[C0].unroll(vw)
s[C0].vectorize(vc)
s[CC].compute_at(s[C0], ow)
_, co, oh, ow, vh, vw, vc = s[CC].op.axis
ci, dh, dw = s[CC].op.reduce_axis
s[CC].reorder(ci, dh, vh, dw, vw, vc)
if UNROLL:
s[CC].unroll(vw)
s[CC].vectorize(vc)
##### Schedule A
if DOPAD:
s[A0].compute_inline()
_, h, _, _, _, _ = s[A1].op.axis
if sch.ba == 1:
oaxis = h
paxis = h
else:
oh, ih = s[A1].split(h, sch.ba)
oaxis = oh
paxis = ih
s[A1].parallel(paxis)
s[A1].pragma(oaxis, "parallel_launch_point")
s[A1].pragma(paxis, "parallel_stride_pattern")
s[A1].pragma(oaxis, "parallel_barrier_when_finish")
##### Schedule B
co, _, _, _, _ = s[B0].op.axis
if sch.bc == 1:
oaxis = co
paxis = co
else:
oco, ico = s[B0].split(co, sch.bc)
oaxis = oco
paxis = ico
s[B0].parallel(paxis)
s[B0].pragma(oaxis, "parallel_launch_point")
s[B0].pragma(paxis, "parallel_stride_pattern")
s[B0].pragma(oaxis, "parallel_barrier_when_finish")
##### Schedule C
n, co, h, w = s[C].op.axis
co, vc = s[C].split(co, VC)
oh, ow, vh, vw = s[C].tile(h, w, VH, VW)
s[C].reorder(n, co, oh, ow, vh, vw, vc)
if C != C1:
s[C1].compute_inline()
s[C0].compute_at(s[C], ow)
if sch.bc == 1:
oaxis = co
paxis = co
else:
oco, ico = s[C].split(co, sch.bc)
oaxis = oco
paxis = ico
s[C].parallel(paxis)
s[C].pragma(oaxis, "parallel_launch_point")
s[C].pragma(paxis, "parallel_stride_pattern")
s[C].pragma(oaxis, "parallel_barrier_when_finish")
return s
def _schedule_im2col_conv2d(s, data, data_pad, data_col, data_vec,
kernel, kernel_vec,
conv_out, output, last):
# no stride and padding info here
padding = infer_pad(data, data_pad)
if data_pad is None:
stride = infer_stride(data, kernel, output)
else:
stride = infer_stride(data_pad, kernel, output)
wkl = _get_workload(data, kernel, stride, padding, output.dtype)
sch = _get_schedule(wkl)
H, W = wkl.height, wkl.width
CI = wkl.in_filter
CO = wkl.out_filter
HK, WK = wkl.hkernel, wkl.wkernel
HPAD, WPAD = wkl.hpad, wkl.wpad
HSTR, WSTR = wkl.hstride, wkl.wstride
HCAT, WCAT = HK-1, WK-1
DOPAD = (HPAD != 0 and WPAD != 0)
P = sch.vp
Q = sch.vq
UNROLL = sch.unroll
A, B, C = data, kernel, last
A0, A1, A2 = data_pad, data_col, data_vec
B0 = kernel_vec
C0, C1 = conv_out, output
CC = s.cache_write(C0, "global")
AA = s.cache_read(A2, "global", [CC])
BB = s.cache_read(B0, "global", [CC])
##### Schedule CC
_, co, im, vim, vco = s[C0].op.axis
s[C0].unroll(vim)
s[C0].vectorize(vco)
s[CC].compute_at(s[C0], im)
_, co, im, vim, vco = s[CC].op.axis
ci, hk, wk = s[CC].op.reduce_axis
s[CC].reorder(ci, hk, wk, vim, vco)
s[CC].unroll(vim)
s[CC].vectorize(vco)
# s[CC].unroll(ccr)
### Schedule C
_, co, h, w = s[C].op.axis
im = s[C].fuse(h, w)
im, vim = s[C].split(im, P)
co, vco = s[C].split(co, Q)
s[C].reorder(co, im, vim, vco)
if sch.bc == 1:
oaxis = co
paxis = co
else:
oco, ico = s[C].split(co, sch.bc)
oaxis = oco
paxis = ico
s[C].parallel(paxis)
s[C].pragma(oaxis, "parallel_launch_point")
s[C].pragma(paxis, "parallel_stride_pattern")
s[C].pragma(oaxis, "parallel_barrier_when_finish")
if C1 != C:
s[C1].compute_inline()
s[C0].compute_at(s[C], paxis)
##### Schedule A
if DOPAD:
s[A0].compute_inline()
s[A1].compute_inline()
s[AA].compute_at(s[CC], wk)
s[AA].unroll(AA.op.axis[4])
_, im, _, _, _, _ = s[A2].op.axis
if sch.ba == 1:
oaxis = im
paxis = im
else:
oim, iim = s[A2].split(im, sch.ba)
oaxis = oim
paxis = iim
s[A2].parallel(paxis)
s[A2].pragma(oaxis, "parallel_launch_point")
s[A2].pragma(paxis, "parallel_stride_pattern")
s[A2].pragma(oaxis, "parallel_barrier_when_finish")
##### Schedule B
s[BB].compute_at(s[CC], wk)
s[BB].vectorize(BB.op.axis[4])
co, _, _, _, _ = s[B0].op.axis
if sch.bc == 1:
oaxis = co
paxis = co
else:
oco, ico = s[B0].split(co, sch.bc)
oaxis = oco
paxis = ico
s[B0].parallel(paxis)
s[B0].pragma(oaxis, "parallel_launch_point")
s[B0].pragma(paxis, "parallel_stride_pattern")
s[B0].pragma(oaxis, "parallel_barrier_when_finish")
return s
@generic.schedule_conv2d_nchw.register(["tcpu", "vta"])
def schedule_conv2d(outs):
"""Create schedule for tensors"""
s = tvm.create_schedule([x.op for x in outs])
def traverse(op):
"""Traverse operators from computation graph"""
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(op.tag):
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors:
traverse(tensor.op)
if 'spatial_conv_output' in op.tag:
output = op.output(0)
conv_out = op.input_tensors[0]
kernel_vec = conv_out.op.input_tensors[1]
kernel = kernel_vec.op.input_tensors[0]
data_vec = conv_out.op.input_tensors[0]
data = data_vec.op.input_tensors[0]
data_pad = None
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
data_pad = data
data = data_pad.op.input_tensors[0]
_schedule_spatial_conv2d(s, data, data_pad, data_vec,
kernel, kernel_vec,
conv_out, output, outs[0])
if 'im2col_conv_output' in op.tag:
output = op.output(0)
conv_out = op.input_tensors[0]
kernel_vec = conv_out.op.input_tensors[1]
kernel = kernel_vec.op.input_tensors[0]
data_vec = conv_out.op.input_tensors[0]
data_col = data_vec.op.input_tensors[0]
data = data_col.op.input_tensors[0]
data_pad = None
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
data_pad = data
data = data_pad.op.input_tensors[0]
_schedule_im2col_conv2d(s, data, data_pad, data_col, data_vec,
kernel, kernel_vec,
conv_out, output, outs[0])
traverse(outs[0].op)
return s
"""Runtime function related hooks"""
from __future__ import absolute_import as _abs
import tvm
from tvm import build_module
from . runtime import CB_HANDLE
from . import ir_pass
def lift_coproc_scope(x):
x = ir_pass.lift_alloc_to_scope_begin(x)
x = tvm.ir_pass.LiftAttrScope(x, "coproc_scope", False)
return x
def early_rewrite(stmt):
try:
return tvm.ir_pass.StorageRewrite(stmt)
except tvm.TVMError:
return stmt
def debug_mode(debug_flag):
"""Pass to enable vta debug mode.
Parameters
----------
debug_flag : int
The dbeug flag to be passed.
Returns
-------
pass_list: list of function
The pass to set to build_config(add_lower_pass=vta.debug_mode(mode))
"""
def add_debug(stmt):
debug = tvm.call_extern(
"int32", "VTASetDebugMode", CB_HANDLE, debug_flag)
return tvm.make.stmt_seq(debug, stmt)
pass_list = [(1, ir_pass.inject_dma_intrin),
(1, ir_pass.inject_skip_copy),
(1, ir_pass.annotate_alu_coproc_scope),
(1, lambda x: tvm.ir_pass.LiftAttrScope(x, "coproc_uop_scope", True)),
(1, lift_coproc_scope),
(1, ir_pass.inject_coproc_sync),
(1, early_rewrite)]
if debug_flag:
pass_list.append((1, add_debug))
pass_list.append((2, ir_pass.inject_alu_intrin))
pass_list.append((3, ir_pass.fold_uop_loop))
pass_list.append((3, ir_pass.cpu_access_rewrite))
return pass_list
# Add a lower pass to sync uop
build_module.BuildConfig.current.add_lower_pass = debug_mode(0)
"""Graph transformation specific to accelerator.
This module provide specific NNVM graph transformations
to transform a generic NNVM graph to a version that can
be executed on accelerator.
"""
import nnvm
from nnvm.compiler import graph_attr, graph_util
def _pack_channel(data, dshape, factor):
"""Pack the data channel dimension.
"""
assert dshape[1] % factor == 0
data = nnvm.sym.reshape(data,
shape=(dshape[0], dshape[1] // factor,
factor, dshape[2], dshape[3]))
data = nnvm.sym.transpose(
data, axes=(0, 1, 3, 4, 2))
return data
def _unpack_channel(data, old_shape):
"""Unpack the data channel dimension.
"""
data = nnvm.sym.transpose(data, axes=(0, 1, 4, 2, 3))
data = nnvm.sym.reshape(data, shape=old_shape)
return data
def _pack_weight(data, dshape, factor):
"""Pack the weight into packed format.
"""
assert len(dshape) == 4
assert dshape[0] % factor == 0
assert dshape[1] % factor == 0
data = nnvm.sym.reshape(data,
shape=(dshape[0] // factor, factor,
dshape[1] // factor, factor,
dshape[2], dshape[3]))
data = nnvm.sym.transpose(
data, axes=(0, 2, 4, 5, 1, 3))
return data
def _pack_bias(data, dshape, factor):
"""Pack the bias parameter.
"""
assert len(dshape) == 3
assert dshape[0] % factor == 0
data = nnvm.sym.reshape(data,
shape=(dshape[0] // factor,
factor, dshape[1], dshape[2]))
data = nnvm.sym.transpose(
data, axes=(0, 2, 3, 1))
return data
def _get_shape(sym, shape_dict):
"""Get the shape of a node.
"""
return graph_util.infer_shape(
nnvm.graph.create(sym), **shape_dict)[1][0]
def remove_stochastic(graph):
"""
Replace stochastic rounding and shift with determinstic version.
Parameters
----------
graph : Graph
The input graph
Returns
-------
replaced_graph : Graph
The final replaced graph.
"""
gidx = graph.index
node_map = {}
for nid, node in enumerate(gidx.nodes):
children = [node_map[e[0]] for e in node["inputs"]]
attrs = node.get("attrs", {})
node_name = node["name"]
op_name = node["op"]
get_clone = lambda c, o_n, n_n, a: getattr(nnvm.symbol, o_n)(
*c, name=n_n, **a)
if op_name == "null":
new_node = nnvm.symbol.Variable(node_name)
elif op_name == "stochastic_round":
new_node = children[0]
elif op_name == "noise_lshift":
new_node = nnvm.symbol.left_shift(
children[0], **attrs)
else:
new_node = get_clone(children, op_name, node_name, attrs)
node_map[nid] = new_node
assert len(graph.index.output_entries) == 1
ret = node_map[graph.index.output_entries[0][0]]
ret = nnvm.graph.create(ret)
return ret
def clean_conv_fuse(graph):
"""Cleanup the convolution's later fuse stages
Parameters
----------
graph : Graph
Input graph
Returns
-------
graph : Graph
Optimized graph
"""
def _clean_entry(entry):
node, flag = entry
if flag:
node = nnvm.symbol.clip(node, a_max=127, a_min=-127)
node = nnvm.symbol.cast(node, dtype="int8")
# Use identity as a hint to block conv2d schedules
node = nnvm.symbol.identity(node)
flag = False
return node, flag
gidx = graph.index
ref_count = {}
# count reference of each node
for nid, node in enumerate(gidx.nodes):
ref_count[nid] = 0
for elem in node["inputs"]:
ref_count[elem[0]] += 1
# construction remap
# entry_id->(new_node, conv_fuse)
# need_fold: bool indicates if we need fold
node_map = {}
for nid, node in enumerate(gidx.nodes):
children = [node_map[e[0]] for e in node["inputs"]]
attrs = node.get("attrs", {})
node_name = node["name"]
op_name = node["op"]
get_clone = lambda c, o_n, n_n, a: getattr(nnvm.symbol, o_n)(
*c, name=n_n, **a)
new_entry = None
if op_name == "null":
new_entry = (nnvm.symbol.Variable(node_name), False)
elif op_name in ("cast", "clip"):
if children[0][1]:
new_entry = children[0]
else:
new_entry = (
get_clone([children[0][0]], op_name, node_name, attrs),
False)
elif op_name == "quantized_conv2d":
data, weight = children
data = _clean_entry(data)
new_node = nnvm.sym.quantized_conv2d(
data[0], weight[0], name=node_name, **attrs)
new_entry = (new_node, True)
elif op_name in ("left_shift", "right_shift", "relu"):
new_entry = (
get_clone([children[0][0]], op_name, node_name, attrs),
children[0][1])
elif op_name in ("broadcast_add", "broadcast_mul"):
rhs = children[1][0]
lhs, _ = _clean_entry(children[0])
lhs = nnvm.sym.cast(lhs, dtype="int32")
rhs = nnvm.sym.cast(rhs, dtype="int32")
new_entry = (
get_clone([lhs, rhs], op_name, node_name, attrs),
False)
if new_entry is None:
inputs = [_clean_entry(x) for x in children]
new_entry = (
get_clone([x[0] for x in inputs], op_name, node_name, attrs),
False)
if ref_count[nid] > 1:
new_entry = _clean_entry(new_entry)
node_map[nid] = new_entry
assert len(graph.index.output_entries) == 1
ret = node_map[graph.index.output_entries[0][0]][0]
ret = nnvm.graph.create(ret)
return ret
def clean_cast(graph):
"""
Move the casts to early part of graph,
remove uncessary clip operations when possible.
"""
gidx = graph.index
node_map = {}
def _clean_cast(node, target_type):
op_name = node.attr("op_name")
if op_name == "cast":
return _clean_cast(node.get_children(), target_type)
elif op_name == "relu":
data, has_clip = _clean_cast(
node.get_children(), target_type)
data = nnvm.sym.relu(data)
return data, has_clip
return nnvm.sym.cast(node, dtype=target_type), False
for nid, node in enumerate(gidx.nodes):
children = [node_map[e[0]] for e in node["inputs"]]
attrs = node.get("attrs", {})
node_name = node["name"]
op_name = node["op"]
get_clone = lambda c, o_n, n_n, a: getattr(nnvm.symbol, o_n)(
*c, name=n_n, **a)
if op_name == "null":
new_node = nnvm.symbol.Variable(node_name)
elif op_name == "cast":
dtype = attrs["dtype"]
new_node, _ = _clean_cast(children[0], dtype)
elif op_name == "quantized_conv2d":
data, weight = children
data, _ = _clean_cast(data, "int8")
weight, _ = _clean_cast(weight, "int8")
new_node = nnvm.sym.quantized_conv2d(
data, weight, name=node_name, **attrs)
elif op_name == "elemwise_add":
lhs, rhs = children
rhs = nnvm.sym.cast(rhs, dtype="int8")
new_node = nnvm.sym.elemwise_add(lhs, rhs)
else:
new_node = get_clone(children, op_name, node_name, attrs)
node_map[nid] = new_node
assert len(graph.index.output_entries) == 1
ret = node_map[graph.index.output_entries[0][0]]
ret = nnvm.graph.create(ret)
return ret
def pack(graph, shape_dict, factor, start_name=None):
"""Pack the graph into channel packed format.
Parameters
----------
graph : Graph
The input graph.
shape_dict : dict of str to shapex
The input shape.
factor : int
The packing factor
start_name: str, optional
Start name start packing from certain known node.
Returns
-------
graph : Graph
The transformed graph.
"""
graph = graph_attr.set_shape_inputs(graph, shape_dict)
graph = graph.apply("InferShape")
shape = graph.json_attr("shape")
gidx = graph.index
node_map = {}
dset = set()
counter = 0
start_pack = False
for nid, node in enumerate(gidx.nodes):
children = [node_map[e[0]] for e in node["inputs"]]
ishape = [shape[gidx.entry_id(e)] for e in node["inputs"]]
oshape = shape[gidx.entry_id(nid, 0)]
attrs = node.get("attrs", {})
node_name = node["name"]
op_name = node["op"]
get_clone = lambda c, o_n, n_n, a: getattr(nnvm.symbol, o_n)(
*c, name=n_n, **a)
if op_name == "null":
new_node = nnvm.symbol.Variable(node_name)
if start_name and node_name == start_name:
start_pack = True
new_node = _pack_channel(new_node, oshape, factor)
elif op_name == "max_pool2d":
assert not start_pack
start_pack = True
new_node = get_clone(children, op_name, node_name, attrs)
new_node = _pack_channel(new_node, oshape, factor)
elif op_name == "global_avg_pool2d":
if start_pack:
start_pack = False
children[0] = _unpack_channel(children[0], ishape[0])
new_node = getattr(nnvm.symbol, op_name)(
*children, name=node_name, **attrs)
else:
new_node = get_clone(children, op_name, node_name, attrs)
elif op_name == "quantized_conv2d":
if start_pack:
attrs["pack_channel"] = str(factor)
data, weight = children
weight = _pack_weight(weight, ishape[1], factor)
new_node = nnvm.sym.quantized_conv2d(
data, weight, name=node_name, **attrs)
elif counter == 1:
attrs["pack_channel"] = str(factor)
data, weight = children
data = _pack_channel(data, ishape[0], factor)
weight = _pack_weight(weight, ishape[1], factor)
new_node = nnvm.sym.quantized_conv2d(
data, weight, name=node_name, **attrs)
new_node = _unpack_channel(new_node, oshape)
counter = counter + 1
else:
new_node = get_clone(children, op_name, node_name, attrs)
elif op_name.startswith("broadcast"):
if start_pack:
assert len(ishape[1]) == 3
children[1] = _pack_bias(children[1], ishape[1], factor)
new_node = getattr(nnvm.symbol, op_name)(
*children, name=node_name, **attrs)
else:
new_node = get_clone(children, op_name, node_name, attrs)
elif op_name.startswith("elementwise_add"):
new_node = get_clone(children, op_name, node_name, attrs)
else:
new_node = get_clone(children, op_name, node_name, attrs)
dset.add(op_name)
node_map[nid] = new_node
assert len(graph.index.output_entries) == 1
ret = node_map[graph.index.output_entries[0][0]]
if start_pack:
oshape = shape[graph.index.output_entries[0][0]]
ret = _unpack_channel(ret, oshape)
graph = nnvm.graph.create(ret)
graph = graph_attr.set_shape_inputs(graph, shape_dict)
graph = graph.apply("InferShape")
return graph
"""VTA configuration constants (should match hw_spec.h"""
from __future__ import absolute_import as _abs
# The Constants
VTA_WGT_WIDTH = 8
VTA_INP_WIDTH = VTA_WGT_WIDTH
VTA_OUT_WIDTH = 32
# Dimensions of the GEMM unit
# (BATCH,BLOCK_IN) x (BLOCK_IN,BLOCK_OUT)
VTA_BATCH = 1
VTA_BLOCK_IN = 16
VTA_BLOCK_OUT = 16
# log-2 On-chip wgt buffer size in Bytes
VTA_LOG_WGT_BUFF_SIZE = 15
# log-2 On-chip input buffer size in Bytes
VTA_LOG_INP_BUFF_SIZE = 15
# log-2 On-chip output buffer size in Bytes
VTA_LOG_OUT_BUFF_SIZE = 17
# On-chip wgt buffer size in Bytes
VTA_WGT_BUFF_SIZE = 1 << VTA_LOG_WGT_BUFF_SIZE
# Input buffer size
VTA_INP_BUFF_SIZE = 1 << VTA_LOG_INP_BUFF_SIZE
# Output buffer size.
VTA_OUT_BUFF_SIZE = 1 << VTA_LOG_OUT_BUFF_SIZE
# Number of bytes per buffer
VTA_INP_ELEM_BYTES = (VTA_BATCH*VTA_BLOCK_IN*VTA_INP_WIDTH//8)
VTA_WGT_ELEM_BYTES = (VTA_BLOCK_OUT*VTA_BLOCK_IN*VTA_WGT_WIDTH//8)
VTA_OUT_ELEM_BYTES = (VTA_BATCH*VTA_BLOCK_OUT*VTA_OUT_WIDTH//8)
# Maximum external buffer size in bytes
VTA_MAX_XFER = 1 << 22
# Number of elements
VTA_INP_BUFF_DEPTH = VTA_INP_BUFF_SIZE//VTA_INP_ELEM_BYTES
VTA_WGT_BUFF_DEPTH = VTA_WGT_BUFF_SIZE//VTA_WGT_ELEM_BYTES
VTA_OUT_BUFF_DEPTH = VTA_OUT_BUFF_SIZE//VTA_OUT_ELEM_BYTES
# Memory id for DMA
VTA_MEM_ID_UOP = 0
VTA_MEM_ID_WGT = 1
VTA_MEM_ID_INP = 2
VTA_MEM_ID_ACC = 3
VTA_MEM_ID_OUT = 4
# VTA ALU Opcodes
VTA_ALU_OPCODE_MIN = 0
VTA_ALU_OPCODE_MAX = 1
VTA_ALU_OPCODE_ADD = 2
VTA_ALU_OPCODE_SUB = 3
VTA_ALU_OPCODE_MUL = 4
VTA_ALU_OPCODE_SHL = 5
VTA_ALU_OPCODE_SHR = 6
VTA_ALU_OPCODE_UNSET = 7
# Task queue id (pipeline stage)
VTA_QID_LOAD_INP = 1
VTA_QID_LOAD_WGT = 1
VTA_QID_LOAD_OUT = 2
VTA_QID_STORE_OUT = 3
VTA_QID_COMPUTE = 2
VTA_QID_STORE_INP = 3
# Debug flags
DEBUG_DUMP_INSN = (1 << 1)
DEBUG_DUMP_UOP = (1 << 2)
DEBUG_SKIP_READ_BARRIER = (1 << 3)
DEBUG_SKIP_WRITE_BARRIER = (1 << 4)
\ No newline at end of file
"""VTA related intrinsics"""
from __future__ import absolute_import as _abs
import tvm
from . import hw_spec as spec
from .runtime import VTA_AXIS, VTA_PUSH_UOP, get_task_qid
from .runtime import SCOPE_OUT, SCOPE_INP, SCOPE_WGT
# The memory information for the compiler
@tvm.register_func("tvm.info.mem.%s" % SCOPE_INP)
def mem_info_inp_buffer():
return tvm.make.node("MemoryInfo",
unit_bits=spec.VTA_INP_ELEM_BYTES * 8,
max_simd_bits=spec.VTA_INP_ELEM_BYTES * 8,
max_num_bits=spec.VTA_INP_BUFF_SIZE * 8,
head_address=None)
@tvm.register_func("tvm.info.mem.%s" % SCOPE_WGT)
def mem_info_wgt_buffer():
return tvm.make.node("MemoryInfo",
unit_bits=spec.VTA_WGT_ELEM_BYTES * 8,
max_simd_bits=spec.VTA_WGT_ELEM_BYTES * 8,
max_num_bits=spec.VTA_WGT_BUFF_SIZE * 8,
head_address=None)
@tvm.register_func("tvm.info.mem.%s" % SCOPE_OUT)
def mem_info_out_buffer():
return tvm.make.node("MemoryInfo",
unit_bits=spec.VTA_OUT_ELEM_BYTES * 8,
max_simd_bits=spec.VTA_OUT_ELEM_BYTES * 8,
max_num_bits=spec.VTA_OUT_BUFF_SIZE * 8,
head_address=None)
def intrin_gevm(mock=False):
"""Vector-matrix multiply intrinsic"""
wgt_lanes = spec.VTA_WGT_ELEM_BYTES * 8 // spec.VTA_WGT_WIDTH
assert wgt_lanes == spec.VTA_BLOCK_OUT * spec.VTA_BLOCK_IN
wgt_shape = (spec.VTA_BLOCK_OUT, spec.VTA_BLOCK_IN)
assert wgt_shape[0] * wgt_shape[1] == wgt_lanes
inp_lanes = spec.VTA_INP_ELEM_BYTES * 8 // spec.VTA_INP_WIDTH
out_lanes = spec.VTA_OUT_ELEM_BYTES * 8 // spec.VTA_OUT_WIDTH
wgt = tvm.placeholder((wgt_shape[0], wgt_shape[1]),
dtype="int%d" % spec.VTA_WGT_WIDTH,
name=SCOPE_WGT)
inp = tvm.placeholder((wgt_shape[1], ),
dtype="int%d" % spec.VTA_INP_WIDTH,
name=SCOPE_INP)
k = tvm.reduce_axis((0, wgt_shape[1]), name="k")
out_dtype = "int%d" % spec.VTA_OUT_WIDTH
out = tvm.compute((wgt_shape[0],),
lambda i: tvm.sum(inp[k].astype(out_dtype) *
wgt[i, k].astype(out_dtype),
axis=[k]),
name="out")
wgt_layout = tvm.decl_buffer(
wgt.shape, wgt.dtype, SCOPE_WGT,
scope=SCOPE_WGT, offset_factor=wgt_lanes, data_alignment=wgt_lanes)
inp_layout = tvm.decl_buffer(
inp.shape, inp.dtype, SCOPE_INP,
scope=SCOPE_INP, offset_factor=inp_lanes, data_alignment=inp_lanes)
out_layout = tvm.decl_buffer(
out.shape, out.dtype, SCOPE_OUT,
scope=SCOPE_OUT, offset_factor=out_lanes, data_alignment=out_lanes)
def intrin_func(ins, outs):
"""Vector-matrix multiply intrinsic function"""
dinp, dwgt = ins
dout = outs[0]
def instr(index):
"""Generate vector-matrix multiply VTA instruction"""
irb = tvm.ir_builder.create()
irb.scope_attr(VTA_AXIS, "coproc_scope", get_task_qid(spec.VTA_QID_COMPUTE))
irb.scope_attr(VTA_AXIS, "coproc_uop_scope", VTA_PUSH_UOP)
if index == 0 or index == 2:
irb.emit(tvm.call_extern(
"int32", "VTAUopPush",
0, 0,
dout.access_ptr("rw", "int32"),
dinp.access_ptr("r", "int32"),
dwgt.access_ptr("r", "int32"),
0, 0, 0))
else:
irb.emit(tvm.call_extern(
"int32", "VTAUopPush",
0, 1,
dout.access_ptr("rw", "int32"),
0,
0,
0, 0, 0))
return irb.get()
# return a triple of normal-set, reset, update
nop = tvm.make.Evaluate(0)
if mock:
return (nop, nop, nop)
return (instr(0), instr(1), instr(2))
return tvm.decl_tensor_intrin(out.op, intrin_func,
name="GEVM",
binds={inp: inp_layout,
wgt: wgt_layout,
out: out_layout})
def intrin_gemm(mock=False):
"""Matrix-matrix multiply intrinsic"""
wgt_lanes = spec.VTA_WGT_ELEM_BYTES * 8 // spec.VTA_WGT_WIDTH
assert wgt_lanes == spec.VTA_BLOCK_OUT * spec.VTA_BLOCK_IN
wgt_shape = (spec.VTA_BLOCK_OUT, spec.VTA_BLOCK_IN)
assert wgt_shape[0] * wgt_shape[1] == wgt_lanes
inp_lanes = spec.VTA_INP_ELEM_BYTES * 8 // spec.VTA_INP_WIDTH
assert inp_lanes == spec.VTA_BATCH * spec.VTA_BLOCK_IN
inp_shape = (spec.VTA_BATCH, spec.VTA_BLOCK_IN)
assert inp_shape[0] * inp_shape[1] == inp_lanes
out_lanes = spec.VTA_OUT_ELEM_BYTES * 8 // spec.VTA_OUT_WIDTH
assert out_lanes == spec.VTA_BATCH * spec.VTA_BLOCK_OUT
out_shape = (spec.VTA_BATCH, spec.VTA_BLOCK_OUT)
assert out_shape[0] * out_shape[1] == out_lanes
wgt = tvm.placeholder((wgt_shape[0], wgt_shape[1]),
dtype="int%d" % spec.VTA_WGT_WIDTH,
name=SCOPE_WGT)
inp = tvm.placeholder((inp_shape[0], inp_shape[1]),
dtype="int%d" % spec.VTA_INP_WIDTH,
name=SCOPE_INP)
k = tvm.reduce_axis((0, wgt_shape[1]), name="k")
out_dtype = "int%d" % spec.VTA_OUT_WIDTH
out = tvm.compute((out_shape[0], out_shape[1]),
lambda i, j: tvm.sum(inp[i, k].astype(out_dtype) *
wgt[j, k].astype(out_dtype),
axis=[k]),
name="out")
wgt_layout = tvm.decl_buffer(
wgt.shape, wgt.dtype, SCOPE_WGT,
scope=SCOPE_WGT, offset_factor=wgt_lanes, data_alignment=wgt_lanes)
inp_layout = tvm.decl_buffer(
inp.shape, inp.dtype, SCOPE_INP,
scope=SCOPE_INP, offset_factor=inp_lanes, data_alignment=inp_lanes)
out_layout = tvm.decl_buffer(
out.shape, out.dtype, SCOPE_OUT,
scope=SCOPE_OUT, offset_factor=out_lanes, data_alignment=out_lanes)
def intrin_func(ins, outs):
"""Matrix-matrix multiply intrinsic function"""
dinp, dwgt = ins
dout = outs[0]
def instr(index):
"""Generate matrix-matrix multiply VTA instruction"""
irb = tvm.ir_builder.create()
irb.scope_attr(VTA_AXIS, "coproc_scope", get_task_qid(spec.VTA_QID_COMPUTE))
irb.scope_attr(VTA_AXIS, "coproc_uop_scope", VTA_PUSH_UOP)
if index == 0 or index == 2:
irb.emit(tvm.call_extern(
"int32", "VTAUopPush",
0, 0,
dout.access_ptr("rw", "int32"),
dinp.access_ptr("r", "int32"),
dwgt.access_ptr("r", "int32"),
0, 0, 0))
else:
irb.emit(tvm.call_extern(
"int32", "VTAUopPush",
0, 1,
dout.access_ptr("rw", "int32"),
0,
0,
0, 0, 0))
return irb.get()
# return a triple of normal-set, reset, update
nop = tvm.make.Evaluate(0)
if mock:
return (nop, nop, nop)
return (instr(0), instr(1), instr(2))
return tvm.decl_tensor_intrin(out.op, intrin_func,
name="GEMM",
binds={inp: inp_layout,
wgt: wgt_layout,
out: out_layout})
GEMM = intrin_gemm()
GEVM = intrin_gevm()
"""Additional IR Pass for VTA"""
from __future__ import absolute_import as _abs
import tvm
from topi import util as util
from . import hw_spec as spec
from . runtime import CB_HANDLE, VTA_AXIS, VTA_PUSH_UOP
from . runtime import SCOPE_OUT, SCOPE_INP, SCOPE_WGT, DMA_COPY, get_task_qid
def fold_uop_loop(stmt_in):
"""Pass to fold uop loops"""
def _fold_outermost_loop(body):
stmt = body
while not isinstance(stmt, tvm.stmt.For):
if isinstance(stmt, (tvm.stmt.ProducerConsumer, )):
stmt = stmt.body
else:
return None, body, None
loop_var = stmt.loop_var
gemm_offsets = [None, None, None]
fail = [False]
def _post_order(op):
assert isinstance(op, tvm.expr.Call)
base_args = 2
if op.name == "VTAUopPush":
args = []
args += op.args[:base_args]
for i in range(3):
m = tvm.arith.DetectLinearEquation(op.args[i + base_args], [loop_var])
if not m:
fail[0] = True
return op
if gemm_offsets[i] is not None:
if not tvm.ir_pass.Equal(m[0], gemm_offsets[i]):
fail[0] = True
return op
args.append(m[1])
else:
gemm_offsets[i] = m[0]
args.append(m[1])
args += op.args[base_args+3:]
return tvm.call_extern("int32", "VTAUopPush", *args)
else:
if op.name not in ("VTATLSCommandHandle", "tvm_thread_context"):
raise RuntimeError("unexpected op %s" % op)
return op
ret = tvm.ir_pass.IRTransform(
stmt.body, None, _post_order, ["Call"])
if not fail[0] and all(x is not None for x in gemm_offsets):
def _visit(op):
if op.same_as(loop_var):
fail[0] = True
tvm.ir_pass.PostOrderVisit(ret, _visit)
if not fail[0]:
begin = tvm.call_extern(
"int32", "VTAUopLoopBegin", stmt.extent, *gemm_offsets)
end = tvm.call_extern(
"int32", "VTAUopLoopEnd", stmt.extent, *gemm_offsets)
return [begin, ret, end]
raise ValueError("Failed to fold the GEMM instructions..")
def _do_fold(stmt):
if (stmt.attr_key == "coproc_uop_scope" and
isinstance(stmt.value, tvm.expr.StringImm) and
stmt.value.value == VTA_PUSH_UOP.value):
body = stmt.body
begins = []
ends = []
try:
begin, body, end = _fold_outermost_loop(body)
if begin is not None:
begins.append(begin)
if end is not None:
ends.append(end)
begin, body, end = _fold_outermost_loop(body)
if begin is not None:
begins.append(begin)
if end is not None:
ends.append(end)
except ValueError:
pass
if body == stmt.body:
return stmt
ends = list(reversed(ends))
body = tvm.make.stmt_seq(*(begins + [body] + ends))
return tvm.make.AttrStmt(
stmt.node, stmt.attr_key, stmt.value, body)
return None
out = tvm.ir_pass.IRTransform(
stmt_in, _do_fold, None, ["AttrStmt"])
return out
def cpu_access_rewrite(stmt_in):
"""Rewrite the code when there is CPU access happening.
Parameters
----------
stmt_in : Stmt
Input statement
Returns
-------
stmt_out : Stmt
Transformed statement
"""
rw_info = {}
def _post_order(op):
if isinstance(op, tvm.stmt.Allocate):
buffer_var = op.buffer_var
if not buffer_var in rw_info:
return None
new_var = rw_info[buffer_var]
let_stmt = tvm.make.LetStmt(
new_var, tvm.call_extern(
"handle", "VTABufferCPUPtr", CB_HANDLE,
buffer_var), op.body)
alloc = tvm.make.Allocate(
buffer_var, op.dtype, op.extents,
op.condition, let_stmt)
del rw_info[buffer_var]
return alloc
elif isinstance(op, tvm.expr.Load):
buffer_var = op.buffer_var
if not buffer_var in rw_info:
rw_info[buffer_var] = tvm.var(
buffer_var.name + "_ptr", "handle")
new_var = rw_info[buffer_var]
return tvm.make.Load(op.dtype, new_var, op.index)
elif isinstance(op, tvm.stmt.Store):
buffer_var = op.buffer_var
if not buffer_var in rw_info:
rw_info[buffer_var] = tvm.var(
buffer_var.name + "_ptr", "handle")
new_var = rw_info[buffer_var]
return tvm.make.Store(new_var, op.value, op.index)
else:
raise RuntimeError("not reached")
stmt = tvm.ir_pass.IRTransform(
stmt_in, None, _post_order, ["Allocate", "Load", "Store"])
for buffer_var, new_var in rw_info.items():
stmt = tvm.make.LetStmt(
new_var, tvm.call_extern(
"handle", "VTABufferCPUPtr", CB_HANDLE,
buffer_var), stmt)
return stmt
def lift_alloc_to_scope_begin(stmt_in):
"""Lift allocate to beginning of the current scope.
Parameters
----------
stmt_in : Stmt
Input statement
Returns
-------
stmt_out : Stmt
Transformed statement
"""
lift_stmt = [[]]
def _merge_block(slist, body):
for op in slist:
if op.body == body:
body = op
elif isinstance(op, tvm.stmt.Allocate):
body = tvm.make.Allocate(
op.buffer_var, op.dtype,
op.extents, op.condition, body)
elif isinstance(op, tvm.stmt.AttrStmt):
body = tvm.make.AttrStmt(
op.node, op.attr_key, op.value, body)
elif isinstance(op, tvm.stmt.For):
body = tvm.make.For(
op.loop_var, op.min, op.extent, op.for_type,
op.device_api, body)
else:
raise RuntimeError("unexpected op")
del slist[:]
# n = len(slist)
# for i in range(n):
# slist.pop()
return body
def _pre_order(op):
if isinstance(op, tvm.stmt.For):
lift_stmt.append([])
elif isinstance(op, tvm.stmt.AttrStmt):
if op.attr_key == "virtual_thread":
lift_stmt.append([])
return None
def _post_order(op):
if isinstance(op, tvm.stmt.Allocate):
lift_stmt[-1].append(op)
return op.body
elif isinstance(op, tvm.stmt.AttrStmt):
if op.attr_key == "storage_scope":
lift_stmt[-1].append(op)
return op.body
elif op.attr_key == "virtual_thread":
return _merge_block(lift_stmt.pop() + [op], op.body)
return op
elif isinstance(op, tvm.stmt.For):
return _merge_block(lift_stmt.pop() + [op], op.body)
else:
raise RuntimeError("not reached")
stmt = tvm.ir_pass.IRTransform(
stmt_in, _pre_order, _post_order, ["Allocate", "AttrStmt", "For"])
assert len(lift_stmt) == 1
return _merge_block(lift_stmt[0], stmt)
def inject_skip_copy(stmt_in):
"""Pass to inject skip copy stmt, used in debug.
Parameters
----------
stmt_in : Stmt
Input statement
Returns
-------
stmt_out : Stmt
Transformed statement
"""
def _do_fold(stmt):
if stmt.attr_key == "pragma_scope" and stmt.value.value == "skip_dma_copy":
return tvm.make.Evaluate(0)
return None
return tvm.ir_pass.IRTransform(
stmt_in, _do_fold, None, ["AttrStmt"])
def inject_coproc_sync(stmt_in):
"""Pass to inject skip copy stmt, used in debug.
Parameters
----------
stmt_in : Stmt
Input statement
Returns
-------
stmt_out : Stmt
Transformed statement
"""
success = [False]
def _do_fold(stmt):
if stmt.attr_key == "pragma_scope" and stmt.value.value == "coproc_sync":
success[0] = True
sync = tvm.make.Call(
"int32", "vta.coproc_sync", [], tvm.expr.Call.Intrinsic, None, 0)
return tvm.make.Block(stmt.body, tvm.make.Evaluate(sync))
elif stmt.attr_key == "pragma_scope" and stmt.value.value == "trim_loop":
op = stmt.body
assert isinstance(op, tvm.stmt.For)
return tvm.make.For(
op.loop_var, op.min, 2, op.for_type,
op.device_api, op.body)
return None
stmt = tvm.ir_pass.IRTransform(
stmt_in, None, _do_fold, ["AttrStmt"])
stmt = tvm.ir_pass.CoProcSync(stmt)
return stmt
def inject_dma_intrin(stmt_in):
"""Pass to inject DMA copy intrinsics.
Parameters
----------
stmt_in : Stmt
Input statement
Returns
-------
stmt_out : Stmt
Transformed statement
"""
def _check_compact(buf):
ndim = len(buf.shape)
size = tvm.const(1, buf.shape[0].dtype)
for i in reversed(range(ndim)):
if not util.equal_const_int(size - buf.strides[i], 0):
raise RuntimeError(
"Cannot prove compact: shape=%s, strides=%s" % (buf.shape, buf.strides))
size = size * buf.shape[i]
def _fold_buffer_dim(buf, scope, elem_block):
ndim = len(buf.shape)
x_size = 1
base = 0
for i in range(1, ndim + 1):
if not util.equal_const_int(buf.strides[ndim - i] - x_size, 0):
raise RuntimeError("scope %s need need to have block=%d" % (scope, elem_block))
x_size = x_size * buf.shape[ndim - i]
if util.equal_const_int(x_size - elem_block, 0):
base = i + 1
break
if base == 0:
raise RuntimeError("scope %s need to have block=%d, shape=%s" % (
scope, elem_block, buf.shape))
shape = [elem_block]
strides = [1]
if (base < ndim + 1 and not util.equal_const_int(buf.strides[ndim - base], elem_block)):
shape.append(1)
strides.append(elem_block)
while base < ndim + 1:
x_size = 1
x_stride = buf.strides[ndim - base]
next_base = base
if not util.equal_const_int(x_stride % elem_block, 0):
raise RuntimeError("scope %s need to have block=%d, shape=%s, strides=%s" % (
scope, elem_block, buf.shape, buf.strides))
for i in range(base, ndim + 1):
k = ndim - i
if not util.equal_const_int(x_size * x_stride - buf.strides[k], 0):
break
x_size = x_size * buf.shape[k]
next_base = i + 1
shape.append(tvm.ir_pass.Simplify(x_size))
strides.append(x_stride)
assert next_base != base
base = next_base
strides = list(reversed(strides))
shape = list(reversed(shape))
return shape, strides
def _get_2d_pattern(buf, elem_width, elem_bytes, dtype, scope, allow_fold):
elem_block = elem_bytes * 8 // elem_width
if buf.dtype != dtype:
raise RuntimeError("Expect buffer type to be %s instead of %s" %
(dtype, buf.dtype))
shape, strides = buf.shape, buf.strides
if not util.equal_const_int(buf.elem_offset % elem_block, 0):
raise RuntimeError("scope %s need to have block=%d" % (scope, elem_block))
if allow_fold:
shape, strides = _fold_buffer_dim(buf, scope, elem_block)
else:
shape = list(x for x in shape)
strides = list(x for x in strides)
def raise_error():
"""Internal function to raise error """
raise RuntimeError(
("Scope[%s]: cannot detect 2d pattern with elem_block=%d:" +
" shape=%s, strides=%s") % (scope, elem_block, buf.shape, buf.strides))
ndim = len(shape)
# Check if the inner-tensor is already flat
flat = util.equal_const_int(shape[-1], elem_block)
if flat:
if not util.equal_const_int(strides[-1], 1):
raise_error()
if ndim == 1:
x_size = 1
x_stride = 1
y_size = 1
return x_size, y_size, x_stride, buf.elem_offset / elem_block
if not util.equal_const_int(strides[-2] - elem_block, 0):
raise_error()
if ndim == 2:
x_size = shape[-2]
x_stride = shape[-2]
y_size = 1
return x_size, y_size, x_stride, buf.elem_offset / elem_block
if not util.equal_const_int(strides[-3] % elem_block, 0):
raise_error()
if ndim == 3:
x_size = shape[-2]
x_stride = strides[-3] / elem_block
y_size = shape[-3]
return x_size, y_size, x_stride, buf.elem_offset / elem_block
else:
if not util.equal_const_int(strides[-1], 1):
raise_error()
if not util.equal_const_int(strides[-2] - shape[-1], 0):
raise_error()
if not util.equal_const_int(shape[-1] * shape[-2], elem_block):
raise_error()
if ndim == 2:
x_size = 1
x_stride = 1
y_size = 1
return x_size, y_size, x_stride, buf.elem_offset / elem_block
if not util.equal_const_int(strides[-3], elem_block):
raise_error()
if ndim == 3:
x_size = shape[-3]
x_stride = shape[-3]
y_size = 1
return x_size, y_size, x_stride, buf.elem_offset / elem_block
if not util.equal_const_int(strides[-4] % elem_block, 0):
raise_error()
if ndim == 4:
x_size = shape[-3]
x_stride = strides[-4] / elem_block
y_size = shape[-4]
return x_size, y_size, x_stride, buf.elem_offset / elem_block
raise_error()
def _inject_copy(src, dst, pad_before, pad_after, pad_value):
# FIXME: pad_value is ignored...
if dst.scope == "global":
# Store
if pad_before or pad_after:
raise RuntimeError("Do not support copy into DRAM with pad")
if src.scope == SCOPE_OUT:
elem_width = spec.VTA_INP_WIDTH # output compression to inp type
elem_bytes = spec.VTA_INP_ELEM_BYTES # output compression to inp type
mem_type = spec.VTA_MEM_ID_OUT
data_type = "int%d" % spec.VTA_INP_WIDTH
task_qid = spec.VTA_QID_STORE_OUT
else:
raise RuntimeError("Do not support copy %s->dram" % (src.scope))
_check_compact(src)
x_size, y_size, x_stride, offset = _get_2d_pattern(
dst, elem_width, elem_bytes, data_type, src.scope, allow_fold=True)
irb = tvm.ir_builder.create()
irb.scope_attr(VTA_AXIS, "coproc_scope", get_task_qid(task_qid))
irb.emit(tvm.call_extern(
"int32", "VTAStoreBuffer2D", CB_HANDLE,
src.access_ptr("r", "int32"),
mem_type, dst.data, offset, x_size, y_size, x_stride))
return irb.get()
elif src.scope == "global":
if dst.scope == SCOPE_OUT:
elem_width = spec.VTA_OUT_WIDTH
elem_bytes = spec.VTA_OUT_ELEM_BYTES
mem_type = spec.VTA_MEM_ID_ACC
data_type = "int%d" % spec.VTA_OUT_WIDTH
task_qid = spec.VTA_QID_LOAD_OUT
elif dst.scope == SCOPE_INP:
elem_width = spec.VTA_INP_WIDTH
elem_bytes = spec.VTA_INP_ELEM_BYTES
mem_type = spec.VTA_MEM_ID_INP
data_type = "int%d" % spec.VTA_INP_WIDTH
task_qid = spec.VTA_QID_LOAD_INP
elif dst.scope == SCOPE_WGT:
elem_width = spec.VTA_WGT_WIDTH
elem_bytes = spec.VTA_WGT_ELEM_BYTES
mem_type = spec.VTA_MEM_ID_WGT
data_type = "int%d" % spec.VTA_WGT_WIDTH
task_qid = spec.VTA_QID_LOAD_WGT
else:
raise RuntimeError("Do not support copy dram->%s" % (dst.scope))
# collect pad statistics
if pad_before:
assert pad_after
ndim = len(pad_before)
if ndim <= 2 or ndim > 4:
raise ValueError("Limitation of 2D pad load forbid ndim=%d" % ndim)
if ndim > 2:
if not util.equal_const_int(pad_before[ndim - 1], 0):
raise ValueError("Do not support pad on the innermost block")
if not util.equal_const_int(pad_after[ndim - 1], 0):
raise ValueError("Do not support pad on the innermost block")
if ndim > 3:
if not util.equal_const_int(pad_before[ndim - 2], 0):
raise ValueError("Do not support pad on the innermost block")
if not util.equal_const_int(pad_after[ndim - 2], 0):
raise ValueError("Do not support pad on the innermost block")
y_pad_before = pad_before[0]
x_pad_before = pad_before[1]
y_pad_after = pad_after[0]
x_pad_after = pad_after[1]
allow_fold = False
else:
x_pad_before = 0
y_pad_before = 0
x_pad_after = 0
y_pad_after = 0
allow_fold = True
_check_compact(dst)
x_size, y_size, x_stride, offset = _get_2d_pattern(
src, elem_width, elem_bytes, data_type, dst.scope, allow_fold=allow_fold)
irb = tvm.ir_builder.create()
irb.scope_attr(VTA_AXIS, "coproc_scope", get_task_qid(task_qid))
irb.emit(tvm.call_extern(
"int32", "VTALoadBuffer2D", CB_HANDLE,
src.data, offset, x_size, y_size, x_stride,
x_pad_before, y_pad_before,
x_pad_after, y_pad_after,
dst.access_ptr("r", "int32"), mem_type))
return irb.get()
else:
raise RuntimeError("Donot support copy %s->%s" % (src.scope, dst.scope))
return tvm.ir_pass.InjectCopyIntrin(stmt_in, DMA_COPY, _inject_copy)
def annotate_alu_coproc_scope(stmt_in):
"""Pass to insert ALU instruction.
Parameters
----------
stmt_in : Stmt
Input statement
Returns
-------
stmt_out : Stmt
Transformed statement
"""
def _do_fold(stmt):
if (stmt.attr_key == "pragma_scope" and stmt.value.value == "alu"):
irb = tvm.ir_builder.create()
irb.scope_attr(VTA_AXIS, "coproc_scope", get_task_qid(spec.VTA_QID_COMPUTE))
irb.scope_attr(VTA_AXIS, "coproc_uop_scope", tvm.make.StringImm("VTAPushALUOp"))
irb.emit(stmt)
return irb.get()
elif (stmt.attr_key == "pragma_scope" and stmt.value.value == "skip_alu"):
return tvm.make.Evaluate(0)
return stmt
stmt_out = tvm.ir_pass.IRTransform(
stmt_in, None, _do_fold, ["AttrStmt"])
return stmt_out
def inject_alu_intrin(stmt_in):
"""Pass to inject ALU micro-ops.
Parameters
----------
stmt_in : Stmt
Input statement
Returns
-------
stmt_out : Stmt
Transformed statement
"""
def _do_fold(stmt):
def _equal(x, y):
return tvm.ir_pass.Equal(tvm.ir_pass.Simplify(x - y), 0)
def _flatten_loop(src_coeff, dst_coeff, extents):
src_coeff = list(src_coeff)
dst_coeff = list(dst_coeff)
extents = list(extents)
rev_src_coeff = [src_coeff.pop()]
rev_dst_coeff = [dst_coeff.pop()]
rev_extents = []
assert src_coeff
vsrc = src_coeff.pop()
vdst = dst_coeff.pop()
vext = extents.pop()
while src_coeff:
next_src = src_coeff.pop()
next_dst = dst_coeff.pop()
next_ext = extents.pop()
if (_equal(next_src, vsrc * vext) and _equal(next_dst, vdst * vext)):
vext = tvm.ir_pass.Simplify(vext * next_ext)
else:
rev_src_coeff.append(vsrc)
rev_dst_coeff.append(vdst)
rev_extents.append(vext)
vsrc = next_src
vdst = next_dst
vext = next_ext
rev_src_coeff.append(vsrc)
rev_dst_coeff.append(vdst)
rev_extents.append(vext)
rev_src_coeff.reverse()
rev_dst_coeff.reverse()
rev_extents.reverse()
return rev_src_coeff, rev_dst_coeff, rev_extents
if (stmt.attr_key == "pragma_scope" and stmt.value.value == "alu"):
# Get to the innermost loop body
loop_body = stmt.body
nest_size = 0
while isinstance(loop_body, tvm.stmt.For):
loop_body = loop_body.body
nest_size += 1
# Get the src/dst arguments
dst_var = loop_body.buffer_var
dst_idx = loop_body.index
# Derive loop variables and extents
tmp_body = stmt.body
indices = []
extents = []
for _ in range(nest_size):
indices.append(tmp_body.loop_var)
extents.append(tmp_body.extent)
tmp_body = tmp_body.body
# Derive opcode
alu_opcode = spec.VTA_ALU_OPCODE_UNSET
if isinstance(loop_body.value, tvm.expr.Add):
alu_opcode = spec.VTA_ALU_OPCODE_ADD
lhs = loop_body.value.a
rhs = loop_body.value.b
elif isinstance(loop_body.value, tvm.expr.Sub):
alu_opcode = spec.VTA_ALU_OPCODE_SUB
lhs = loop_body.value.a
rhs = loop_body.value.b
elif isinstance(loop_body.value, tvm.expr.Mul):
alu_opcode = spec.VTA_ALU_OPCODE_MUL
lhs = loop_body.value.a
rhs = loop_body.value.b
elif isinstance(loop_body.value, tvm.expr.Min):
alu_opcode = spec.VTA_ALU_OPCODE_MIN
lhs = loop_body.value.a
rhs = loop_body.value.b
elif isinstance(loop_body.value, tvm.expr.Max):
alu_opcode = spec.VTA_ALU_OPCODE_MAX
lhs = loop_body.value.a
rhs = loop_body.value.b
elif isinstance(loop_body.value, tvm.expr.Call):
if loop_body.value.name == 'shift_left':
alu_opcode = spec.VTA_ALU_OPCODE_SHL
lhs = loop_body.value.args[0]
rhs = loop_body.value.args[1]
elif loop_body.value.name == 'shift_right':
alu_opcode = spec.VTA_ALU_OPCODE_SHR
lhs = loop_body.value.args[0]
rhs = loop_body.value.args[1]
else:
raise RuntimeError(
"Function call not recognized %s" % (loop_body.value.name))
else:
raise RuntimeError("Expression not recognized %s" % (type(loop_body.value)))
# Derive array index coefficients
dst_coeff = tvm.arith.DetectLinearEquation(dst_idx, indices)
# Check if lhs/rhs is immediate
use_imm = False
imm_val = None
if isinstance(rhs, tvm.expr.IntImm):
assert lhs.buffer_var.same_as(dst_var)
src_coeff = tvm.arith.DetectLinearEquation(lhs.index, indices)
use_imm = True
imm_val = rhs
if isinstance(lhs, tvm.expr.IntImm):
assert rhs.buffer_var.same_as(dst_var)
src_coeff = tvm.arith.DetectLinearEquation(rhs.index, indices)
use_imm = True
imm_val = lhs
if imm_val is None:
imm_val = 0
assert lhs.buffer_var.same_as(dst_var) and rhs.buffer_var.same_as(dst_var)
src_lhs_coeff = tvm.arith.DetectLinearEquation(lhs.index, indices)
src_rhs_coeff = tvm.arith.DetectLinearEquation(rhs.index, indices)
# Determine which side has the same coefficients
lhs_equal = True
rhs_equal = True
for i, coef in enumerate(dst_coeff):
if not tvm.ir_pass.Equal(coef, src_lhs_coeff[i]):
lhs_equal = False
if not tvm.ir_pass.Equal(coef, src_rhs_coeff[i]):
rhs_equal = False
# Make sure at least one of the source is identical to the
# destination (in-place computation)
assert lhs_equal or rhs_equal
# Assign the source coefficients
if lhs_equal:
src_coeff = src_rhs_coeff
else:
src_coeff = src_lhs_coeff
# Ensure that we have the proper tensor dimensions in the
# innermost loop (pattern match)
src_coeff = list(src_coeff)
dst_coeff = list(dst_coeff)
extents = list(extents)
assert len(src_coeff) > 1
assert len(dst_coeff) > 1
assert len(extents) > 0
assert tvm.ir_pass.Equal(
tvm.ir_pass.Simplify(
src_coeff[-1]%(spec.VTA_BATCH*spec.VTA_BLOCK_OUT)), 0)
assert tvm.ir_pass.Equal(
tvm.ir_pass.Simplify(
dst_coeff[-1]%(spec.VTA_BATCH*spec.VTA_BLOCK_OUT)), 0)
assert tvm.ir_pass.Equal(src_coeff[-2], 1)
assert tvm.ir_pass.Equal(dst_coeff[-2], 1)
if spec.VTA_BATCH > 1:
assert len(src_coeff) > 2
assert len(dst_coeff) > 2
assert len(extents) > 1
assert tvm.ir_pass.Equal(src_coeff[-3], spec.VTA_BLOCK_OUT)
assert tvm.ir_pass.Equal(dst_coeff[-3], spec.VTA_BLOCK_OUT)
# Apply tensorization of the loop coefficients
src_offset = src_coeff[-1]
dst_offset = dst_coeff[-1]
if spec.VTA_BATCH == 1:
src_coeff = src_coeff[:-2]
dst_coeff = dst_coeff[:-2]
extents = extents[:-1]
else:
src_coeff = src_coeff[:-3]
dst_coeff = dst_coeff[:-3]
extents = extents[:-2]
src_coeff.append(src_offset)
dst_coeff.append(dst_offset)
src_coeff = [
tvm.ir_pass.Simplify(c/(spec.VTA_BATCH*spec.VTA_BLOCK_OUT)) for c in src_coeff]
dst_coeff = [
tvm.ir_pass.Simplify(c/(spec.VTA_BATCH*spec.VTA_BLOCK_OUT)) for c in dst_coeff]
# Flatten the outer loops
src_coeff, dst_coeff, extents = _flatten_loop(src_coeff, dst_coeff, extents)
# Insert ALU micro-ops
irb = tvm.ir_builder.create()
for idx, extent in enumerate(extents):
irb.emit(tvm.call_extern(
"int32", "VTAUopLoopBegin", extent, dst_coeff[idx], src_coeff[idx]))
irb.emit(tvm.call_extern(
"int32", "VTAUopPush",
1, 0,
dst_coeff[len(dst_coeff)-1],
src_coeff[len(src_coeff)-1],
0,
alu_opcode, use_imm, imm_val))
for extent in extents:
irb.emit(tvm.call_extern(
"int32", "VTAUopLoopEnd"))
return irb.get()
return stmt
stmt_out = tvm.ir_pass.IRTransform(
stmt_in, None, _do_fold, ["AttrStmt"])
return stmt_out
def debug_print(stmt):
print stmt
return stmt
"""Mock interface for skip part of compute """
from .intrin import intrin_gevm, intrin_gemm
GEMM = intrin_gemm(True)
GEVM = intrin_gevm(True)
DMA_COPY = "skip_dma_copy"
ALU = "skip_alu"
"""Runtime function related hooks"""
from __future__ import absolute_import as _abs
import tvm
def thread_local_command_buffer():
"""Get thread local command buffer"""
ctx = tvm.call_extern("handle", "VTATLSCommandHandle")
return tvm.make.Call(
"handle", "tvm_thread_context", [ctx], tvm.expr.Call.Intrinsic, None, 0)
CB_HANDLE = thread_local_command_buffer()
VTA_AXIS = tvm.thread_axis("vta")
VTA_PUSH_UOP = tvm.make.StringImm("VTAPushGEMMOp")
SCOPE_INP = "local.inp_buffer"
SCOPE_OUT = "local.out_buffer"
SCOPE_WGT = "local.wgt_buffer"
DMA_COPY = "dma_copy"
ALU = "alu"
DEBUG_NO_SYNC = False
def get_task_qid(qid):
"""Get transformed queue index."""
return 1 if DEBUG_NO_SYNC else qid
@tvm.register_func("tvm.intrin.rule.default.vta.coproc_sync")
def coproc_sync(op):
return tvm.call_extern(
"int32", "VTASynchronize", CB_HANDLE, 1<<31)
@tvm.register_func("tvm.intrin.rule.default.vta.coproc_dep_push")
def coproc_dep_push(op):
return tvm.call_extern(
"int32", "VTADepPush", CB_HANDLE, op.args[0], op.args[1])
@tvm.register_func("tvm.intrin.rule.default.vta.coproc_dep_pop")
def coproc_dep_pop(op):
return tvm.call_extern(
"int32", "VTADepPop", CB_HANDLE, op.args[0], op.args[1])
"""Namespace for supporting packed_conv2d + ewise variant of nnvm."""
from collections import namedtuple
import logging
import tvm
import topi
from nnvm.top import registry as reg, OpPattern
from . import intrin, runtime as vta
from intrin import GEVM
TARGET_BOARD = "llvm -mtriple=armv7-none-linux-gnueabihf -mcpu=cortex-a9 -mattr=+neon"
Workload = namedtuple("Conv2DWorkload",
['height', 'width', 'in_filter', 'out_filter',
'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])
def packed_conv2d(data,
kernel,
padding,
strides,
out_dtype="int32"):
""" Packed conv2d function.
"""
if padding[0]:
pad_data = topi.nn.pad(data, [0, 0, padding[0], padding[1], 0], name="pad_data")
else:
pad_data = data
assert len(data.shape) == 5
assert len(kernel.shape) == 6
oheight = topi.util.simplify((pad_data.shape[2] - kernel.shape[2]) // strides[0] + 1)
owidth = topi.util.simplify((pad_data.shape[3] - kernel.shape[3]) // strides[1] + 1)
oshape = (data.shape[0], kernel.shape[0], oheight, owidth, kernel.shape[4])
ishape = topi.util.get_const_tuple(data.shape)
kshape = topi.util.get_const_tuple(kernel.shape)
assert data.dtype == "int8", data.dtype
assert kernel.dtype == "int8", kernel.dtype
d_i = tvm.reduce_axis((0, kshape[2]), name='d_i')
d_j = tvm.reduce_axis((0, kshape[3]), name='d_j')
k_o = tvm.reduce_axis((0, ishape[1]), name='k_o')
k_i = tvm.reduce_axis((0, ishape[-1]), name='k_i')
hstride, wstride = strides
res = tvm.compute(
oshape,
lambda b, co, i, j, ci: tvm.sum(
pad_data[b, k_o, i*hstride+d_i, j*wstride+d_j, k_i].astype(out_dtype) *
kernel[co, k_o, d_i, d_j, ci, k_i].astype(out_dtype),
axis=[k_o, d_i, d_j, k_i]),
name="res", tag="packed_conv2d")
return res
@tvm.register_func("nnvm.compiler.build_target", override=True)
def _build(funcs, target, target_host):
tvm_t = tvm.target.create(target)
if tvm_t.device_name == "vta":
return tvm.build(funcs, target="ext_dev",
target_host=TARGET_BOARD)
elif tvm_t.device_name == "rasp" or tvm_t.device_name == "tcpu":
return tvm.build(funcs, target=TARGET_BOARD)
return tvm.build(funcs, target=target)
@tvm.register_func("nnvm.compiler.lower", override=True)
def _lower(sch, inputs, func_name, graph):
import traceback
# pylint: disable=broad-except
try:
f = tvm.lower(sch, inputs, name=func_name)
if "quantized_conv2d" in func_name:
logging.info(graph.ir(join_entry_attrs=["shape"]))
except Exception:
msg = traceback.format_exc()
msg += "Error during compile graph\n"
msg += "--------------------------\n"
msg += graph.ir(join_entry_attrs=["shape"])
raise RuntimeError(msg)
return f if isinstance(
f, (tvm.container.Array, tuple, list)) else [f]
@reg.register_compute("clip", level=11)
def compute_clip(attrs, inputs, _):
""" Clip operator.
"""
x = inputs[0]
a_min = attrs.get_float("a_min")
a_max = attrs.get_float("a_max")
const_min = tvm.const(a_min, x.dtype)
const_max = tvm.const(a_max, x.dtype)
with tvm.tag_scope(topi.tag.ELEMWISE):
x = tvm.compute(
x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA")
x = tvm.compute(
x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB")
return x
reg.register_pattern("identity", OpPattern.INJECTIVE, level=11)
@reg.register_compute("quantized_conv2d", level=11)
def compute_quantized_conv2d(attrs, inputs, out):
""" 2D convolution algorithm.
"""
padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides")
dilation = attrs.get_int_tuple("dilation")
groups = attrs.get_int("groups")
channels = attrs.get_int("channels")
layout = attrs["layout"]
out_dtype = attrs['out_type']
cmp_dtype = 'int32' # compute data type
assert layout == "NCHW", "only support nchw for now"
assert dilation == (1, 1), "not support dilate now"
assert attrs.get_bool("use_bias") is False
pack_channel = attrs.get_int("pack_channel")
if pack_channel != 0:
assert groups == 1
return packed_conv2d(inputs[0], inputs[1],
padding, strides)
if groups == 1:
out = topi.nn.conv2d(inputs[0], inputs[1], strides, padding, out_dtype=cmp_dtype)
elif groups == get_const_int(inputs[0].shape[1]) and groups == channels:
out = topi.nn.depthwise_conv2d_nchw(
inputs[0], inputs[1], strides, padding, out_dtype=cmp_dtype)
else:
raise ValueError("not support arbitrary group number for now")
assert out_dtype == cmp_dtype
return out
@reg.register_schedule("quantized_conv2d", level=11)
def schedule_quantized_conv2d(attrs, outs, target):
""" 2D convolution schedule.
"""
channels = attrs.get_int("channels")
pack_channel = attrs.get_int("pack_channel")
if channels != 0 and pack_channel:
target = tvm.target.create(target)
if target.device_name == "vta":
return schedule_packed_conv2d(outs)
elif target.startswith("llvm"):
return tvm.create_schedule([x.op for x in outs])
else:
raise RuntimeError("not support target %s" % target)
with tvm.target.create(target):
return topi.generic.schedule_conv2d_nchw(outs)
def _get_workload(data, pad_data, kernel, output):
""" Get the workload structure.
"""
o_shape = topi.util.get_const_tuple(output.shape)
d_shape = topi.util.get_const_tuple(data.shape)
k_shape = topi.util.get_const_tuple(kernel.shape)
o_b, o_c, o_h, o_w, o_blk = o_shape
i_b, i_c, i_h, i_w, i_blk = d_shape
k_o, k_i, k_h, k_w, ko_blk, ki_blk = k_shape
# For now we need to assume that input channel blocking is the same
# as the output channel blocking
assert o_blk == i_blk
# Make sure that dimensions match
assert o_b == i_b
assert o_blk == ko_blk
assert i_blk == ki_blk
assert k_o == o_c
assert k_i == i_c
# Scale the channel size
i_c *= i_blk
o_c *= o_blk
if pad_data is not None:
p_shape = topi.util.get_const_tuple(pad_data.shape)
h_pad = (p_shape[2] - d_shape[2]) // 2
w_pad = (p_shape[3] - d_shape[3]) // 2
else:
h_pad, w_pad = 0, 0
h_str = (i_h + h_pad*2 - k_h) // (o_h - 1)
w_str = (i_w + w_pad*2 - k_w) // (o_w - 1)
return Workload(i_h, i_w, i_c, o_c, k_h, k_w, h_pad, w_pad, h_str, w_str)
_WL2PLAN = {}
def schedule_packed_conv2d(outs):
""" Schedule the packed conv2d.
"""
assert len(outs) == 1
output = outs[0]
ewise_inputs = []
ewise_ops = []
conv2d_res = []
assert output.dtype == "int8"
assert output.op.input_tensors[0].dtype == "int32"
def _traverse(op):
if topi.tag.is_broadcast(op.tag):
if not op.same_as(output.op):
ewise_ops.append(op)
for tensor in op.input_tensors:
if isinstance(tensor.op, tvm.tensor.PlaceholderOp):
ewise_inputs.append((op, tensor))
else:
_traverse(tensor.op)
else:
assert op.tag == "packed_conv2d"
conv2d_res.append(op)
_traverse(output.op)
assert len(conv2d_res) == 1
conv2d_stage = conv2d_res[0].output(0)
data, kernel = conv2d_stage.op.input_tensors
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
temp = data.op.input_tensors[0]
pad_data = data
data = temp
else:
pad_data = None
wrkld = _get_workload(data, pad_data, kernel, output)
plan = _WL2PLAN[wrkld]
load_inp = load_wgt = load_out = store_out = "dma_copy"
alu = "alu"
gevm = GEVM
# schedule1
oshape = topi.util.get_const_tuple(output.shape)
s = tvm.create_schedule(output.op)
# setup pad
if pad_data is not None:
cdata = pad_data
s[pad_data].set_scope(vta.SCOPE_INP)
else:
cdata = s.cache_read(data, vta.SCOPE_INP, [conv2d_stage])
ckernel = s.cache_read(kernel, vta.SCOPE_WGT, [conv2d_stage])
s[conv2d_stage].set_scope(vta.SCOPE_OUT)
# cache read input
cache_read_ewise = []
for consumer, tensor in ewise_inputs:
cache_read_ewise.append(
s.cache_read(tensor, vta.SCOPE_OUT, [consumer]))
# set ewise scope
for op in ewise_ops:
s[op].set_scope(vta.SCOPE_OUT)
s[op].pragma(s[op].op.axis[0], alu)
# tile
oc_factor = (plan.oc_factor if plan.oc_factor
else wrkld.out_filter // vta.VTA_BLOCK_OUT)
h_factor = (plan.h_factor if plan.h_factor else oshape[2])
w_factor = (plan.w_factor if plan.w_factor else oshape[3])
x_b, x_oc, x_i, x_j, x_ic = s[output].op.axis
x_oc0, x_oc1 = s[output].split(x_oc, factor=oc_factor)
x_i0, x_i1 = s[output].split(x_i, factor=h_factor)
x_j0, x_j1 = s[output].split(x_j, factor=w_factor)
s[output].reorder(x_b, x_oc0, x_i0, x_j0, x_oc1, x_i1, x_j1, x_ic)
store_pt = x_j0
# set all compute scopes
s[conv2d_stage].compute_at(s[output], store_pt)
for op in ewise_ops:
s[op].compute_at(s[output], store_pt)
for tensor in cache_read_ewise:
s[tensor].compute_at(s[output], store_pt)
s[tensor].pragma(s[tensor].op.axis[0], load_out)
# virtual threading along output channel axes
if plan.oc_nthread:
_, v_t = s[output].split(x_oc0, factor=plan.oc_nthread)
s[output].reorder(v_t, x_b)
s[output].bind(v_t, tvm.thread_axis("cthread"))
# virtual threading along spatial rows
if plan.h_nthread:
_, v_t = s[output].split(x_i0, factor=plan.h_nthread)
s[output].reorder(v_t, x_b)
s[output].bind(v_t, tvm.thread_axis("cthread"))
x_b, x_oc, x_i, x_j, x_ic = s[conv2d_stage].op.axis
k_o, d_i, d_j, k_i = s[conv2d_stage].op.reduce_axis
s[conv2d_stage].reorder(k_o, x_j, d_j, d_i, x_oc, x_i, x_ic, k_i)
if plan.ko_factor:
k_o, _ = s[conv2d_stage].split(k_o, factor=plan.ko_factor)
s[cdata].compute_at(s[conv2d_stage], k_o)
s[ckernel].compute_at(s[conv2d_stage], k_o)
# Use VTA instructions
s[cdata].pragma(s[cdata].op.axis[0], load_inp)
s[ckernel].pragma(s[ckernel].op.axis[0], load_wgt)
s[conv2d_stage].tensorize(x_ic, gevm)
s[output].pragma(x_oc1, store_out)
return s
class Conv2DSchedule(object):
""" 2D convolution schedule object.
"""
def __init__(self,
oc_factor,
ko_factor=1,
h_factor=1,
w_factor=0,
oc_nthread=0,
h_nthread=0):
self.oc_factor = oc_factor
self.ko_factor = ko_factor
self.h_factor = h_factor
self.w_factor = w_factor
self.oc_nthread = oc_nthread
self.h_nthread = h_nthread
Schedule = Conv2DSchedule
# ResNet18 workloads
RESNET = {
# Workloads of resnet18 on imagenet
0: Workload(224, 224, 16, 64, 7, 7, 3, 3, 2, 2),
1: Workload(56, 56, 64, 64, 3, 3, 1, 1, 1, 1),
2: Workload(56, 56, 64, 64, 1, 1, 0, 0, 1, 1),
3: Workload(56, 56, 64, 128, 3, 3, 1, 1, 2, 2),
4: Workload(56, 56, 64, 128, 1, 1, 0, 0, 2, 2),
5: Workload(28, 28, 128, 128, 3, 3, 1, 1, 1, 1),
6: Workload(28, 28, 128, 256, 3, 3, 1, 1, 2, 2),
7: Workload(28, 28, 128, 256, 1, 1, 0, 0, 2, 2),
8: Workload(14, 14, 256, 256, 3, 3, 1, 1, 1, 1),
9: Workload(14, 14, 256, 512, 3, 3, 1, 1, 2, 2),
10: Workload(14, 14, 256, 512, 1, 1, 0, 0, 2, 2),
11: Workload(7, 7, 512, 512, 3, 3, 1, 1, 1, 1),
}
# Serial schedule
RESNET_SERIAL = {
RESNET[0]: Schedule(oc_factor=1, ko_factor=1, h_factor=4, w_factor=56),
RESNET[1]: Schedule(oc_factor=2, ko_factor=1, h_factor=14, w_factor=0),
RESNET[2]: Schedule(oc_factor=4, ko_factor=4, h_factor=8, w_factor=0),
RESNET[3]: Schedule(oc_factor=4, ko_factor=1, h_factor=14, w_factor=0),
RESNET[4]: Schedule(oc_factor=8, ko_factor=1, h_factor=4, w_factor=0),
RESNET[5]: Schedule(oc_factor=8, ko_factor=1, h_factor=7, w_factor=0),
RESNET[6]: Schedule(oc_factor=8, ko_factor=1, h_factor=14, w_factor=0),
RESNET[7]: Schedule(oc_factor=16, ko_factor=1, h_factor=7, w_factor=0),
RESNET[8]: Schedule(oc_factor=8, ko_factor=1, h_factor=7, w_factor=0),
RESNET[9]: Schedule(oc_factor=8, ko_factor=1, h_factor=7, w_factor=0),
RESNET[10]: Schedule(oc_factor=16, ko_factor=1, h_factor=7, w_factor=0),
RESNET[11]: Schedule(oc_factor=8, ko_factor=1, h_factor=7, w_factor=0),
}
# Latency hiding schedule
RESNET_OPT = {
RESNET[0]: Schedule(oc_factor=1, ko_factor=1, h_factor=4, w_factor=56),
RESNET[1]: Schedule(oc_factor=2, ko_factor=1, h_factor=7, h_nthread=2),
RESNET[2]: Schedule(oc_factor=4, ko_factor=2, h_factor=4, w_factor=0, h_nthread=2),
RESNET[3]: Schedule(oc_factor=4, ko_factor=1, h_factor=7, w_factor=0, h_nthread=2),
RESNET[4]: Schedule(oc_factor=4, ko_factor=1, h_factor=7, h_nthread=2),
RESNET[5]: Schedule(oc_factor=4, ko_factor=1, h_factor=7, w_factor=0, h_nthread=2),
RESNET[6]: Schedule(oc_factor=4, ko_factor=1, h_factor=7, w_factor=0, oc_nthread=2),
RESNET[7]: Schedule(oc_factor=8, ko_factor=1, h_factor=7, w_factor=0, oc_nthread=2),
RESNET[8]: Schedule(oc_factor=4, ko_factor=1, h_factor=7, w_factor=0, oc_nthread=2),
RESNET[9]: Schedule(oc_factor=4, ko_factor=1, h_factor=7, w_factor=0, oc_nthread=2),
RESNET[10]: Schedule(oc_factor=8, ko_factor=1, h_factor=7, w_factor=0, oc_nthread=2),
RESNET[11]: Schedule(oc_factor=4, ko_factor=1, h_factor=7, w_factor=0, oc_nthread=2),
}
_WL2PLAN = RESNET_OPT
......@@ -1043,9 +1043,9 @@ class CommandQueue {
VTAWriteMappedReg(vta_load_handle_, 0x10, 0);
// LOAD @ 0x18 : Data signal of weight_V
VTAWriteMappedReg(vta_load_handle_, 0x18, 0);
// COMPUTE @ 0x10 : Data signal of uops_V
// COMPUTE @ 0x20 : Data signal of uops_V
VTAWriteMappedReg(vta_compute_handle_, 0x20, 0);
// COMPUTE @ 0x18 : Data signal of biases_V
// COMPUTE @ 0x28 : Data signal of biases_V
VTAWriteMappedReg(vta_compute_handle_, 0x28, 0);
// STORE @ 0x10 : Data signal of outputs_V
VTAWriteMappedReg(vta_store_handle_, 0x10, 0);
......
......@@ -39,7 +39,7 @@ uint64_t vta(
#else // NO_SIM
#include "../../../hardware/vivado/src/vta.h"
#include "../../../hardware/xilinx/src/vta.h"
#endif // NO_SIM
......
CC ?= g++
CFLAGS = -Wall -O3 -std=c++11 -I/usr/include
LDFLAGS = -L/usr/lib -L/home/xilinx/pynq/drivers
LDFLAGS = -L/usr/lib -L/opt/python3.6/lib/python3.6/site-packages/pynq/lib/
LIBS = -l:libsds_lib.so -l:libdma.so
INCLUDE_DIR = ../../../include
DRIVER_DIR = ../../../src/pynq
......
import os
import tvm
import mxnet as mx
import vta
import numpy as np
import topi
from collections import namedtuple
from tvm.contrib import rpc, util
import pandas as pd
host = "pynq"
port = 9091
target = "llvm -target=armv7-none-linux-gnueabihf -mattr=+neon"
out_dtype = "int%d" % vta.VTA_OUT_WIDTH
inp_dtype = "int%d" % vta.VTA_INP_WIDTH
wgt_dtype = "int%d" % vta.VTA_WGT_WIDTH
Workload = namedtuple("Conv2DWorkload",
['height', 'width', 'in_filter', 'out_filter',
'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])
class Conv2DSchedule(object):
def __init__(self,
oc_factor,
ko_factor=1,
h_factor=1,
w_factor=0,
oc_nthread=0,
h_nthread=0,
debug_sync=False):
self.oc_factor = oc_factor
self.ko_factor = ko_factor
self.h_factor = h_factor
self.w_factor = w_factor
self.oc_nthread = oc_nthread
self.h_nthread = h_nthread
self.debug_sync = debug_sync
Schedule = Conv2DSchedule
def test_conv2d_chwv(key, batch_size, wl, plan, log_frame, profile=True):
assert batch_size % vta.VTA_BATCH == 0
assert wl.in_filter % vta.VTA_BLOCK_IN == 0
assert wl.out_filter % vta.VTA_BLOCK_OUT == 0
data_shape = (batch_size//vta.VTA_BATCH, wl.in_filter//vta.VTA_BLOCK_IN,
wl.height, wl.width, vta.VTA_BATCH, vta.VTA_BLOCK_IN)
kernel_shape = (wl.out_filter//vta.VTA_BLOCK_OUT, wl.in_filter//vta.VTA_BLOCK_IN,
wl.hkernel, wl.wkernel, vta.VTA_BLOCK_OUT, vta.VTA_BLOCK_IN)
fout_height = (wl.height + 2 * wl.hpad - wl.hkernel) // wl.hstride + 1
fout_width = (wl.width + 2 * wl.wpad - wl.wkernel) // wl.wstride + 1
res_shape = (batch_size//vta.VTA_BATCH, wl.out_filter//vta.VTA_BLOCK_OUT,
fout_height, fout_width, vta.VTA_BATCH, vta.VTA_BLOCK_OUT)
data = tvm.placeholder(data_shape, name="data", dtype=inp_dtype)
kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=wgt_dtype)
if wl.hpad or wl.wpad:
data_buf = topi.nn.pad(data, [0, 0, wl.hpad, wl.wpad, 0, 0], name="data_buf")
else:
data_buf = tvm.compute(data_shape, lambda *i: data(*i), "data_buf")
kernel_buf = tvm.compute(kernel_shape, lambda *i: kernel(*i), "kernel_buf")
di = tvm.reduce_axis((0, wl.hkernel), name='di')
dj = tvm.reduce_axis((0, wl.wkernel), name='dj')
ko = tvm.reduce_axis((0, wl.in_filter//vta.VTA_BLOCK_IN), name='ko')
ki = tvm.reduce_axis((0, vta.VTA_BLOCK_IN), name='ki')
res_cnv = tvm.compute(
res_shape,
lambda bo, co, i, j, bi, ci: tvm.sum(
data_buf[bo, ko, i*wl.hstride+di, j*wl.wstride+dj, bi, ki].astype(out_dtype) *
kernel_buf[co, ko, di, dj, ci, ki].astype(out_dtype),
axis=[ko, di, dj, ki]),
name="res_cnv")
res_shf = tvm.compute(res_shape, lambda *i: res_cnv(*i) >> 8, name="res_shf")
res = tvm.compute(res_shape, lambda *i: res_shf(*i).astype(inp_dtype), name="res")
num_ops = batch_size * fout_height * fout_width * wl.hkernel * wl.wkernel * wl.out_filter * wl.in_filter
def verify(s, check_correctness):
mod = tvm.build(s, [data, kernel, res], "ext_dev", target, name="conv2d")
temp = util.tempdir()
remote = rpc.connect(host, port)
mod.save(temp.relpath("conv2d.o"))
remote.upload(temp.relpath("conv2d.o"))
f = remote.load_module("conv2d.o")
# verify
ctx = remote.ext_dev(0)
# Data in original format
data_orig = np.random.randint(
-128, 128, size=(batch_size, wl.in_filter, wl.height, wl.width)).astype(data.dtype)
kernel_orig = np.random.randint(
-128, 128, size=(wl.out_filter, wl.in_filter, wl.hkernel, wl.wkernel)).astype(kernel.dtype)
data_packed = data_orig.reshape(
batch_size//vta.VTA_BATCH, vta.VTA_BATCH,
wl.in_filter//vta.VTA_BLOCK_IN, vta.VTA_BLOCK_IN,
wl.height, wl.width).transpose((0, 2, 4, 5, 1, 3))
kernel_packed = kernel_orig.reshape(
wl.out_filter//vta.VTA_BLOCK_OUT, vta.VTA_BLOCK_OUT,
wl.in_filter//vta.VTA_BLOCK_IN, vta.VTA_BLOCK_IN,
wl.hkernel, wl.wkernel).transpose((0, 2, 4, 5, 1, 3))
res_np = np.zeros(res_shape).astype(res.dtype)
data_arr = tvm.nd.array(data_packed, ctx)
kernel_arr = tvm.nd.array(kernel_packed, ctx)
res_arr = tvm.nd.array(res_np, ctx)
time_f = f.time_evaluator("conv2d", ctx, number=10)
cost = time_f(data_arr, kernel_arr, res_arr)
res_unpack = res_arr.asnumpy().transpose(
(0, 4, 1, 5, 2, 3)).reshape(batch_size, wl.out_filter, fout_height, fout_width)
if check_correctness:
res_ref = mx.nd.Convolution(
mx.nd.array(data_orig.astype(out_dtype), mx.cpu(0)),
mx.nd.array(kernel_orig.astype(out_dtype), mx.cpu(0)),
stride=(wl.hstride, wl.wstride),
kernel=(wl.hkernel, wl.wkernel),
num_filter=wl.out_filter,
no_bias=True,
pad=(wl.hpad, wl.wpad)).asnumpy().astype(out_dtype)
res_ref = np.right_shift(res_ref, 8).astype(res.dtype)
np.testing.assert_allclose(res_unpack, res_ref)
print("Correctness check pass...")
return cost
def run_schedule(load_inp, load_wgt, gemm, alu, store_out,
print_ir, check_correctness):
# schedule1
s = tvm.create_schedule(res.op)
s[data_buf].set_scope(vta.SCOPE_INP)
s[kernel_buf].set_scope(vta.SCOPE_WGT)
s[res_cnv].set_scope(vta.SCOPE_OUT)
s[res_shf].set_scope(vta.SCOPE_OUT)
# tile
oc_factor = (plan.oc_factor if plan.oc_factor
else wl.out_filter // vta.VTA_BLOCK_OUT)
h_factor = (plan.h_factor if plan.h_factor else fout_height)
w_factor = (plan.w_factor if plan.w_factor else fout_width)
xbo, xco, xi, xj, xbi, xci = s[res].op.axis
xco0, xco1 = s[res].split(xco, factor=oc_factor)
xi0, xi1 = s[res].split(xi, factor=h_factor)
xj0, xj1 = s[res].split(xj, factor=w_factor)
s[res].reorder(xbo, xi0, xco0, xj0, xco1, xi1, xj1, xbi, xci)
s[res_cnv].compute_at(s[res], xj0)
s[res_shf].compute_at(s[res], xj0)
if plan.oc_nthread:
_, tx = s[res].split(xco0, factor=plan.oc_nthread)
s[res].reorder(tx, xbo)
s[res].bind(tx, tvm.thread_axis("cthread"))
if plan.h_nthread:
xo, tx = s[res].split(xi0, factor=plan.h_nthread)
s[res].reorder(tx, xbo)
s[res].bind(tx, tvm.thread_axis("cthread"))
xbo, xco, xi, xj, xbi, xci = s[res_cnv].op.axis
s[res_cnv].reorder(xbo, ko, xj, dj, di, xco, xi, xbi, xci, ki)
if plan.ko_factor:
ko0, ko1 = s[res_cnv].split(ko, factor=plan.ko_factor)
s[data_buf].compute_at(s[res_cnv], ko0)
s[kernel_buf].compute_at(s[res_cnv], ko0)
# Use VTA instructions
s[data_buf].pragma(s[data_buf].op.axis[0], load_inp)
s[kernel_buf].pragma(s[kernel_buf].op.axis[0], load_wgt)
s[res_cnv].tensorize(xbi, gemm)
s[res_shf].pragma(s[res_shf].op.axis[0], alu)
s[res].pragma(xco1, store_out)
if plan.debug_sync:
s[res].pragma(xco0, "coproc_sync")
if print_ir:
print(tvm.lower(s, [data, kernel, res], simple_mode=True))
return verify(s, check_correctness)
def conv_normal(print_ir):
print("----- CONV2D End-to-End Test-------")
def run_test(header, print_ir, check_correctness):
cost = run_schedule(
vta.DMA_COPY, vta.DMA_COPY,
vta.GEMM, vta.ALU, vta.DMA_COPY,
print_ir, check_correctness)
gops = (num_ops / cost.mean) / float(10 ** 9)
print(header)
print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops))
log_frame["key"].append(key)
log_frame["total-gops"].append(gops)
log_frame["total-cost"].append(cost.mean)
with tvm.build_config(add_lower_pass=vta.debug_mode(0)):
run_test("NORMAL", print_ir, True)
def skip_alu_unittest(print_ir):
mock = vta.mock
print("----- Skip ALU Unit Test-------")
def run_test(header, print_ir):
cost = run_schedule(
vta.DMA_COPY, vta.DMA_COPY,
vta.GEMM, mock.ALU, vta.DMA_COPY,
print_ir, False)
gops = (num_ops / cost.mean) / float(10 ** 9)
print(header)
print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops))
log_frame["skip-alu-gops"].append(gops)
log_frame["skip-alu-cost"].append(cost.mean)
with tvm.build_config(add_lower_pass=vta.debug_mode(0)):
run_test("NORMAL", print_ir)
print("")
def gemm_unittest(print_ir):
mock = vta.mock
print("----- GEMM Unit Test-------")
def run_test(header, print_ir):
cost = run_schedule(
mock.DMA_COPY, mock.DMA_COPY,
vta.GEMM, mock.ALU, mock.DMA_COPY,
print_ir, False)
gops = (num_ops / cost.mean) / float(10 ** 9)
print(header)
print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops))
log_frame["gemm-gops"].append(gops)
log_frame["gemm-cost"].append(cost.mean)
with tvm.build_config(add_lower_pass=vta.debug_mode(0)):
run_test("NORMAL", print_ir)
print("")
def alu_unittest(print_ir):
mock = vta.mock
print("----- ALU Unit Test-------")
def run_test(header, print_ir):
cost = run_schedule(
mock.DMA_COPY, mock.DMA_COPY,
mock.GEMM, vta.ALU, mock.DMA_COPY,
print_ir, False)
gops = (num_ops / cost.mean) / float(10 ** 9)
print(header)
print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops))
log_frame["alu-gops"].append(gops)
log_frame["alu-cost"].append(cost.mean)
with tvm.build_config(add_lower_pass=vta.debug_mode(0)):
run_test("NORMAL", print_ir)
print("")
def load_inp_unittest(print_ir):
mock = vta.mock
print("----- LoadInp Unit Test-------")
def run_test(header, print_ir):
cost = run_schedule(
vta.DMA_COPY, mock.DMA_COPY,
mock.GEMM, mock.ALU, mock.DMA_COPY,
print_ir, False)
gops = (num_ops / cost.mean) / float(10 ** 9)
bandwith = (batch_size * wl.in_filter * wl.height *
wl.width * vta.INP_WIDTH / cost.mean) / float(10 ** 9)
print(header)
print("\tTime cost = %g sec/op, %g GFLOPS, bandwith=%g gbits" % (
cost.mean, gops, bandwith))
log_frame["ld-inp-gbits"].append(bandwith)
log_frame["ld-inp-cost"].append(cost.mean)
with tvm.build_config(add_lower_pass=vta.debug_mode(0)):
run_test("NORMAL", print_ir)
print("")
def load_wgt_unittest(print_ir):
mock = vta.mock
print("----- LoadWgt Unit Test-------")
def run_test(header, print_ir):
cost = run_schedule(
mock.DMA_COPY, vta.DMA_COPY,
mock.GEMM, mock.ALU, mock.DMA_COPY, print_ir,
False)
gops = (num_ops / cost.mean) / float(10 ** 9)
bandwith = (wl.out_filter * wl.in_filter * wl.hkernel *
wl.wkernel * vta.WGT_WIDTH / cost.mean) / float(10 ** 9)
print(header)
print("\tTime cost = %g sec/op, %g GFLOPS, bandwith=%g gbits" % (
cost.mean, gops, bandwith))
log_frame["ld-wgt-gbits"].append(bandwith)
log_frame["ld-wgt-cost"].append(cost.mean)
with tvm.build_config(add_lower_pass=vta.debug_mode(0)):
run_test("NORMAL", print_ir)
print("")
def store_out_unittest(print_ir):
mock = vta.mock
print("----- StoreOut Unit Test-------")
def run_test(header, print_ir):
cost = run_schedule(
mock.DMA_COPY, mock.DMA_COPY,
mock.GEMM, mock.ALU, vta.DMA_COPY, print_ir,
False)
gops = (num_ops / cost.mean) / float(10 ** 9)
bandwith = (batch_size * wl.out_filter * fout_height *
fout_width * vta.OUT_WIDTH / cost.mean) / float(10 ** 9)
print(header)
print("\tTime cost = %g sec/op, %g GFLOPS, bandwith=%g gbits" % (
cost.mean, gops, bandwith))
log_frame["st-out-gbits"].append(bandwith)
log_frame["st-out-cost"].append(cost.mean)
with tvm.build_config(add_lower_pass=vta.debug_mode(0)):
run_test("NORMAL", print_ir)
print("")
def manual_unittest(print_ir):
# Manual section used to teak the components
mock = vta.mock
print("----- Manual Unit Test-------")
def run_test(header, print_ir):
cost = run_schedule(
vta.DMA_COPY, vta.DMA_COPY,
vta.GEMM, vta.ALU, mock.DMA_COPY, print_ir,
False)
gops = (num_ops / cost.mean) / float(10 ** 9)
print(header)
print("\tTime cost = %g sec/op, %g GFLOPS" % (
cost.mean, gops))
with tvm.build_config(add_lower_pass=vta.debug_mode(0)):
run_test("NORMAL", print_ir)
print("")
print("=================================")
print("key=%s" % key)
print(wl)
conv_normal(False)
if not profile:
return
skip_alu_unittest(False)
gemm_unittest(False)
alu_unittest(False)
load_inp_unittest(False)
load_wgt_unittest(False)
store_out_unittest(False)
# ResNet18 workloads
resnet = {
# Workloads of resnet18 on imagenet
0: Workload(224, 224, 16, 64, 7, 7, 3, 3, 2, 2),
1: Workload(56, 56, 64, 64, 3, 3, 1, 1, 1, 1),
2: Workload(56, 56, 64, 64, 1, 1, 0, 0, 1, 1),
3: Workload(56, 56, 64, 128, 3, 3, 1, 1, 2, 2),
4: Workload(56, 56, 64, 128, 1, 1, 0, 0, 2, 2),
5: Workload(28, 28, 128, 128, 3, 3, 1, 1, 1, 1),
6: Workload(28, 28, 128, 256, 3, 3, 1, 1, 2, 2),
7: Workload(28, 28, 128, 256, 1, 1, 0, 0, 2, 2),
8: Workload(14, 14, 256, 256, 3, 3, 1, 1, 1, 1),
9: Workload(14, 14, 256, 512, 3, 3, 1, 1, 2, 2),
10: Workload(14, 14, 256, 512, 1, 1, 0, 0, 2, 2),
11: Workload(7, 7, 512, 512, 3, 3, 1, 1, 1, 1),
}
# List of simple benchmarks
simple = [
Workload(height=22, width=22, in_filter=256, out_filter=64,
hkernel=3, wkernel=3, hpad=1, wpad=1, hstride=1, wstride=1)
]
# Serial schedule
resnet_serial = [
[None, None],
[resnet[1], Schedule(oc_factor=2, ko_factor=1, h_factor=14, w_factor=0)],
[resnet[2], Schedule(oc_factor=4, ko_factor=4, h_factor=8, w_factor=0)],
[resnet[3], Schedule(oc_factor=4, ko_factor=1, h_factor=14, w_factor=0)],
[resnet[4], Schedule(oc_factor=8, ko_factor=1, h_factor=4, w_factor=0)],
[resnet[5], Schedule(oc_factor=8, ko_factor=1, h_factor=7, w_factor=0)],
[resnet[6], Schedule(oc_factor=8, ko_factor=1, h_factor=14, w_factor=0)],
[resnet[7], Schedule(oc_factor=16, ko_factor=1, h_factor=7, w_factor=0)],
[resnet[8], Schedule(oc_factor=8, ko_factor=1, h_factor=7, w_factor=0)],
[resnet[9], Schedule(oc_factor=8, ko_factor=1, h_factor=7, w_factor=0)],
[resnet[10], Schedule(oc_factor=16, ko_factor=1, h_factor=7, w_factor=0)],
[resnet[11], Schedule(oc_factor=8, ko_factor=1, h_factor=7, w_factor=0)],
]
# SMT schedule
resnet_smt = [
[resnet[0], Schedule(oc_factor=1, ko_factor=1, h_factor=4, w_factor=56)],
[resnet[1], Schedule(oc_factor=2, ko_factor=1, h_factor=7, h_nthread=2)],
[resnet[2], Schedule(oc_factor=4, ko_factor=2, h_factor=4, w_factor=0, h_nthread=2)],
[resnet[3], Schedule(oc_factor=4, ko_factor=1, h_factor=7, w_factor=0, h_nthread=2)],
[resnet[4], Schedule(oc_factor=4, ko_factor=1, h_factor=7, h_nthread=2)],
[resnet[5], Schedule(oc_factor=4, ko_factor=1, h_factor=7, w_factor=0, h_nthread=2)],
[resnet[6], Schedule(oc_factor=4, ko_factor=1, h_factor=7, w_factor=0, oc_nthread=2)],
[resnet[7], Schedule(oc_factor=8, ko_factor=1, h_factor=7, w_factor=0, oc_nthread=2)],
[resnet[8], Schedule(oc_factor=4, ko_factor=1, h_factor=7, w_factor=0, oc_nthread=2)],
[resnet[9], Schedule(oc_factor=4, ko_factor=1, h_factor=7, w_factor=0, oc_nthread=2)],
[resnet[10], Schedule(oc_factor=8, ko_factor=1, h_factor=7, w_factor=0, oc_nthread=2)],
[resnet[11], Schedule(oc_factor=4, ko_factor=1, h_factor=7, w_factor=0, oc_nthread=2)],
]
# Perform profiling
profile = False
# Whether use SMT
use_smt = True
# Data set batch size
batch_size = 1
resnet_schedule = resnet_smt if use_smt else resnet_serial
begin = 0
end = len(resnet_schedule)
keys = ["key", "total-gops", "total-cost",
"skip-alu-gops", "skip-alu-cost",
"gemm-gops", "gemm-cost", "alu-gops", "alu-cost",
"ld-inp-cost", "ld-wgt-cost", "st-out-cost",
"ld-inp-gbits", "ld-wgt-gbits", "st-out-gbits",]
log_frame = {
k : [] for k in keys
}
for i, x in enumerate(resnet_schedule[begin:end]):
wl, plan = x
if not wl:
continue
key = "resnet-cfg[%d]" % i
test_conv2d_chwv(key, batch_size, wl, plan, log_frame, profile)
if profile:
pd.set_option('expand_frame_repr', False)
log_df = pd.DataFrame()
for k in keys:
log_df[k] = log_frame[k]
print(log_df)
import os
import tvm
import vta
import numpy as np
import time
from tvm.contrib import rpc, util
host = "pynq"
port = 9091
target = "llvm -target=armv7-none-linux-gnueabihf"
out_dtype = "int%d" % vta.VTA_OUT_WIDTH
inp_dtype = "int%d" % vta.VTA_INP_WIDTH
wgt_dtype = "int%d" % vta.VTA_WGT_WIDTH
def test_gemm_packed(batch_size, channel, block):
data_shape = (batch_size//vta.VTA_BATCH,
channel//vta.VTA_BLOCK_IN,
vta.VTA_BATCH,
vta.VTA_BLOCK_IN)
weight_shape = (channel//vta.VTA_BLOCK_OUT,
channel//vta.VTA_BLOCK_IN,
vta.VTA_BLOCK_OUT,
vta.VTA_BLOCK_IN)
res_shape = (batch_size//vta.VTA_BATCH,
channel//vta.VTA_BLOCK_OUT,
vta.VTA_BATCH,
vta.VTA_BLOCK_OUT)
num_ops = channel * channel * batch_size
ko = tvm.reduce_axis((0, channel//vta.VTA_BLOCK_IN), name='ko')
ki = tvm.reduce_axis((0, vta.VTA_BLOCK_IN), name='ki')
data = tvm.placeholder(data_shape,
name="data",
dtype=inp_dtype)
weight = tvm.placeholder(weight_shape,
name="weight",
dtype=wgt_dtype)
data_buf = tvm.compute(data_shape,
lambda *i: data(*i),
"data_buf")
weight_buf = tvm.compute(weight_shape,
lambda *i: weight(*i),
"weight_buf")
res_gem = tvm.compute(res_shape,
lambda bo, co, bi, ci: tvm.sum(
data_buf[bo, ko, bi, ki].astype(out_dtype) *
weight_buf[co, ko, ci, ki].astype(out_dtype),
axis=[ko, ki]),
name="res_gem")
res_shf = tvm.compute(res_shape,
lambda *i: res_gem(*i)>>8,
name="res_shf")
res_max = tvm.compute(res_shape,
lambda *i: tvm.max(res_shf(*i), 0),
"res_max") #relu
res_min = tvm.compute(res_shape,
lambda *i: tvm.min(res_max(*i), (1<<(vta.VTA_INP_WIDTH-1))-1),
"res_min") #relu
res = tvm.compute(res_shape,
lambda *i: res_min(*i).astype(inp_dtype),
name="res")
def verify(s, check_correctness=True):
mod = tvm.build(s, [data, weight, res], "ext_dev", target, name="gemm")
temp = util.tempdir()
remote = rpc.connect(host, port)
mod.save(temp.relpath("gemm.o"))
remote.upload(temp.relpath("gemm.o"))
f = remote.load_module("gemm.o")
# verify
ctx = remote.ext_dev(0)
# Data in original format
data_orig = np.random.randint(
-128, 128, size=(batch_size, channel)).astype(data.dtype)
weight_orig = np.random.randint(
-128, 128, size=(channel, channel)).astype(weight.dtype)
data_packed = data_orig.reshape(
batch_size//vta.VTA_BATCH, vta.VTA_BATCH,
channel//vta.VTA_BLOCK_IN, vta.VTA_BLOCK_IN).transpose((0, 2, 1, 3))
weight_packed = weight_orig.reshape(
channel//vta.VTA_BLOCK_OUT, vta.VTA_BLOCK_OUT,
channel//vta.VTA_BLOCK_IN, vta.VTA_BLOCK_IN).transpose((0, 2, 1, 3))
res_np = np.zeros(res_shape).astype(res.dtype)
data_arr = tvm.nd.array(data_packed, ctx)
weight_arr = tvm.nd.array(weight_packed, ctx)
res_arr = tvm.nd.array(res_np, ctx)
res_ref = np.zeros(res_shape).astype(out_dtype)
for b in range(batch_size//vta.VTA_BATCH):
for i in range(channel//vta.VTA_BLOCK_OUT):
for j in range(channel//vta.VTA_BLOCK_IN):
res_ref[b,i,:] += np.dot(data_packed[b,j,:].astype(out_dtype),
weight_packed[i,j].T.astype(out_dtype))
res_ref = np.right_shift(res_ref, 8)
res_ref = np.clip(res_ref, 0, (1<<(vta.VTA_INP_WIDTH-1))-1).astype(res.dtype)
time_f = f.time_evaluator("gemm", ctx, number=20)
cost = time_f(data_arr, weight_arr, res_arr)
res_unpack = res_arr.asnumpy().reshape(batch_size//vta.VTA_BATCH,
channel//vta.VTA_BLOCK_OUT,
vta.VTA_BATCH,
vta.VTA_BLOCK_OUT)
if check_correctness:
np.testing.assert_allclose(res_unpack, res_ref)
return cost
def run_schedule(load_inp,
load_wgt,
gemm,
alu,
store_out,
print_ir,
check_correctness):
s = tvm.create_schedule(res.op)
s[data_buf].set_scope(vta.SCOPE_INP)
s[weight_buf].set_scope(vta.SCOPE_WGT)
s[res_gem].set_scope(vta.SCOPE_OUT)
s[res_shf].set_scope(vta.SCOPE_OUT)
s[res_min].set_scope(vta.SCOPE_OUT)
s[res_max].set_scope(vta.SCOPE_OUT)
if block:
bblock = block // vta.VTA_BATCH
iblock = block // vta.VTA_BLOCK_IN
oblock = block // vta.VTA_BLOCK_OUT
xbo, xco, xbi, xci = s[res].op.axis
xb1, xco1, xb2, xco2 = s[res].tile(xbo, xco, bblock, oblock)
store_pt = xb2
s[res_gem].compute_at(s[res], xco1)
s[res_shf].compute_at(s[res], xco1)
s[res_min].compute_at(s[res], xco1)
s[res_max].compute_at(s[res], xco1)
xbo, xco, xbi, xci = s[res_gem].op.axis
# Compute one line at a time
ko1, ko2 = s[res_gem].split(ko, iblock)
s[res_gem].reorder(ko1, ko2, xbo, xco, xbi, xci, ki)
s[data_buf].compute_at(s[res_gem], ko1)
s[weight_buf].compute_at(s[res_gem], ko1)
# Use VTA instructions
s[data_buf].pragma(s[data_buf].op.axis[0], load_inp)
s[weight_buf].pragma(s[weight_buf].op.axis[0], load_wgt)
s[res_gem].tensorize(xbi, gemm)
s[res_shf].pragma(s[res_shf].op.axis[0], alu)
s[res_min].pragma(s[res_min].op.axis[0], alu)
s[res_max].pragma(s[res_max].op.axis[0], alu)
s[res].pragma(store_pt, store_out)
else:
xbo, xco, xbi, xci = s[res_gem].op.axis
s[res_gem].reorder(ko, xbo, xco, xbi, xci, ki)
# Use VTA instructions
s[data_buf].pragma(s[data_buf].op.axis[0], load_inp)
s[weight_buf].pragma(s[weight_buf].op.axis[0], load_wgt)
s[res_gem].tensorize(xbi, gemm)
s[res_shf].pragma(s[res_shf].op.axis[0], alu)
s[res_min].pragma(s[res_min].op.axis[0], alu)
s[res_max].pragma(s[res_max].op.axis[0], alu)
s[res].pragma(s[res].op.axis[0], store_out)
if print_ir:
print(tvm.lower(s, [data, weight, res], simple_mode=True))
return verify(s, check_correctness)
def gemm_normal(print_ir):
mock = vta.mock
print("----- GEMM GFLOPS End-to-End Test-------")
def run_test(header, print_ir, check_correctness):
cost = run_schedule(
vta.DMA_COPY, vta.DMA_COPY, vta.GEMM, vta.ALU, vta.DMA_COPY,
print_ir, check_correctness)
gops = (num_ops / cost.mean) / float(10 ** 9)
print(header)
print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops))
with tvm.build_config(add_lower_pass=vta.debug_mode(vta.DEBUG_DUMP_INSN)):
run_test("NORMAL", print_ir, True)
print("")
def gevm_unittest(print_ir):
mock = vta.mock
print("----- GEMM Unit Test-------")
def run_test(header, print_ir):
cost = run_schedule(
mock.DMA_COPY, mock.DMA_COPY, vta.GEMM, mock.ALU, mock.DMA_COPY,
print_ir, False)
gops = (num_ops / cost.mean) / float(10 ** 9)
print(header)
print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops))
with tvm.build_config(add_lower_pass=vta.debug_mode(0)):
run_test("NORMAL", print_ir)
print("")
def alu_unittest(print_ir):
mock = vta.mock
print("----- ALU Unit Test-------")
def run_test(header, print_ir):
cost = run_schedule(
mock.DMA_COPY, mock.DMA_COPY, mock.GEMM, vta.ALU, mock.DMA_COPY,
print_ir, False)
gops = (num_ops / cost.mean) / float(10 ** 9)
print(header)
print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops))
with tvm.build_config(add_lower_pass=vta.debug_mode(0)):
run_test("NORMAL", print_ir)
print("")
def load_inp_unittest(print_ir):
mock = vta.mock
print("----- LoadInp Unit Test-------")
def run_test(header, print_ir):
cost = run_schedule(
vta.DMA_COPY, mock.DMA_COPY, mock.GEMM, mock.ALU, mock.DMA_COPY, print_ir, False)
gops = (num_ops / cost.mean) / float(10 ** 9)
bandwith = (batch_size * channel * vta.VTA_INP_WIDTH / cost.mean) / float(10 ** 9)
print(header)
print("\tTime cost = %g sec/op, %g GFLOPS, bandwidth=%g Gbits" % (
cost.mean, gops, bandwith))
with tvm.build_config(add_lower_pass=vta.debug_mode(0)):
run_test("NORMAL", print_ir)
print("")
def load_wgt_unittest(print_ir):
mock = vta.mock
print("----- LoadWgt Unit Test-------")
def run_test(header, print_ir):
cost = run_schedule(
mock.DMA_COPY, vta.DMA_COPY, mock.GEMM, mock.ALU, mock.DMA_COPY, print_ir, False)
gops = (num_ops / cost.mean) / float(10 ** 9)
bandwith = (channel * channel * vta.VTA_WGT_WIDTH / cost.mean) / float(10 ** 9)
print(header)
print("\tTime cost = %g sec/op, %g GFLOPS, bandwidth=%g Gbits" % (
cost.mean, gops, bandwith))
with tvm.build_config(add_lower_pass=vta.debug_mode(0)):
run_test("NORMAL", print_ir)
print("")
def store_out_unittest(print_ir):
mock = vta.mock
print("----- StoreOut Unit Test-------")
def run_test(header, print_ir):
cost = run_schedule(
mock.DMA_COPY, mock.DMA_COPY, mock.GEMM, mock.ALU, vta.DMA_COPY,
print_ir, False)
gops = (num_ops / cost.mean) / float(10 ** 9)
bandwith = (batch_size * channel * vta.VTA_OUT_WIDTH / cost.mean) / float(10 ** 9)
print(header)
print("\tTime cost = %g sec/op, %g GFLOPS, bandwidth=%g Gbits" % (
cost.mean, gops, bandwith))
with tvm.build_config(add_lower_pass=vta.debug_mode(0)):
run_test("NORMAL", print_ir)
print("")
gemm_normal(False)
gevm_unittest(False)
alu_unittest(False)
# FIXME: report time that is too short
# load_inp_unittest(False)
# load_wgt_unittest(False)
# store_out_unittest(False)
print("========GEMM 128=========")
test_gemm_packed(128, 128, 128)
# FIXME: hanging run
# print("========GEMM 1024========")
# test_gemm_packed(1024, 1024, 128)
"""Testing if we can generate code in topi style"""
import topi
import tvm
from tvm.contrib import util, rpc
import vta
from vta import vta_conv2d
import numpy as np
import mxnet as mx
Workload = vta_conv2d.Workload
@tvm.tag_scope(tag=topi.tag.ELEMWISE)
def my_clip(x, a_min, a_max):
"""Unlike topi's current clip, put min and max into two stages."""
const_min = tvm.const(a_min, x.dtype)
const_max = tvm.const(a_max, x.dtype)
x = tvm.compute(x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA")
x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB")
return x
host = "pynq"
port = 9091
out_dtype = "int%d" % vta.VTA_OUT_WIDTH
wgt_dtype = "int%d" % vta.VTA_WGT_WIDTH
inp_dtype = "int%d" % vta.VTA_INP_WIDTH
target = "llvm -target=armv7-none-linux-gnueabihf -mattr=+neon"
print_ir = False
def test_vta_conv2d(key, batch_size, wl, profile=True):
data_shape = (batch_size, wl.in_filter//vta.VTA_BLOCK_IN,
wl.height, wl.width, vta.VTA_BLOCK_IN)
kernel_shape = (wl.out_filter//vta.VTA_BLOCK_OUT, wl.in_filter//vta.VTA_BLOCK_IN,
wl.hkernel, wl.wkernel, vta.VTA_BLOCK_OUT, vta.VTA_BLOCK_IN)
bias_shape = (wl.out_filter//vta.VTA_BLOCK_OUT, 1, 1, vta.VTA_BLOCK_OUT)
fout_height = (wl.height + 2 * wl.hpad - wl.hkernel) // wl.hstride + 1
fout_width = (wl.width + 2 * wl.wpad - wl.wkernel) // wl.wstride + 1
data = tvm.placeholder(data_shape, name="data", dtype=inp_dtype)
kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=wgt_dtype)
bias = tvm.placeholder(bias_shape, name="kernel", dtype=out_dtype)
res_conv = vta_conv2d.packed_conv2d(
data, kernel, padding=(wl.hpad, wl.wpad), strides=(wl.hstride, wl.wstride))
res = topi.right_shift(res_conv, 8)
res = topi.broadcast_add(res, bias)
res = my_clip(res, 0, 127)
res = topi.cast(res, "int8")
num_ops = fout_height * fout_width * wl.hkernel * wl.wkernel * wl.out_filter * wl.in_filter
def verify(s, check_correctness):
mod = tvm.build(s, [data, kernel, bias, res], "ext_dev", target, name="conv2d")
temp = util.tempdir()
remote = rpc.connect(host, port)
mod.save(temp.relpath("conv2d.o"))
remote.upload(temp.relpath("conv2d.o"))
f = remote.load_module("conv2d.o")
# verify
ctx = remote.ext_dev(0)
# Data in original format
data_orig = (np.random.uniform(
size=(batch_size, wl.in_filter, wl.height, wl.width)) * 4).astype(data.dtype)
kernel_orig = (np.random.uniform(
size=(wl.out_filter, wl.in_filter, wl.hkernel, wl.wkernel)) * 4).astype(kernel.dtype)
bias_orig = (np.random.uniform(size=(wl.out_filter,)) * 4).astype("int32")
data_orig = np.abs(data_orig)
kernel_orig = np.abs(kernel_orig)
bias_orig = np.abs(bias_orig)
data_packed = data_orig.reshape(
batch_size, wl.in_filter//vta.VTA_BLOCK_IN, vta.VTA_BLOCK_IN,
wl.height, wl.width).transpose((0, 1, 3, 4, 2))
kernel_packed = kernel_orig.reshape(
wl.out_filter//vta.VTA_BLOCK_OUT, vta.VTA_BLOCK_OUT,
wl.in_filter//vta.VTA_BLOCK_IN, vta.VTA_BLOCK_IN,
wl.hkernel, wl.wkernel).transpose((0, 2, 4, 5, 1, 3))
bias_packed = bias_orig.reshape(
wl.out_filter//vta.VTA_BLOCK_OUT, 1, 1, vta.VTA_BLOCK_OUT)
res_shape = topi.util.get_const_tuple(res.shape)
res_np = np.zeros(res_shape).astype(res.dtype)
data_arr = tvm.nd.array(data_packed, ctx)
kernel_arr = tvm.nd.array(kernel_packed, ctx)
bias_arr = tvm.nd.array(bias_packed, ctx)
res_arr = tvm.nd.array(res_np, ctx)
time_f = f.time_evaluator("conv2d", ctx, number=10)
cost = time_f(data_arr, kernel_arr, bias_arr, res_arr)
res_unpack = res_arr.asnumpy().transpose(
(0, 1, 4, 2, 3)).reshape(batch_size, wl.out_filter, fout_height, fout_width)
if check_correctness:
res_ref = mx.nd.Convolution(
mx.nd.array(data_orig.astype(out_dtype), mx.cpu(0)),
mx.nd.array(kernel_orig.astype(out_dtype), mx.cpu(0)),
stride=(wl.hstride, wl.wstride),
kernel=(wl.hkernel, wl.wkernel),
num_filter=wl.out_filter,
no_bias=True,
pad=(wl.hpad, wl.wpad)).asnumpy().astype(out_dtype)
res_ref = res_ref >> 8
res_ref += bias_orig.reshape(wl.out_filter, 1, 1)
res_ref = np.clip(res_ref, 0, 127).astype("int8")
np.testing.assert_allclose(res_unpack, res_ref)
print("Correctness check pass...")
return cost
def conv_normal(print_ir):
print("----- CONV2D End-to-End Test-------")
with tvm.build_config(add_lower_pass=vta.debug_mode(0)):
s = vta_conv2d.schedule_packed_conv2d([res])
if print_ir:
print(tvm.lower(s, [data, kernel, bias, res], simple_mode=True))
cost = verify(s, True)
gops = (num_ops / cost.mean) / float(10 ** 9)
print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops))
conv_normal(print_ir)
# ResNet18 workloads
resnet = {
# Workloads of resnet18 on imagenet
0: Workload(224, 224, 16, 64, 7, 7, 3, 3, 2, 2),
1: Workload(56, 56, 64, 64, 3, 3, 1, 1, 1, 1),
2: Workload(56, 56, 64, 64, 1, 1, 0, 0, 1, 1),
3: Workload(56, 56, 64, 128, 3, 3, 1, 1, 2, 2),
4: Workload(56, 56, 64, 128, 1, 1, 0, 0, 2, 2),
5: Workload(28, 28, 128, 128, 3, 3, 1, 1, 1, 1),
6: Workload(28, 28, 128, 256, 3, 3, 1, 1, 2, 2),
7: Workload(28, 28, 128, 256, 1, 1, 0, 0, 2, 2),
8: Workload(14, 14, 256, 256, 3, 3, 1, 1, 1, 1),
9: Workload(14, 14, 256, 512, 3, 3, 1, 1, 2, 2),
10: Workload(14, 14, 256, 512, 1, 1, 0, 0, 2, 2),
11: Workload(7, 7, 512, 512, 3, 3, 1, 1, 1, 1),
}
batch_size = 1
for i in range(0, len(resnet)):
wl = resnet[i]
key = "resnet-cfg[%d]" % i
print "key=%s" % key
print wl
test_vta_conv2d(key, batch_size, wl)
import tvm
import vta
import os
from tvm.contrib import rpc, util
host = "pynq"
port = 9091
target = "llvm -target=armv7-none-linux-gnueabihf"
bit = "vta.bit"
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
bitstream = os.path.join(curr_path, "./", bit)
def test_program_rpc():
assert tvm.module.enabled("rpc")
remote = rpc.connect(host, port)
remote.upload(bitstream, bit)
fprogram = remote.get_function("tvm.contrib.vta.init")
fprogram(bit)
test_program_rpc()
"""Unit test TPU's instructions """
import tvm
import vta
import mxnet as mx
import numpy as np
import topi
from tvm.contrib import rpc, util
host = "pynq"
port = 9091
target = "llvm -target=armv7-none-linux-gnueabihf"
out_dtype = "int%d" % vta.VTA_OUT_WIDTH
inp_dtype = "int%d" % vta.VTA_INP_WIDTH
wgt_dtype = "int%d" % vta.VTA_WGT_WIDTH
do_verify = True
print_ir = False
def test_save_load_out():
"""Test save/store output command"""
n = 4
x = tvm.placeholder(
(n, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT),
name="x",
dtype=out_dtype)
x_buf = tvm.compute(
(n, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT),
lambda *i: x(*i),
"x_buf")
# insert no-op that won't be optimized away
y_buf = tvm.compute(
(n, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT),
lambda *i: x_buf(*i)>>0,
"y_buf")
y = tvm.compute(
(n, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT),
lambda *i: y_buf(*i).astype(inp_dtype),
"y")
# schedule
s = tvm.create_schedule(y.op)
s[x_buf].set_scope(vta.SCOPE_OUT)
s[x_buf].pragma(x_buf.op.axis[0], vta.DMA_COPY)
s[y_buf].set_scope(vta.SCOPE_OUT)
s[y_buf].pragma(y_buf.op.axis[0], vta.ALU)
s[y].pragma(y.op.axis[0], vta.DMA_COPY)
def verify():
# build
m = tvm.build(s, [x, y], "ext_dev", target)
temp = util.tempdir()
remote = rpc.connect(host, port)
m.save(temp.relpath("load_act.o"))
remote.upload(temp.relpath("load_act.o"))
f = remote.load_module("load_act.o")
# verify
ctx = remote.ext_dev(0)
x_np = np.random.randint(
1, 10, size=(n, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT)).astype(x.dtype)
y_np = x_np.astype(y.dtype)
x_nd = tvm.nd.array(x_np, ctx)
y_nd = tvm.nd.empty(y_np.shape, ctx=ctx, dtype=y_np.dtype)
f(x_nd, y_nd)
np.testing.assert_equal(y_np, y_nd.asnumpy())
print("\tFinished verification...")
if do_verify:
verify()
def test_padded_load():
"""Test padded load."""
# declare
n = 21
m = 20
pad_before = [0, 1, 0, 0]
pad_after = [1, 3, 0, 0]
x = tvm.placeholder(
(n, m, vta.VTA_BATCH, vta.VTA_BLOCK_OUT),
name="x",
dtype=out_dtype)
x_buf = topi.nn.pad(x, pad_before, pad_after, name="y")
# insert no-op that won't be optimized away
y_buf = tvm.compute((n + pad_before[0] + pad_after[0],
m + pad_before[1] + pad_after[1],
vta.VTA_BATCH,
vta.VTA_BLOCK_OUT), lambda *i: x_buf(*i)>>0, "y_buf")
y = tvm.compute((n + pad_before[0] + pad_after[0],
m + pad_before[1] + pad_after[1],
vta.VTA_BATCH,
vta.VTA_BLOCK_OUT), lambda *i: y_buf(*i).astype(inp_dtype), "y")
# schedule
s = tvm.create_schedule(y.op)
s[x_buf].set_scope(vta.SCOPE_OUT)
s[x_buf].pragma(x_buf.op.axis[0], vta.DMA_COPY)
s[y_buf].set_scope(vta.SCOPE_OUT)
s[y_buf].pragma(y_buf.op.axis[0], vta.ALU)
s[y].pragma(y.op.axis[0], vta.DMA_COPY)
def verify():
# build
mod = tvm.build(s, [x, y], "ext_dev", target)
temp = util.tempdir()
remote = rpc.connect(host, port)
mod.save(temp.relpath("padded_load.o"))
remote.upload(temp.relpath("padded_load.o"))
f = remote.load_module("padded_load.o")
# verify
ctx = remote.ext_dev(0)
x_np = np.random.randint(1, 2, size=(
n, m, vta.VTA_BATCH, vta.VTA_BLOCK_OUT)).astype(x.dtype)
y_np = np.zeros((n + pad_before[0] + pad_after[0],
m + pad_before[1] + pad_after[1],
vta.VTA_BATCH,
vta.VTA_BLOCK_OUT)).astype(y.dtype)
y_np[pad_before[0]:pad_before[0] + n,
pad_before[1]:pad_before[1] + m,
:] = x_np
x_nd = tvm.nd.array(x_np, ctx)
y_nd = tvm.nd.empty(y_np.shape, ctx=ctx, dtype=y_np.dtype)
f(x_nd, y_nd)
np.testing.assert_equal(y_np, y_nd.asnumpy())
print("\tFinished verification...")
if print_ir:
print(tvm.lower(s, [y, x], simple_mode=True))
if do_verify:
with tvm.build_config(add_lower_pass=vta.debug_mode(
vta.DEBUG_DUMP_INSN)):
verify()
def test_gemm():
"""Test GEMM."""
# declare
o = 4
n = 4
m = 4
x = tvm.placeholder((o, n, vta.VTA_BATCH, vta.VTA_BLOCK_IN), name="x", dtype=inp_dtype)
w = tvm.placeholder((m, n, vta.VTA_BLOCK_OUT, vta.VTA_BLOCK_IN), name="w", dtype=wgt_dtype)
x_buf = tvm.compute((o, n, vta.VTA_BATCH, vta.VTA_BLOCK_IN), lambda *i: x(*i), "x_buf")
w_buf = tvm.compute((m, n, vta.VTA_BLOCK_OUT, vta.VTA_BLOCK_IN), lambda *i: w(*i), "w_buf")
ko = tvm.reduce_axis((0, n), name="ko")
ki = tvm.reduce_axis((0, vta.VTA_BLOCK_IN), name="ki")
y_gem = tvm.compute(
(o, m, vta.VTA_BATCH, vta.VTA_BLOCK_OUT),
lambda bo, co, bi, ci:
tvm.sum(x_buf[bo, ko, bi, ki].astype(out_dtype) *
w_buf[co, ko, ci, ki].astype(out_dtype),
axis=[ko, ki]),
name="y_gem")
y_shf = tvm.compute(
(o, m, vta.VTA_BATCH, vta.VTA_BLOCK_OUT),
lambda *i: y_gem(*i)>>8,
name="y_shf")
y_max = tvm.compute(
(o, m, vta.VTA_BATCH, vta.VTA_BLOCK_OUT),
lambda *i: tvm.max(y_shf(*i), 0),
"y_max") #relu
y_min = tvm.compute(
(o, m, vta.VTA_BATCH, vta.VTA_BLOCK_OUT),
lambda *i: tvm.min(y_max(*i), (1<<(vta.VTA_INP_WIDTH-1))-1),
"y_min") #relu
y = tvm.compute(
(o, m, vta.VTA_BATCH, vta.VTA_BLOCK_OUT),
lambda *i: y_min(*i).astype(inp_dtype),
name="y")
def verify(s):
mod = tvm.build(s, [x, w, y], "ext_dev", target)
temp = util.tempdir()
remote = rpc.connect(host, port)
mod.save(temp.relpath("gemm.o"))
remote.upload(temp.relpath("gemm.o"))
f = remote.load_module("gemm.o")
# verify
ctx = remote.ext_dev(0)
x_np = np.random.randint(
-128, 128, size=(o, n, vta.VTA_BATCH, vta.VTA_BLOCK_IN)).astype(x.dtype)
w_np = np.random.randint(
-128, 128, size=(m, n, vta.VTA_BLOCK_OUT, vta.VTA_BLOCK_IN)).astype(w.dtype)
y_np = np.zeros((o, m, vta.VTA_BATCH, vta.VTA_BLOCK_OUT)).astype(y.dtype)
x_nd = tvm.nd.array(x_np, ctx)
w_nd = tvm.nd.array(w_np, ctx)
y_nd = tvm.nd.array(y_np, ctx)
y_np = y_np.astype(out_dtype)
for b in range(o):
for i in range(m):
for j in range(n):
y_np[b,i,:] += np.dot(x_np[b,j,:].astype(out_dtype),
w_np[i,j].T.astype(out_dtype))
y_np = np.right_shift(y_np, 8)
y_np = np.clip(y_np, 0, (1<<(vta.VTA_INP_WIDTH-1))-1).astype(y.dtype)
f(x_nd, w_nd, y_nd)
np.testing.assert_equal(y_np, y_nd.asnumpy())
print("\tFinished verification...")
def test_schedule1():
# default schedule with no smt
s = tvm.create_schedule(y.op)
# set the scope of the SRAM buffers
s[x_buf].set_scope(vta.SCOPE_INP)
s[w_buf].set_scope(vta.SCOPE_WGT)
s[y_gem].set_scope(vta.SCOPE_OUT)
s[y_shf].set_scope(vta.SCOPE_OUT)
s[y_max].set_scope(vta.SCOPE_OUT)
s[y_min].set_scope(vta.SCOPE_OUT)
# set pragmas for DMA transfer and ALU ops
s[x_buf].pragma(s[x_buf].op.axis[0], vta.DMA_COPY)
s[w_buf].pragma(s[w_buf].op.axis[0], vta.DMA_COPY)
s[y_shf].pragma(s[y_shf].op.axis[0], vta.ALU)
s[y_max].pragma(s[y_max].op.axis[0], vta.ALU)
s[y_min].pragma(s[y_min].op.axis[0], vta.ALU)
s[y].pragma(s[y].op.axis[0], vta.DMA_COPY)
# tensorization
s[y_gem].reorder(
ko,
s[y_gem].op.axis[0],
s[y_gem].op.axis[1],
s[y_gem].op.axis[2],
s[y_gem].op.axis[3],
ki)
s[y_gem].tensorize(s[y_gem].op.axis[2], vta.GEMM)
if print_ir:
print(tvm.lower(s, [x, w, y], simple_mode=True))
if do_verify:
with tvm.build_config(
add_lower_pass=vta.debug_mode(vta.DEBUG_DUMP_INSN)):
verify(s)
def test_smt():
# test smt schedule
s = tvm.create_schedule(y.op)
s[x_buf].set_scope(vta.SCOPE_INP)
s[w_buf].set_scope(vta.SCOPE_WGT)
s[y_gem].set_scope(vta.SCOPE_OUT)
s[y_shf].set_scope(vta.SCOPE_OUT)
s[y_max].set_scope(vta.SCOPE_OUT)
s[y_min].set_scope(vta.SCOPE_OUT)
abo, aco, abi, aci = s[y].op.axis
abo1, abo2 = s[y].split(abo, nparts=2)
s[y].bind(abo1, tvm.thread_axis("cthread"))
s[y_gem].compute_at(s[y], abo1)
s[y_shf].compute_at(s[y], abo1)
s[y_max].compute_at(s[y], abo1)
s[y_min].compute_at(s[y], abo1)
s[y_gem].reorder(
ko,
s[y_gem].op.axis[0],
s[y_gem].op.axis[1],
s[y_gem].op.axis[2],
s[y_gem].op.axis[3],
ki)
s[y_gem].tensorize(s[y_gem].op.axis[2], vta.GEMM)
s[y_shf].pragma(s[y_shf].op.axis[0], vta.ALU)
s[y_max].pragma(s[y_max].op.axis[0], vta.ALU)
s[y_min].pragma(s[y_min].op.axis[0], vta.ALU)
s[x_buf].compute_at(s[y_gem], ko)
s[x_buf].pragma(s[x_buf].op.axis[0], vta.DMA_COPY)
s[w_buf].compute_at(s[y_gem], ko)
s[w_buf].pragma(s[w_buf].op.axis[0], vta.DMA_COPY)
s[y].pragma(abo2, vta.DMA_COPY)
if print_ir:
print(tvm.lower(s, [x, y, w], simple_mode=True))
if do_verify:
with tvm.build_config(
add_lower_pass=vta.debug_mode(vta.DEBUG_DUMP_INSN)):
verify(s)
test_schedule1()
test_smt()
def test_alu(tvm_op, np_op=None, use_imm=False):
"""Test ALU"""
m = 8
n = 8
imm = np.random.randint(1,5)
# compute
a = tvm.placeholder(
(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT),
name="a",
dtype=out_dtype)
a_buf = tvm.compute(
(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT),
lambda *i: a(*i),
"a_buf") #DRAM->SRAM
if use_imm:
res_buf = tvm.compute(
(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT),
lambda *i: tvm_op(a_buf(*i), imm),
"res_buf") #compute
else:
b = tvm.placeholder(
(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT),
name="b",
dtype=out_dtype)
b_buf = tvm.compute(
(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT),
lambda *i: b(*i),
"b_buf") #DRAM->SRAM
res_buf = tvm.compute(
(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT),
lambda *i: tvm_op(a_buf(*i), b_buf(*i)),
"res_buf") #compute
res = tvm.compute(
(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT),
lambda *i: res_buf(*i).astype(inp_dtype),
"res") #SRAM->DRAM
# schedule
s = tvm.create_schedule(res.op)
s[a_buf].set_scope(vta.SCOPE_OUT) # SRAM
s[a_buf].pragma(a_buf.op.axis[0], vta.DMA_COPY) # DRAM->SRAM
s[res_buf].set_scope(vta.SCOPE_OUT) # SRAM
s[res_buf].pragma(res_buf.op.axis[0], vta.ALU) # compute
s[res].pragma(res.op.axis[0], vta.DMA_COPY) # SRAM->DRAM
if use_imm:
if print_ir:
print(tvm.lower(s, [a, res], simple_mode=True))
else:
s[b_buf].set_scope(vta.SCOPE_OUT) # SRAM
s[b_buf].pragma(b_buf.op.axis[0], vta.DMA_COPY) # DRAM->SRAM
if print_ir:
print(tvm.lower(s, [a, b, res], simple_mode=True))
def verify():
# build
if use_imm:
mod = tvm.build(s, [a, res], "ext_dev", target)
else:
mod = tvm.build(s, [a, b, res], "ext_dev", target)
temp = util.tempdir()
remote = rpc.connect(host, port)
mod.save(temp.relpath("load_act.o"))
remote.upload(temp.relpath("load_act.o"))
f = remote.load_module("load_act.o")
# verify
ctx = remote.ext_dev(0)
a_np = np.random.randint(
-16, 16, size=(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT)).astype(a.dtype)
if use_imm:
res_np = np_op(a_np, imm) if np_op else tvm_op(a_np, imm)
else:
b_np = np.random.randint(
-16, 16, size=(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT)).astype(b.dtype)
res_np = np_op(a_np, b_np) if np_op else tvm_op(a_np, b_np)
res_np = res_np.astype(res.dtype)
a_nd = tvm.nd.array(a_np, ctx)
res_nd = tvm.nd.array(
np.zeros((m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT)).astype(res.dtype), ctx)
if use_imm:
f(a_nd, res_nd)
else:
b_nd = tvm.nd.array(b_np, ctx)
f(a_nd, b_nd, res_nd)
np.testing.assert_equal(res_np, res_nd.asnumpy())
print("\tFinished verification...")
if do_verify:
verify()
def test_relu():
"""Test RELU on ALU"""
m = 8
n = 8
# compute
a = tvm.placeholder(
(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT),
name="a",
dtype=out_dtype)
a_buf = tvm.compute(
(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT),
lambda *i: a(*i),
"a_buf") # DRAM->SRAM
max_buf = tvm.compute(
(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT),
lambda *i: tvm.max(a_buf(*i), 0),
"res_buf") # relu
min_buf = tvm.compute(
(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT),
lambda *i: tvm.min(max_buf(*i), (1<<(vta.VTA_INP_WIDTH-1))-1),
"max_buf") # relu
res = tvm.compute(
(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT),
lambda *i: min_buf(*i).astype(inp_dtype),
"min_buf") # SRAM->DRAM
# schedule
s = tvm.create_schedule(res.op)
s[a_buf].set_scope(vta.SCOPE_OUT) # SRAM
s[a_buf].pragma(a_buf.op.axis[0], vta.DMA_COPY) # DRAM->SRAM
s[max_buf].set_scope(vta.SCOPE_OUT) # SRAM
s[min_buf].set_scope(vta.SCOPE_OUT) # SRAM
s[max_buf].pragma(max_buf.op.axis[0], vta.ALU) # compute
s[min_buf].pragma(min_buf.op.axis[0], vta.ALU) # compute
s[res].pragma(res.op.axis[0], vta.DMA_COPY) # SRAM->DRAM
if print_ir:
print(tvm.lower(s, [a, res], simple_mode=True))
def verify():
# build
mod = tvm.build(s, [a, res], "ext_dev", target)
temp = util.tempdir()
remote = rpc.connect(host, port)
mod.save(temp.relpath("load_act.o"))
remote.upload(temp.relpath("load_act.o"))
f = remote.load_module("load_act.o")
# verify
ctx = remote.ext_dev(0)
a_np = np.random.randint(
-256, 256, size=(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT)).astype(a.dtype)
res_np = np.clip(a_np, 0, (1<<(vta.VTA_INP_WIDTH-1))-1).astype(res.dtype)
a_nd = tvm.nd.array(a_np, ctx)
res_nd = tvm.nd.array(
np.zeros((m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT)).astype(res.dtype), ctx)
f(a_nd, res_nd)
np.testing.assert_equal(res_np, res_nd.asnumpy())
print("\tFinished verification...")
if do_verify:
verify()
def test_shift_and_scale():
"""Test shift and scale on ALU"""
m = 8
n = 8
imm_shift = np.random.randint(-10,10)
imm_scale = np.random.randint(1,5)
# compute
a = tvm.placeholder(
(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT),
name="a", dtype=out_dtype)
a_buf = tvm.compute(
(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT),
lambda *i: a(*i),
"a_buf") # DRAM->SRAM
res_shift = tvm.compute(
(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT),
lambda *i: a_buf(*i)+imm_shift,
"res_shift") # compute
res_scale = tvm.compute(
(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT),
lambda *i: res_shift(*i)>>imm_scale,
"res_scale") # compute
res = tvm.compute(
(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT),
lambda *i: res_scale(*i).astype(inp_dtype),
"res") # SRAM->DRAM
# schedule
s = tvm.create_schedule(res.op)
s[a_buf].set_scope(vta.SCOPE_OUT) # SRAM
s[res_shift].set_scope(vta.SCOPE_OUT) # SRAM
s[res_scale].set_scope(vta.SCOPE_OUT) # SRAM
s[a_buf].pragma(a_buf.op.axis[0], vta.DMA_COPY) # DRAM->SRAM
s[res_shift].pragma(res_shift.op.axis[0], vta.ALU) # compute
s[res_scale].pragma(res_scale.op.axis[0], vta.ALU) # compute
s[res].pragma(res.op.axis[0], vta.DMA_COPY) # SRAM->DRAM
if print_ir:
print(tvm.lower(s, [a, res], simple_mode=True))
def verify():
# build
mod = tvm.build(s, [a, res], "ext_dev", target)
temp = util.tempdir()
remote = rpc.connect(host, port)
mod.save(temp.relpath("load_act.o"))
remote.upload(temp.relpath("load_act.o"))
f = remote.load_module("load_act.o")
# verify
ctx = remote.ext_dev(0)
a_np = np.random.randint(
-10, 10, size=(m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT)).astype(a.dtype)
res_np = np.right_shift((a_np + imm_shift), imm_scale)
res_np = res_np.astype(res.dtype)
a_nd = tvm.nd.array(a_np, ctx)
res_nd = tvm.nd.array(
np.zeros((m, n, vta.VTA_BATCH, vta.VTA_BLOCK_OUT)).astype(res.dtype), ctx)
f(a_nd, res_nd)
np.testing.assert_equal(res_np, res_nd.asnumpy())
print("\tFinished verification...")
if do_verify:
verify()
if __name__ == "__main__":
print("Padded load test")
test_padded_load()
print("Load/store test")
test_save_load_out()
print("GEMM test")
test_gemm()
print("Max immediate")
test_alu(tvm.max, np.maximum, use_imm=True)
print("Max")
test_alu(tvm.max, np.maximum)
print("Add immediate")
test_alu(lambda x, y: x + y, use_imm=True)
print("Add")
test_alu(lambda x, y: x + y)
print("Shift right immediate")
test_alu(lambda x, y: x >> y, np.right_shift, use_imm=True)
print("Relu")
test_relu()
# print("Shift and scale")
# test_shift_and_scale()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment