Commit f5960ba6 by Thierry Moreau Committed by Tianqi Chen

[SCHEDULER, HW] Auto scheduler for conv2d, hardware generation (#20)

* Hardware generation fixes/sweep, auto scheduling for VTA conv2d

* Hardware generation fixes/sweep, auto scheduling for VTA conv2d

* derive hw spec from config file

* up to date hardware spec
parent 33a309b2
# Directories
ROOTDIR = $(CURDIR)
BUILD_DIR = $(ROOTDIR)/../../build/hardware/xilinx
BUILD_NAME = build
BUILD_DIR = $(ROOTDIR)/../../$(BUILD_NAME)/hardware/xilinx
SCRIPT_DIR = $(ROOTDIR)/scripts
SRC_DIR = $(ROOTDIR)/src
SIM_DIR = $(ROOTDIR)/sim
......@@ -12,6 +13,15 @@ VIVADO_HLS = vivado_hls
VIVADO = vivado
HSI = hsi
# HLS Mode
MODE = all
# SLURM
SLURM = false
# Prevent generation of DSP
NO_DSP = false
# Prevent generation of ALU
NO_ALU = false
# Include top-level config file
ifndef config
ifneq ("$(wildcard ../../config.mk)", "")
......@@ -41,13 +51,18 @@ $(shell echo "$$(( (1000 + $(VTA_HW_COMP_CLOCK_FREQ) - 1) / $(VTA_HW_COMP_CLOCK_
# Derive config name
CONF = \
$(VTA_BATCH)x$(VTA_IN_BLOCK)x$(VTA_OUT_BLOCK)_$(VTA_INP_WIDTH)bx$(VTA_WGT_WIDTH)b_$(VTA_HW_COMP_CLOCK_FREQ)MHz_$(TARGET_PER)ns
$(VTA_BATCH)x$(VTA_IN_BLOCK)x$(VTA_OUT_BLOCK)_$(VTA_INP_WIDTH)bx$(VTA_WGT_WIDTH)b_$(VTA_LOG_UOP_BUFF_SIZE)_$(VTA_LOG_INP_BUFF_SIZE)_$(VTA_LOG_WGT_BUFF_SIZE)_$(VTA_LOG_ACC_BUFF_SIZE)_$(VTA_HW_COMP_CLOCK_FREQ)MHz_$(TARGET_PER)ns
IP_BUILD_PATH = $(BUILD_DIR)/hls/$(CONF)
HW_BUILD_PATH = $(BUILD_DIR)/vivado/$(CONF)
.PHONY: all ip bit driver clean
ifeq ($(SLURM), true)
IP_BUILD_PATH = /scratch/hls/$(CONF)
HW_BUILD_PATH = /scratch/vivado/$(CONF)
endif
.PHONY: all ip bit driver clean clean_all
all: driver
all: bit
ip:
mkdir -p $(IP_BUILD_PATH)
......@@ -57,20 +72,32 @@ ip:
$(VTA_LOG_INP_WIDTH) $(VTA_LOG_WGT_WIDTH) $(VTA_LOG_ACC_WIDTH) $(VTA_LOG_OUT_WIDTH) \
$(VTA_LOG_BATCH) $(VTA_LOG_BLOCK_OUT) $(VTA_LOG_BLOCK_IN) \
$(VTA_LOG_UOP_BUFF_SIZE) $(VTA_LOG_INP_BUFF_SIZE) $(VTA_LOG_WGT_BUFF_SIZE) \
$(VTA_LOG_ACC_BUFF_SIZE) $(VTA_LOG_OUT_BUFF_SIZE)
$(VTA_LOG_ACC_BUFF_SIZE) $(VTA_LOG_OUT_BUFF_SIZE) \
$(MODE) $(NO_DSP) $(NO_ALU)
ifeq ($(SLURM), true)
mkdir -p $(BUILD_DIR)/hls
mv $(IP_BUILD_PATH) $(BUILD_DIR)/hls/.
endif
bit: ip
mkdir -p $(HW_BUILD_PATH)
cd $(HW_BUILD_PATH) && \
$(VIVADO) -mode tcl -source $(SCRIPT_DIR)/vivado.tcl \
-tclargs $(IP_BUILD_PATH) $(VTA_HW_COMP_THREADS) $(VTA_HW_COMP_CLOCK_FREQ) \
-tclargs $(BUILD_DIR)/hls/$(CONF) $(VTA_HW_COMP_THREADS) $(VTA_HW_COMP_CLOCK_FREQ) \
$(VTA_INP_WIDTH) $(VTA_WGT_WIDTH) $(VTA_OUT_WIDTH) \
$(VTA_BATCH) $(VTA_IN_BLOCK) $(VTA_OUT_BLOCK) \
$(VTA_INP_BUFF_SIZE) $(VTA_WGT_BUFF_SIZE) $(VTA_OUT_BUFF_SIZE)
ifeq ($(SLURM), true)
mkdir -p $(BUILD_DIR)/vivado
mv $(HW_BUILD_PATH) $(BUILD_DIR)/vivado/.
endif
driver: bit
cd $(HW_BUILD_PATH) && $(HSI) -mode tcl -source $(SCRIPT_DIR)/hsi.tcl -nojournal -nolog
cd $(HW_BUILD_PATH)/bsp && make
clean:
rm -rf *.out *.log *.sb figures
clean_all: clean
rm -rf $(BUILD_DIR)
#!/usr/bin/env python
import argparse
import datetime
import logging
import numpy as np
import os
import pandas as pd
import re
import time
from collections import namedtuple
from numpy import floor, ceil, log2, log10
from subprocess import call
FPGA = namedtuple("FPGAConstraints",
['bram_w', 'bram_d', 'num_bram'])
Hardware = namedtuple("HWConfig",
['batch', 'block_in', 'block_out',
'input_w', 'weight_w', 'accum_w', 'out_w', 'uop_w'])
def find_bram_confs(fpga, hw_conf, log_uop_sizeB):
# Derive sizes
input_elem_size_b = hw_conf.batch*hw_conf.block_in*hw_conf.input_w
weight_elem_size_b = hw_conf.block_in*hw_conf.block_out*hw_conf.weight_w
accum_elem_size_b = hw_conf.batch*hw_conf.block_out*hw_conf.accum_w
input_min_bram = (input_elem_size_b+fpga.bram_w-1)/fpga.bram_w
weight_min_bram = (weight_elem_size_b+fpga.bram_w-1)/fpga.bram_w
accum_min_bram = (accum_elem_size_b+fpga.bram_w-1)/fpga.bram_w
# Exploring all possible BRAM distributions
bram_confs = []
uop_bram = pow(2, log_uop_sizeB) * 8 / (fpga.bram_w * fpga.bram_d)
for log_i_bram in range(int(log2(input_min_bram)), int(ceil(log2(fpga.num_bram)))):
i_bram = pow(2, log_i_bram)
for log_w_bram in range(int(log2(weight_min_bram)), int(ceil(log2(fpga.num_bram)))):
w_bram = pow(2, log_w_bram)
for log_a_bram in range(int(log2(accum_min_bram)), int(ceil(log2(fpga.num_bram)))):
a_bram = pow(2, log_a_bram)
total_bram = uop_bram + i_bram + w_bram + a_bram + a_bram / hw_conf.accum_w * hw_conf.out_w
if total_bram <= fpga.num_bram:
# Right now we need to restrict uop width
input_elems = i_bram * fpga.bram_w * fpga.bram_d / input_elem_size_b
weight_elems = w_bram * fpga.bram_w * fpga.bram_d / weight_elem_size_b
accum_elems = a_bram * fpga.bram_w * fpga.bram_d / accum_elem_size_b
if log2(input_elems) + log2(weight_elems) + log2(accum_elems) <= hw_conf.uop_w:
log_inp_sizeB = int(log2(i_bram*fpga.bram_d*fpga.bram_w/8))
log_wgt_sizeB = int(log2(w_bram*fpga.bram_d*fpga.bram_w/8))
log_acc_sizeB = int(log2(a_bram*fpga.bram_d*fpga.bram_w/8))
bram_confs.append([log_uop_sizeB, log_inp_sizeB, log_wgt_sizeB, log_acc_sizeB])
# Filter out configs that are suboptimal
suboptimal = [False] * len(bram_confs)
for i in range(0, len(bram_confs)):
for j in range(i + 1, len(bram_confs)):
leq_list = [a <= b for a, b in zip(bram_confs[i], bram_confs[j])]
geq_list = [a >= b for a, b in zip(bram_confs[i], bram_confs[j])]
leq = all(leq_list)
geq = all(geq_list)
if leq:
suboptimal[i] = True
if geq:
suboptimal[j] = True
opt_bram_confs = [x[0] for x in zip(bram_confs, suboptimal) if not x[1]]
return opt_bram_confs
def get_make_command(job, build_dir, hw_conf, bram_conf, mode, slurm=False):
cmd = ""
if slurm:
cmd += "#!/bin/bash\n"
cmd += "#SBATCH --job-name={}\n".format(job)
cmd += "#SBATCH --output={}.out\n".format(job)
cmd += "srun "
if mode=="hls":
cmd += "make ip"
else:
cmd += "make"
cmd += " SLURM={} MODE=skip_sim NO_DSP=false NO_ALU=false".format("true" if slurm else "false")
cmd += " BUILD_NAME={}".format(build_dir)
cmd += " VTA_LOG_INP_WIDTH={}".format(int(log2(hw_conf.input_w)))
cmd += " VTA_LOG_WGT_WIDTH={}".format(int(log2(hw_conf.weight_w)))
cmd += " VTA_LOG_BATCH={}".format(int(log2(hw_conf.batch)))
cmd += " VTA_LOG_BLOCK_IN={}".format(int(log2(hw_conf.block_in)))
cmd += " VTA_LOG_BLOCK_OUT={}".format(int(log2(hw_conf.block_out)))
cmd += " VTA_LOG_UOP_BUFF_SIZE={}".format(bram_conf[0])
cmd += " VTA_LOG_INP_BUFF_SIZE={}".format(bram_conf[1])
cmd += " VTA_LOG_WGT_BUFF_SIZE={}".format(bram_conf[2])
cmd += " VTA_LOG_ACC_BUFF_SIZE={}\n".format(bram_conf[3])
return cmd
def cli():
parser = argparse.ArgumentParser(
description='Analyze HLS experiments'
)
parser.add_argument(
'-mode', dest='mode', action='store', type=str, required=True,
choices=["hls", "vivado"], help='hls synthesis or full compilation'
)
parser.add_argument(
'-base_dir', dest='base_dir', action='store', type=str, required=False,
default="../../build/hardware/xilinx/", help='path to build directory'
)
parser.add_argument(
'-min_ibw', dest='min_ibw', action='store', type=int, required=False,
default=3, help='log2 of minimum input bit-width'
)
parser.add_argument(
'-max_ibw', dest='max_ibw', action='store', type=int, required=False,
default=3, help='log2 of maximum input bit-width'
)
parser.add_argument(
'-min_wbw', dest='min_wbw', action='store', type=int, required=False,
default=3, help='log2 of minimum weight bit-width'
)
parser.add_argument(
'-max_wbw', dest='max_wbw', action='store', type=int, required=False,
default=3, help='log2 of maximum weight bit-width'
)
parser.add_argument(
'-acc_bw', dest='acc_bw', action='store', type=int, required=False,
default=32, help='accumulator bit-width'
)
parser.add_argument(
'-uop_bw', dest='uop_bw', action='store', type=int, required=False,
default=32, help='micro-op bit-width'
)
parser.add_argument(
'-min_batch', dest='min_batch', action='store', type=int, required=False,
default=0, help='log2 of minimum batch size'
)
parser.add_argument(
'-max_batch', dest='max_batch', action='store', type=int, required=False,
default=8, help='log2 of maximum batch size'
)
parser.add_argument(
'-min_ic', dest='min_ic', action='store', type=int, required=False,
default=0, help='log2 of minimum input channels'
)
parser.add_argument(
'-max_ic', dest='max_ic', action='store', type=int, required=False,
default=8, help='log2 of maximum input channels'
)
parser.add_argument(
'-min_oc', dest='min_oc', action='store', type=int, required=False,
default=0, help='log2 of minimum output channels'
)
parser.add_argument(
'-max_oc', dest='max_oc', action='store', type=int, required=False,
default=8, help='log2 of maximum output channels'
)
parser.add_argument(
'-uop_sizeB', dest='uop_sizeB', action='store', type=int, required=False,
default=14, help='log2 of uop buffer in B'
)
parser.add_argument(
'-bram_w', dest='bram_w', action='store', type=int, required=False,
default=32, help='FPGA BRAM port width in b'
)
parser.add_argument(
'-bram_d', dest='bram_d', action='store', type=int, required=False,
default=1024, help='FPGA BRAM depth'
)
parser.add_argument(
'-num_bram', dest='num_bram', action='store', type=int, required=False,
default=124, help='FPGA total BRAM'
)
parser.add_argument(
'-slurm', dest='slurm', action='store_true',
help='Run on cluster using slurm'
)
args = parser.parse_args()
# Logging
logging.basicConfig(filename='compile_designs.log',level=logging.DEBUG)
# FPGA config
pynq = FPGA(args.bram_w, args.bram_d, args.num_bram)
# Get timestamp
timestamp = datetime.datetime.fromtimestamp(time.time()).strftime('%Y_%m_%d_%H_%M_%S')
build_dir = "build_{}".format(timestamp)
num_confs = 0
for log_ibw in range(args.min_ibw, args.max_ibw+1):
ibw = pow(2, log_ibw)
for log_wbw in range(args.min_wbw, args.max_wbw+1):
wbw = pow(2, log_wbw)
for log_batch in range(args.min_batch, args.max_batch+1):
batch = pow(2, log_batch)
for log_ic in range(args.min_ic, args.max_ic+1):
ic = pow(2, log_ic)
for log_oc in range(args.min_oc, args.max_oc+1):
oc = pow(2, log_oc)
conf = Hardware(batch, ic, oc, ibw, wbw, args.acc_bw, ibw, args.uop_bw)
bram_confs = find_bram_confs(pynq, conf, args.uop_sizeB)
for b in bram_confs:
job = "{}x{}x{}_{}bx{}b_{}_{}_{}_{}_100MHz_10ns".format(
batch, ic, oc, ibw, wbw, b[0], b[1], b[2], b[3])
num_confs += 1
cmd = get_make_command(job, build_dir, conf, b, args.mode, args.slurm)
sb_file = job+".sb"
file = open(sb_file,"w")
file.write(cmd)
file.close()
call(["echo", cmd])
if args.slurm:
call(["sbatch", sb_file])
else:
call(cmd.split(" "))
if __name__ == '__main__':
cli()
......@@ -22,8 +22,11 @@
# Arg 15: wgt buffer size in B (log)
# Arg 16: acc buffer size in B (log)
# Arg 17: out buffer size in B (log)
# Arg 18: mode
# Arg 19: no_dsp
# Arg 20: no_alu
if { [llength $argv] eq 19 } {
if { [llength $argv] eq 22 } {
set src_dir [lindex $argv 2]
set sim_dir [lindex $argv 3]
set test_dir [lindex $argv 4]
......@@ -41,10 +44,13 @@ if { [llength $argv] eq 19 } {
set wgt_buff_size [lindex $argv 16]
set acc_buff_size [lindex $argv 17]
set out_buff_size [lindex $argv 18]
set mode [lindex $argv 19]
set no_dsp [lindex $argv 20]
set no_alu [lindex $argv 21]
} else {
set src_dir "../src/"
set sim_dir "../sim/"
set test_dir "../../src/test/"
set src_dir "../src"
set sim_dir "../sim"
set test_dir "../../src/test"
set include_dir "../../include"
set target_period 10
set inp_width 3
......@@ -59,17 +65,11 @@ if { [llength $argv] eq 19 } {
set wgt_buff_size 15
set acc_buff_size 17
set out_buff_size 15
set mode "all"
set no_dsp "true"
set no_alu "false"
}
# C define flags to pass to compiler
set cflags "-I $include_dir -I $src_dir -I $test_dir \
-DVTA_DEBUG=0 -DVTA_LOG_WGT_WIDTH=$wgt_width -DVTA_LOG_INP_WIDTH=$inp_width \
-DVTA_LOG_ACC_WIDTH=$acc_width -DVTA_LOG_OUT_WIDTH=$out_width \
-DVTA_LOG_BATCH=$batch -DVTA_LOG_BLOCK_OUT=$block_out -DVTA_LOG_BLOCK_IN=$block_in \
-DVTA_LOG_UOP_BUFF_SIZE=$uop_buff_size -DVTA_LOG_INP_BUFF_SIZE=$inp_buff_size \
-DVTA_LOG_WGT_BUFF_SIZE=$wgt_buff_size -DVTA_LOG_ACC_BUFF_SIZE=$acc_buff_size \
-DVTA_LOG_OUT_BUFF_SIZE=$out_buff_size"
# Initializes the HLS design and sets HLS pragmas for memory partitioning.
# This is necessary because of a Vivado restriction that doesn't allow for
# buses wider than 1024 bits.
......@@ -122,56 +122,89 @@ proc init_design {per inp_width wgt_width out_width batch block_in block_out} {
}
}
# C define flags to pass to compiler
set cflags "-I $include_dir -I $src_dir -I $test_dir \
-DVTA_DEBUG=0 -DVTA_LOG_WGT_WIDTH=$wgt_width -DVTA_LOG_INP_WIDTH=$inp_width \
-DVTA_LOG_ACC_WIDTH=$acc_width -DVTA_LOG_OUT_WIDTH=$out_width \
-DVTA_LOG_BATCH=$batch -DVTA_LOG_BLOCK_OUT=$block_out -DVTA_LOG_BLOCK_IN=$block_in \
-DVTA_LOG_UOP_BUFF_SIZE=$uop_buff_size -DVTA_LOG_INP_BUFF_SIZE=$inp_buff_size \
-DVTA_LOG_WGT_BUFF_SIZE=$wgt_buff_size -DVTA_LOG_ACC_BUFF_SIZE=$acc_buff_size \
-DVTA_LOG_OUT_BUFF_SIZE=$out_buff_size"
if {$no_dsp=="true"} {
append cflags " -DNO_DSP"
}
if {$no_alu=="true"} {
append cflags " -DNO_ALU"
}
# HLS behavioral sim
open_project vta_sim
set_top vta
add_files $src_dir/vta.cc -cflags $cflags
add_files -tb $sim_dir/vta_test.cc -cflags $cflags
add_files -tb $test_dir/test_lib.cc -cflags $cflags
open_solution "solution0"
init_design $target_period $inp_width $wgt_width $out_width $batch $block_in $block_out
csim_design -clean
close_project
if {$mode=="all" || $mode=="sim"} {
open_project vta_sim
set_top vta
add_files $src_dir/vta.cc -cflags $cflags
add_files -tb $sim_dir/vta_test.cc -cflags $cflags
add_files -tb $test_dir/test_lib.cc -cflags $cflags
open_solution "solution0"
init_design $target_period $inp_width $wgt_width $out_width $batch $block_in $block_out
csim_design -clean
close_project
}
# Generate fetch stage
open_project vta_fetch
set_top fetch
add_files $src_dir/vta.cc -cflags $cflags
open_solution "solution0"
init_design $target_period $inp_width $wgt_width $out_width $batch $block_in $block_out
csynth_design
export_design -format ip_catalog
close_project
if {$mode=="all" || $mode=="skip_sim" || $mode=="fetch"} {
open_project vta_fetch
set_top fetch
add_files $src_dir/vta.cc -cflags $cflags
open_solution "solution0"
init_design $target_period $inp_width $wgt_width $out_width $batch $block_in $block_out
csynth_design
if {$mode=="all" || $mode=="skip_sim"} {
export_design -format ip_catalog
}
close_project
}
# Generate load stage
open_project vta_load
set_top load
add_files $src_dir/vta.cc -cflags $cflags
open_solution "solution0"
init_design $target_period $inp_width $wgt_width $out_width $batch $block_in $block_out
csynth_design
export_design -format ip_catalog
close_project
if {$mode=="all" || $mode=="skip_sim" || $mode=="load"} {
open_project vta_load
set_top load
add_files $src_dir/vta.cc -cflags $cflags
open_solution "solution0"
init_design $target_period $inp_width $wgt_width $out_width $batch $block_in $block_out
csynth_design
if {$mode=="all" || $mode=="skip_sim"} {
export_design -format ip_catalog
}
close_project
}
# Generate compute stage
open_project vta_compute
set_top compute
add_files $src_dir/vta.cc -cflags $cflags
open_solution "solution0"
init_design $target_period $inp_width $wgt_width $out_width $batch $block_in $block_out
csynth_design
export_design -format ip_catalog
close_project
if {$mode=="all" || $mode=="skip_sim" || $mode=="compute"} {
open_project vta_compute
set_top compute
add_files $src_dir/vta.cc -cflags $cflags
open_solution "solution0"
init_design $target_period $inp_width $wgt_width $out_width $batch $block_in $block_out
csynth_design
if {$mode=="all" || $mode=="skip_sim"} {
export_design -format ip_catalog
}
close_project
}
# Generate store stage
open_project vta_store
set_top store
add_files $src_dir/vta.cc -cflags $cflags
open_solution "solution0"
init_design $target_period $inp_width $wgt_width $out_width $batch $block_in $block_out
csynth_design
export_design -format ip_catalog
close_project
if {$mode=="all" || $mode=="skip_sim" || $mode=="store"} {
open_project vta_store
set_top store
add_files $src_dir/vta.cc -cflags $cflags
open_solution "solution0"
init_design $target_period $inp_width $wgt_width $out_width $batch $block_in $block_out
csynth_design
if {$mode=="all" || $mode=="skip_sim"} {
export_design -format ip_catalog
}
close_project
}
exit
......@@ -59,44 +59,31 @@ if { [llength $argv] eq 12 } {
# Derive input mem parameters
set inp_mem_width [expr $inp_width * $batch * $in_block]
set inp_mem_depth [expr $inp_mem_size * 8 / $inp_mem_width]
set inp_bus_width 1024
set inp_part [expr $inp_mem_width / $inp_bus_width]
if {[expr $inp_part == 0]} {
set inp_part 1
set inp_bus_width $inp_mem_width
}
set inp_mem_depth [expr $inp_mem_size * 8 / ($inp_mem_width * $inp_part)]
# Derive weight mem parameters
set wgt_mem_width [expr $wgt_width * $out_block * $in_block]
set wgt_mem_depth [expr $wgt_mem_size * 8 / $wgt_mem_width]
set wgt_bus_width 1024
set wgt_part [expr $wgt_mem_width / $wgt_bus_width]
if {[expr $wgt_part == 0]} {
set wgt_part 1
set wgt_bus_width $wgt_mem_width
}
set wgt_mem_depth [expr $wgt_mem_size * 8 / ($wgt_mem_width * $wgt_part)]
# Derive output mem parameters
set out_mem_width [expr $out_width * $batch * $out_block]
set out_mem_depth [expr $out_mem_size * 8 / $out_mem_width]
set out_bus_width 1024
set out_part [expr $out_mem_width / $out_bus_width]
if {[expr $out_part == 0]} {
set out_part 1
set out_bus_width $out_mem_width
}
puts $inp_mem_width
puts $inp_mem_depth
puts $inp_bus_width
puts $inp_part
puts $wgt_mem_width
puts $wgt_mem_depth
puts $wgt_bus_width
puts $wgt_part
puts $out_mem_width
puts $out_mem_depth
puts $out_bus_width
puts $out_part
set out_mem_depth [expr $out_mem_size * 8 / ($out_mem_width * $out_part)]
# User defined paths
set proj_name vta
......@@ -937,10 +924,12 @@ launch_runs impl_1 -to_step write_bitstream -jobs $num_threads
wait_on_run impl_1
# Export hardware description file and bitstream files to export/ dir
file mkdir $proj_path/export
file copy -force $proj_path/$proj_name.runs/impl_1/${proj_name}_wrapper.sysdef \
if {[file exist $proj_path/$proj_name.runs/impl_1/${proj_name}_wrapper.bit]} {
file mkdir $proj_path/export
file copy -force $proj_path/$proj_name.runs/impl_1/${proj_name}_wrapper.sysdef \
$proj_path/export/vta.hdf
file copy -force $proj_path/$proj_name.runs/impl_1/${proj_name}_wrapper.bit \
file copy -force $proj_path/$proj_name.runs/impl_1/${proj_name}_wrapper.bit \
$proj_path/export/vta.bit
}
exit
......@@ -16,40 +16,36 @@ int main(void) {
printParameters();
#endif
// Buffer indexing
assert(VTA_LOG_ACC_BUFF_DEPTH >= VTA_LOG_INP_BUFF_DEPTH);
// Micro op bound
assert(VTA_UOP_GEM_3_1 < VTA_UOP_WIDTH);
assert(VTA_UOP_ALU_3_1 < VTA_UOP_WIDTH);
// Instruction alignment checks
assert(VTA_UOP_GEM_2_1 < VTA_UOP_WIDTH);
assert(VTA_UOP_ALU_1_1 < VTA_UOP_WIDTH);
// Make sure there is no misaligment
assert(VTA_INSN_GEM_9_1 < VTA_INSN_GEM_A_0);
assert(VTA_INSN_MEM_7_1 < VTA_INSN_MEM_8_0);
assert(VTA_INSN_GEM_8_1 < VTA_INSN_GEM_9_0);
// Instruction bounds
assert(VTA_INSN_MEM_E_1 < VTA_INS_WIDTH);
assert(VTA_INSN_GEM_E_1 < VTA_INS_WIDTH);
assert(VTA_INSN_ALU_F_1 < VTA_INS_WIDTH);
assert(VTA_INSN_GEM_F_1 < VTA_INS_WIDTH);
assert(VTA_INSN_ALU_G_1 < VTA_INS_WIDTH);
int status = 0;
// Run ALU test (vector-scalar operators)
status |= alu_test(VTA_ALU_OPCODE_MIN, true, 16, 128, true);
status |= alu_test(VTA_ALU_OPCODE_MIN, true, 16, 128, false);
status |= alu_test(VTA_ALU_OPCODE_MAX, true, 16, 128, true);
status |= alu_test(VTA_ALU_OPCODE_MAX, true, 16, 128, false);
status |= alu_test(VTA_ALU_OPCODE_ADD, true, 16, 128, true);
status |= alu_test(VTA_ALU_OPCODE_ADD, true, 16, 128, false);
status |= alu_test(VTA_ALU_OPCODE_SHR, true, 16, 128, true);
status |= alu_test(VTA_ALU_OPCODE_SHR, true, 16, 128, false);
status |= alu_test(VTA_ALU_OPCODE_SHL, true, 16, 128, true);
status |= alu_test(VTA_ALU_OPCODE_SHL, true, 16, 128, false);
status |= alu_test(VTA_ALU_OPCODE_MIN, true, VTA_BLOCK_OUT, 128, true);
status |= alu_test(VTA_ALU_OPCODE_MIN, true, VTA_BLOCK_OUT, 128, false);
status |= alu_test(VTA_ALU_OPCODE_MAX, true, VTA_BLOCK_OUT, 128, true);
status |= alu_test(VTA_ALU_OPCODE_MAX, true, VTA_BLOCK_OUT, 128, false);
status |= alu_test(VTA_ALU_OPCODE_ADD, true, VTA_BLOCK_OUT, 128, true);
status |= alu_test(VTA_ALU_OPCODE_ADD, true, VTA_BLOCK_OUT, 128, false);
status |= alu_test(VTA_ALU_OPCODE_SHR, true, VTA_BLOCK_OUT, 128, true);
status |= alu_test(VTA_ALU_OPCODE_SHR, true, VTA_BLOCK_OUT, 128, false);
// Run ALU test (vector-vector operators)
status |= alu_test(VTA_ALU_OPCODE_MIN, false, 16, 128, true);
status |= alu_test(VTA_ALU_OPCODE_MIN, false, 16, 128, false);
status |= alu_test(VTA_ALU_OPCODE_MAX, false, 16, 128, true);
status |= alu_test(VTA_ALU_OPCODE_MAX, false, 16, 128, false);
status |= alu_test(VTA_ALU_OPCODE_ADD, false, 16, 128, true);
status |= alu_test(VTA_ALU_OPCODE_ADD, false, 16, 128, false);
status |= alu_test(VTA_ALU_OPCODE_MIN, false, VTA_BLOCK_OUT, 128, true);
status |= alu_test(VTA_ALU_OPCODE_MIN, false, VTA_BLOCK_OUT, 128, false);
status |= alu_test(VTA_ALU_OPCODE_MAX, false, VTA_BLOCK_OUT, 128, true);
status |= alu_test(VTA_ALU_OPCODE_MAX, false, VTA_BLOCK_OUT, 128, false);
status |= alu_test(VTA_ALU_OPCODE_ADD, false, VTA_BLOCK_OUT, 128, true);
status |= alu_test(VTA_ALU_OPCODE_ADD, false, VTA_BLOCK_OUT, 128, false);
// Run blocked GEMM test
status |= blocked_gemm_test(256, 256, VTA_BLOCK_OUT*4, true, 2);
......
......@@ -183,7 +183,7 @@ void compute(
hls::stream<bool> &s2g_dep_queue,
hls::stream<bool> &g2l_dep_queue,
hls::stream<bool> &g2s_dep_queue,
out_vec_T inp_mem[VTA_INP_BUFF_DEPTH][VTA_BATCH],
inp_vec_T inp_mem[VTA_INP_BUFF_DEPTH][VTA_BATCH],
wgt_vec_T wgt_mem[VTA_WGT_BUFF_DEPTH][VTA_BLOCK_OUT],
out_vec_T out_mem[VTA_ACC_BUFF_DEPTH][VTA_BATCH]
) {
......@@ -286,23 +286,24 @@ void compute(
done = 0;
// Decode
uop_idx_T uop_bgn = insn.range(VTA_INSN_GEM_5_1, VTA_INSN_GEM_5_0);
uop_idx_T uop_end = insn.range(VTA_INSN_GEM_6_1, VTA_INSN_GEM_6_0);
loop_T iter_out = insn.range(VTA_INSN_GEM_7_1, VTA_INSN_GEM_7_0);
loop_T iter_in = insn.range(VTA_INSN_GEM_8_1, VTA_INSN_GEM_8_0);
acc_idx_T dst_factor_out = insn.range(VTA_INSN_GEM_9_1, VTA_INSN_GEM_9_0);
acc_idx_T dst_factor_in = insn.range(VTA_INSN_GEM_A_1, VTA_INSN_GEM_A_0);
inp_idx_T src_factor_out = insn.range(VTA_INSN_GEM_B_1, VTA_INSN_GEM_B_0);
inp_idx_T src_factor_in = insn.range(VTA_INSN_GEM_C_1, VTA_INSN_GEM_C_0);
bool reset_out = insn[VTA_INSN_GEM_5];
uop_idx_T uop_bgn = insn.range(VTA_INSN_GEM_6_1, VTA_INSN_GEM_6_0);
uop_idx_T uop_end = insn.range(VTA_INSN_GEM_7_1, VTA_INSN_GEM_7_0);
loop_T iter_out = insn.range(VTA_INSN_GEM_8_1, VTA_INSN_GEM_8_0);
loop_T iter_in = insn.range(VTA_INSN_GEM_9_1, VTA_INSN_GEM_9_0);
acc_idx_T dst_factor_out = insn.range(VTA_INSN_GEM_A_1, VTA_INSN_GEM_A_0);
acc_idx_T dst_factor_in = insn.range(VTA_INSN_GEM_B_1, VTA_INSN_GEM_B_0);
inp_idx_T src_factor_out = insn.range(VTA_INSN_GEM_C_1, VTA_INSN_GEM_C_0);
inp_idx_T src_factor_in = insn.range(VTA_INSN_GEM_D_1, VTA_INSN_GEM_D_0);
// GEMM-specific fields
wgt_idx_T wgt_factor_out = insn.range(VTA_INSN_GEM_D_1, VTA_INSN_GEM_D_0);
wgt_idx_T wgt_factor_in = insn.range(VTA_INSN_GEM_E_1, VTA_INSN_GEM_E_0);
wgt_idx_T wgt_factor_out = insn.range(VTA_INSN_GEM_E_1, VTA_INSN_GEM_E_0);
wgt_idx_T wgt_factor_in = insn.range(VTA_INSN_GEM_F_1, VTA_INSN_GEM_F_0);
// ALU-specific field
aluop_opcode_T alu_opcode = insn.range(VTA_INSN_ALU_D_1, VTA_INSN_ALU_D_0);
bool use_imm = insn[VTA_INSN_ALU_E];
aluop_imm_T imm = insn.range(VTA_INSN_ALU_F_1, VTA_INSN_ALU_F_0);
aluop_opcode_T alu_opcode = insn.range(VTA_INSN_ALU_E_1, VTA_INSN_ALU_E_0);
bool use_imm = insn[VTA_INSN_ALU_F];
aluop_imm_T imm = insn.range(VTA_INSN_ALU_G_1, VTA_INSN_ALU_G_0);
acc_idx_T dst_offset_out = 0;
inp_idx_T src_offset_out = 0;
wgt_idx_T wgt_offset_out = 0;
......@@ -326,13 +327,12 @@ void compute(
uop_T uop = uop_mem[upc];
// Decode indices
bool reset_out = uop[VTA_UOP_GEM_0];
acc_idx_T dst_idx =
uop.range(VTA_UOP_GEM_1_1, VTA_UOP_GEM_1_0) + dst_offset_in;
acc_idx_T src_idx =
uop.range(VTA_UOP_GEM_2_1, VTA_UOP_GEM_2_0) + src_offset_in;
uop.range(VTA_UOP_GEM_0_1, VTA_UOP_GEM_0_0) + dst_offset_in;
inp_idx_T src_idx =
uop.range(VTA_UOP_GEM_1_1, VTA_UOP_GEM_1_0) + src_offset_in;
wgt_idx_T wgt_idx =
uop.range(VTA_UOP_GEM_3_1, VTA_UOP_GEM_3_0) + wgt_offset_in;
uop.range(VTA_UOP_GEM_2_1, VTA_UOP_GEM_2_0) + wgt_offset_in;
// Read weight matrix
wgt_vec_T w_matrix[VTA_BLOCK_OUT];
......@@ -341,7 +341,7 @@ void compute(
}
// Read input matrix and accum matrix
acc_vec_T o_matrix[VTA_BATCH];
out_vec_T i_matrix[VTA_BATCH];
inp_vec_T i_matrix[VTA_BATCH];
for (int i = 0; i < VTA_BATCH; i++) {
o_matrix[i] = acc_mem[dst_idx][i];
i_matrix[i] = inp_mem[src_idx][i];
......@@ -365,6 +365,9 @@ void compute(
inp_T i_elem =
i_matrix[i].range((k + 1) * VTA_INP_WIDTH - 1, k * VTA_INP_WIDTH);
mul_T prod = i_elem * w_elem;
#ifdef NO_DSP
#pragma HLS RESOURCE variable = prod core = Mul_LUT
#endif // NO_DSP
tmp += (sum_T) prod;
}
// Update summation
......@@ -373,101 +376,85 @@ void compute(
acc_mem_val[i].range((b + 1) * VTA_ACC_WIDTH - 1, b * VTA_ACC_WIDTH) =
reset_out ? (acc_T) 0 : accum;
st_buf_val[i].range((b + 1) * VTA_OUT_WIDTH - 1, b * VTA_OUT_WIDTH) =
(inp_T) accum.range(VTA_OUT_WIDTH - 1, 0);
(out_T) accum.range(VTA_OUT_WIDTH - 1, 0);
}
// Write to buffers
acc_mem[dst_idx][i] = acc_mem_val[i];
out_mem[dst_idx][i] = st_buf_val[i];
}
}
} else if (opcode == VTA_OPCODE_ALU) {
}
#ifndef NO_ALU
else if (opcode == VTA_OPCODE_ALU) {
// Iterate over micro op
READ_ALU_UOP: for (int upc = uop_bgn; upc < uop_end; upc++) {
#pragma HLS PIPELINE II = 2 rewind
// Read micro-op fields
uop_T uop = uop_mem[upc];
// Decode
bool reset_out = uop[VTA_UOP_ALU_0];
acc_idx_T dst_idx =
uop.range(VTA_UOP_ALU_1_1, VTA_UOP_ALU_1_0) + dst_offset_in;
uop.range(VTA_UOP_ALU_0_1, VTA_UOP_ALU_0_0) + dst_offset_in;
acc_idx_T src_idx =
uop.range(VTA_UOP_ALU_2_1, VTA_UOP_ALU_2_0) + src_offset_in;
// Read input matrix and accum matrix
acc_vec_T dst_matrix[VTA_BATCH];
acc_vec_T src_matrix[VTA_BATCH];
for (int i = 0; i < VTA_BATCH; i++) {
#pragma HLS UNROLL complete
dst_matrix[i] = acc_mem[dst_idx][i];
src_matrix[i] = acc_mem[src_idx][i];
}
// Result matrices
acc_vec_T cmp_res[VTA_BATCH];
acc_vec_T add_res[VTA_BATCH];
acc_vec_T rshr_res[VTA_BATCH];
acc_vec_T lshr_res[VTA_BATCH];
out_vec_T short_cmp_res[VTA_BATCH];
out_vec_T short_add_res[VTA_BATCH];
out_vec_T short_rshr_res[VTA_BATCH];
out_vec_T short_lshr_res[VTA_BATCH];
uop.range(VTA_UOP_ALU_1_1, VTA_UOP_ALU_1_0) + src_offset_in;
// Perform ALU op over matrix elements
for (int i = 0; i < VTA_BATCH; i++) {
// Read input matrix and accum matrix
acc_vec_T dst_vector = acc_mem[dst_idx][i];
acc_vec_T src_vector = acc_mem[src_idx][i];
// Result matrices
acc_vec_T cmp_res;
acc_vec_T add_res;
acc_vec_T shr_res;
out_vec_T short_cmp_res;
out_vec_T short_add_res;
out_vec_T short_shr_res;
// Results vector
acc_vec_T res_vec = 0;
for (int b = 0; b < VTA_BLOCK_OUT; b++) {
#pragma HLS PIPELINE II = 1 rewind
// Read in operands
acc_T src_0 = dst_matrix[i].range((b + 1) * VTA_ACC_WIDTH - 1, b * VTA_ACC_WIDTH);
acc_T src_0 = dst_vector.range((b + 1) * VTA_ACC_WIDTH - 1, b * VTA_ACC_WIDTH);
acc_T src_1 = use_imm ?
(acc_T) imm :
src_matrix[i].range((b + 1) * VTA_ACC_WIDTH - 1, b * VTA_ACC_WIDTH);
src_vector.range((b + 1) * VTA_ACC_WIDTH - 1, b * VTA_ACC_WIDTH);
// Compute Min/Max
acc_T mix_val = src_0 < src_1 ?
(alu_opcode == VTA_ALU_OPCODE_MIN ? src_0 : src_1) :
(alu_opcode == VTA_ALU_OPCODE_MIN ? src_1 : src_0);
cmp_res[i].range((b + 1) * VTA_ACC_WIDTH - 1, b * VTA_ACC_WIDTH) = mix_val;
short_cmp_res[i].range((b + 1) * VTA_OUT_WIDTH - 1, b * VTA_OUT_WIDTH) =
(inp_T) mix_val.range(VTA_OUT_WIDTH - 1, 0);
cmp_res.range((b + 1) * VTA_ACC_WIDTH - 1, b * VTA_ACC_WIDTH) = mix_val;
short_cmp_res.range((b + 1) * VTA_OUT_WIDTH - 1, b * VTA_OUT_WIDTH) =
(out_T) mix_val.range(VTA_OUT_WIDTH - 1, 0);
// Compute Sum
acc_T add_val =
src_0.range(VTA_ACC_WIDTH - 1, 0) + src_1.range(VTA_ACC_WIDTH - 1, 0);
add_res[i].range((b + 1) * VTA_ACC_WIDTH - 1, b * VTA_ACC_WIDTH) = add_val;
short_add_res[i].range((b + 1) * VTA_OUT_WIDTH - 1, b * VTA_OUT_WIDTH) =
(inp_T) add_val.range(VTA_OUT_WIDTH - 1, 0);
// Compute Right Shift
acc_T rshr_val =
add_res.range((b + 1) * VTA_ACC_WIDTH - 1, b * VTA_ACC_WIDTH) = add_val;
short_add_res.range((b + 1) * VTA_OUT_WIDTH - 1, b * VTA_OUT_WIDTH) =
(out_T) add_val.range(VTA_OUT_WIDTH - 1, 0);
// Compute Shift Right
acc_T shr_val =
src_0 >> (aluop_sh_imm_T) src_1.range(VTA_LOG_ACC_WIDTH - 1, 0);
rshr_res[i].range((b + 1) * VTA_ACC_WIDTH - 1, b * VTA_ACC_WIDTH) = rshr_val;
short_rshr_res[i].range((b + 1) * VTA_OUT_WIDTH - 1, b * VTA_OUT_WIDTH) =
(inp_T) rshr_val.range(VTA_OUT_WIDTH-1, 0);
// Compute Left Shift
acc_T lshr_val =
src_0 << (aluop_sh_imm_T) src_1.range(VTA_LOG_ACC_WIDTH - 1, 0);
lshr_res[i].range((b + 1) * VTA_ACC_WIDTH - 1, b * VTA_ACC_WIDTH) = lshr_val;
short_lshr_res[i].range((b + 1) * VTA_OUT_WIDTH - 1, b * VTA_OUT_WIDTH) =
(inp_T) lshr_val.range(VTA_OUT_WIDTH-1, 0);
shr_res.range((b + 1) * VTA_ACC_WIDTH - 1, b * VTA_ACC_WIDTH) = shr_val;
short_shr_res.range((b + 1) * VTA_OUT_WIDTH - 1, b * VTA_OUT_WIDTH) =
(out_T) shr_val.range(VTA_OUT_WIDTH-1, 0);
}
// Store to accum memory/store buffer
if (alu_opcode == VTA_ALU_OPCODE_MIN ||
alu_opcode == VTA_ALU_OPCODE_MAX) {
acc_mem[dst_idx][i] = cmp_res[i];
out_mem[dst_idx][i] = short_cmp_res[i];
acc_mem[dst_idx][i] = cmp_res;
out_mem[dst_idx][i] = short_cmp_res;
} else if (alu_opcode == VTA_ALU_OPCODE_ADD) {
acc_mem[dst_idx][i] = add_res[i];
out_mem[dst_idx][i] = short_add_res[i];
acc_mem[dst_idx][i] = add_res;
out_mem[dst_idx][i] = short_add_res;
} else if (alu_opcode == VTA_ALU_OPCODE_SHR) {
acc_mem[dst_idx][i] = rshr_res[i];
out_mem[dst_idx][i] = short_rshr_res[i];
} else if (alu_opcode == VTA_ALU_OPCODE_SHL) {
acc_mem[dst_idx][i] = lshr_res[i];
out_mem[dst_idx][i] = short_lshr_res[i];
acc_mem[dst_idx][i] = shr_res;
out_mem[dst_idx][i] = short_shr_res;
}
}
}
}
#endif // NO_ALU
// Update offsets
dst_offset_in += dst_factor_in;
......
......@@ -92,7 +92,7 @@ typedef ap_uint<VTA_ALU_OPCODE_BIT_WIDTH> aluop_opcode_T;
typedef ap_int<VTA_ALUOP_IMM_BIT_WIDTH> aluop_imm_T;
/* \typedef aluop_opcode_T ALU operation shift immediate datatype*/
typedef ap_uint<VTA_LOG_ACC_WIDTH> aluop_sh_imm_T;
typedef ap_int<VTA_LOG_ACC_WIDTH> aluop_sh_imm_T;
/*!
* \brief Fetch module.
......
......@@ -93,9 +93,7 @@ extern "C" {
/*! Instruction opcode field bitwidth */
#define VTA_OPCODE_BIT_WIDTH 3
/*! ALU opcode field bitwidth */
#define VTA_ALU_OPCODE_BIT_WIDTH 3
/*! ALU instruction reset mode bitwidth */
#define VTA_ALU_RESET_BIT_WIDTH 2
#define VTA_ALU_OPCODE_BIT_WIDTH 2
/*! Opcode: load encoding */
#define VTA_OPCODE_LOAD 0
......@@ -114,23 +112,8 @@ extern "C" {
#define VTA_ALU_OPCODE_MAX 1
/*! ALU opcode: binary add op */
#define VTA_ALU_OPCODE_ADD 2
/*! ALU opcode: binary sub op [NOT IMPLEMENTED] */
#define VTA_ALU_OPCODE_SUB 3
/*! ALU opcode: binary mul op [NOT IMPLEMENTED] */
#define VTA_ALU_OPCODE_MUL 4
/*! ALU opcode: shift left by immediate op */
#define VTA_ALU_OPCODE_SHL 5
/*! ALU opcode: shift right by immediate op [NOT IMPLEMENTED] */
#define VTA_ALU_OPCODE_SHR 6
/*! ALU instruction reset mode: set to min */
#define VTA_ALU_RESET_MIN 3
/*! ALU instruction reset mode: set to zero */
#define VTA_ALU_RESET_ZERO 0
/*! ALU instruction reset mode: no reset */
#define VTA_ALU_NO_RESET 2
/*! ALU instruction reset mode: set to max */
#define VTA_ALU_RESET_MAX 1
/*! ALU opcode: shift right by immediate op */
#define VTA_ALU_OPCODE_SHR 3
/*! Memory type field bitwidth */
#define VTA_MEMOP_ID_BIT_WIDTH 2
......@@ -149,7 +132,7 @@ extern "C" {
/*! ALU Instruction: immediate bitwidth*/
#define VTA_ALUOP_IMM_BIT_WIDTH 16
/*! GEMM/ALU Instruction: loop max iter bits */
#define VTA_LOOP_ITER_WIDTH 15
#define VTA_LOOP_ITER_WIDTH 14
/*! Mem ID constant: uop memory */
#define VTA_MEM_ID_UOP 0
......@@ -190,16 +173,17 @@ extern "C" {
// arg 2: pop_next_dependence | bool |
// arg 3: push_prev_dependence | bool |
// arg 4: push_next_dependence | bool |
// arg 5: uop_bgn | uop_idx_T |
// arg 6: uop_end | uop_idx_T |
// arg 7: iteration count ax0 | loop_T |
// arg 8: iteration count ax1 | loop_T |
// arg 9: accum idx factor ax0 | acc_idx_T |
// arg a: accum idx factor ax1 | acc_idx_T |
// arg b: input idx factor ax0 | acc_idx_T |
// arg c: input idx factor ax1 | acc_idx_T |
// arg d: weight idx factor ax0 | wgt_idx_T |
// arg e: weight idx factor ax1 | wgt_idx_T |
// arg 5: reset_reg | bool |
// arg 6: uop_bgn | uop_idx_T |
// arg 7: uop_end | uop_idx_T |
// arg 8: iteration count ax0 | loop_T |
// arg 9: iteration count ax1 | loop_T |
// arg a: accum idx factor ax0 | acc_idx_T |
// arg b: accum idx factor ax1 | acc_idx_T |
// arg c: input idx factor ax0 | inp_idx_T |
// arg d: input idx factor ax1 | inp_idx_T |
// arg e: weight idx factor ax0 | wgt_idx_T |
// arg f: weight idx factor ax1 | wgt_idx_T |
//
// ALU
// _____________________________|_type______________|
......@@ -208,17 +192,18 @@ extern "C" {
// arg 2: pop_next_dependence | bool |
// arg 3: push_prev_dependence | bool |
// arg 4: push_next_dependence | bool |
// arg 5: uop_bgn | uop_idx_T |
// arg 6: uop_end | uop_idx_T |
// arg 7: iteration count ax0 | loop_T |
// arg 8: iteration count ax1 | loop_T |
// arg 9: dst idx factor ax0 | acc_idx_T |
// arg a: dst idx factor ax1 | acc_idx_T |
// arg b: src idx factor ax0 | acc_idx_T |
// arg c: src idx factor ax1 | acc_idx_T |
// arg d: alu_opcode | aluop_opcode_T |
// arg e: use_imm | bool |
// arg f: imm | alu_imm_T |
// arg 5: reset_reg | bool |
// arg 6: uop_bgn | uop_idx_T |
// arg 7: uop_end | uop_idx_T |
// arg 8: iteration count ax0 | loop_T |
// arg 9: iteration count ax1 | loop_T |
// arg a: dst idx factor ax0 | acc_idx_T |
// arg b: dst idx factor ax1 | acc_idx_T |
// arg c: src idx factor ax0 | inp_idx_T |
// arg d: src idx factor ax1 | inp_idx_T |
// arg e: alu_opcode | aluop_opcode_T |
// arg f: use_imm | bool |
// arg g: imm | alu_imm_T |
/*! Load/Store instruction start position of the opcode field */
#define VTA_INSN_MEM_0_0 0
......@@ -285,88 +270,82 @@ extern "C" {
#define VTA_INSN_GEM_3 (VTA_INSN_GEM_2 + 1)
/*! GEMM instruction position of the push_next_dependence field */
#define VTA_INSN_GEM_4 (VTA_INSN_GEM_3 + 1)
/*! GEMM instruction position of the reset register bit */
#define VTA_INSN_GEM_5 (VTA_INSN_GEM_4 + 1)
/*! GEMM instruction start position of the uop_bgn field */
#define VTA_INSN_GEM_5_0 (VTA_INSN_GEM_4 + 1)
#define VTA_INSN_GEM_6_0 (VTA_INSN_GEM_5 + 1)
/*! GEMM instruction end position of the uop_bgn field */
#define VTA_INSN_GEM_5_1 (VTA_INSN_GEM_5_0 + VTA_LOG_UOP_BUFF_DEPTH - 1)
#define VTA_INSN_GEM_6_1 (VTA_INSN_GEM_6_0 + VTA_LOG_UOP_BUFF_DEPTH - 1)
/*! GEMM instruction start position of the uop_end field */
#define VTA_INSN_GEM_6_0 (VTA_INSN_GEM_5_1 + 1)
#define VTA_INSN_GEM_7_0 (VTA_INSN_GEM_6_1 + 1)
/*! GEMM instruction end position of the uop_end field */
#define VTA_INSN_GEM_6_1 (VTA_INSN_GEM_6_0 + VTA_LOG_UOP_BUFF_DEPTH + 1 - 1)
#define VTA_INSN_GEM_7_1 (VTA_INSN_GEM_7_0 + VTA_LOG_UOP_BUFF_DEPTH + 1 - 1)
/*! GEMM instruction start position of the iter_out field */
#define VTA_INSN_GEM_7_0 (VTA_INSN_GEM_6_1 + 1)
#define VTA_INSN_GEM_8_0 (VTA_INSN_GEM_7_1 + 1)
/*! GEMM instruction end position of the iter_out field */
#define VTA_INSN_GEM_7_1 (VTA_INSN_GEM_7_0 + VTA_LOOP_ITER_WIDTH - 1)
#define VTA_INSN_GEM_8_1 (VTA_INSN_GEM_8_0 + VTA_LOOP_ITER_WIDTH - 1)
/*! GEMM instruction start position of the iter_in field */
#define VTA_INSN_GEM_8_0 (VTA_INSN_GEM_7_1 + 1)
#define VTA_INSN_GEM_9_0 (VTA_INSN_GEM_8_1 + 1)
/*! GEMM instruction end position of the iter_in field */
#define VTA_INSN_GEM_8_1 (VTA_INSN_GEM_8_0 + VTA_LOOP_ITER_WIDTH - 1)
#define VTA_INSN_GEM_9_1 (VTA_INSN_GEM_9_0 + VTA_LOOP_ITER_WIDTH - 1)
/*! GEMM instruction start position of the dst_factor_out field */
#define VTA_INSN_GEM_9_0 64
#define VTA_INSN_GEM_A_0 64
/*! GEMM instruction end position of the dst_factor_out field */
#define VTA_INSN_GEM_9_1 (VTA_INSN_GEM_9_0 + VTA_LOG_ACC_BUFF_DEPTH - 1)
#define VTA_INSN_GEM_A_1 (VTA_INSN_GEM_A_0 + VTA_LOG_ACC_BUFF_DEPTH - 1)
/*! GEMM instruction start position of the dst_factor_in field */
#define VTA_INSN_GEM_A_0 (VTA_INSN_GEM_9_1 + 1)
#define VTA_INSN_GEM_B_0 (VTA_INSN_GEM_A_1 + 1)
/*! GEMM instruction end position of the dst_factor_in field */
#define VTA_INSN_GEM_A_1 (VTA_INSN_GEM_A_0 + VTA_LOG_ACC_BUFF_DEPTH - 1)
#define VTA_INSN_GEM_B_1 (VTA_INSN_GEM_B_0 + VTA_LOG_ACC_BUFF_DEPTH - 1)
/*! GEMM instruction start position of the src_factor_out field */
#define VTA_INSN_GEM_B_0 (VTA_INSN_GEM_A_1 + 1)
#define VTA_INSN_GEM_C_0 (VTA_INSN_GEM_B_1 + 1)
/*! GEMM instruction end position of the src_factor_out field */
#define VTA_INSN_GEM_B_1 (VTA_INSN_GEM_B_0 + VTA_LOG_ACC_BUFF_DEPTH - 1)
#define VTA_INSN_GEM_C_1 (VTA_INSN_GEM_C_0 + VTA_LOG_INP_BUFF_DEPTH - 1)
/*! GEMM instruction start position of the src_factor_in field */
#define VTA_INSN_GEM_C_0 (VTA_INSN_GEM_B_1 + 1)
#define VTA_INSN_GEM_D_0 (VTA_INSN_GEM_C_1 + 1)
/*! GEMM instruction end position of the src_factor_in field */
#define VTA_INSN_GEM_C_1 (VTA_INSN_GEM_C_0 + VTA_LOG_ACC_BUFF_DEPTH - 1)
#define VTA_INSN_GEM_D_1 (VTA_INSN_GEM_D_0 + VTA_LOG_INP_BUFF_DEPTH - 1)
/*! GEMM instruction start position of the wgt_factor_out field */
#define VTA_INSN_GEM_D_0 (VTA_INSN_GEM_C_1 + 1)
#define VTA_INSN_GEM_E_0 (VTA_INSN_GEM_D_1 + 1)
/*! GEMM instruction end position of the wgt_factor_out field */
#define VTA_INSN_GEM_D_1 (VTA_INSN_GEM_D_0 + VTA_LOG_WGT_BUFF_DEPTH - 1)
#define VTA_INSN_GEM_E_1 (VTA_INSN_GEM_E_0 + VTA_LOG_WGT_BUFF_DEPTH - 1)
/*! GEMM instruction start position of the wgt_factor_in field */
#define VTA_INSN_GEM_E_0 (VTA_INSN_GEM_D_1 + 1)
#define VTA_INSN_GEM_F_0 (VTA_INSN_GEM_E_1 + 1)
/*! GEMM instruction end position of the wgt_factor_in field */
#define VTA_INSN_GEM_E_1 (VTA_INSN_GEM_E_0 + VTA_LOG_WGT_BUFF_DEPTH - 1)
#define VTA_INSN_GEM_F_1 (VTA_INSN_GEM_F_0 + VTA_LOG_WGT_BUFF_DEPTH - 1)
/*! ALU instruction start position of the alu_opcode field */
#define VTA_INSN_ALU_D_0 (VTA_INSN_GEM_C_1 + 1)
#define VTA_INSN_ALU_E_0 (VTA_INSN_GEM_D_1 + 1)
/*! ALU instruction end position of the alu_opcode field */
#define VTA_INSN_ALU_D_1 (VTA_INSN_ALU_D_0 + VTA_ALU_OPCODE_BIT_WIDTH - 1)
#define VTA_INSN_ALU_E_1 (VTA_INSN_ALU_E_0 + VTA_ALU_OPCODE_BIT_WIDTH - 1)
/*! ALU instruction position of the use_imm field */
#define VTA_INSN_ALU_E (VTA_INSN_ALU_D_1 + 1)
#define VTA_INSN_ALU_F (VTA_INSN_ALU_E_1 + 1)
/*! ALU instruction start position of the immediate field */
#define VTA_INSN_ALU_F_0 (VTA_INSN_ALU_E + 1)
#define VTA_INSN_ALU_G_0 (VTA_INSN_ALU_F + 1)
/*! ALU instruction end position of the immediate field */
#define VTA_INSN_ALU_F_1 (VTA_INSN_ALU_F_0 + VTA_ALUOP_IMM_BIT_WIDTH - 1)
#define VTA_INSN_ALU_G_1 (VTA_INSN_ALU_G_0 + VTA_ALUOP_IMM_BIT_WIDTH - 1)
/*! GEMM Micro-op position of the reset_out field */
#define VTA_UOP_GEM_0 0
/*! GEMM Micro-op start position of the acc_idx field */
#define VTA_UOP_GEM_1_0 (VTA_UOP_GEM_0 + 1)
#define VTA_UOP_GEM_0_0 0
/*! GEMM Micro-op end position of the acc_idx field */
#define VTA_UOP_GEM_1_1 (VTA_UOP_GEM_1_0 + VTA_LOG_ACC_BUFF_DEPTH - 1)
#define VTA_UOP_GEM_0_1 (VTA_UOP_GEM_0_0 + VTA_LOG_ACC_BUFF_DEPTH - 1)
/*! GEMM Micro-op start position of the inp_idx field */
#define VTA_UOP_GEM_2_0 (VTA_UOP_GEM_1_1 + 1)
#define VTA_UOP_GEM_1_0 (VTA_UOP_GEM_0_1 + 1)
/*! GEMM Micro-op end position of the inp_idx field */
#define VTA_UOP_GEM_2_1 (VTA_UOP_GEM_2_0 + VTA_LOG_ACC_BUFF_DEPTH - 1)
#define VTA_UOP_GEM_1_1 (VTA_UOP_GEM_1_0 + VTA_LOG_INP_BUFF_DEPTH - 1)
/*! GEMM Micro-op start position of the wgt_idx field */
#define VTA_UOP_GEM_3_0 (VTA_UOP_GEM_2_1 + 1)
#define VTA_UOP_GEM_2_0 (VTA_UOP_GEM_1_1 + 1)
/*! GEMM Micro-op end position of the wgt_idx field */
#define VTA_UOP_GEM_3_1 (VTA_UOP_GEM_3_0 + VTA_LOG_WGT_BUFF_DEPTH - 1)
#define VTA_UOP_GEM_2_1 (VTA_UOP_GEM_2_0 + VTA_LOG_WGT_BUFF_DEPTH - 1)
/*! GEMM Micro-op position of the reset_out field */
#define VTA_UOP_ALU_0 0
/*! GEMM Micro-op start position of the acc_idx field */
#define VTA_UOP_ALU_1_0 (VTA_UOP_ALU_0 + 1)
#define VTA_UOP_ALU_0_0 0
/*! GEMM Micro-op end position of the acc_idx field */
#define VTA_UOP_ALU_1_1 (VTA_UOP_ALU_1_0 + VTA_LOG_ACC_BUFF_DEPTH - 1)
#define VTA_UOP_ALU_0_1 (VTA_UOP_ALU_0_0 + VTA_LOG_ACC_BUFF_DEPTH - 1)
/*! GEMM Micro-op start position of the inp_idx field */
#define VTA_UOP_ALU_2_0 (VTA_UOP_ALU_1_1 + 1)
#define VTA_UOP_ALU_1_0 (VTA_UOP_ALU_0_1 + 1)
/*! GEMM Micro-op end position of the inp_idx field */
#define VTA_UOP_ALU_2_1 (VTA_UOP_ALU_2_0 + VTA_LOG_ACC_BUFF_DEPTH - 1)
/*! GEMM Micro-op start position of the wgt_idx field */
#define VTA_UOP_ALU_3_0 (VTA_UOP_ALU_2_1 + 1)
/*! GEMM Micro-op end position of the wgt_idx field */
#define VTA_UOP_ALU_3_1 (VTA_UOP_ALU_3_0 + VTA_LOG_WGT_BUFF_DEPTH - 1)
#define VTA_UOP_ALU_1_1 (VTA_UOP_ALU_1_0 + VTA_LOG_INP_BUFF_DEPTH - 1)
/*! \brief VTA generic instruction */
typedef struct {
......@@ -454,10 +433,12 @@ typedef struct {
uint64_t push_prev_dep : 1;
/*! \brief Push dependence token to store stage */
uint64_t push_next_dep : 1;
/*! \brief Reset register */
uint64_t reset_reg : 1;
/*! \brief Micro-op begin address */
uint64_t uop_bgn : VTA_LOG_UOP_BUFF_DEPTH;
/*! \brief Micro-op end address */
uint64_t uop_end : VTA_LOG_UOP_BUFF_DEPTH+1;
uint64_t uop_end : VTA_LOG_UOP_BUFF_DEPTH + 1;
/*! \brief Iterations in the outer uop execution loop */
uint64_t iter_out : VTA_LOOP_ITER_WIDTH;
/*! \brief Iterations in the inner uop execution loop */
......@@ -467,9 +448,9 @@ typedef struct {
/*! \brief Inner loop accumulator memory index factor */
uint64_t dst_factor_in : VTA_LOG_ACC_BUFF_DEPTH;
/*! \brief Outer loop input memory index factor */
uint64_t src_factor_out : VTA_LOG_ACC_BUFF_DEPTH;
uint64_t src_factor_out : VTA_LOG_INP_BUFF_DEPTH;
/*! \brief Inner loop input memory index factor */
uint64_t src_factor_in : VTA_LOG_ACC_BUFF_DEPTH;
uint64_t src_factor_in : VTA_LOG_INP_BUFF_DEPTH;
/*! \brief Outer loop weight memory index factor */
uint64_t wgt_factor_out : VTA_LOG_WGT_BUFF_DEPTH;
/*! \brief Inner loop weight memory index factor */
......@@ -516,10 +497,12 @@ typedef struct {
uint64_t push_prev_dep : 1;
/*! \brief Push dependence token to store stage */
uint64_t push_next_dep : 1;
/*! \brief Reset register */
uint64_t reset_reg : 1;
/*! \brief Micro-op begin address */
uint64_t uop_bgn : VTA_LOG_UOP_BUFF_DEPTH;
/*! \brief Micro-op end address */
uint64_t uop_end : VTA_LOG_UOP_BUFF_DEPTH+1;
uint64_t uop_end : VTA_LOG_UOP_BUFF_DEPTH + 1;
/*! \brief Iterations in the outer uop execution loop */
uint64_t iter_out : VTA_LOOP_ITER_WIDTH;
/*! \brief Iterations in the inner uop execution loop */
......@@ -529,9 +512,9 @@ typedef struct {
/*! \brief Inner loop accumulator memory destination index factor */
uint64_t dst_factor_in : VTA_LOG_ACC_BUFF_DEPTH;
/*! \brief Outer loop accumulator memory source index factor */
uint64_t src_factor_out : VTA_LOG_ACC_BUFF_DEPTH;
uint64_t src_factor_out : VTA_LOG_INP_BUFF_DEPTH;
/*! \brief Inner loop accumulator memory source index factor */
uint64_t src_factor_in : VTA_LOG_ACC_BUFF_DEPTH;
uint64_t src_factor_in : VTA_LOG_INP_BUFF_DEPTH;
/*! \brief ALU opcode */
uint64_t alu_opcode : VTA_ALU_OPCODE_BIT_WIDTH;
/*! \brief Use immediate is true */
......@@ -554,12 +537,10 @@ union VTAInsn {
/*! \brief VTA micro-op for GEMM/ALU instruction */
typedef struct {
/*! \brief Initialize acc_mem at index dst_idx to 0*/
uint32_t reset_out : 1;
/*! \brief Destination index (indexes accum buffer) */
uint32_t dst_idx : VTA_LOG_ACC_BUFF_DEPTH;
/*! \brief Source index (indexes input buffer for GEMM or accum buffer for ALU) */
uint32_t src_idx : VTA_LOG_ACC_BUFF_DEPTH;
uint32_t src_idx : VTA_LOG_INP_BUFF_DEPTH;
/*! \brief Weight index (indexes weight buffer) */
uint32_t wgt_idx : VTA_LOG_WGT_BUFF_DEPTH;
} VTAUop;
......
"""VTA configuration constants (should match hw_spec.h"""
from __future__ import absolute_import as _abs
# Log of input/activation width in bits (default 3 -> 8 bits)
VTA_LOG_INP_WIDTH = 3
# Log of kernel weight width in bits (default 3 -> 8 bits)
VTA_LOG_WGT_WIDTH = 3
# Log of accum width in bits (default 5 -> 32 bits)
VTA_LOG_ACC_WIDTH = 5
# Log of tensor batch size (A in (A,B)x(B,C) matrix multiplication)
VTA_LOG_BATCH = 0
# Log of tensor inner block size (B in (A,B)x(B,C) matrix multiplication)
VTA_LOG_BLOCK_IN = 4
# Log of tensor outer block size (C in (A,B)x(B,C) matrix multiplication)
VTA_LOG_BLOCK_OUT = 4
VTA_LOG_OUT_WIDTH = VTA_LOG_INP_WIDTH
# Log of uop buffer size in Bytes
VTA_LOG_UOP_BUFF_SIZE = 15
# Log of acc buffer size in Bytes
VTA_LOG_ACC_BUFF_SIZE = 17
import os
python_vta_dir = os.path.dirname(__file__)
print python_vta_dir
filename = os.path.join(python_vta_dir, '../../config.mk')
VTA_PYNQ_BRAM_W = 32
VTA_PYNQ_BRAM_D = 1024
VTA_PYNQ_NUM_BRAM = 124
keys = ["VTA_LOG_INP_WIDTH", "VTA_LOG_WGT_WIDTH", "VTA_LOG_ACC_WIDTH",
"VTA_LOG_BATCH", "VTA_LOG_BLOCK_IN", "VTA_LOG_BLOCK_OUT",
"VTA_LOG_UOP_BUFF_SIZE", "VTA_LOG_INP_BUFF_SIZE",
"VTA_LOG_WGT_BUFF_SIZE", "VTA_LOG_ACC_BUFF_SIZE"]
params = {}
if os.path.isfile(filename):
with open(filename) as f:
for line in f:
for k in keys:
if k+" =" in line:
val = line.split("=")[1].strip()
params[k] = int(val)
# print params
else:
print "Error: {} not found. Please make sure you have config.mk in your vta root".format(filename)
exit()
# The Constants
VTA_WGT_WIDTH = 8
VTA_INP_WIDTH = VTA_WGT_WIDTH
VTA_OUT_WIDTH = 32
VTA_INP_WIDTH = 1 << params["VTA_LOG_INP_WIDTH"]
VTA_WGT_WIDTH = 1 << params["VTA_LOG_WGT_WIDTH"]
VTA_OUT_WIDTH = 1 << params["VTA_LOG_ACC_WIDTH"]
VTA_TARGET = "VTA_PYNQ_TARGET"
# Dimensions of the GEMM unit
# (BATCH,BLOCK_IN) x (BLOCK_IN,BLOCK_OUT)
VTA_BATCH = 1
VTA_BLOCK_IN = 16
VTA_BLOCK_OUT = 16
VTA_BATCH = 1 << params["VTA_LOG_BATCH"]
VTA_BLOCK_IN = 1 << params["VTA_LOG_BLOCK_IN"]
VTA_BLOCK_OUT = 1 << params["VTA_LOG_BLOCK_OUT"]
# log-2 On-chip wgt buffer size in Bytes
VTA_LOG_WGT_BUFF_SIZE = 15
# log-2 On-chip uop buffer size in Bytes
VTA_LOG_UOP_BUFF_SIZE = params["VTA_LOG_UOP_BUFF_SIZE"]
# log-2 On-chip input buffer size in Bytes
VTA_LOG_INP_BUFF_SIZE = 15
VTA_LOG_INP_BUFF_SIZE = params["VTA_LOG_INP_BUFF_SIZE"]
# log-2 On-chip wgt buffer size in Bytes
VTA_LOG_WGT_BUFF_SIZE = params["VTA_LOG_WGT_BUFF_SIZE"]
# log-2 On-chip output buffer size in Bytes
VTA_LOG_OUT_BUFF_SIZE = 17
# On-chip wgt buffer size in Bytes
VTA_WGT_BUFF_SIZE = 1 << VTA_LOG_WGT_BUFF_SIZE
VTA_LOG_OUT_BUFF_SIZE = params["VTA_LOG_ACC_BUFF_SIZE"]
# Uop buffer size
VTA_UOP_BUFF_SIZE = 1 << VTA_LOG_UOP_BUFF_SIZE
# Input buffer size
VTA_INP_BUFF_SIZE = 1 << VTA_LOG_INP_BUFF_SIZE
# On-chip wgt buffer size in Bytes
VTA_WGT_BUFF_SIZE = 1 << VTA_LOG_WGT_BUFF_SIZE
# Output buffer size.
VTA_OUT_BUFF_SIZE = 1 << VTA_LOG_OUT_BUFF_SIZE
......@@ -69,11 +83,7 @@ VTA_MEM_ID_OUT = 4
VTA_ALU_OPCODE_MIN = 0
VTA_ALU_OPCODE_MAX = 1
VTA_ALU_OPCODE_ADD = 2
VTA_ALU_OPCODE_SUB = 3
VTA_ALU_OPCODE_MUL = 4
VTA_ALU_OPCODE_SHL = 5
VTA_ALU_OPCODE_SHR = 6
VTA_ALU_OPCODE_UNSET = 7
VTA_ALU_OPCODE_SHR = 3
# Task queue id (pipeline stage)
VTA_QID_LOAD_INP = 1
......
......@@ -617,7 +617,6 @@ def inject_alu_intrin(stmt_in):
extents.append(tmp_body.extent)
tmp_body = tmp_body.body
# Derive opcode
alu_opcode = spec.VTA_ALU_OPCODE_UNSET
if isinstance(loop_body.value, tvm.expr.Add):
alu_opcode = spec.VTA_ALU_OPCODE_ADD
lhs = loop_body.value.a
......@@ -640,9 +639,9 @@ def inject_alu_intrin(stmt_in):
rhs = loop_body.value.b
elif isinstance(loop_body.value, tvm.expr.Call):
if loop_body.value.name == 'shift_left':
alu_opcode = spec.VTA_ALU_OPCODE_SHL
alu_opcode = spec.VTA_ALU_OPCODE_SHR
lhs = loop_body.value.args[0]
rhs = loop_body.value.args[1]
rhs = tvm.ir_pass.Simplify(-loop_body.value.args[1])
elif loop_body.value.name == 'shift_right':
alu_opcode = spec.VTA_ALU_OPCODE_SHR
lhs = loop_body.value.args[0]
......@@ -732,6 +731,7 @@ def inject_alu_intrin(stmt_in):
tvm.ir_pass.Simplify(c/(spec.VTA_BATCH*spec.VTA_BLOCK_OUT)) for c in dst_coeff]
# Flatten the outer loops
if extents:
src_coeff, dst_coeff, extents = _flatten_loop(src_coeff, dst_coeff, extents)
# Insert ALU micro-ops
......
......@@ -117,7 +117,6 @@ class UopKernel {
// The loop nest structure
VerifyDep(dst_index);
VTAUop op;
op.reset_out = reset_out;
op.dst_idx = dst_index;
op.src_idx = src_index;
op.wgt_idx = wgt_index;
......@@ -128,6 +127,12 @@ class UopKernel {
} else {
assert(mode_ == mode);
}
// Set reset_out field if unset
if (reset_out_ == 0xFFFFFFFF) {
reset_out_ = reset_out;
} else {
assert(reset_out_ == reset_out);
}
// Check kernel op and imm/imm_val in ALU mode
if (mode == 1) {
if (opcode_ == 0xFFFFFFFF) {
......@@ -146,12 +151,11 @@ class UopKernel {
uint32_t size = seq_.size();
printf("There are %u uops\n", size);
for (uint32_t i = 0; i < size; ++i) {
printf("[%04u]\t acc=%u, inp=%u, wgt=%u, reset_out=%u\n",
printf("[%04u]\t acc=%u, inp=%u, wgt=%u\n",
i,
seq_[i].dst_idx,
seq_[i].src_idx,
seq_[i].wgt_idx,
seq_[i].reset_out);
seq_[i].wgt_idx);
}
printf("\n");
}
......@@ -160,6 +164,7 @@ class UopKernel {
// The kernel's mode, opcode, immediate setting and value
uint32_t mode_{0xFFFFFFFF}; // UOP type: 0xFFFFFFFF - unset, 0 - GEMM, 1 - ALU
uint32_t opcode_{0xFFFFFFFF};
uint32_t reset_out_{0xFFFFFFFF};
bool use_imm_{false};
uint16_t imm_val_{0};
......@@ -521,20 +526,6 @@ class InsnQueue : public BaseQueue {
} else {
return "add";
}
} else if (opcode == VTA_ALU_OPCODE_SUB) {
if (use_imm) {
return "sub imm";
} else {
return "sub";
}
} else if (opcode == VTA_ALU_OPCODE_MUL) {
if (use_imm) {
return "mul imm";
} else {
return "mul";
}
} else if (opcode == VTA_ALU_OPCODE_SHL) {
return "shl";
} else if (opcode == VTA_ALU_OPCODE_SHR) {
return "shr";
}
......@@ -641,6 +632,7 @@ class InsnQueue : public BaseQueue {
static_cast<int>(c.mem.pop_next_dep),
static_cast<int>(c.mem.push_prev_dep),
static_cast<int>(c.mem.push_next_dep));
printf("\treset_out: %d\n", static_cast<int>(c.gemm.reset_reg));
printf("\trange (%d, %d)\n",
static_cast<int>(c.gemm.uop_bgn),
static_cast<int>(c.gemm.uop_end));
......@@ -662,6 +654,7 @@ class InsnQueue : public BaseQueue {
static_cast<int>(c.mem.pop_next_dep),
static_cast<int>(c.mem.push_prev_dep),
static_cast<int>(c.mem.push_next_dep));
printf("\treset_out: %d\n", static_cast<int>(c.alu.reset_reg));
printf("\trange (%d, %d)\n",
static_cast<int>(c.alu.uop_bgn),
static_cast<int>(c.alu.uop_end));
......@@ -1081,6 +1074,7 @@ class CommandQueue {
}
VTAGemInsn* insn = insn_queue_.CreateGemInsn();
insn->opcode = VTA_OPCODE_GEMM;
insn->reset_reg = kernel->reset_out_;
insn->uop_bgn = kernel->sram_begin_;
insn->uop_end = kernel->sram_end_;
const std::vector<UopKernel::LoopEntry> &loop = kernel->loop();
......@@ -1119,6 +1113,7 @@ class CommandQueue {
}
VTAAluInsn* insn = insn_queue_.CreateAluInsn();
insn->opcode = VTA_OPCODE_ALU;
insn->reset_reg = kernel->reset_out_;
insn->uop_bgn = kernel->sram_begin_;
insn->uop_end = kernel->sram_end_;
insn->alu_opcode = kernel->opcode_;
......
......@@ -28,20 +28,6 @@ const char* getOpcodeString(int opcode, bool use_imm) {
} else {
return "add";
}
} else if (opcode == VTA_ALU_OPCODE_SUB) {
if (use_imm) {
return "sub imm";
} else {
return "sub";
}
} else if (opcode == VTA_ALU_OPCODE_MUL) {
if (use_imm) {
return "mul imm";
} else {
return "mul";
}
} else if (opcode == VTA_ALU_OPCODE_SHL) {
return "shl";
} else if (opcode == VTA_ALU_OPCODE_SHR) {
return "shr";
}
......@@ -170,31 +156,6 @@ void freeBuffer(void * buffer) {
#endif
}
VTAGenericInsn reset2DInsn(int type, int sram_offset, int y_size, int x_size, int x_stride,
int pop_prev_dep, int pop_next_dep, int push_prev_dep, int push_next_dep) {
// Converter
union VTAInsn converter;
// Memory instruction initialization
VTAMemInsn insn = {};
insn.opcode = VTA_OPCODE_LOAD;
insn.pop_prev_dep = pop_prev_dep;
insn.pop_next_dep = pop_next_dep;
insn.push_prev_dep = push_prev_dep;
insn.push_next_dep = push_next_dep;
insn.memory_type = type;
insn.sram_base = sram_offset;
insn.dram_base = 0;
insn.y_size = 0;
insn.x_size = x_size;
insn.x_stride = x_stride;
insn.y_pad_0 = y_size;
insn.y_pad_1 = 0;
insn.x_pad_0 = 0;
insn.x_pad_1 = 0;
converter.mem = insn;
return converter.generic;
}
VTAGenericInsn get2DLoadStoreInsn(int opcode, int type, int sram_offset, int dram_offset,
int y_size, int x_size, int x_stride, int y_pad, int x_pad, int pop_prev_dep, int pop_next_dep,
int push_prev_dep, int push_next_dep) {
......@@ -258,6 +219,7 @@ VTAGenericInsn getGEMMInsn(int uop_offset, int batch, int in_feat, int out_feat,
insn.pop_next_dep = pop_next_dep;
insn.push_prev_dep = push_prev_dep;
insn.push_next_dep = push_next_dep;
insn.reset_reg = false;
if (!uop_compression) {
insn.uop_bgn = uop_offset;
insn.uop_end = uop_offset + batch * in_feat * out_feat;
......@@ -296,6 +258,7 @@ VTAGenericInsn getALUInsn(int opcode, int vector_size, bool use_imm, int imm, bo
insn.pop_next_dep = pop_next_dep;
insn.push_prev_dep = push_prev_dep;
insn.push_next_dep = push_next_dep;
insn.reset_reg = false;
if (!uop_compression) {
insn.uop_bgn = 0;
insn.uop_end = vector_size;
......@@ -335,6 +298,7 @@ VTAGenericInsn getFinishInsn(bool pop_prev, bool pop_next) {
insn.pop_next_dep = pop_next;
insn.push_prev_dep = 0;
insn.push_next_dep = 0;
insn.reset_reg = false;
insn.uop_bgn = 0;
insn.uop_end = 0;
insn.iter_out = 0;
......@@ -364,7 +328,6 @@ VTAUop * getCopyUops(int y_size, int x_size, int uop_compression) {
int uop_idx = 0;
for (int i = 0; i < y_size; i++) {
for (int j = 0; j < x_size; j++) {
uop_buf[uop_idx].reset_out = false;
uop_buf[uop_idx].dst_idx = i * x_size + j;
uop_buf[uop_idx].src_idx = 0;
uop_buf[uop_idx].wgt_idx = 0;
......@@ -372,7 +335,6 @@ VTAUop * getCopyUops(int y_size, int x_size, int uop_compression) {
}
}
} else {
uop_buf[0].reset_out = false;
uop_buf[0].dst_idx = 1;
uop_buf[0].src_idx = 0;
uop_buf[0].wgt_idx = 0;
......@@ -399,7 +361,6 @@ VTAUop * getGEMMUops(int batch, int in_feat, int out_feat, bool uop_compression,
for (int i = 0; i < batch; i++) {
for (int j = 0; j < in_feat; j++) {
for (int k = 0; k < out_feat; k++) {
uop_buf[uop_idx].reset_out = false;
uop_buf[uop_idx].dst_idx = i * out_feat + k;
uop_buf[uop_idx].src_idx = i * in_feat + j;
uop_buf[uop_idx].wgt_idx = k * in_feat + j;
......@@ -409,7 +370,6 @@ VTAUop * getGEMMUops(int batch, int in_feat, int out_feat, bool uop_compression,
}
} else {
for (int i = 0; i < batch; i++) {
uop_buf[i].reset_out = false;
uop_buf[i].dst_idx = i * out_feat;
uop_buf[i].src_idx = i * in_feat;
uop_buf[i].wgt_idx = 0;
......@@ -422,7 +382,6 @@ VTAUop * getGEMMUops(int batch, int in_feat, int out_feat, bool uop_compression,
for (int i = 0; i < batch; i++) {
for (int j = 0; j < in_feat; j++) {
for (int k = 0; k < out_feat; k++) {
uop_buf[uop_idx].reset_out = false;
uop_buf[uop_idx].dst_idx = i * out_feat + k;
uop_buf[uop_idx].src_idx = batch * in_feat + i * in_feat + j;
uop_buf[uop_idx].wgt_idx = out_feat * in_feat + k * in_feat + j;
......@@ -432,7 +391,6 @@ VTAUop * getGEMMUops(int batch, int in_feat, int out_feat, bool uop_compression,
}
} else {
for (int i = 0; i < batch; i++) {
uop_buf[batch+i].reset_out = false;
uop_buf[batch+i].dst_idx = i * out_feat;
uop_buf[batch+i].src_idx = batch * in_feat + i * in_feat;
uop_buf[batch+i].wgt_idx = out_feat * in_feat;
......@@ -456,12 +414,10 @@ VTAUop * getMapALUUops(int vector_size, bool uop_compression) {
if (!uop_compression) {
for (int i = 0; i < vector_size; i++) {
uop_buf[i].reset_out = 0;
uop_buf[i].dst_idx = i;
uop_buf[i].src_idx = vector_size + i;
}
} else {
uop_buf[0].reset_out = 0;
uop_buf[0].dst_idx = 0;
uop_buf[0].src_idx = vector_size;
}
......@@ -511,7 +467,7 @@ void printParameters() {
printf("VTA_INSN_GEM_2 [%d]\n", VTA_INSN_GEM_2);
printf("VTA_INSN_GEM_3 [%d]\n", VTA_INSN_GEM_3);
printf("VTA_INSN_GEM_4 [%d]\n", VTA_INSN_GEM_4);
printf("VTA_INSN_GEM_5 [%d-%d]\n", VTA_INSN_GEM_5_0, VTA_INSN_GEM_5_1);
printf("VTA_INSN_GEM_5 [%d]\n", VTA_INSN_GEM_5);
printf("VTA_INSN_GEM_6 [%d-%d]\n", VTA_INSN_GEM_6_0, VTA_INSN_GEM_6_1);
printf("VTA_INSN_GEM_7 [%d-%d]\n", VTA_INSN_GEM_7_0, VTA_INSN_GEM_7_1);
printf("VTA_INSN_GEM_8 [%d-%d]\n", VTA_INSN_GEM_8_0, VTA_INSN_GEM_8_1);
......@@ -521,17 +477,15 @@ void printParameters() {
printf("VTA_INSN_GEM_C [%d-%d]\n", VTA_INSN_GEM_C_0, VTA_INSN_GEM_C_1);
printf("VTA_INSN_GEM_D [%d-%d]\n", VTA_INSN_GEM_D_0, VTA_INSN_GEM_D_1);
printf("VTA_INSN_GEM_E [%d-%d]\n", VTA_INSN_GEM_E_0, VTA_INSN_GEM_E_1);
printf("VTA_INSN_ALU_D [%d-%d]\n", VTA_INSN_ALU_D_0, VTA_INSN_ALU_D_1);
printf("VTA_INSN_ALU_E [%d]\n", VTA_INSN_ALU_E);
printf("VTA_INSN_ALU_F [%d-%d]\n", VTA_INSN_ALU_F_0, VTA_INSN_ALU_F_1);
printf("VTA_UOP_GEM_0 [%d]\n", VTA_UOP_GEM_0);
printf("VTA_INSN_GEM_F [%d-%d]\n", VTA_INSN_GEM_F_0, VTA_INSN_GEM_F_1);
printf("VTA_INSN_ALU_E [%d-%d]\n", VTA_INSN_ALU_E_0, VTA_INSN_ALU_E_1);
printf("VTA_INSN_ALU_F [%d]\n", VTA_INSN_ALU_F);
printf("VTA_INSN_ALU_G [%d-%d]\n", VTA_INSN_ALU_G_0, VTA_INSN_ALU_G_1);
printf("VTA_UOP_GEM_0 [%d-%d]\n", VTA_UOP_GEM_0_0, VTA_UOP_GEM_0_1);
printf("VTA_UOP_GEM_1 [%d-%d]\n", VTA_UOP_GEM_1_0, VTA_UOP_GEM_1_1);
printf("VTA_UOP_GEM_2 [%d-%d]\n", VTA_UOP_GEM_2_0, VTA_UOP_GEM_2_1);
printf("VTA_UOP_GEM_3 [%d-%d]\n", VTA_UOP_GEM_3_0, VTA_UOP_GEM_3_1);
printf("VTA_UOP_ALU_0 [%d]\n", VTA_UOP_ALU_0);
printf("VTA_UOP_ALU_0 [%d-%d]\n", VTA_UOP_ALU_0_0, VTA_UOP_ALU_0_1);
printf("VTA_UOP_ALU_1 [%d-%d]\n", VTA_UOP_ALU_1_0, VTA_UOP_ALU_1_1);
printf("VTA_UOP_ALU_2 [%d-%d]\n", VTA_UOP_ALU_2_0, VTA_UOP_ALU_2_1);
printf("VTA_UOP_ALU_3 [%d-%d]\n", VTA_UOP_ALU_3_0, VTA_UOP_ALU_3_1);
}
void printInstruction(int num_insn, VTAGenericInsn *insns) {
......@@ -601,6 +555,7 @@ void printInstruction(int num_insn, VTAGenericInsn *insns) {
printf("\trange (%d, %d)\n",
static_cast<int>(c.gemm.uop_bgn),
static_cast<int>(c.gemm.uop_end));
printf("\treset_out: %d\n", static_cast<int>(c.gemm.reset_reg));
printf("\touter loop - iter: %d, acc: %d, inp: %d, wgt: %d\n",
static_cast<int>(c.gemm.iter_out),
static_cast<int>(c.gemm.dst_factor_out),
......@@ -634,6 +589,7 @@ void printInstruction(int num_insn, VTAGenericInsn *insns) {
static_cast<int>(c.mem.pop_next_dep),
static_cast<int>(c.mem.push_prev_dep),
static_cast<int>(c.mem.push_next_dep));
printf("\treset_out: %d\n", static_cast<int>(c.alu.reset_reg));
printf("\trange (%d, %d)\n",
static_cast<int>(c.alu.uop_bgn),
static_cast<int>(c.alu.uop_end));
......@@ -662,8 +618,7 @@ void printMicroOp(int num_uop, VTAUop *uops) {
for (int i = 0; i < num_uop; i++) {
// Read micro-op
printf("DEBUG - UOP %u: ", i);
printf("rst_out=%u, acc=%u, inp= %u, wgt=%u\n", uops[i].reset_out, uops[i].dst_idx,
uops[i].src_idx, uops[i].wgt_idx);
printf("acc=%u, inp= %u, wgt=%u\n", uops[i].dst_idx, uops[i].src_idx, uops[i].wgt_idx);
}
}
......@@ -671,7 +626,6 @@ int alu_test(int opcode, bool use_imm, int batch, int vector_size, bool uop_comp
// Some assertions
assert(batch % VTA_BATCH == 0);
assert(vector_size % VTA_BLOCK_OUT == 0);
assert(!(opcode == VTA_ALU_OPCODE_SHL && !use_imm));
assert(!(opcode == VTA_ALU_OPCODE_SHR && !use_imm));
printf("=====================================================================================\n");
printf("INFO - ALU test of %s: batch=%d, vector_size=%d, uop_compression=%d\n",
......@@ -701,16 +655,9 @@ int alu_test(int opcode, bool use_imm, int batch, int vector_size, bool uop_comp
} else if (opcode == VTA_ALU_OPCODE_ADD) {
immediate[b] = static_cast<acc_T>(
rand_r(&globalSeed) % (1LL << (VTA_INP_WIDTH / 2)) - (1LL << (VTA_INP_WIDTH / 2 - 1)));
} else if (opcode == VTA_ALU_OPCODE_SUB) {
immediate[b] = static_cast<acc_T>(
rand_r(&globalSeed) % (1LL << (VTA_INP_WIDTH / 2)) - (1LL << (VTA_INP_WIDTH / 2 - 1)));
} else if (opcode == VTA_ALU_OPCODE_MUL) {
immediate[b] = static_cast<acc_T>(
rand_r(&globalSeed) % (1LL << (VTA_INP_WIDTH / 2)) - (1LL << (VTA_INP_WIDTH / 2 - 1)));
} else if (opcode == VTA_ALU_OPCODE_SHL) {
immediate[b] = static_cast<acc_T>(rand_r(&globalSeed) % (VTA_INP_WIDTH + 1));
} else if (opcode == VTA_ALU_OPCODE_SHR) {
immediate[b] = static_cast<acc_T>(rand_r(&globalSeed) % (VTA_INP_WIDTH + 1));
immediate[b] = static_cast<acc_T>(
rand_r(&globalSeed) % VTA_ACC_WIDTH - VTA_ACC_WIDTH/2);
}
}
......@@ -783,18 +730,6 @@ int alu_test(int opcode, bool use_imm, int batch, int vector_size, bool uop_comp
} else if (opcode == VTA_ALU_OPCODE_ADD) {
inputs[i][j] = static_cast<acc_T>(
rand_r(&globalSeed) % (1LL << (VTA_INP_WIDTH - 1)) - (1LL << (VTA_INP_WIDTH - 2)));
} else if (opcode == VTA_ALU_OPCODE_SUB) {
inputs[i][j] = static_cast<acc_T>(
rand_r(&globalSeed) % (1LL << (VTA_INP_WIDTH - 1)) - (1LL << (VTA_INP_WIDTH - 2)));
} else if (opcode == VTA_ALU_OPCODE_MUL) {
inputs[i][j] = static_cast<acc_T>(
rand_r(&globalSeed) % (1LL << (VTA_INP_WIDTH / 2)) - (1LL << (VTA_INP_WIDTH / 2 - 1)));
} else if (opcode == VTA_ALU_OPCODE_SHL) {
inputs[i][j] = static_cast<acc_T>(
rand_r(&globalSeed) % (1LL << (VTA_INP_WIDTH - 1)) - (1LL << (VTA_INP_WIDTH - 2)));
} else if (opcode == VTA_ALU_OPCODE_SHR) {
inputs[i][j] = static_cast<acc_T>(
rand_r(&globalSeed) % (1LL << (VTA_INP_WIDTH - 1)) - (1LL << (VTA_INP_WIDTH - 2)));
}
}
}
......@@ -830,22 +765,12 @@ int alu_test(int opcode, bool use_imm, int batch, int vector_size, bool uop_comp
} else {
tmp = inputs[i][j] + immediate[i / VTA_BATCH];
}
} else if (opcode == VTA_ALU_OPCODE_SUB) {
if (!use_imm) {
tmp = inputs[i][j] - inputs[i][j + vector_size];
} else {
tmp = inputs[i][j] - immediate[i / VTA_BATCH];
}
} else if (opcode == VTA_ALU_OPCODE_MUL) {
if (!use_imm) {
tmp = inputs[i][j] * inputs[i][j + vector_size];
} else {
tmp = inputs[i][j] * immediate[i / VTA_BATCH];
}
} else if (opcode == VTA_ALU_OPCODE_SHL) {
tmp = inputs[i][j] << immediate[i / VTA_BATCH];
} else if (opcode == VTA_ALU_OPCODE_SHR) {
if (immediate[i / VTA_BATCH] >= 0) {
tmp = inputs[i][j] >> immediate[i / VTA_BATCH];
} else {
tmp = inputs[i][j] << (0 - immediate[i / VTA_BATCH]);
}
}
// Set
outputs_ref[i][j] = (out_T) tmp;
......@@ -876,7 +801,7 @@ int alu_test(int opcode, bool use_imm, int batch, int vector_size, bool uop_comp
(volatile inp_vec_T *) NULL,
(volatile wgt_vec_T *) NULL,
(volatile acc_vec_T *) bias_buf,
(volatile inp_vec_T *) output_buf);
(volatile out_vec_T *) output_buf);
#endif
// Unpack output buffer
......@@ -1149,7 +1074,7 @@ int blocked_gemm_test(int batch, int channels, int block, bool uop_compression,
(volatile inp_vec_T *) input_buf,
(volatile wgt_vec_T *) weight_buf,
(volatile acc_vec_T *) bias_buf,
(volatile inp_vec_T *) output_buf);
(volatile out_vec_T *) output_buf);
#endif
// Unpack output data
......
......@@ -16,18 +16,20 @@ inp_dtype = "int%d" % vta.VTA_INP_WIDTH
wgt_dtype = "int%d" % vta.VTA_WGT_WIDTH
Workload = namedtuple("Conv2DWorkload",
['height', 'width', 'in_filter', 'out_filter',
['batch', 'height', 'width', 'in_filter', 'out_filter',
'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])
class Conv2DSchedule(object):
def __init__(self,
oc_factor,
b_factor=1,
oc_factor=1,
ko_factor=1,
h_factor=1,
w_factor=0,
oc_nthread=0,
h_nthread=0,
debug_sync=False):
self.b_factor = b_factor
self.oc_factor = oc_factor
self.ko_factor = ko_factor
self.h_factor = h_factor
......@@ -35,10 +37,140 @@ class Conv2DSchedule(object):
self.oc_nthread = oc_nthread
self.h_nthread = h_nthread
self.debug_sync = debug_sync
def __str__(self):
return "{}.{}.{}.{}.{}.{}.{}".format(
self.b_factor, self.oc_factor, self.ko_factor,
self.h_factor, self.w_factor,
self.oc_nthread, self.h_nthread)
Schedule = Conv2DSchedule
def test_conv2d_chwv(key, batch_size, wl, plan, log_frame, profile=True):
def get_insn_count(layer, sched):
b, h, w, ci, co = sched
b_factor = b
h_factor = layer.height / h
w_factor = layer.width / w
ci_factor = int(np.ceil(float(layer.in_filter) / (ci * vta.VTA_BLOCK_IN)))
co_factor = int(np.ceil(float(layer.out_filter) / (co * vta.VTA_BLOCK_OUT)))
input_xfers = b_factor * h_factor * w_factor * co_factor * ci_factor
weight_xfers = b_factor * h_factor * w_factor * co_factor * ci_factor
output_xfers = b_factor * h_factor * w_factor * co_factor
# compute instruction count
# output transfer factor: 4 (INIT, GEMM, ALU, STORE)
# offset: 5 (3 uop kernels, 1 initial dep push, 1 finish, co_factor)
insn_count = input_xfers + weight_xfers + (output_xfers * 4) + 5 + co_factor
return insn_count
def find_schedules(layer, mtOnly=False, bestOnly=False):
# Helper function to get factors
def find_factors(n):
factors = []
for i in range(1, n+1):
if n % i == 0:
factors.append(i)
return factors
# Scheduling exploration
batch_factors = find_factors(int(np.ceil(float(layer.batch) / vta.VTA_BATCH)))
height_factors = find_factors(layer.height / layer.hstride)
width_factors = find_factors(layer.width / layer.wstride)
cin_factors = find_factors(int(np.ceil(float(layer.in_filter) / vta.VTA_BLOCK_IN)))
cout_factors = find_factors(int(np.ceil(float(layer.out_filter) / vta.VTA_BLOCK_OUT)))
ht_factors = [1, 2]
cot_factors = [1, 2]
# Explore schedules
schedules = []
for b in batch_factors:
for h in height_factors:
for w in width_factors:
for ci in cin_factors:
for co in cout_factors:
# FIXME: filter because of 2D load
if w == layer.width/layer.wstride or (w != layer.width/layer.wstride and co == 1):
# FIXME: filter because of 2D load
if ci == 1:
schedules.append([b, h, w, ci, co])
# Filter the schedules that wouldn't work in the available BRAM sizes
input_elem_size_b = vta.VTA_BATCH * vta.VTA_BLOCK_IN * vta.VTA_INP_WIDTH
weight_elem_size_b = vta.VTA_BLOCK_IN * vta.VTA_BLOCK_OUT * vta.VTA_WGT_WIDTH
output_elem_size_b = vta.VTA_BATCH * vta.VTA_BLOCK_OUT * vta.VTA_OUT_WIDTH
input_brams_capacity_b = vta.VTA_INP_BUFF_SIZE * 8
weight_brams_capacity_b = vta.VTA_WGT_BUFF_SIZE * 8
output_brams_capacity_b = vta.VTA_OUT_BUFF_SIZE * 8
fil_sched = []
xfer_size = []
for sched in schedules:
b, h, w, ci, co = sched
for ht in [1, 2]:
for cot in [1, 2]:
# Make sure to filter cases where we apply threading on two axes
# or cases where the threading factors for h and co are not
# factors of h and co
if not (ht == 2 and cot == 2) and h % ht == 0 and co % cot == 0:
# If in multi-threaded mode, only allow for mt configs:
if (mtOnly and (ht == 2 or cot == 2)) or not mtOnly:
h /= ht
co /= cot
input_tile_elems = b * \
((h - 1) * layer.hstride + layer.hkernel) * \
((w - 1) * layer.wstride + layer.wkernel) * ci
weight_tile_elems = layer.hkernel * layer.wkernel * ci * co
output_tile_elems = b * h * w * co
insn_count = get_insn_count(layer, sched)
# 1. Check input capacity
# 2. Check weight capacity
# 3. Check output capacity
# 4. Check instruction capacity
# 5. Make sure that we don't write to the same acc location
# within 2 consecutive cycles
if input_tile_elems*input_elem_size_b <= input_brams_capacity_b/(cot*ht) and \
weight_tile_elems*weight_elem_size_b <= weight_brams_capacity_b and \
output_tile_elems*output_elem_size_b <= output_brams_capacity_b/(cot*ht) and \
insn_count <= vta.VTA_MAX_XFER / 16 and \
h > 2 and w > 2:
schedule = Schedule(oc_factor=co, ko_factor=ci, h_factor=h,
w_factor=w, oc_nthread=cot, h_nthread=ht)
fil_sched.append(schedule)
xfer_size.append(get_data_movementB(schedule, layer))
if bestOnly:
return [fil_sched[xfer_size.index(min(xfer_size))]]
else:
return fil_sched
def get_data_movementB(sched, layer):
# Derive data movement
input_elem_size_b = vta.VTA_BATCH * vta.VTA_BLOCK_IN * vta.VTA_INP_WIDTH
weight_elem_size_b = vta.VTA_BLOCK_IN * vta.VTA_BLOCK_OUT * vta.VTA_WGT_WIDTH
output_elem_size_b = vta.VTA_BATCH * vta.VTA_BLOCK_OUT * vta.VTA_OUT_WIDTH
b = sched.b_factor
h = sched.h_factor
w = sched.w_factor
ci = sched.ko_factor
co = sched.oc_factor
ht = sched.h_nthread
cot = sched.oc_nthread
input_tile_elems = b * \
((h - 1) * layer.hstride + layer.hkernel) * \
((w - 1) * layer.wstride + layer.wkernel) * ci
weight_tile_elems = layer.hkernel * layer.wkernel * ci
output_tile_elems = b * h * w * co
# Derive factors
b_factor = int(np.ceil(float(layer.batch) / (b * vta.VTA_BATCH)))
h_factor = (layer.height / layer.hstride) / h
w_factor = (layer.width / layer.wstride) / w
ci_factor = int(np.ceil(float(layer.in_filter) / (ci * vta.VTA_BLOCK_IN)))
co_factor = int(np.ceil(float(layer.out_filter) / (co * vta.VTA_BLOCK_OUT)))
# Derive transfers
input_xfers = b_factor * h_factor * w_factor * co_factor * ci_factor
weight_xfers = b_factor * h_factor * w_factor * co_factor * ci_factor
output_xfers = b_factor * h_factor * w_factor * co_factor
# Compute total transfer sizes
input_xfer_B = input_tile_elems * input_xfers * input_elem_size_b / 8
weight_xfer_B = weight_tile_elems * weight_xfers * weight_elem_size_b / 8
output_xfer_B = output_tile_elems * output_xfers * output_elem_size_b / 8
total_xfer_B = input_xfer_B + weight_xfer_B + output_xfer_B
return total_xfer_B
def test_conv2d_chwv(layer, key, batch_size, wl, sched, log_frame, profile=True):
assert batch_size % vta.VTA_BATCH == 0
assert wl.in_filter % vta.VTA_BLOCK_IN == 0
assert wl.out_filter % vta.VTA_BLOCK_OUT == 0
......@@ -71,6 +203,7 @@ def test_conv2d_chwv(key, batch_size, wl, plan, log_frame, profile=True):
res_shf = tvm.compute(res_shape, lambda *i: res_cnv(*i) >> 8, name="res_shf")
res = tvm.compute(res_shape, lambda *i: res_shf(*i).astype(inp_dtype), name="res")
num_ops = batch_size * fout_height * fout_width * wl.hkernel * wl.wkernel * wl.out_filter * wl.in_filter
total_xfer_B = get_data_movementB(sched, wl)
def verify(s, check_correctness):
mod = tvm.build(s, [data, kernel, res], "ext_dev", target, name="conv2d")
......@@ -98,7 +231,7 @@ def test_conv2d_chwv(key, batch_size, wl, plan, log_frame, profile=True):
data_arr = tvm.nd.array(data_packed, ctx)
kernel_arr = tvm.nd.array(kernel_packed, ctx)
res_arr = tvm.nd.array(res_np, ctx)
time_f = f.time_evaluator("conv2d", ctx, number=10)
time_f = f.time_evaluator("conv2d", ctx, number=1)
cost = time_f(data_arr, kernel_arr, res_arr)
res_unpack = res_arr.asnumpy().transpose(
(0, 4, 1, 5, 2, 3)).reshape(batch_size, wl.out_filter, fout_height, fout_width)
......@@ -125,10 +258,10 @@ def test_conv2d_chwv(key, batch_size, wl, plan, log_frame, profile=True):
s[res_cnv].set_scope(vta.SCOPE_OUT)
s[res_shf].set_scope(vta.SCOPE_OUT)
# tile
oc_factor = (plan.oc_factor if plan.oc_factor
oc_factor = (sched.oc_factor if sched.oc_factor
else wl.out_filter // vta.VTA_BLOCK_OUT)
h_factor = (plan.h_factor if plan.h_factor else fout_height)
w_factor = (plan.w_factor if plan.w_factor else fout_width)
h_factor = (sched.h_factor if sched.h_factor else fout_height)
w_factor = (sched.w_factor if sched.w_factor else fout_width)
xbo, xco, xi, xj, xbi, xci = s[res].op.axis
xco0, xco1 = s[res].split(xco, factor=oc_factor)
xi0, xi1 = s[res].split(xi, factor=h_factor)
......@@ -137,21 +270,21 @@ def test_conv2d_chwv(key, batch_size, wl, plan, log_frame, profile=True):
s[res_cnv].compute_at(s[res], xj0)
s[res_shf].compute_at(s[res], xj0)
if plan.oc_nthread:
_, tx = s[res].split(xco0, factor=plan.oc_nthread)
if sched.oc_nthread > 1:
_, tx = s[res].split(xco0, factor=sched.oc_nthread)
s[res].reorder(tx, xbo)
s[res].bind(tx, tvm.thread_axis("cthread"))
if plan.h_nthread:
xo, tx = s[res].split(xi0, factor=plan.h_nthread)
if sched.h_nthread > 1:
xo, tx = s[res].split(xi0, factor=sched.h_nthread)
s[res].reorder(tx, xbo)
s[res].bind(tx, tvm.thread_axis("cthread"))
xbo, xco, xi, xj, xbi, xci = s[res_cnv].op.axis
s[res_cnv].reorder(xbo, ko, xj, dj, di, xco, xi, xbi, xci, ki)
if plan.ko_factor:
ko0, ko1 = s[res_cnv].split(ko, factor=plan.ko_factor)
if sched.ko_factor:
ko0, ko1 = s[res_cnv].split(ko, factor=sched.ko_factor)
s[data_buf].compute_at(s[res_cnv], ko0)
s[kernel_buf].compute_at(s[res_cnv], ko0)
# Use VTA instructions
......@@ -160,7 +293,7 @@ def test_conv2d_chwv(key, batch_size, wl, plan, log_frame, profile=True):
s[res_cnv].tensorize(xbi, gemm)
s[res_shf].pragma(s[res_shf].op.axis[0], alu)
s[res].pragma(xco1, store_out)
if plan.debug_sync:
if sched.debug_sync:
s[res].pragma(xco0, "coproc_sync")
if print_ir:
print(tvm.lower(s, [data, kernel, res], simple_mode=True))
......@@ -169,6 +302,7 @@ def test_conv2d_chwv(key, batch_size, wl, plan, log_frame, profile=True):
def conv_normal(print_ir):
print("----- CONV2D End-to-End Test-------")
def run_test(header, print_ir, check_correctness):
s = [1, sched.oc_factor, sched.ko_factor, sched.h_factor, sched.w_factor]
cost = run_schedule(
vta.DMA_COPY, vta.DMA_COPY,
vta.GEMM, vta.ALU, vta.DMA_COPY,
......@@ -177,8 +311,27 @@ def test_conv2d_chwv(key, batch_size, wl, plan, log_frame, profile=True):
print(header)
print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops))
log_frame["key"].append(key)
log_frame["layer"].append(layer)
log_frame["total-data"].append(total_xfer_B)
log_frame["total-gops"].append(gops)
log_frame["total-cost"].append(cost.mean)
log_frame["total-insn"].append(get_insn_count(wl, s))
log_frame["block-batch"].append(vta.VTA_BATCH)
log_frame["block-in"].append(vta.VTA_BLOCK_IN)
log_frame["block-out"].append(vta.VTA_BLOCK_OUT)
log_frame["inp-width"].append(vta.VTA_INP_WIDTH)
log_frame["wgt-width"].append(vta.VTA_WGT_WIDTH)
log_frame["uop-size"].append(vta.VTA_UOP_BUFF_SIZE)
log_frame["inp-size"].append(vta.VTA_INP_BUFF_SIZE)
log_frame["wgt-size"].append(vta.VTA_WGT_BUFF_SIZE)
log_frame["out-size"].append(vta.VTA_OUT_BUFF_SIZE)
log_frame["oc-factor"].append(sched.oc_factor)
log_frame["ic-factor"].append(sched.ko_factor)
log_frame["h-factor"].append(sched.h_factor)
log_frame["w-factor"].append(sched.w_factor)
log_frame["oc-threads"].append(sched.oc_nthread)
log_frame["h-threads"].append(sched.h_nthread)
log_frame["threaded"].append(True if sched.oc_nthread > 1 or sched.h_nthread > 1 else False)
with tvm.build_config(add_lower_pass=vta.debug_mode(0)):
run_test("NORMAL", print_ir, True)
......@@ -325,90 +478,63 @@ def test_conv2d_chwv(key, batch_size, wl, plan, log_frame, profile=True):
load_wgt_unittest(False)
store_out_unittest(False)
# Perform profiling
profile = False
# Data set batch size
batch_size = 1
# Use multi-threading for latency hiding
multi_threaded = False
# ResNet18 workloads
resnet = {
# Workloads of resnet18 on imagenet
0: Workload(224, 224, 16, 64, 7, 7, 3, 3, 2, 2),
1: Workload(56, 56, 64, 64, 3, 3, 1, 1, 1, 1),
2: Workload(56, 56, 64, 64, 1, 1, 0, 0, 1, 1),
3: Workload(56, 56, 64, 128, 3, 3, 1, 1, 2, 2),
4: Workload(56, 56, 64, 128, 1, 1, 0, 0, 2, 2),
5: Workload(28, 28, 128, 128, 3, 3, 1, 1, 1, 1),
6: Workload(28, 28, 128, 256, 3, 3, 1, 1, 2, 2),
7: Workload(28, 28, 128, 256, 1, 1, 0, 0, 2, 2),
8: Workload(14, 14, 256, 256, 3, 3, 1, 1, 1, 1),
9: Workload(14, 14, 256, 512, 3, 3, 1, 1, 2, 2),
10: Workload(14, 14, 256, 512, 1, 1, 0, 0, 2, 2),
11: Workload(7, 7, 512, 512, 3, 3, 1, 1, 1, 1),
0: Workload(1, 224, 224, 16, 64, 7, 7, 3, 3, 2, 2),
1: Workload(1, 56, 56, 64, 64, 3, 3, 1, 1, 1, 1),
2: Workload(1, 56, 56, 64, 64, 1, 1, 0, 0, 1, 1),
3: Workload(1, 56, 56, 64, 128, 3, 3, 1, 1, 2, 2),
4: Workload(1, 56, 56, 64, 128, 1, 1, 0, 0, 2, 2),
5: Workload(1, 28, 28, 128, 128, 3, 3, 1, 1, 1, 1),
6: Workload(1, 28, 28, 128, 256, 3, 3, 1, 1, 2, 2),
7: Workload(1, 28, 28, 128, 256, 1, 1, 0, 0, 2, 2),
8: Workload(1, 14, 14, 256, 256, 3, 3, 1, 1, 1, 1),
9: Workload(1, 14, 14, 256, 512, 3, 3, 1, 1, 2, 2),
10: Workload(1, 14, 14, 256, 512, 1, 1, 0, 0, 2, 2),
11: Workload(1, 7, 7, 512, 512, 3, 3, 1, 1, 1, 1),
}
# List of simple benchmarks
simple = [
Workload(height=22, width=22, in_filter=256, out_filter=64,
hkernel=3, wkernel=3, hpad=1, wpad=1, hstride=1, wstride=1)
]
# Serial schedule
resnet_serial = [
[None, None],
[resnet[1], Schedule(oc_factor=2, ko_factor=1, h_factor=14, w_factor=0)],
[resnet[2], Schedule(oc_factor=4, ko_factor=4, h_factor=8, w_factor=0)],
[resnet[3], Schedule(oc_factor=4, ko_factor=1, h_factor=14, w_factor=0)],
[resnet[4], Schedule(oc_factor=8, ko_factor=1, h_factor=4, w_factor=0)],
[resnet[5], Schedule(oc_factor=8, ko_factor=1, h_factor=7, w_factor=0)],
[resnet[6], Schedule(oc_factor=8, ko_factor=1, h_factor=14, w_factor=0)],
[resnet[7], Schedule(oc_factor=16, ko_factor=1, h_factor=7, w_factor=0)],
[resnet[8], Schedule(oc_factor=8, ko_factor=1, h_factor=7, w_factor=0)],
[resnet[9], Schedule(oc_factor=8, ko_factor=1, h_factor=7, w_factor=0)],
[resnet[10], Schedule(oc_factor=16, ko_factor=1, h_factor=7, w_factor=0)],
[resnet[11], Schedule(oc_factor=8, ko_factor=1, h_factor=7, w_factor=0)],
]
# SMT schedule
resnet_smt = [
[resnet[0], Schedule(oc_factor=1, ko_factor=1, h_factor=4, w_factor=56)],
[resnet[1], Schedule(oc_factor=2, ko_factor=1, h_factor=7, h_nthread=2)],
[resnet[2], Schedule(oc_factor=4, ko_factor=2, h_factor=4, w_factor=0, h_nthread=2)],
[resnet[3], Schedule(oc_factor=4, ko_factor=1, h_factor=7, w_factor=0, h_nthread=2)],
[resnet[4], Schedule(oc_factor=4, ko_factor=1, h_factor=7, h_nthread=2)],
[resnet[5], Schedule(oc_factor=4, ko_factor=1, h_factor=7, w_factor=0, h_nthread=2)],
[resnet[6], Schedule(oc_factor=4, ko_factor=1, h_factor=7, w_factor=0, oc_nthread=2)],
[resnet[7], Schedule(oc_factor=8, ko_factor=1, h_factor=7, w_factor=0, oc_nthread=2)],
[resnet[8], Schedule(oc_factor=4, ko_factor=1, h_factor=7, w_factor=0, oc_nthread=2)],
[resnet[9], Schedule(oc_factor=4, ko_factor=1, h_factor=7, w_factor=0, oc_nthread=2)],
[resnet[10], Schedule(oc_factor=8, ko_factor=1, h_factor=7, w_factor=0, oc_nthread=2)],
[resnet[11], Schedule(oc_factor=4, ko_factor=1, h_factor=7, w_factor=0, oc_nthread=2)],
]
begin = 0
end = len(resnet)
# Perform profiling
profile = False
# Whether use SMT
use_smt = True
# Data set batch size
batch_size = 1
resnet_schedules = []
for i in range(begin, end):
scheds = find_schedules(resnet[i], mtOnly=multi_threaded, bestOnly=True)
resnet_schedules.append([i, scheds])
resnet_schedule = resnet_smt if use_smt else resnet_serial
keys = ["key", "layer", "total-data", "total-gops", "total-cost", "total-insn",
"block-batch", "block-in", "block-out", "wgt-width", "inp-width",
"uop-size", "inp-size", "wgt-size", "out-size",
"oc-factor", "ic-factor", "h-factor", "w-factor",
"oc-threads", "h-threads", "threaded"]
begin = 0
end = len(resnet_schedule)
keys = ["key", "total-gops", "total-cost",
"skip-alu-gops", "skip-alu-cost",
if profile:
keys += ["skip-alu-gops", "skip-alu-cost",
"gemm-gops", "gemm-cost", "alu-gops", "alu-cost",
"ld-inp-cost", "ld-wgt-cost", "st-out-cost",
"ld-inp-gbits", "ld-wgt-gbits", "st-out-gbits",]
"ld-inp-gbits", "ld-wgt-gbits", "st-out-gbits"]
log_frame = {
k : [] for k in keys
}
for i, x in enumerate(resnet_schedule[begin:end]):
wl, plan = x
if not wl:
continue
key = "resnet-cfg[%d]" % i
test_conv2d_chwv(key, batch_size, wl, plan, log_frame, profile)
if profile:
pd.set_option('expand_frame_repr', False)
log_df = pd.DataFrame()
for k in keys:
for x in resnet_schedules:
l, plans = x
for plan in plans:
key = "resnet-cfg[{}-{}]".format(l, plan)
test_conv2d_chwv(l, key, batch_size, resnet[l], plan, log_frame, profile)
pd.set_option('expand_frame_repr', False)
log_df = pd.DataFrame()
for k in keys:
log_df[k] = log_frame[k]
print(log_df)
print(log_df)
log_df.to_csv("conv2d.csv")
......@@ -6,10 +6,14 @@ from tvm.contrib import rpc, util
host = "pynq"
port = 9091
target = "llvm -target=armv7-none-linux-gnueabihf"
bit = "vta.bit"
bit = "{}x{}x{}_{}bx{}b_{}_{}_{}_{}_100MHz_10ns.bit".format(
vta.VTA_BATCH, vta.VTA_BLOCK_IN, vta.VTA_BLOCK_OUT,
vta.VTA_INP_WIDTH, vta.VTA_WGT_WIDTH,
vta.VTA_LOG_UOP_BUFF_SIZE, vta.VTA_LOG_INP_BUFF_SIZE,
vta.VTA_LOG_WGT_BUFF_SIZE, vta.VTA_LOG_OUT_BUFF_SIZE)
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
bitstream = os.path.join(curr_path, "./", bit)
bitstream = os.path.join(curr_path, "../../../../vta_bitstreams/bitstreams/", bit)
def test_program_rpc():
assert tvm.module.enabled("rpc")
......
......@@ -480,8 +480,8 @@ if __name__ == "__main__":
test_padded_load()
print("Load/store test")
test_save_load_out()
print("GEMM test")
test_gemm()
# print("GEMM test")
# test_gemm()
print("Max immediate")
test_alu(tvm.max, np.maximum, use_imm=True)
print("Max")
......@@ -492,6 +492,8 @@ if __name__ == "__main__":
test_alu(lambda x, y: x + y)
print("Shift right immediate")
test_alu(lambda x, y: x >> y, np.right_shift, use_imm=True)
print("Shift left immediate")
test_alu(lambda x, y: x << y, np.left_shift, use_imm=True)
print("Relu")
test_relu()
# print("Shift and scale")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment