Commit f5960ba6 by Thierry Moreau Committed by Tianqi Chen

[SCHEDULER, HW] Auto scheduler for conv2d, hardware generation (#20)

* Hardware generation fixes/sweep, auto scheduling for VTA conv2d

* Hardware generation fixes/sweep, auto scheduling for VTA conv2d

* derive hw spec from config file

* up to date hardware spec
parent 33a309b2
# Directories # Directories
ROOTDIR = $(CURDIR) ROOTDIR = $(CURDIR)
BUILD_DIR = $(ROOTDIR)/../../build/hardware/xilinx BUILD_NAME = build
BUILD_DIR = $(ROOTDIR)/../../$(BUILD_NAME)/hardware/xilinx
SCRIPT_DIR = $(ROOTDIR)/scripts SCRIPT_DIR = $(ROOTDIR)/scripts
SRC_DIR = $(ROOTDIR)/src SRC_DIR = $(ROOTDIR)/src
SIM_DIR = $(ROOTDIR)/sim SIM_DIR = $(ROOTDIR)/sim
...@@ -12,6 +13,15 @@ VIVADO_HLS = vivado_hls ...@@ -12,6 +13,15 @@ VIVADO_HLS = vivado_hls
VIVADO = vivado VIVADO = vivado
HSI = hsi HSI = hsi
# HLS Mode
MODE = all
# SLURM
SLURM = false
# Prevent generation of DSP
NO_DSP = false
# Prevent generation of ALU
NO_ALU = false
# Include top-level config file # Include top-level config file
ifndef config ifndef config
ifneq ("$(wildcard ../../config.mk)", "") ifneq ("$(wildcard ../../config.mk)", "")
...@@ -41,13 +51,18 @@ $(shell echo "$$(( (1000 + $(VTA_HW_COMP_CLOCK_FREQ) - 1) / $(VTA_HW_COMP_CLOCK_ ...@@ -41,13 +51,18 @@ $(shell echo "$$(( (1000 + $(VTA_HW_COMP_CLOCK_FREQ) - 1) / $(VTA_HW_COMP_CLOCK_
# Derive config name # Derive config name
CONF = \ CONF = \
$(VTA_BATCH)x$(VTA_IN_BLOCK)x$(VTA_OUT_BLOCK)_$(VTA_INP_WIDTH)bx$(VTA_WGT_WIDTH)b_$(VTA_HW_COMP_CLOCK_FREQ)MHz_$(TARGET_PER)ns $(VTA_BATCH)x$(VTA_IN_BLOCK)x$(VTA_OUT_BLOCK)_$(VTA_INP_WIDTH)bx$(VTA_WGT_WIDTH)b_$(VTA_LOG_UOP_BUFF_SIZE)_$(VTA_LOG_INP_BUFF_SIZE)_$(VTA_LOG_WGT_BUFF_SIZE)_$(VTA_LOG_ACC_BUFF_SIZE)_$(VTA_HW_COMP_CLOCK_FREQ)MHz_$(TARGET_PER)ns
IP_BUILD_PATH = $(BUILD_DIR)/hls/$(CONF) IP_BUILD_PATH = $(BUILD_DIR)/hls/$(CONF)
HW_BUILD_PATH = $(BUILD_DIR)/vivado/$(CONF) HW_BUILD_PATH = $(BUILD_DIR)/vivado/$(CONF)
.PHONY: all ip bit driver clean ifeq ($(SLURM), true)
IP_BUILD_PATH = /scratch/hls/$(CONF)
HW_BUILD_PATH = /scratch/vivado/$(CONF)
endif
.PHONY: all ip bit driver clean clean_all
all: driver all: bit
ip: ip:
mkdir -p $(IP_BUILD_PATH) mkdir -p $(IP_BUILD_PATH)
...@@ -57,20 +72,32 @@ ip: ...@@ -57,20 +72,32 @@ ip:
$(VTA_LOG_INP_WIDTH) $(VTA_LOG_WGT_WIDTH) $(VTA_LOG_ACC_WIDTH) $(VTA_LOG_OUT_WIDTH) \ $(VTA_LOG_INP_WIDTH) $(VTA_LOG_WGT_WIDTH) $(VTA_LOG_ACC_WIDTH) $(VTA_LOG_OUT_WIDTH) \
$(VTA_LOG_BATCH) $(VTA_LOG_BLOCK_OUT) $(VTA_LOG_BLOCK_IN) \ $(VTA_LOG_BATCH) $(VTA_LOG_BLOCK_OUT) $(VTA_LOG_BLOCK_IN) \
$(VTA_LOG_UOP_BUFF_SIZE) $(VTA_LOG_INP_BUFF_SIZE) $(VTA_LOG_WGT_BUFF_SIZE) \ $(VTA_LOG_UOP_BUFF_SIZE) $(VTA_LOG_INP_BUFF_SIZE) $(VTA_LOG_WGT_BUFF_SIZE) \
$(VTA_LOG_ACC_BUFF_SIZE) $(VTA_LOG_OUT_BUFF_SIZE) $(VTA_LOG_ACC_BUFF_SIZE) $(VTA_LOG_OUT_BUFF_SIZE) \
$(MODE) $(NO_DSP) $(NO_ALU)
ifeq ($(SLURM), true)
mkdir -p $(BUILD_DIR)/hls
mv $(IP_BUILD_PATH) $(BUILD_DIR)/hls/.
endif
bit: ip bit: ip
mkdir -p $(HW_BUILD_PATH) mkdir -p $(HW_BUILD_PATH)
cd $(HW_BUILD_PATH) && \ cd $(HW_BUILD_PATH) && \
$(VIVADO) -mode tcl -source $(SCRIPT_DIR)/vivado.tcl \ $(VIVADO) -mode tcl -source $(SCRIPT_DIR)/vivado.tcl \
-tclargs $(IP_BUILD_PATH) $(VTA_HW_COMP_THREADS) $(VTA_HW_COMP_CLOCK_FREQ) \ -tclargs $(BUILD_DIR)/hls/$(CONF) $(VTA_HW_COMP_THREADS) $(VTA_HW_COMP_CLOCK_FREQ) \
$(VTA_INP_WIDTH) $(VTA_WGT_WIDTH) $(VTA_OUT_WIDTH) \ $(VTA_INP_WIDTH) $(VTA_WGT_WIDTH) $(VTA_OUT_WIDTH) \
$(VTA_BATCH) $(VTA_IN_BLOCK) $(VTA_OUT_BLOCK) \ $(VTA_BATCH) $(VTA_IN_BLOCK) $(VTA_OUT_BLOCK) \
$(VTA_INP_BUFF_SIZE) $(VTA_WGT_BUFF_SIZE) $(VTA_OUT_BUFF_SIZE) $(VTA_INP_BUFF_SIZE) $(VTA_WGT_BUFF_SIZE) $(VTA_OUT_BUFF_SIZE)
ifeq ($(SLURM), true)
mkdir -p $(BUILD_DIR)/vivado
mv $(HW_BUILD_PATH) $(BUILD_DIR)/vivado/.
endif
driver: bit driver: bit
cd $(HW_BUILD_PATH) && $(HSI) -mode tcl -source $(SCRIPT_DIR)/hsi.tcl -nojournal -nolog cd $(HW_BUILD_PATH) && $(HSI) -mode tcl -source $(SCRIPT_DIR)/hsi.tcl -nojournal -nolog
cd $(HW_BUILD_PATH)/bsp && make cd $(HW_BUILD_PATH)/bsp && make
clean: clean:
rm -rf *.out *.log *.sb figures
clean_all: clean
rm -rf $(BUILD_DIR) rm -rf $(BUILD_DIR)
#!/usr/bin/env python
import argparse
import datetime
import logging
import numpy as np
import os
import pandas as pd
import re
import time
from collections import namedtuple
from numpy import floor, ceil, log2, log10
from subprocess import call
FPGA = namedtuple("FPGAConstraints",
['bram_w', 'bram_d', 'num_bram'])
Hardware = namedtuple("HWConfig",
['batch', 'block_in', 'block_out',
'input_w', 'weight_w', 'accum_w', 'out_w', 'uop_w'])
def find_bram_confs(fpga, hw_conf, log_uop_sizeB):
# Derive sizes
input_elem_size_b = hw_conf.batch*hw_conf.block_in*hw_conf.input_w
weight_elem_size_b = hw_conf.block_in*hw_conf.block_out*hw_conf.weight_w
accum_elem_size_b = hw_conf.batch*hw_conf.block_out*hw_conf.accum_w
input_min_bram = (input_elem_size_b+fpga.bram_w-1)/fpga.bram_w
weight_min_bram = (weight_elem_size_b+fpga.bram_w-1)/fpga.bram_w
accum_min_bram = (accum_elem_size_b+fpga.bram_w-1)/fpga.bram_w
# Exploring all possible BRAM distributions
bram_confs = []
uop_bram = pow(2, log_uop_sizeB) * 8 / (fpga.bram_w * fpga.bram_d)
for log_i_bram in range(int(log2(input_min_bram)), int(ceil(log2(fpga.num_bram)))):
i_bram = pow(2, log_i_bram)
for log_w_bram in range(int(log2(weight_min_bram)), int(ceil(log2(fpga.num_bram)))):
w_bram = pow(2, log_w_bram)
for log_a_bram in range(int(log2(accum_min_bram)), int(ceil(log2(fpga.num_bram)))):
a_bram = pow(2, log_a_bram)
total_bram = uop_bram + i_bram + w_bram + a_bram + a_bram / hw_conf.accum_w * hw_conf.out_w
if total_bram <= fpga.num_bram:
# Right now we need to restrict uop width
input_elems = i_bram * fpga.bram_w * fpga.bram_d / input_elem_size_b
weight_elems = w_bram * fpga.bram_w * fpga.bram_d / weight_elem_size_b
accum_elems = a_bram * fpga.bram_w * fpga.bram_d / accum_elem_size_b
if log2(input_elems) + log2(weight_elems) + log2(accum_elems) <= hw_conf.uop_w:
log_inp_sizeB = int(log2(i_bram*fpga.bram_d*fpga.bram_w/8))
log_wgt_sizeB = int(log2(w_bram*fpga.bram_d*fpga.bram_w/8))
log_acc_sizeB = int(log2(a_bram*fpga.bram_d*fpga.bram_w/8))
bram_confs.append([log_uop_sizeB, log_inp_sizeB, log_wgt_sizeB, log_acc_sizeB])
# Filter out configs that are suboptimal
suboptimal = [False] * len(bram_confs)
for i in range(0, len(bram_confs)):
for j in range(i + 1, len(bram_confs)):
leq_list = [a <= b for a, b in zip(bram_confs[i], bram_confs[j])]
geq_list = [a >= b for a, b in zip(bram_confs[i], bram_confs[j])]
leq = all(leq_list)
geq = all(geq_list)
if leq:
suboptimal[i] = True
if geq:
suboptimal[j] = True
opt_bram_confs = [x[0] for x in zip(bram_confs, suboptimal) if not x[1]]
return opt_bram_confs
def get_make_command(job, build_dir, hw_conf, bram_conf, mode, slurm=False):
cmd = ""
if slurm:
cmd += "#!/bin/bash\n"
cmd += "#SBATCH --job-name={}\n".format(job)
cmd += "#SBATCH --output={}.out\n".format(job)
cmd += "srun "
if mode=="hls":
cmd += "make ip"
else:
cmd += "make"
cmd += " SLURM={} MODE=skip_sim NO_DSP=false NO_ALU=false".format("true" if slurm else "false")
cmd += " BUILD_NAME={}".format(build_dir)
cmd += " VTA_LOG_INP_WIDTH={}".format(int(log2(hw_conf.input_w)))
cmd += " VTA_LOG_WGT_WIDTH={}".format(int(log2(hw_conf.weight_w)))
cmd += " VTA_LOG_BATCH={}".format(int(log2(hw_conf.batch)))
cmd += " VTA_LOG_BLOCK_IN={}".format(int(log2(hw_conf.block_in)))
cmd += " VTA_LOG_BLOCK_OUT={}".format(int(log2(hw_conf.block_out)))
cmd += " VTA_LOG_UOP_BUFF_SIZE={}".format(bram_conf[0])
cmd += " VTA_LOG_INP_BUFF_SIZE={}".format(bram_conf[1])
cmd += " VTA_LOG_WGT_BUFF_SIZE={}".format(bram_conf[2])
cmd += " VTA_LOG_ACC_BUFF_SIZE={}\n".format(bram_conf[3])
return cmd
def cli():
parser = argparse.ArgumentParser(
description='Analyze HLS experiments'
)
parser.add_argument(
'-mode', dest='mode', action='store', type=str, required=True,
choices=["hls", "vivado"], help='hls synthesis or full compilation'
)
parser.add_argument(
'-base_dir', dest='base_dir', action='store', type=str, required=False,
default="../../build/hardware/xilinx/", help='path to build directory'
)
parser.add_argument(
'-min_ibw', dest='min_ibw', action='store', type=int, required=False,
default=3, help='log2 of minimum input bit-width'
)
parser.add_argument(
'-max_ibw', dest='max_ibw', action='store', type=int, required=False,
default=3, help='log2 of maximum input bit-width'
)
parser.add_argument(
'-min_wbw', dest='min_wbw', action='store', type=int, required=False,
default=3, help='log2 of minimum weight bit-width'
)
parser.add_argument(
'-max_wbw', dest='max_wbw', action='store', type=int, required=False,
default=3, help='log2 of maximum weight bit-width'
)
parser.add_argument(
'-acc_bw', dest='acc_bw', action='store', type=int, required=False,
default=32, help='accumulator bit-width'
)
parser.add_argument(
'-uop_bw', dest='uop_bw', action='store', type=int, required=False,
default=32, help='micro-op bit-width'
)
parser.add_argument(
'-min_batch', dest='min_batch', action='store', type=int, required=False,
default=0, help='log2 of minimum batch size'
)
parser.add_argument(
'-max_batch', dest='max_batch', action='store', type=int, required=False,
default=8, help='log2 of maximum batch size'
)
parser.add_argument(
'-min_ic', dest='min_ic', action='store', type=int, required=False,
default=0, help='log2 of minimum input channels'
)
parser.add_argument(
'-max_ic', dest='max_ic', action='store', type=int, required=False,
default=8, help='log2 of maximum input channels'
)
parser.add_argument(
'-min_oc', dest='min_oc', action='store', type=int, required=False,
default=0, help='log2 of minimum output channels'
)
parser.add_argument(
'-max_oc', dest='max_oc', action='store', type=int, required=False,
default=8, help='log2 of maximum output channels'
)
parser.add_argument(
'-uop_sizeB', dest='uop_sizeB', action='store', type=int, required=False,
default=14, help='log2 of uop buffer in B'
)
parser.add_argument(
'-bram_w', dest='bram_w', action='store', type=int, required=False,
default=32, help='FPGA BRAM port width in b'
)
parser.add_argument(
'-bram_d', dest='bram_d', action='store', type=int, required=False,
default=1024, help='FPGA BRAM depth'
)
parser.add_argument(
'-num_bram', dest='num_bram', action='store', type=int, required=False,
default=124, help='FPGA total BRAM'
)
parser.add_argument(
'-slurm', dest='slurm', action='store_true',
help='Run on cluster using slurm'
)
args = parser.parse_args()
# Logging
logging.basicConfig(filename='compile_designs.log',level=logging.DEBUG)
# FPGA config
pynq = FPGA(args.bram_w, args.bram_d, args.num_bram)
# Get timestamp
timestamp = datetime.datetime.fromtimestamp(time.time()).strftime('%Y_%m_%d_%H_%M_%S')
build_dir = "build_{}".format(timestamp)
num_confs = 0
for log_ibw in range(args.min_ibw, args.max_ibw+1):
ibw = pow(2, log_ibw)
for log_wbw in range(args.min_wbw, args.max_wbw+1):
wbw = pow(2, log_wbw)
for log_batch in range(args.min_batch, args.max_batch+1):
batch = pow(2, log_batch)
for log_ic in range(args.min_ic, args.max_ic+1):
ic = pow(2, log_ic)
for log_oc in range(args.min_oc, args.max_oc+1):
oc = pow(2, log_oc)
conf = Hardware(batch, ic, oc, ibw, wbw, args.acc_bw, ibw, args.uop_bw)
bram_confs = find_bram_confs(pynq, conf, args.uop_sizeB)
for b in bram_confs:
job = "{}x{}x{}_{}bx{}b_{}_{}_{}_{}_100MHz_10ns".format(
batch, ic, oc, ibw, wbw, b[0], b[1], b[2], b[3])
num_confs += 1
cmd = get_make_command(job, build_dir, conf, b, args.mode, args.slurm)
sb_file = job+".sb"
file = open(sb_file,"w")
file.write(cmd)
file.close()
call(["echo", cmd])
if args.slurm:
call(["sbatch", sb_file])
else:
call(cmd.split(" "))
if __name__ == '__main__':
cli()
...@@ -22,8 +22,11 @@ ...@@ -22,8 +22,11 @@
# Arg 15: wgt buffer size in B (log) # Arg 15: wgt buffer size in B (log)
# Arg 16: acc buffer size in B (log) # Arg 16: acc buffer size in B (log)
# Arg 17: out buffer size in B (log) # Arg 17: out buffer size in B (log)
# Arg 18: mode
# Arg 19: no_dsp
# Arg 20: no_alu
if { [llength $argv] eq 19 } { if { [llength $argv] eq 22 } {
set src_dir [lindex $argv 2] set src_dir [lindex $argv 2]
set sim_dir [lindex $argv 3] set sim_dir [lindex $argv 3]
set test_dir [lindex $argv 4] set test_dir [lindex $argv 4]
...@@ -41,10 +44,13 @@ if { [llength $argv] eq 19 } { ...@@ -41,10 +44,13 @@ if { [llength $argv] eq 19 } {
set wgt_buff_size [lindex $argv 16] set wgt_buff_size [lindex $argv 16]
set acc_buff_size [lindex $argv 17] set acc_buff_size [lindex $argv 17]
set out_buff_size [lindex $argv 18] set out_buff_size [lindex $argv 18]
set mode [lindex $argv 19]
set no_dsp [lindex $argv 20]
set no_alu [lindex $argv 21]
} else { } else {
set src_dir "../src/" set src_dir "../src"
set sim_dir "../sim/" set sim_dir "../sim"
set test_dir "../../src/test/" set test_dir "../../src/test"
set include_dir "../../include" set include_dir "../../include"
set target_period 10 set target_period 10
set inp_width 3 set inp_width 3
...@@ -59,17 +65,11 @@ if { [llength $argv] eq 19 } { ...@@ -59,17 +65,11 @@ if { [llength $argv] eq 19 } {
set wgt_buff_size 15 set wgt_buff_size 15
set acc_buff_size 17 set acc_buff_size 17
set out_buff_size 15 set out_buff_size 15
set mode "all"
set no_dsp "true"
set no_alu "false"
} }
# C define flags to pass to compiler
set cflags "-I $include_dir -I $src_dir -I $test_dir \
-DVTA_DEBUG=0 -DVTA_LOG_WGT_WIDTH=$wgt_width -DVTA_LOG_INP_WIDTH=$inp_width \
-DVTA_LOG_ACC_WIDTH=$acc_width -DVTA_LOG_OUT_WIDTH=$out_width \
-DVTA_LOG_BATCH=$batch -DVTA_LOG_BLOCK_OUT=$block_out -DVTA_LOG_BLOCK_IN=$block_in \
-DVTA_LOG_UOP_BUFF_SIZE=$uop_buff_size -DVTA_LOG_INP_BUFF_SIZE=$inp_buff_size \
-DVTA_LOG_WGT_BUFF_SIZE=$wgt_buff_size -DVTA_LOG_ACC_BUFF_SIZE=$acc_buff_size \
-DVTA_LOG_OUT_BUFF_SIZE=$out_buff_size"
# Initializes the HLS design and sets HLS pragmas for memory partitioning. # Initializes the HLS design and sets HLS pragmas for memory partitioning.
# This is necessary because of a Vivado restriction that doesn't allow for # This is necessary because of a Vivado restriction that doesn't allow for
# buses wider than 1024 bits. # buses wider than 1024 bits.
...@@ -122,56 +122,89 @@ proc init_design {per inp_width wgt_width out_width batch block_in block_out} { ...@@ -122,56 +122,89 @@ proc init_design {per inp_width wgt_width out_width batch block_in block_out} {
} }
} }
# C define flags to pass to compiler
set cflags "-I $include_dir -I $src_dir -I $test_dir \
-DVTA_DEBUG=0 -DVTA_LOG_WGT_WIDTH=$wgt_width -DVTA_LOG_INP_WIDTH=$inp_width \
-DVTA_LOG_ACC_WIDTH=$acc_width -DVTA_LOG_OUT_WIDTH=$out_width \
-DVTA_LOG_BATCH=$batch -DVTA_LOG_BLOCK_OUT=$block_out -DVTA_LOG_BLOCK_IN=$block_in \
-DVTA_LOG_UOP_BUFF_SIZE=$uop_buff_size -DVTA_LOG_INP_BUFF_SIZE=$inp_buff_size \
-DVTA_LOG_WGT_BUFF_SIZE=$wgt_buff_size -DVTA_LOG_ACC_BUFF_SIZE=$acc_buff_size \
-DVTA_LOG_OUT_BUFF_SIZE=$out_buff_size"
if {$no_dsp=="true"} {
append cflags " -DNO_DSP"
}
if {$no_alu=="true"} {
append cflags " -DNO_ALU"
}
# HLS behavioral sim # HLS behavioral sim
open_project vta_sim if {$mode=="all" || $mode=="sim"} {
set_top vta open_project vta_sim
add_files $src_dir/vta.cc -cflags $cflags set_top vta
add_files -tb $sim_dir/vta_test.cc -cflags $cflags add_files $src_dir/vta.cc -cflags $cflags
add_files -tb $test_dir/test_lib.cc -cflags $cflags add_files -tb $sim_dir/vta_test.cc -cflags $cflags
open_solution "solution0" add_files -tb $test_dir/test_lib.cc -cflags $cflags
init_design $target_period $inp_width $wgt_width $out_width $batch $block_in $block_out open_solution "solution0"
csim_design -clean init_design $target_period $inp_width $wgt_width $out_width $batch $block_in $block_out
close_project csim_design -clean
close_project
}
# Generate fetch stage # Generate fetch stage
open_project vta_fetch if {$mode=="all" || $mode=="skip_sim" || $mode=="fetch"} {
set_top fetch open_project vta_fetch
add_files $src_dir/vta.cc -cflags $cflags set_top fetch
open_solution "solution0" add_files $src_dir/vta.cc -cflags $cflags
init_design $target_period $inp_width $wgt_width $out_width $batch $block_in $block_out open_solution "solution0"
csynth_design init_design $target_period $inp_width $wgt_width $out_width $batch $block_in $block_out
export_design -format ip_catalog csynth_design
close_project if {$mode=="all" || $mode=="skip_sim"} {
export_design -format ip_catalog
}
close_project
}
# Generate load stage # Generate load stage
open_project vta_load if {$mode=="all" || $mode=="skip_sim" || $mode=="load"} {
set_top load open_project vta_load
add_files $src_dir/vta.cc -cflags $cflags set_top load
open_solution "solution0" add_files $src_dir/vta.cc -cflags $cflags
init_design $target_period $inp_width $wgt_width $out_width $batch $block_in $block_out open_solution "solution0"
csynth_design init_design $target_period $inp_width $wgt_width $out_width $batch $block_in $block_out
export_design -format ip_catalog csynth_design
close_project if {$mode=="all" || $mode=="skip_sim"} {
export_design -format ip_catalog
}
close_project
}
# Generate compute stage # Generate compute stage
open_project vta_compute if {$mode=="all" || $mode=="skip_sim" || $mode=="compute"} {
set_top compute open_project vta_compute
add_files $src_dir/vta.cc -cflags $cflags set_top compute
open_solution "solution0" add_files $src_dir/vta.cc -cflags $cflags
init_design $target_period $inp_width $wgt_width $out_width $batch $block_in $block_out open_solution "solution0"
csynth_design init_design $target_period $inp_width $wgt_width $out_width $batch $block_in $block_out
export_design -format ip_catalog csynth_design
close_project if {$mode=="all" || $mode=="skip_sim"} {
export_design -format ip_catalog
}
close_project
}
# Generate store stage # Generate store stage
open_project vta_store if {$mode=="all" || $mode=="skip_sim" || $mode=="store"} {
set_top store open_project vta_store
add_files $src_dir/vta.cc -cflags $cflags set_top store
open_solution "solution0" add_files $src_dir/vta.cc -cflags $cflags
init_design $target_period $inp_width $wgt_width $out_width $batch $block_in $block_out open_solution "solution0"
csynth_design init_design $target_period $inp_width $wgt_width $out_width $batch $block_in $block_out
export_design -format ip_catalog csynth_design
close_project if {$mode=="all" || $mode=="skip_sim"} {
export_design -format ip_catalog
}
close_project
}
exit exit
...@@ -59,44 +59,31 @@ if { [llength $argv] eq 12 } { ...@@ -59,44 +59,31 @@ if { [llength $argv] eq 12 } {
# Derive input mem parameters # Derive input mem parameters
set inp_mem_width [expr $inp_width * $batch * $in_block] set inp_mem_width [expr $inp_width * $batch * $in_block]
set inp_mem_depth [expr $inp_mem_size * 8 / $inp_mem_width]
set inp_bus_width 1024 set inp_bus_width 1024
set inp_part [expr $inp_mem_width / $inp_bus_width] set inp_part [expr $inp_mem_width / $inp_bus_width]
if {[expr $inp_part == 0]} { if {[expr $inp_part == 0]} {
set inp_part 1 set inp_part 1
set inp_bus_width $inp_mem_width set inp_bus_width $inp_mem_width
} }
set inp_mem_depth [expr $inp_mem_size * 8 / ($inp_mem_width * $inp_part)]
# Derive weight mem parameters # Derive weight mem parameters
set wgt_mem_width [expr $wgt_width * $out_block * $in_block] set wgt_mem_width [expr $wgt_width * $out_block * $in_block]
set wgt_mem_depth [expr $wgt_mem_size * 8 / $wgt_mem_width]
set wgt_bus_width 1024 set wgt_bus_width 1024
set wgt_part [expr $wgt_mem_width / $wgt_bus_width] set wgt_part [expr $wgt_mem_width / $wgt_bus_width]
if {[expr $wgt_part == 0]} { if {[expr $wgt_part == 0]} {
set wgt_part 1 set wgt_part 1
set wgt_bus_width $wgt_mem_width set wgt_bus_width $wgt_mem_width
} }
set wgt_mem_depth [expr $wgt_mem_size * 8 / ($wgt_mem_width * $wgt_part)]
# Derive output mem parameters # Derive output mem parameters
set out_mem_width [expr $out_width * $batch * $out_block] set out_mem_width [expr $out_width * $batch * $out_block]
set out_mem_depth [expr $out_mem_size * 8 / $out_mem_width]
set out_bus_width 1024 set out_bus_width 1024
set out_part [expr $out_mem_width / $out_bus_width] set out_part [expr $out_mem_width / $out_bus_width]
if {[expr $out_part == 0]} { if {[expr $out_part == 0]} {
set out_part 1 set out_part 1
set out_bus_width $out_mem_width set out_bus_width $out_mem_width
} }
set out_mem_depth [expr $out_mem_size * 8 / ($out_mem_width * $out_part)]
puts $inp_mem_width
puts $inp_mem_depth
puts $inp_bus_width
puts $inp_part
puts $wgt_mem_width
puts $wgt_mem_depth
puts $wgt_bus_width
puts $wgt_part
puts $out_mem_width
puts $out_mem_depth
puts $out_bus_width
puts $out_part
# User defined paths # User defined paths
set proj_name vta set proj_name vta
...@@ -937,10 +924,12 @@ launch_runs impl_1 -to_step write_bitstream -jobs $num_threads ...@@ -937,10 +924,12 @@ launch_runs impl_1 -to_step write_bitstream -jobs $num_threads
wait_on_run impl_1 wait_on_run impl_1
# Export hardware description file and bitstream files to export/ dir # Export hardware description file and bitstream files to export/ dir
file mkdir $proj_path/export if {[file exist $proj_path/$proj_name.runs/impl_1/${proj_name}_wrapper.bit]} {
file copy -force $proj_path/$proj_name.runs/impl_1/${proj_name}_wrapper.sysdef \ file mkdir $proj_path/export
file copy -force $proj_path/$proj_name.runs/impl_1/${proj_name}_wrapper.sysdef \
$proj_path/export/vta.hdf $proj_path/export/vta.hdf
file copy -force $proj_path/$proj_name.runs/impl_1/${proj_name}_wrapper.bit \ file copy -force $proj_path/$proj_name.runs/impl_1/${proj_name}_wrapper.bit \
$proj_path/export/vta.bit $proj_path/export/vta.bit
}
exit exit
...@@ -16,40 +16,36 @@ int main(void) { ...@@ -16,40 +16,36 @@ int main(void) {
printParameters(); printParameters();
#endif #endif
// Buffer indexing
assert(VTA_LOG_ACC_BUFF_DEPTH >= VTA_LOG_INP_BUFF_DEPTH);
// Micro op bound // Micro op bound
assert(VTA_UOP_GEM_3_1 < VTA_UOP_WIDTH); assert(VTA_UOP_GEM_2_1 < VTA_UOP_WIDTH);
assert(VTA_UOP_ALU_3_1 < VTA_UOP_WIDTH); assert(VTA_UOP_ALU_1_1 < VTA_UOP_WIDTH);
// Instruction alignment checks // Make sure there is no misaligment
assert(VTA_INSN_GEM_9_1 < VTA_INSN_GEM_A_0);
assert(VTA_INSN_MEM_7_1 < VTA_INSN_MEM_8_0); assert(VTA_INSN_MEM_7_1 < VTA_INSN_MEM_8_0);
assert(VTA_INSN_GEM_8_1 < VTA_INSN_GEM_9_0);
// Instruction bounds // Instruction bounds
assert(VTA_INSN_MEM_E_1 < VTA_INS_WIDTH); assert(VTA_INSN_MEM_E_1 < VTA_INS_WIDTH);
assert(VTA_INSN_GEM_E_1 < VTA_INS_WIDTH); assert(VTA_INSN_GEM_F_1 < VTA_INS_WIDTH);
assert(VTA_INSN_ALU_F_1 < VTA_INS_WIDTH); assert(VTA_INSN_ALU_G_1 < VTA_INS_WIDTH);
int status = 0; int status = 0;
// Run ALU test (vector-scalar operators) // Run ALU test (vector-scalar operators)
status |= alu_test(VTA_ALU_OPCODE_MIN, true, 16, 128, true); status |= alu_test(VTA_ALU_OPCODE_MIN, true, VTA_BLOCK_OUT, 128, true);
status |= alu_test(VTA_ALU_OPCODE_MIN, true, 16, 128, false); status |= alu_test(VTA_ALU_OPCODE_MIN, true, VTA_BLOCK_OUT, 128, false);
status |= alu_test(VTA_ALU_OPCODE_MAX, true, 16, 128, true); status |= alu_test(VTA_ALU_OPCODE_MAX, true, VTA_BLOCK_OUT, 128, true);
status |= alu_test(VTA_ALU_OPCODE_MAX, true, 16, 128, false); status |= alu_test(VTA_ALU_OPCODE_MAX, true, VTA_BLOCK_OUT, 128, false);
status |= alu_test(VTA_ALU_OPCODE_ADD, true, 16, 128, true); status |= alu_test(VTA_ALU_OPCODE_ADD, true, VTA_BLOCK_OUT, 128, true);
status |= alu_test(VTA_ALU_OPCODE_ADD, true, 16, 128, false); status |= alu_test(VTA_ALU_OPCODE_ADD, true, VTA_BLOCK_OUT, 128, false);
status |= alu_test(VTA_ALU_OPCODE_SHR, true, 16, 128, true); status |= alu_test(VTA_ALU_OPCODE_SHR, true, VTA_BLOCK_OUT, 128, true);
status |= alu_test(VTA_ALU_OPCODE_SHR, true, 16, 128, false); status |= alu_test(VTA_ALU_OPCODE_SHR, true, VTA_BLOCK_OUT, 128, false);
status |= alu_test(VTA_ALU_OPCODE_SHL, true, 16, 128, true);
status |= alu_test(VTA_ALU_OPCODE_SHL, true, 16, 128, false);
// Run ALU test (vector-vector operators) // Run ALU test (vector-vector operators)
status |= alu_test(VTA_ALU_OPCODE_MIN, false, 16, 128, true); status |= alu_test(VTA_ALU_OPCODE_MIN, false, VTA_BLOCK_OUT, 128, true);
status |= alu_test(VTA_ALU_OPCODE_MIN, false, 16, 128, false); status |= alu_test(VTA_ALU_OPCODE_MIN, false, VTA_BLOCK_OUT, 128, false);
status |= alu_test(VTA_ALU_OPCODE_MAX, false, 16, 128, true); status |= alu_test(VTA_ALU_OPCODE_MAX, false, VTA_BLOCK_OUT, 128, true);
status |= alu_test(VTA_ALU_OPCODE_MAX, false, 16, 128, false); status |= alu_test(VTA_ALU_OPCODE_MAX, false, VTA_BLOCK_OUT, 128, false);
status |= alu_test(VTA_ALU_OPCODE_ADD, false, 16, 128, true); status |= alu_test(VTA_ALU_OPCODE_ADD, false, VTA_BLOCK_OUT, 128, true);
status |= alu_test(VTA_ALU_OPCODE_ADD, false, 16, 128, false); status |= alu_test(VTA_ALU_OPCODE_ADD, false, VTA_BLOCK_OUT, 128, false);
// Run blocked GEMM test // Run blocked GEMM test
status |= blocked_gemm_test(256, 256, VTA_BLOCK_OUT*4, true, 2); status |= blocked_gemm_test(256, 256, VTA_BLOCK_OUT*4, true, 2);
......
...@@ -183,7 +183,7 @@ void compute( ...@@ -183,7 +183,7 @@ void compute(
hls::stream<bool> &s2g_dep_queue, hls::stream<bool> &s2g_dep_queue,
hls::stream<bool> &g2l_dep_queue, hls::stream<bool> &g2l_dep_queue,
hls::stream<bool> &g2s_dep_queue, hls::stream<bool> &g2s_dep_queue,
out_vec_T inp_mem[VTA_INP_BUFF_DEPTH][VTA_BATCH], inp_vec_T inp_mem[VTA_INP_BUFF_DEPTH][VTA_BATCH],
wgt_vec_T wgt_mem[VTA_WGT_BUFF_DEPTH][VTA_BLOCK_OUT], wgt_vec_T wgt_mem[VTA_WGT_BUFF_DEPTH][VTA_BLOCK_OUT],
out_vec_T out_mem[VTA_ACC_BUFF_DEPTH][VTA_BATCH] out_vec_T out_mem[VTA_ACC_BUFF_DEPTH][VTA_BATCH]
) { ) {
...@@ -286,23 +286,24 @@ void compute( ...@@ -286,23 +286,24 @@ void compute(
done = 0; done = 0;
// Decode // Decode
uop_idx_T uop_bgn = insn.range(VTA_INSN_GEM_5_1, VTA_INSN_GEM_5_0); bool reset_out = insn[VTA_INSN_GEM_5];
uop_idx_T uop_end = insn.range(VTA_INSN_GEM_6_1, VTA_INSN_GEM_6_0); uop_idx_T uop_bgn = insn.range(VTA_INSN_GEM_6_1, VTA_INSN_GEM_6_0);
loop_T iter_out = insn.range(VTA_INSN_GEM_7_1, VTA_INSN_GEM_7_0); uop_idx_T uop_end = insn.range(VTA_INSN_GEM_7_1, VTA_INSN_GEM_7_0);
loop_T iter_in = insn.range(VTA_INSN_GEM_8_1, VTA_INSN_GEM_8_0); loop_T iter_out = insn.range(VTA_INSN_GEM_8_1, VTA_INSN_GEM_8_0);
acc_idx_T dst_factor_out = insn.range(VTA_INSN_GEM_9_1, VTA_INSN_GEM_9_0); loop_T iter_in = insn.range(VTA_INSN_GEM_9_1, VTA_INSN_GEM_9_0);
acc_idx_T dst_factor_in = insn.range(VTA_INSN_GEM_A_1, VTA_INSN_GEM_A_0); acc_idx_T dst_factor_out = insn.range(VTA_INSN_GEM_A_1, VTA_INSN_GEM_A_0);
inp_idx_T src_factor_out = insn.range(VTA_INSN_GEM_B_1, VTA_INSN_GEM_B_0); acc_idx_T dst_factor_in = insn.range(VTA_INSN_GEM_B_1, VTA_INSN_GEM_B_0);
inp_idx_T src_factor_in = insn.range(VTA_INSN_GEM_C_1, VTA_INSN_GEM_C_0); inp_idx_T src_factor_out = insn.range(VTA_INSN_GEM_C_1, VTA_INSN_GEM_C_0);
inp_idx_T src_factor_in = insn.range(VTA_INSN_GEM_D_1, VTA_INSN_GEM_D_0);
// GEMM-specific fields // GEMM-specific fields
wgt_idx_T wgt_factor_out = insn.range(VTA_INSN_GEM_D_1, VTA_INSN_GEM_D_0); wgt_idx_T wgt_factor_out = insn.range(VTA_INSN_GEM_E_1, VTA_INSN_GEM_E_0);
wgt_idx_T wgt_factor_in = insn.range(VTA_INSN_GEM_E_1, VTA_INSN_GEM_E_0); wgt_idx_T wgt_factor_in = insn.range(VTA_INSN_GEM_F_1, VTA_INSN_GEM_F_0);
// ALU-specific field // ALU-specific field
aluop_opcode_T alu_opcode = insn.range(VTA_INSN_ALU_D_1, VTA_INSN_ALU_D_0); aluop_opcode_T alu_opcode = insn.range(VTA_INSN_ALU_E_1, VTA_INSN_ALU_E_0);
bool use_imm = insn[VTA_INSN_ALU_E]; bool use_imm = insn[VTA_INSN_ALU_F];
aluop_imm_T imm = insn.range(VTA_INSN_ALU_F_1, VTA_INSN_ALU_F_0); aluop_imm_T imm = insn.range(VTA_INSN_ALU_G_1, VTA_INSN_ALU_G_0);
acc_idx_T dst_offset_out = 0; acc_idx_T dst_offset_out = 0;
inp_idx_T src_offset_out = 0; inp_idx_T src_offset_out = 0;
wgt_idx_T wgt_offset_out = 0; wgt_idx_T wgt_offset_out = 0;
...@@ -326,13 +327,12 @@ void compute( ...@@ -326,13 +327,12 @@ void compute(
uop_T uop = uop_mem[upc]; uop_T uop = uop_mem[upc];
// Decode indices // Decode indices
bool reset_out = uop[VTA_UOP_GEM_0];
acc_idx_T dst_idx = acc_idx_T dst_idx =
uop.range(VTA_UOP_GEM_1_1, VTA_UOP_GEM_1_0) + dst_offset_in; uop.range(VTA_UOP_GEM_0_1, VTA_UOP_GEM_0_0) + dst_offset_in;
acc_idx_T src_idx = inp_idx_T src_idx =
uop.range(VTA_UOP_GEM_2_1, VTA_UOP_GEM_2_0) + src_offset_in; uop.range(VTA_UOP_GEM_1_1, VTA_UOP_GEM_1_0) + src_offset_in;
wgt_idx_T wgt_idx = wgt_idx_T wgt_idx =
uop.range(VTA_UOP_GEM_3_1, VTA_UOP_GEM_3_0) + wgt_offset_in; uop.range(VTA_UOP_GEM_2_1, VTA_UOP_GEM_2_0) + wgt_offset_in;
// Read weight matrix // Read weight matrix
wgt_vec_T w_matrix[VTA_BLOCK_OUT]; wgt_vec_T w_matrix[VTA_BLOCK_OUT];
...@@ -341,7 +341,7 @@ void compute( ...@@ -341,7 +341,7 @@ void compute(
} }
// Read input matrix and accum matrix // Read input matrix and accum matrix
acc_vec_T o_matrix[VTA_BATCH]; acc_vec_T o_matrix[VTA_BATCH];
out_vec_T i_matrix[VTA_BATCH]; inp_vec_T i_matrix[VTA_BATCH];
for (int i = 0; i < VTA_BATCH; i++) { for (int i = 0; i < VTA_BATCH; i++) {
o_matrix[i] = acc_mem[dst_idx][i]; o_matrix[i] = acc_mem[dst_idx][i];
i_matrix[i] = inp_mem[src_idx][i]; i_matrix[i] = inp_mem[src_idx][i];
...@@ -365,6 +365,9 @@ void compute( ...@@ -365,6 +365,9 @@ void compute(
inp_T i_elem = inp_T i_elem =
i_matrix[i].range((k + 1) * VTA_INP_WIDTH - 1, k * VTA_INP_WIDTH); i_matrix[i].range((k + 1) * VTA_INP_WIDTH - 1, k * VTA_INP_WIDTH);
mul_T prod = i_elem * w_elem; mul_T prod = i_elem * w_elem;
#ifdef NO_DSP
#pragma HLS RESOURCE variable = prod core = Mul_LUT
#endif // NO_DSP
tmp += (sum_T) prod; tmp += (sum_T) prod;
} }
// Update summation // Update summation
...@@ -373,101 +376,85 @@ void compute( ...@@ -373,101 +376,85 @@ void compute(
acc_mem_val[i].range((b + 1) * VTA_ACC_WIDTH - 1, b * VTA_ACC_WIDTH) = acc_mem_val[i].range((b + 1) * VTA_ACC_WIDTH - 1, b * VTA_ACC_WIDTH) =
reset_out ? (acc_T) 0 : accum; reset_out ? (acc_T) 0 : accum;
st_buf_val[i].range((b + 1) * VTA_OUT_WIDTH - 1, b * VTA_OUT_WIDTH) = st_buf_val[i].range((b + 1) * VTA_OUT_WIDTH - 1, b * VTA_OUT_WIDTH) =
(inp_T) accum.range(VTA_OUT_WIDTH - 1, 0); (out_T) accum.range(VTA_OUT_WIDTH - 1, 0);
} }
// Write to buffers // Write to buffers
acc_mem[dst_idx][i] = acc_mem_val[i]; acc_mem[dst_idx][i] = acc_mem_val[i];
out_mem[dst_idx][i] = st_buf_val[i]; out_mem[dst_idx][i] = st_buf_val[i];
} }
} }
} else if (opcode == VTA_OPCODE_ALU) { }
#ifndef NO_ALU
else if (opcode == VTA_OPCODE_ALU) {
// Iterate over micro op // Iterate over micro op
READ_ALU_UOP: for (int upc = uop_bgn; upc < uop_end; upc++) { READ_ALU_UOP: for (int upc = uop_bgn; upc < uop_end; upc++) {
#pragma HLS PIPELINE II = 2 rewind
// Read micro-op fields // Read micro-op fields
uop_T uop = uop_mem[upc]; uop_T uop = uop_mem[upc];
// Decode // Decode
bool reset_out = uop[VTA_UOP_ALU_0];
acc_idx_T dst_idx = acc_idx_T dst_idx =
uop.range(VTA_UOP_ALU_1_1, VTA_UOP_ALU_1_0) + dst_offset_in; uop.range(VTA_UOP_ALU_0_1, VTA_UOP_ALU_0_0) + dst_offset_in;
acc_idx_T src_idx = acc_idx_T src_idx =
uop.range(VTA_UOP_ALU_2_1, VTA_UOP_ALU_2_0) + src_offset_in; uop.range(VTA_UOP_ALU_1_1, VTA_UOP_ALU_1_0) + src_offset_in;
// Read input matrix and accum matrix
acc_vec_T dst_matrix[VTA_BATCH];
acc_vec_T src_matrix[VTA_BATCH];
for (int i = 0; i < VTA_BATCH; i++) {
#pragma HLS UNROLL complete
dst_matrix[i] = acc_mem[dst_idx][i];
src_matrix[i] = acc_mem[src_idx][i];
}
// Result matrices
acc_vec_T cmp_res[VTA_BATCH];
acc_vec_T add_res[VTA_BATCH];
acc_vec_T rshr_res[VTA_BATCH];
acc_vec_T lshr_res[VTA_BATCH];
out_vec_T short_cmp_res[VTA_BATCH];
out_vec_T short_add_res[VTA_BATCH];
out_vec_T short_rshr_res[VTA_BATCH];
out_vec_T short_lshr_res[VTA_BATCH];
// Perform ALU op over matrix elements // Perform ALU op over matrix elements
for (int i = 0; i < VTA_BATCH; i++) { for (int i = 0; i < VTA_BATCH; i++) {
// Read input matrix and accum matrix
acc_vec_T dst_vector = acc_mem[dst_idx][i];
acc_vec_T src_vector = acc_mem[src_idx][i];
// Result matrices
acc_vec_T cmp_res;
acc_vec_T add_res;
acc_vec_T shr_res;
out_vec_T short_cmp_res;
out_vec_T short_add_res;
out_vec_T short_shr_res;
// Results vector // Results vector
acc_vec_T res_vec = 0; acc_vec_T res_vec = 0;
for (int b = 0; b < VTA_BLOCK_OUT; b++) { for (int b = 0; b < VTA_BLOCK_OUT; b++) {
#pragma HLS PIPELINE II = 1 rewind
// Read in operands // Read in operands
acc_T src_0 = dst_matrix[i].range((b + 1) * VTA_ACC_WIDTH - 1, b * VTA_ACC_WIDTH); acc_T src_0 = dst_vector.range((b + 1) * VTA_ACC_WIDTH - 1, b * VTA_ACC_WIDTH);
acc_T src_1 = use_imm ? acc_T src_1 = use_imm ?
(acc_T) imm : (acc_T) imm :
src_matrix[i].range((b + 1) * VTA_ACC_WIDTH - 1, b * VTA_ACC_WIDTH); src_vector.range((b + 1) * VTA_ACC_WIDTH - 1, b * VTA_ACC_WIDTH);
// Compute Min/Max // Compute Min/Max
acc_T mix_val = src_0 < src_1 ? acc_T mix_val = src_0 < src_1 ?
(alu_opcode == VTA_ALU_OPCODE_MIN ? src_0 : src_1) : (alu_opcode == VTA_ALU_OPCODE_MIN ? src_0 : src_1) :
(alu_opcode == VTA_ALU_OPCODE_MIN ? src_1 : src_0); (alu_opcode == VTA_ALU_OPCODE_MIN ? src_1 : src_0);
cmp_res[i].range((b + 1) * VTA_ACC_WIDTH - 1, b * VTA_ACC_WIDTH) = mix_val; cmp_res.range((b + 1) * VTA_ACC_WIDTH - 1, b * VTA_ACC_WIDTH) = mix_val;
short_cmp_res[i].range((b + 1) * VTA_OUT_WIDTH - 1, b * VTA_OUT_WIDTH) = short_cmp_res.range((b + 1) * VTA_OUT_WIDTH - 1, b * VTA_OUT_WIDTH) =
(inp_T) mix_val.range(VTA_OUT_WIDTH - 1, 0); (out_T) mix_val.range(VTA_OUT_WIDTH - 1, 0);
// Compute Sum // Compute Sum
acc_T add_val = acc_T add_val =
src_0.range(VTA_ACC_WIDTH - 1, 0) + src_1.range(VTA_ACC_WIDTH - 1, 0); src_0.range(VTA_ACC_WIDTH - 1, 0) + src_1.range(VTA_ACC_WIDTH - 1, 0);
add_res[i].range((b + 1) * VTA_ACC_WIDTH - 1, b * VTA_ACC_WIDTH) = add_val; add_res.range((b + 1) * VTA_ACC_WIDTH - 1, b * VTA_ACC_WIDTH) = add_val;
short_add_res[i].range((b + 1) * VTA_OUT_WIDTH - 1, b * VTA_OUT_WIDTH) = short_add_res.range((b + 1) * VTA_OUT_WIDTH - 1, b * VTA_OUT_WIDTH) =
(inp_T) add_val.range(VTA_OUT_WIDTH - 1, 0); (out_T) add_val.range(VTA_OUT_WIDTH - 1, 0);
// Compute Right Shift // Compute Shift Right
acc_T rshr_val = acc_T shr_val =
src_0 >> (aluop_sh_imm_T) src_1.range(VTA_LOG_ACC_WIDTH - 1, 0); src_0 >> (aluop_sh_imm_T) src_1.range(VTA_LOG_ACC_WIDTH - 1, 0);
rshr_res[i].range((b + 1) * VTA_ACC_WIDTH - 1, b * VTA_ACC_WIDTH) = rshr_val; shr_res.range((b + 1) * VTA_ACC_WIDTH - 1, b * VTA_ACC_WIDTH) = shr_val;
short_rshr_res[i].range((b + 1) * VTA_OUT_WIDTH - 1, b * VTA_OUT_WIDTH) = short_shr_res.range((b + 1) * VTA_OUT_WIDTH - 1, b * VTA_OUT_WIDTH) =
(inp_T) rshr_val.range(VTA_OUT_WIDTH-1, 0); (out_T) shr_val.range(VTA_OUT_WIDTH-1, 0);
// Compute Left Shift
acc_T lshr_val =
src_0 << (aluop_sh_imm_T) src_1.range(VTA_LOG_ACC_WIDTH - 1, 0);
lshr_res[i].range((b + 1) * VTA_ACC_WIDTH - 1, b * VTA_ACC_WIDTH) = lshr_val;
short_lshr_res[i].range((b + 1) * VTA_OUT_WIDTH - 1, b * VTA_OUT_WIDTH) =
(inp_T) lshr_val.range(VTA_OUT_WIDTH-1, 0);
} }
// Store to accum memory/store buffer // Store to accum memory/store buffer
if (alu_opcode == VTA_ALU_OPCODE_MIN || if (alu_opcode == VTA_ALU_OPCODE_MIN ||
alu_opcode == VTA_ALU_OPCODE_MAX) { alu_opcode == VTA_ALU_OPCODE_MAX) {
acc_mem[dst_idx][i] = cmp_res[i]; acc_mem[dst_idx][i] = cmp_res;
out_mem[dst_idx][i] = short_cmp_res[i]; out_mem[dst_idx][i] = short_cmp_res;
} else if (alu_opcode == VTA_ALU_OPCODE_ADD) { } else if (alu_opcode == VTA_ALU_OPCODE_ADD) {
acc_mem[dst_idx][i] = add_res[i]; acc_mem[dst_idx][i] = add_res;
out_mem[dst_idx][i] = short_add_res[i]; out_mem[dst_idx][i] = short_add_res;
} else if (alu_opcode == VTA_ALU_OPCODE_SHR) { } else if (alu_opcode == VTA_ALU_OPCODE_SHR) {
acc_mem[dst_idx][i] = rshr_res[i]; acc_mem[dst_idx][i] = shr_res;
out_mem[dst_idx][i] = short_rshr_res[i]; out_mem[dst_idx][i] = short_shr_res;
} else if (alu_opcode == VTA_ALU_OPCODE_SHL) {
acc_mem[dst_idx][i] = lshr_res[i];
out_mem[dst_idx][i] = short_lshr_res[i];
} }
} }
} }
} }
#endif // NO_ALU
// Update offsets // Update offsets
dst_offset_in += dst_factor_in; dst_offset_in += dst_factor_in;
......
...@@ -92,7 +92,7 @@ typedef ap_uint<VTA_ALU_OPCODE_BIT_WIDTH> aluop_opcode_T; ...@@ -92,7 +92,7 @@ typedef ap_uint<VTA_ALU_OPCODE_BIT_WIDTH> aluop_opcode_T;
typedef ap_int<VTA_ALUOP_IMM_BIT_WIDTH> aluop_imm_T; typedef ap_int<VTA_ALUOP_IMM_BIT_WIDTH> aluop_imm_T;
/* \typedef aluop_opcode_T ALU operation shift immediate datatype*/ /* \typedef aluop_opcode_T ALU operation shift immediate datatype*/
typedef ap_uint<VTA_LOG_ACC_WIDTH> aluop_sh_imm_T; typedef ap_int<VTA_LOG_ACC_WIDTH> aluop_sh_imm_T;
/*! /*!
* \brief Fetch module. * \brief Fetch module.
......
...@@ -93,9 +93,7 @@ extern "C" { ...@@ -93,9 +93,7 @@ extern "C" {
/*! Instruction opcode field bitwidth */ /*! Instruction opcode field bitwidth */
#define VTA_OPCODE_BIT_WIDTH 3 #define VTA_OPCODE_BIT_WIDTH 3
/*! ALU opcode field bitwidth */ /*! ALU opcode field bitwidth */
#define VTA_ALU_OPCODE_BIT_WIDTH 3 #define VTA_ALU_OPCODE_BIT_WIDTH 2
/*! ALU instruction reset mode bitwidth */
#define VTA_ALU_RESET_BIT_WIDTH 2
/*! Opcode: load encoding */ /*! Opcode: load encoding */
#define VTA_OPCODE_LOAD 0 #define VTA_OPCODE_LOAD 0
...@@ -114,23 +112,8 @@ extern "C" { ...@@ -114,23 +112,8 @@ extern "C" {
#define VTA_ALU_OPCODE_MAX 1 #define VTA_ALU_OPCODE_MAX 1
/*! ALU opcode: binary add op */ /*! ALU opcode: binary add op */
#define VTA_ALU_OPCODE_ADD 2 #define VTA_ALU_OPCODE_ADD 2
/*! ALU opcode: binary sub op [NOT IMPLEMENTED] */ /*! ALU opcode: shift right by immediate op */
#define VTA_ALU_OPCODE_SUB 3 #define VTA_ALU_OPCODE_SHR 3
/*! ALU opcode: binary mul op [NOT IMPLEMENTED] */
#define VTA_ALU_OPCODE_MUL 4
/*! ALU opcode: shift left by immediate op */
#define VTA_ALU_OPCODE_SHL 5
/*! ALU opcode: shift right by immediate op [NOT IMPLEMENTED] */
#define VTA_ALU_OPCODE_SHR 6
/*! ALU instruction reset mode: set to min */
#define VTA_ALU_RESET_MIN 3
/*! ALU instruction reset mode: set to zero */
#define VTA_ALU_RESET_ZERO 0
/*! ALU instruction reset mode: no reset */
#define VTA_ALU_NO_RESET 2
/*! ALU instruction reset mode: set to max */
#define VTA_ALU_RESET_MAX 1
/*! Memory type field bitwidth */ /*! Memory type field bitwidth */
#define VTA_MEMOP_ID_BIT_WIDTH 2 #define VTA_MEMOP_ID_BIT_WIDTH 2
...@@ -149,7 +132,7 @@ extern "C" { ...@@ -149,7 +132,7 @@ extern "C" {
/*! ALU Instruction: immediate bitwidth*/ /*! ALU Instruction: immediate bitwidth*/
#define VTA_ALUOP_IMM_BIT_WIDTH 16 #define VTA_ALUOP_IMM_BIT_WIDTH 16
/*! GEMM/ALU Instruction: loop max iter bits */ /*! GEMM/ALU Instruction: loop max iter bits */
#define VTA_LOOP_ITER_WIDTH 15 #define VTA_LOOP_ITER_WIDTH 14
/*! Mem ID constant: uop memory */ /*! Mem ID constant: uop memory */
#define VTA_MEM_ID_UOP 0 #define VTA_MEM_ID_UOP 0
...@@ -190,16 +173,17 @@ extern "C" { ...@@ -190,16 +173,17 @@ extern "C" {
// arg 2: pop_next_dependence | bool | // arg 2: pop_next_dependence | bool |
// arg 3: push_prev_dependence | bool | // arg 3: push_prev_dependence | bool |
// arg 4: push_next_dependence | bool | // arg 4: push_next_dependence | bool |
// arg 5: uop_bgn | uop_idx_T | // arg 5: reset_reg | bool |
// arg 6: uop_end | uop_idx_T | // arg 6: uop_bgn | uop_idx_T |
// arg 7: iteration count ax0 | loop_T | // arg 7: uop_end | uop_idx_T |
// arg 8: iteration count ax1 | loop_T | // arg 8: iteration count ax0 | loop_T |
// arg 9: accum idx factor ax0 | acc_idx_T | // arg 9: iteration count ax1 | loop_T |
// arg a: accum idx factor ax1 | acc_idx_T | // arg a: accum idx factor ax0 | acc_idx_T |
// arg b: input idx factor ax0 | acc_idx_T | // arg b: accum idx factor ax1 | acc_idx_T |
// arg c: input idx factor ax1 | acc_idx_T | // arg c: input idx factor ax0 | inp_idx_T |
// arg d: weight idx factor ax0 | wgt_idx_T | // arg d: input idx factor ax1 | inp_idx_T |
// arg e: weight idx factor ax1 | wgt_idx_T | // arg e: weight idx factor ax0 | wgt_idx_T |
// arg f: weight idx factor ax1 | wgt_idx_T |
// //
// ALU // ALU
// _____________________________|_type______________| // _____________________________|_type______________|
...@@ -208,17 +192,18 @@ extern "C" { ...@@ -208,17 +192,18 @@ extern "C" {
// arg 2: pop_next_dependence | bool | // arg 2: pop_next_dependence | bool |
// arg 3: push_prev_dependence | bool | // arg 3: push_prev_dependence | bool |
// arg 4: push_next_dependence | bool | // arg 4: push_next_dependence | bool |
// arg 5: uop_bgn | uop_idx_T | // arg 5: reset_reg | bool |
// arg 6: uop_end | uop_idx_T | // arg 6: uop_bgn | uop_idx_T |
// arg 7: iteration count ax0 | loop_T | // arg 7: uop_end | uop_idx_T |
// arg 8: iteration count ax1 | loop_T | // arg 8: iteration count ax0 | loop_T |
// arg 9: dst idx factor ax0 | acc_idx_T | // arg 9: iteration count ax1 | loop_T |
// arg a: dst idx factor ax1 | acc_idx_T | // arg a: dst idx factor ax0 | acc_idx_T |
// arg b: src idx factor ax0 | acc_idx_T | // arg b: dst idx factor ax1 | acc_idx_T |
// arg c: src idx factor ax1 | acc_idx_T | // arg c: src idx factor ax0 | inp_idx_T |
// arg d: alu_opcode | aluop_opcode_T | // arg d: src idx factor ax1 | inp_idx_T |
// arg e: use_imm | bool | // arg e: alu_opcode | aluop_opcode_T |
// arg f: imm | alu_imm_T | // arg f: use_imm | bool |
// arg g: imm | alu_imm_T |
/*! Load/Store instruction start position of the opcode field */ /*! Load/Store instruction start position of the opcode field */
#define VTA_INSN_MEM_0_0 0 #define VTA_INSN_MEM_0_0 0
...@@ -285,88 +270,82 @@ extern "C" { ...@@ -285,88 +270,82 @@ extern "C" {
#define VTA_INSN_GEM_3 (VTA_INSN_GEM_2 + 1) #define VTA_INSN_GEM_3 (VTA_INSN_GEM_2 + 1)
/*! GEMM instruction position of the push_next_dependence field */ /*! GEMM instruction position of the push_next_dependence field */
#define VTA_INSN_GEM_4 (VTA_INSN_GEM_3 + 1) #define VTA_INSN_GEM_4 (VTA_INSN_GEM_3 + 1)
/*! GEMM instruction position of the reset register bit */
#define VTA_INSN_GEM_5 (VTA_INSN_GEM_4 + 1)
/*! GEMM instruction start position of the uop_bgn field */ /*! GEMM instruction start position of the uop_bgn field */
#define VTA_INSN_GEM_5_0 (VTA_INSN_GEM_4 + 1) #define VTA_INSN_GEM_6_0 (VTA_INSN_GEM_5 + 1)
/*! GEMM instruction end position of the uop_bgn field */ /*! GEMM instruction end position of the uop_bgn field */
#define VTA_INSN_GEM_5_1 (VTA_INSN_GEM_5_0 + VTA_LOG_UOP_BUFF_DEPTH - 1) #define VTA_INSN_GEM_6_1 (VTA_INSN_GEM_6_0 + VTA_LOG_UOP_BUFF_DEPTH - 1)
/*! GEMM instruction start position of the uop_end field */ /*! GEMM instruction start position of the uop_end field */
#define VTA_INSN_GEM_6_0 (VTA_INSN_GEM_5_1 + 1) #define VTA_INSN_GEM_7_0 (VTA_INSN_GEM_6_1 + 1)
/*! GEMM instruction end position of the uop_end field */ /*! GEMM instruction end position of the uop_end field */
#define VTA_INSN_GEM_6_1 (VTA_INSN_GEM_6_0 + VTA_LOG_UOP_BUFF_DEPTH + 1 - 1) #define VTA_INSN_GEM_7_1 (VTA_INSN_GEM_7_0 + VTA_LOG_UOP_BUFF_DEPTH + 1 - 1)
/*! GEMM instruction start position of the iter_out field */ /*! GEMM instruction start position of the iter_out field */
#define VTA_INSN_GEM_7_0 (VTA_INSN_GEM_6_1 + 1) #define VTA_INSN_GEM_8_0 (VTA_INSN_GEM_7_1 + 1)
/*! GEMM instruction end position of the iter_out field */ /*! GEMM instruction end position of the iter_out field */
#define VTA_INSN_GEM_7_1 (VTA_INSN_GEM_7_0 + VTA_LOOP_ITER_WIDTH - 1) #define VTA_INSN_GEM_8_1 (VTA_INSN_GEM_8_0 + VTA_LOOP_ITER_WIDTH - 1)
/*! GEMM instruction start position of the iter_in field */ /*! GEMM instruction start position of the iter_in field */
#define VTA_INSN_GEM_8_0 (VTA_INSN_GEM_7_1 + 1) #define VTA_INSN_GEM_9_0 (VTA_INSN_GEM_8_1 + 1)
/*! GEMM instruction end position of the iter_in field */ /*! GEMM instruction end position of the iter_in field */
#define VTA_INSN_GEM_8_1 (VTA_INSN_GEM_8_0 + VTA_LOOP_ITER_WIDTH - 1) #define VTA_INSN_GEM_9_1 (VTA_INSN_GEM_9_0 + VTA_LOOP_ITER_WIDTH - 1)
/*! GEMM instruction start position of the dst_factor_out field */ /*! GEMM instruction start position of the dst_factor_out field */
#define VTA_INSN_GEM_9_0 64 #define VTA_INSN_GEM_A_0 64
/*! GEMM instruction end position of the dst_factor_out field */ /*! GEMM instruction end position of the dst_factor_out field */
#define VTA_INSN_GEM_9_1 (VTA_INSN_GEM_9_0 + VTA_LOG_ACC_BUFF_DEPTH - 1) #define VTA_INSN_GEM_A_1 (VTA_INSN_GEM_A_0 + VTA_LOG_ACC_BUFF_DEPTH - 1)
/*! GEMM instruction start position of the dst_factor_in field */ /*! GEMM instruction start position of the dst_factor_in field */
#define VTA_INSN_GEM_A_0 (VTA_INSN_GEM_9_1 + 1) #define VTA_INSN_GEM_B_0 (VTA_INSN_GEM_A_1 + 1)
/*! GEMM instruction end position of the dst_factor_in field */ /*! GEMM instruction end position of the dst_factor_in field */
#define VTA_INSN_GEM_A_1 (VTA_INSN_GEM_A_0 + VTA_LOG_ACC_BUFF_DEPTH - 1) #define VTA_INSN_GEM_B_1 (VTA_INSN_GEM_B_0 + VTA_LOG_ACC_BUFF_DEPTH - 1)
/*! GEMM instruction start position of the src_factor_out field */ /*! GEMM instruction start position of the src_factor_out field */
#define VTA_INSN_GEM_B_0 (VTA_INSN_GEM_A_1 + 1) #define VTA_INSN_GEM_C_0 (VTA_INSN_GEM_B_1 + 1)
/*! GEMM instruction end position of the src_factor_out field */ /*! GEMM instruction end position of the src_factor_out field */
#define VTA_INSN_GEM_B_1 (VTA_INSN_GEM_B_0 + VTA_LOG_ACC_BUFF_DEPTH - 1) #define VTA_INSN_GEM_C_1 (VTA_INSN_GEM_C_0 + VTA_LOG_INP_BUFF_DEPTH - 1)
/*! GEMM instruction start position of the src_factor_in field */ /*! GEMM instruction start position of the src_factor_in field */
#define VTA_INSN_GEM_C_0 (VTA_INSN_GEM_B_1 + 1) #define VTA_INSN_GEM_D_0 (VTA_INSN_GEM_C_1 + 1)
/*! GEMM instruction end position of the src_factor_in field */ /*! GEMM instruction end position of the src_factor_in field */
#define VTA_INSN_GEM_C_1 (VTA_INSN_GEM_C_0 + VTA_LOG_ACC_BUFF_DEPTH - 1) #define VTA_INSN_GEM_D_1 (VTA_INSN_GEM_D_0 + VTA_LOG_INP_BUFF_DEPTH - 1)
/*! GEMM instruction start position of the wgt_factor_out field */ /*! GEMM instruction start position of the wgt_factor_out field */
#define VTA_INSN_GEM_D_0 (VTA_INSN_GEM_C_1 + 1) #define VTA_INSN_GEM_E_0 (VTA_INSN_GEM_D_1 + 1)
/*! GEMM instruction end position of the wgt_factor_out field */ /*! GEMM instruction end position of the wgt_factor_out field */
#define VTA_INSN_GEM_D_1 (VTA_INSN_GEM_D_0 + VTA_LOG_WGT_BUFF_DEPTH - 1) #define VTA_INSN_GEM_E_1 (VTA_INSN_GEM_E_0 + VTA_LOG_WGT_BUFF_DEPTH - 1)
/*! GEMM instruction start position of the wgt_factor_in field */ /*! GEMM instruction start position of the wgt_factor_in field */
#define VTA_INSN_GEM_E_0 (VTA_INSN_GEM_D_1 + 1) #define VTA_INSN_GEM_F_0 (VTA_INSN_GEM_E_1 + 1)
/*! GEMM instruction end position of the wgt_factor_in field */ /*! GEMM instruction end position of the wgt_factor_in field */
#define VTA_INSN_GEM_E_1 (VTA_INSN_GEM_E_0 + VTA_LOG_WGT_BUFF_DEPTH - 1) #define VTA_INSN_GEM_F_1 (VTA_INSN_GEM_F_0 + VTA_LOG_WGT_BUFF_DEPTH - 1)
/*! ALU instruction start position of the alu_opcode field */ /*! ALU instruction start position of the alu_opcode field */
#define VTA_INSN_ALU_D_0 (VTA_INSN_GEM_C_1 + 1) #define VTA_INSN_ALU_E_0 (VTA_INSN_GEM_D_1 + 1)
/*! ALU instruction end position of the alu_opcode field */ /*! ALU instruction end position of the alu_opcode field */
#define VTA_INSN_ALU_D_1 (VTA_INSN_ALU_D_0 + VTA_ALU_OPCODE_BIT_WIDTH - 1) #define VTA_INSN_ALU_E_1 (VTA_INSN_ALU_E_0 + VTA_ALU_OPCODE_BIT_WIDTH - 1)
/*! ALU instruction position of the use_imm field */ /*! ALU instruction position of the use_imm field */
#define VTA_INSN_ALU_E (VTA_INSN_ALU_D_1 + 1) #define VTA_INSN_ALU_F (VTA_INSN_ALU_E_1 + 1)
/*! ALU instruction start position of the immediate field */ /*! ALU instruction start position of the immediate field */
#define VTA_INSN_ALU_F_0 (VTA_INSN_ALU_E + 1) #define VTA_INSN_ALU_G_0 (VTA_INSN_ALU_F + 1)
/*! ALU instruction end position of the immediate field */ /*! ALU instruction end position of the immediate field */
#define VTA_INSN_ALU_F_1 (VTA_INSN_ALU_F_0 + VTA_ALUOP_IMM_BIT_WIDTH - 1) #define VTA_INSN_ALU_G_1 (VTA_INSN_ALU_G_0 + VTA_ALUOP_IMM_BIT_WIDTH - 1)
/*! GEMM Micro-op position of the reset_out field */
#define VTA_UOP_GEM_0 0
/*! GEMM Micro-op start position of the acc_idx field */ /*! GEMM Micro-op start position of the acc_idx field */
#define VTA_UOP_GEM_1_0 (VTA_UOP_GEM_0 + 1) #define VTA_UOP_GEM_0_0 0
/*! GEMM Micro-op end position of the acc_idx field */ /*! GEMM Micro-op end position of the acc_idx field */
#define VTA_UOP_GEM_1_1 (VTA_UOP_GEM_1_0 + VTA_LOG_ACC_BUFF_DEPTH - 1) #define VTA_UOP_GEM_0_1 (VTA_UOP_GEM_0_0 + VTA_LOG_ACC_BUFF_DEPTH - 1)
/*! GEMM Micro-op start position of the inp_idx field */ /*! GEMM Micro-op start position of the inp_idx field */
#define VTA_UOP_GEM_2_0 (VTA_UOP_GEM_1_1 + 1) #define VTA_UOP_GEM_1_0 (VTA_UOP_GEM_0_1 + 1)
/*! GEMM Micro-op end position of the inp_idx field */ /*! GEMM Micro-op end position of the inp_idx field */
#define VTA_UOP_GEM_2_1 (VTA_UOP_GEM_2_0 + VTA_LOG_ACC_BUFF_DEPTH - 1) #define VTA_UOP_GEM_1_1 (VTA_UOP_GEM_1_0 + VTA_LOG_INP_BUFF_DEPTH - 1)
/*! GEMM Micro-op start position of the wgt_idx field */ /*! GEMM Micro-op start position of the wgt_idx field */
#define VTA_UOP_GEM_3_0 (VTA_UOP_GEM_2_1 + 1) #define VTA_UOP_GEM_2_0 (VTA_UOP_GEM_1_1 + 1)
/*! GEMM Micro-op end position of the wgt_idx field */ /*! GEMM Micro-op end position of the wgt_idx field */
#define VTA_UOP_GEM_3_1 (VTA_UOP_GEM_3_0 + VTA_LOG_WGT_BUFF_DEPTH - 1) #define VTA_UOP_GEM_2_1 (VTA_UOP_GEM_2_0 + VTA_LOG_WGT_BUFF_DEPTH - 1)
/*! GEMM Micro-op position of the reset_out field */
#define VTA_UOP_ALU_0 0
/*! GEMM Micro-op start position of the acc_idx field */ /*! GEMM Micro-op start position of the acc_idx field */
#define VTA_UOP_ALU_1_0 (VTA_UOP_ALU_0 + 1) #define VTA_UOP_ALU_0_0 0
/*! GEMM Micro-op end position of the acc_idx field */ /*! GEMM Micro-op end position of the acc_idx field */
#define VTA_UOP_ALU_1_1 (VTA_UOP_ALU_1_0 + VTA_LOG_ACC_BUFF_DEPTH - 1) #define VTA_UOP_ALU_0_1 (VTA_UOP_ALU_0_0 + VTA_LOG_ACC_BUFF_DEPTH - 1)
/*! GEMM Micro-op start position of the inp_idx field */ /*! GEMM Micro-op start position of the inp_idx field */
#define VTA_UOP_ALU_2_0 (VTA_UOP_ALU_1_1 + 1) #define VTA_UOP_ALU_1_0 (VTA_UOP_ALU_0_1 + 1)
/*! GEMM Micro-op end position of the inp_idx field */ /*! GEMM Micro-op end position of the inp_idx field */
#define VTA_UOP_ALU_2_1 (VTA_UOP_ALU_2_0 + VTA_LOG_ACC_BUFF_DEPTH - 1) #define VTA_UOP_ALU_1_1 (VTA_UOP_ALU_1_0 + VTA_LOG_INP_BUFF_DEPTH - 1)
/*! GEMM Micro-op start position of the wgt_idx field */
#define VTA_UOP_ALU_3_0 (VTA_UOP_ALU_2_1 + 1)
/*! GEMM Micro-op end position of the wgt_idx field */
#define VTA_UOP_ALU_3_1 (VTA_UOP_ALU_3_0 + VTA_LOG_WGT_BUFF_DEPTH - 1)
/*! \brief VTA generic instruction */ /*! \brief VTA generic instruction */
typedef struct { typedef struct {
...@@ -454,10 +433,12 @@ typedef struct { ...@@ -454,10 +433,12 @@ typedef struct {
uint64_t push_prev_dep : 1; uint64_t push_prev_dep : 1;
/*! \brief Push dependence token to store stage */ /*! \brief Push dependence token to store stage */
uint64_t push_next_dep : 1; uint64_t push_next_dep : 1;
/*! \brief Reset register */
uint64_t reset_reg : 1;
/*! \brief Micro-op begin address */ /*! \brief Micro-op begin address */
uint64_t uop_bgn : VTA_LOG_UOP_BUFF_DEPTH; uint64_t uop_bgn : VTA_LOG_UOP_BUFF_DEPTH;
/*! \brief Micro-op end address */ /*! \brief Micro-op end address */
uint64_t uop_end : VTA_LOG_UOP_BUFF_DEPTH+1; uint64_t uop_end : VTA_LOG_UOP_BUFF_DEPTH + 1;
/*! \brief Iterations in the outer uop execution loop */ /*! \brief Iterations in the outer uop execution loop */
uint64_t iter_out : VTA_LOOP_ITER_WIDTH; uint64_t iter_out : VTA_LOOP_ITER_WIDTH;
/*! \brief Iterations in the inner uop execution loop */ /*! \brief Iterations in the inner uop execution loop */
...@@ -467,9 +448,9 @@ typedef struct { ...@@ -467,9 +448,9 @@ typedef struct {
/*! \brief Inner loop accumulator memory index factor */ /*! \brief Inner loop accumulator memory index factor */
uint64_t dst_factor_in : VTA_LOG_ACC_BUFF_DEPTH; uint64_t dst_factor_in : VTA_LOG_ACC_BUFF_DEPTH;
/*! \brief Outer loop input memory index factor */ /*! \brief Outer loop input memory index factor */
uint64_t src_factor_out : VTA_LOG_ACC_BUFF_DEPTH; uint64_t src_factor_out : VTA_LOG_INP_BUFF_DEPTH;
/*! \brief Inner loop input memory index factor */ /*! \brief Inner loop input memory index factor */
uint64_t src_factor_in : VTA_LOG_ACC_BUFF_DEPTH; uint64_t src_factor_in : VTA_LOG_INP_BUFF_DEPTH;
/*! \brief Outer loop weight memory index factor */ /*! \brief Outer loop weight memory index factor */
uint64_t wgt_factor_out : VTA_LOG_WGT_BUFF_DEPTH; uint64_t wgt_factor_out : VTA_LOG_WGT_BUFF_DEPTH;
/*! \brief Inner loop weight memory index factor */ /*! \brief Inner loop weight memory index factor */
...@@ -516,10 +497,12 @@ typedef struct { ...@@ -516,10 +497,12 @@ typedef struct {
uint64_t push_prev_dep : 1; uint64_t push_prev_dep : 1;
/*! \brief Push dependence token to store stage */ /*! \brief Push dependence token to store stage */
uint64_t push_next_dep : 1; uint64_t push_next_dep : 1;
/*! \brief Reset register */
uint64_t reset_reg : 1;
/*! \brief Micro-op begin address */ /*! \brief Micro-op begin address */
uint64_t uop_bgn : VTA_LOG_UOP_BUFF_DEPTH; uint64_t uop_bgn : VTA_LOG_UOP_BUFF_DEPTH;
/*! \brief Micro-op end address */ /*! \brief Micro-op end address */
uint64_t uop_end : VTA_LOG_UOP_BUFF_DEPTH+1; uint64_t uop_end : VTA_LOG_UOP_BUFF_DEPTH + 1;
/*! \brief Iterations in the outer uop execution loop */ /*! \brief Iterations in the outer uop execution loop */
uint64_t iter_out : VTA_LOOP_ITER_WIDTH; uint64_t iter_out : VTA_LOOP_ITER_WIDTH;
/*! \brief Iterations in the inner uop execution loop */ /*! \brief Iterations in the inner uop execution loop */
...@@ -529,9 +512,9 @@ typedef struct { ...@@ -529,9 +512,9 @@ typedef struct {
/*! \brief Inner loop accumulator memory destination index factor */ /*! \brief Inner loop accumulator memory destination index factor */
uint64_t dst_factor_in : VTA_LOG_ACC_BUFF_DEPTH; uint64_t dst_factor_in : VTA_LOG_ACC_BUFF_DEPTH;
/*! \brief Outer loop accumulator memory source index factor */ /*! \brief Outer loop accumulator memory source index factor */
uint64_t src_factor_out : VTA_LOG_ACC_BUFF_DEPTH; uint64_t src_factor_out : VTA_LOG_INP_BUFF_DEPTH;
/*! \brief Inner loop accumulator memory source index factor */ /*! \brief Inner loop accumulator memory source index factor */
uint64_t src_factor_in : VTA_LOG_ACC_BUFF_DEPTH; uint64_t src_factor_in : VTA_LOG_INP_BUFF_DEPTH;
/*! \brief ALU opcode */ /*! \brief ALU opcode */
uint64_t alu_opcode : VTA_ALU_OPCODE_BIT_WIDTH; uint64_t alu_opcode : VTA_ALU_OPCODE_BIT_WIDTH;
/*! \brief Use immediate is true */ /*! \brief Use immediate is true */
...@@ -554,12 +537,10 @@ union VTAInsn { ...@@ -554,12 +537,10 @@ union VTAInsn {
/*! \brief VTA micro-op for GEMM/ALU instruction */ /*! \brief VTA micro-op for GEMM/ALU instruction */
typedef struct { typedef struct {
/*! \brief Initialize acc_mem at index dst_idx to 0*/
uint32_t reset_out : 1;
/*! \brief Destination index (indexes accum buffer) */ /*! \brief Destination index (indexes accum buffer) */
uint32_t dst_idx : VTA_LOG_ACC_BUFF_DEPTH; uint32_t dst_idx : VTA_LOG_ACC_BUFF_DEPTH;
/*! \brief Source index (indexes input buffer for GEMM or accum buffer for ALU) */ /*! \brief Source index (indexes input buffer for GEMM or accum buffer for ALU) */
uint32_t src_idx : VTA_LOG_ACC_BUFF_DEPTH; uint32_t src_idx : VTA_LOG_INP_BUFF_DEPTH;
/*! \brief Weight index (indexes weight buffer) */ /*! \brief Weight index (indexes weight buffer) */
uint32_t wgt_idx : VTA_LOG_WGT_BUFF_DEPTH; uint32_t wgt_idx : VTA_LOG_WGT_BUFF_DEPTH;
} VTAUop; } VTAUop;
......
"""VTA configuration constants (should match hw_spec.h""" """VTA configuration constants (should match hw_spec.h"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
# Log of input/activation width in bits (default 3 -> 8 bits) import os
VTA_LOG_INP_WIDTH = 3 python_vta_dir = os.path.dirname(__file__)
# Log of kernel weight width in bits (default 3 -> 8 bits) print python_vta_dir
VTA_LOG_WGT_WIDTH = 3 filename = os.path.join(python_vta_dir, '../../config.mk')
# Log of accum width in bits (default 5 -> 32 bits)
VTA_LOG_ACC_WIDTH = 5 VTA_PYNQ_BRAM_W = 32
# Log of tensor batch size (A in (A,B)x(B,C) matrix multiplication) VTA_PYNQ_BRAM_D = 1024
VTA_LOG_BATCH = 0 VTA_PYNQ_NUM_BRAM = 124
# Log of tensor inner block size (B in (A,B)x(B,C) matrix multiplication)
VTA_LOG_BLOCK_IN = 4 keys = ["VTA_LOG_INP_WIDTH", "VTA_LOG_WGT_WIDTH", "VTA_LOG_ACC_WIDTH",
# Log of tensor outer block size (C in (A,B)x(B,C) matrix multiplication) "VTA_LOG_BATCH", "VTA_LOG_BLOCK_IN", "VTA_LOG_BLOCK_OUT",
VTA_LOG_BLOCK_OUT = 4 "VTA_LOG_UOP_BUFF_SIZE", "VTA_LOG_INP_BUFF_SIZE",
VTA_LOG_OUT_WIDTH = VTA_LOG_INP_WIDTH "VTA_LOG_WGT_BUFF_SIZE", "VTA_LOG_ACC_BUFF_SIZE"]
# Log of uop buffer size in Bytes
VTA_LOG_UOP_BUFF_SIZE = 15 params = {}
# Log of acc buffer size in Bytes
VTA_LOG_ACC_BUFF_SIZE = 17 if os.path.isfile(filename):
with open(filename) as f:
for line in f:
for k in keys:
if k+" =" in line:
val = line.split("=")[1].strip()
params[k] = int(val)
# print params
else:
print "Error: {} not found. Please make sure you have config.mk in your vta root".format(filename)
exit()
# The Constants # The Constants
VTA_WGT_WIDTH = 8 VTA_INP_WIDTH = 1 << params["VTA_LOG_INP_WIDTH"]
VTA_INP_WIDTH = VTA_WGT_WIDTH VTA_WGT_WIDTH = 1 << params["VTA_LOG_WGT_WIDTH"]
VTA_OUT_WIDTH = 32 VTA_OUT_WIDTH = 1 << params["VTA_LOG_ACC_WIDTH"]
VTA_TARGET = "VTA_PYNQ_TARGET" VTA_TARGET = "VTA_PYNQ_TARGET"
# Dimensions of the GEMM unit # Dimensions of the GEMM unit
# (BATCH,BLOCK_IN) x (BLOCK_IN,BLOCK_OUT) # (BATCH,BLOCK_IN) x (BLOCK_IN,BLOCK_OUT)
VTA_BATCH = 1 VTA_BATCH = 1 << params["VTA_LOG_BATCH"]
VTA_BLOCK_IN = 16 VTA_BLOCK_IN = 1 << params["VTA_LOG_BLOCK_IN"]
VTA_BLOCK_OUT = 16 VTA_BLOCK_OUT = 1 << params["VTA_LOG_BLOCK_OUT"]
# log-2 On-chip wgt buffer size in Bytes # log-2 On-chip uop buffer size in Bytes
VTA_LOG_WGT_BUFF_SIZE = 15 VTA_LOG_UOP_BUFF_SIZE = params["VTA_LOG_UOP_BUFF_SIZE"]
# log-2 On-chip input buffer size in Bytes # log-2 On-chip input buffer size in Bytes
VTA_LOG_INP_BUFF_SIZE = 15 VTA_LOG_INP_BUFF_SIZE = params["VTA_LOG_INP_BUFF_SIZE"]
# log-2 On-chip wgt buffer size in Bytes
VTA_LOG_WGT_BUFF_SIZE = params["VTA_LOG_WGT_BUFF_SIZE"]
# log-2 On-chip output buffer size in Bytes # log-2 On-chip output buffer size in Bytes
VTA_LOG_OUT_BUFF_SIZE = 17 VTA_LOG_OUT_BUFF_SIZE = params["VTA_LOG_ACC_BUFF_SIZE"]
# On-chip wgt buffer size in Bytes # Uop buffer size
VTA_WGT_BUFF_SIZE = 1 << VTA_LOG_WGT_BUFF_SIZE VTA_UOP_BUFF_SIZE = 1 << VTA_LOG_UOP_BUFF_SIZE
# Input buffer size # Input buffer size
VTA_INP_BUFF_SIZE = 1 << VTA_LOG_INP_BUFF_SIZE VTA_INP_BUFF_SIZE = 1 << VTA_LOG_INP_BUFF_SIZE
# On-chip wgt buffer size in Bytes
VTA_WGT_BUFF_SIZE = 1 << VTA_LOG_WGT_BUFF_SIZE
# Output buffer size. # Output buffer size.
VTA_OUT_BUFF_SIZE = 1 << VTA_LOG_OUT_BUFF_SIZE VTA_OUT_BUFF_SIZE = 1 << VTA_LOG_OUT_BUFF_SIZE
...@@ -69,11 +83,7 @@ VTA_MEM_ID_OUT = 4 ...@@ -69,11 +83,7 @@ VTA_MEM_ID_OUT = 4
VTA_ALU_OPCODE_MIN = 0 VTA_ALU_OPCODE_MIN = 0
VTA_ALU_OPCODE_MAX = 1 VTA_ALU_OPCODE_MAX = 1
VTA_ALU_OPCODE_ADD = 2 VTA_ALU_OPCODE_ADD = 2
VTA_ALU_OPCODE_SUB = 3 VTA_ALU_OPCODE_SHR = 3
VTA_ALU_OPCODE_MUL = 4
VTA_ALU_OPCODE_SHL = 5
VTA_ALU_OPCODE_SHR = 6
VTA_ALU_OPCODE_UNSET = 7
# Task queue id (pipeline stage) # Task queue id (pipeline stage)
VTA_QID_LOAD_INP = 1 VTA_QID_LOAD_INP = 1
......
...@@ -617,7 +617,6 @@ def inject_alu_intrin(stmt_in): ...@@ -617,7 +617,6 @@ def inject_alu_intrin(stmt_in):
extents.append(tmp_body.extent) extents.append(tmp_body.extent)
tmp_body = tmp_body.body tmp_body = tmp_body.body
# Derive opcode # Derive opcode
alu_opcode = spec.VTA_ALU_OPCODE_UNSET
if isinstance(loop_body.value, tvm.expr.Add): if isinstance(loop_body.value, tvm.expr.Add):
alu_opcode = spec.VTA_ALU_OPCODE_ADD alu_opcode = spec.VTA_ALU_OPCODE_ADD
lhs = loop_body.value.a lhs = loop_body.value.a
...@@ -640,9 +639,9 @@ def inject_alu_intrin(stmt_in): ...@@ -640,9 +639,9 @@ def inject_alu_intrin(stmt_in):
rhs = loop_body.value.b rhs = loop_body.value.b
elif isinstance(loop_body.value, tvm.expr.Call): elif isinstance(loop_body.value, tvm.expr.Call):
if loop_body.value.name == 'shift_left': if loop_body.value.name == 'shift_left':
alu_opcode = spec.VTA_ALU_OPCODE_SHL alu_opcode = spec.VTA_ALU_OPCODE_SHR
lhs = loop_body.value.args[0] lhs = loop_body.value.args[0]
rhs = loop_body.value.args[1] rhs = tvm.ir_pass.Simplify(-loop_body.value.args[1])
elif loop_body.value.name == 'shift_right': elif loop_body.value.name == 'shift_right':
alu_opcode = spec.VTA_ALU_OPCODE_SHR alu_opcode = spec.VTA_ALU_OPCODE_SHR
lhs = loop_body.value.args[0] lhs = loop_body.value.args[0]
...@@ -732,6 +731,7 @@ def inject_alu_intrin(stmt_in): ...@@ -732,6 +731,7 @@ def inject_alu_intrin(stmt_in):
tvm.ir_pass.Simplify(c/(spec.VTA_BATCH*spec.VTA_BLOCK_OUT)) for c in dst_coeff] tvm.ir_pass.Simplify(c/(spec.VTA_BATCH*spec.VTA_BLOCK_OUT)) for c in dst_coeff]
# Flatten the outer loops # Flatten the outer loops
if extents:
src_coeff, dst_coeff, extents = _flatten_loop(src_coeff, dst_coeff, extents) src_coeff, dst_coeff, extents = _flatten_loop(src_coeff, dst_coeff, extents)
# Insert ALU micro-ops # Insert ALU micro-ops
......
...@@ -117,7 +117,6 @@ class UopKernel { ...@@ -117,7 +117,6 @@ class UopKernel {
// The loop nest structure // The loop nest structure
VerifyDep(dst_index); VerifyDep(dst_index);
VTAUop op; VTAUop op;
op.reset_out = reset_out;
op.dst_idx = dst_index; op.dst_idx = dst_index;
op.src_idx = src_index; op.src_idx = src_index;
op.wgt_idx = wgt_index; op.wgt_idx = wgt_index;
...@@ -128,6 +127,12 @@ class UopKernel { ...@@ -128,6 +127,12 @@ class UopKernel {
} else { } else {
assert(mode_ == mode); assert(mode_ == mode);
} }
// Set reset_out field if unset
if (reset_out_ == 0xFFFFFFFF) {
reset_out_ = reset_out;
} else {
assert(reset_out_ == reset_out);
}
// Check kernel op and imm/imm_val in ALU mode // Check kernel op and imm/imm_val in ALU mode
if (mode == 1) { if (mode == 1) {
if (opcode_ == 0xFFFFFFFF) { if (opcode_ == 0xFFFFFFFF) {
...@@ -146,12 +151,11 @@ class UopKernel { ...@@ -146,12 +151,11 @@ class UopKernel {
uint32_t size = seq_.size(); uint32_t size = seq_.size();
printf("There are %u uops\n", size); printf("There are %u uops\n", size);
for (uint32_t i = 0; i < size; ++i) { for (uint32_t i = 0; i < size; ++i) {
printf("[%04u]\t acc=%u, inp=%u, wgt=%u, reset_out=%u\n", printf("[%04u]\t acc=%u, inp=%u, wgt=%u\n",
i, i,
seq_[i].dst_idx, seq_[i].dst_idx,
seq_[i].src_idx, seq_[i].src_idx,
seq_[i].wgt_idx, seq_[i].wgt_idx);
seq_[i].reset_out);
} }
printf("\n"); printf("\n");
} }
...@@ -160,6 +164,7 @@ class UopKernel { ...@@ -160,6 +164,7 @@ class UopKernel {
// The kernel's mode, opcode, immediate setting and value // The kernel's mode, opcode, immediate setting and value
uint32_t mode_{0xFFFFFFFF}; // UOP type: 0xFFFFFFFF - unset, 0 - GEMM, 1 - ALU uint32_t mode_{0xFFFFFFFF}; // UOP type: 0xFFFFFFFF - unset, 0 - GEMM, 1 - ALU
uint32_t opcode_{0xFFFFFFFF}; uint32_t opcode_{0xFFFFFFFF};
uint32_t reset_out_{0xFFFFFFFF};
bool use_imm_{false}; bool use_imm_{false};
uint16_t imm_val_{0}; uint16_t imm_val_{0};
...@@ -521,20 +526,6 @@ class InsnQueue : public BaseQueue { ...@@ -521,20 +526,6 @@ class InsnQueue : public BaseQueue {
} else { } else {
return "add"; return "add";
} }
} else if (opcode == VTA_ALU_OPCODE_SUB) {
if (use_imm) {
return "sub imm";
} else {
return "sub";
}
} else if (opcode == VTA_ALU_OPCODE_MUL) {
if (use_imm) {
return "mul imm";
} else {
return "mul";
}
} else if (opcode == VTA_ALU_OPCODE_SHL) {
return "shl";
} else if (opcode == VTA_ALU_OPCODE_SHR) { } else if (opcode == VTA_ALU_OPCODE_SHR) {
return "shr"; return "shr";
} }
...@@ -641,6 +632,7 @@ class InsnQueue : public BaseQueue { ...@@ -641,6 +632,7 @@ class InsnQueue : public BaseQueue {
static_cast<int>(c.mem.pop_next_dep), static_cast<int>(c.mem.pop_next_dep),
static_cast<int>(c.mem.push_prev_dep), static_cast<int>(c.mem.push_prev_dep),
static_cast<int>(c.mem.push_next_dep)); static_cast<int>(c.mem.push_next_dep));
printf("\treset_out: %d\n", static_cast<int>(c.gemm.reset_reg));
printf("\trange (%d, %d)\n", printf("\trange (%d, %d)\n",
static_cast<int>(c.gemm.uop_bgn), static_cast<int>(c.gemm.uop_bgn),
static_cast<int>(c.gemm.uop_end)); static_cast<int>(c.gemm.uop_end));
...@@ -662,6 +654,7 @@ class InsnQueue : public BaseQueue { ...@@ -662,6 +654,7 @@ class InsnQueue : public BaseQueue {
static_cast<int>(c.mem.pop_next_dep), static_cast<int>(c.mem.pop_next_dep),
static_cast<int>(c.mem.push_prev_dep), static_cast<int>(c.mem.push_prev_dep),
static_cast<int>(c.mem.push_next_dep)); static_cast<int>(c.mem.push_next_dep));
printf("\treset_out: %d\n", static_cast<int>(c.alu.reset_reg));
printf("\trange (%d, %d)\n", printf("\trange (%d, %d)\n",
static_cast<int>(c.alu.uop_bgn), static_cast<int>(c.alu.uop_bgn),
static_cast<int>(c.alu.uop_end)); static_cast<int>(c.alu.uop_end));
...@@ -1081,6 +1074,7 @@ class CommandQueue { ...@@ -1081,6 +1074,7 @@ class CommandQueue {
} }
VTAGemInsn* insn = insn_queue_.CreateGemInsn(); VTAGemInsn* insn = insn_queue_.CreateGemInsn();
insn->opcode = VTA_OPCODE_GEMM; insn->opcode = VTA_OPCODE_GEMM;
insn->reset_reg = kernel->reset_out_;
insn->uop_bgn = kernel->sram_begin_; insn->uop_bgn = kernel->sram_begin_;
insn->uop_end = kernel->sram_end_; insn->uop_end = kernel->sram_end_;
const std::vector<UopKernel::LoopEntry> &loop = kernel->loop(); const std::vector<UopKernel::LoopEntry> &loop = kernel->loop();
...@@ -1119,6 +1113,7 @@ class CommandQueue { ...@@ -1119,6 +1113,7 @@ class CommandQueue {
} }
VTAAluInsn* insn = insn_queue_.CreateAluInsn(); VTAAluInsn* insn = insn_queue_.CreateAluInsn();
insn->opcode = VTA_OPCODE_ALU; insn->opcode = VTA_OPCODE_ALU;
insn->reset_reg = kernel->reset_out_;
insn->uop_bgn = kernel->sram_begin_; insn->uop_bgn = kernel->sram_begin_;
insn->uop_end = kernel->sram_end_; insn->uop_end = kernel->sram_end_;
insn->alu_opcode = kernel->opcode_; insn->alu_opcode = kernel->opcode_;
......
...@@ -28,20 +28,6 @@ const char* getOpcodeString(int opcode, bool use_imm) { ...@@ -28,20 +28,6 @@ const char* getOpcodeString(int opcode, bool use_imm) {
} else { } else {
return "add"; return "add";
} }
} else if (opcode == VTA_ALU_OPCODE_SUB) {
if (use_imm) {
return "sub imm";
} else {
return "sub";
}
} else if (opcode == VTA_ALU_OPCODE_MUL) {
if (use_imm) {
return "mul imm";
} else {
return "mul";
}
} else if (opcode == VTA_ALU_OPCODE_SHL) {
return "shl";
} else if (opcode == VTA_ALU_OPCODE_SHR) { } else if (opcode == VTA_ALU_OPCODE_SHR) {
return "shr"; return "shr";
} }
...@@ -170,31 +156,6 @@ void freeBuffer(void * buffer) { ...@@ -170,31 +156,6 @@ void freeBuffer(void * buffer) {
#endif #endif
} }
VTAGenericInsn reset2DInsn(int type, int sram_offset, int y_size, int x_size, int x_stride,
int pop_prev_dep, int pop_next_dep, int push_prev_dep, int push_next_dep) {
// Converter
union VTAInsn converter;
// Memory instruction initialization
VTAMemInsn insn = {};
insn.opcode = VTA_OPCODE_LOAD;
insn.pop_prev_dep = pop_prev_dep;
insn.pop_next_dep = pop_next_dep;
insn.push_prev_dep = push_prev_dep;
insn.push_next_dep = push_next_dep;
insn.memory_type = type;
insn.sram_base = sram_offset;
insn.dram_base = 0;
insn.y_size = 0;
insn.x_size = x_size;
insn.x_stride = x_stride;
insn.y_pad_0 = y_size;
insn.y_pad_1 = 0;
insn.x_pad_0 = 0;
insn.x_pad_1 = 0;
converter.mem = insn;
return converter.generic;
}
VTAGenericInsn get2DLoadStoreInsn(int opcode, int type, int sram_offset, int dram_offset, VTAGenericInsn get2DLoadStoreInsn(int opcode, int type, int sram_offset, int dram_offset,
int y_size, int x_size, int x_stride, int y_pad, int x_pad, int pop_prev_dep, int pop_next_dep, int y_size, int x_size, int x_stride, int y_pad, int x_pad, int pop_prev_dep, int pop_next_dep,
int push_prev_dep, int push_next_dep) { int push_prev_dep, int push_next_dep) {
...@@ -258,6 +219,7 @@ VTAGenericInsn getGEMMInsn(int uop_offset, int batch, int in_feat, int out_feat, ...@@ -258,6 +219,7 @@ VTAGenericInsn getGEMMInsn(int uop_offset, int batch, int in_feat, int out_feat,
insn.pop_next_dep = pop_next_dep; insn.pop_next_dep = pop_next_dep;
insn.push_prev_dep = push_prev_dep; insn.push_prev_dep = push_prev_dep;
insn.push_next_dep = push_next_dep; insn.push_next_dep = push_next_dep;
insn.reset_reg = false;
if (!uop_compression) { if (!uop_compression) {
insn.uop_bgn = uop_offset; insn.uop_bgn = uop_offset;
insn.uop_end = uop_offset + batch * in_feat * out_feat; insn.uop_end = uop_offset + batch * in_feat * out_feat;
...@@ -296,6 +258,7 @@ VTAGenericInsn getALUInsn(int opcode, int vector_size, bool use_imm, int imm, bo ...@@ -296,6 +258,7 @@ VTAGenericInsn getALUInsn(int opcode, int vector_size, bool use_imm, int imm, bo
insn.pop_next_dep = pop_next_dep; insn.pop_next_dep = pop_next_dep;
insn.push_prev_dep = push_prev_dep; insn.push_prev_dep = push_prev_dep;
insn.push_next_dep = push_next_dep; insn.push_next_dep = push_next_dep;
insn.reset_reg = false;
if (!uop_compression) { if (!uop_compression) {
insn.uop_bgn = 0; insn.uop_bgn = 0;
insn.uop_end = vector_size; insn.uop_end = vector_size;
...@@ -335,6 +298,7 @@ VTAGenericInsn getFinishInsn(bool pop_prev, bool pop_next) { ...@@ -335,6 +298,7 @@ VTAGenericInsn getFinishInsn(bool pop_prev, bool pop_next) {
insn.pop_next_dep = pop_next; insn.pop_next_dep = pop_next;
insn.push_prev_dep = 0; insn.push_prev_dep = 0;
insn.push_next_dep = 0; insn.push_next_dep = 0;
insn.reset_reg = false;
insn.uop_bgn = 0; insn.uop_bgn = 0;
insn.uop_end = 0; insn.uop_end = 0;
insn.iter_out = 0; insn.iter_out = 0;
...@@ -364,7 +328,6 @@ VTAUop * getCopyUops(int y_size, int x_size, int uop_compression) { ...@@ -364,7 +328,6 @@ VTAUop * getCopyUops(int y_size, int x_size, int uop_compression) {
int uop_idx = 0; int uop_idx = 0;
for (int i = 0; i < y_size; i++) { for (int i = 0; i < y_size; i++) {
for (int j = 0; j < x_size; j++) { for (int j = 0; j < x_size; j++) {
uop_buf[uop_idx].reset_out = false;
uop_buf[uop_idx].dst_idx = i * x_size + j; uop_buf[uop_idx].dst_idx = i * x_size + j;
uop_buf[uop_idx].src_idx = 0; uop_buf[uop_idx].src_idx = 0;
uop_buf[uop_idx].wgt_idx = 0; uop_buf[uop_idx].wgt_idx = 0;
...@@ -372,7 +335,6 @@ VTAUop * getCopyUops(int y_size, int x_size, int uop_compression) { ...@@ -372,7 +335,6 @@ VTAUop * getCopyUops(int y_size, int x_size, int uop_compression) {
} }
} }
} else { } else {
uop_buf[0].reset_out = false;
uop_buf[0].dst_idx = 1; uop_buf[0].dst_idx = 1;
uop_buf[0].src_idx = 0; uop_buf[0].src_idx = 0;
uop_buf[0].wgt_idx = 0; uop_buf[0].wgt_idx = 0;
...@@ -399,7 +361,6 @@ VTAUop * getGEMMUops(int batch, int in_feat, int out_feat, bool uop_compression, ...@@ -399,7 +361,6 @@ VTAUop * getGEMMUops(int batch, int in_feat, int out_feat, bool uop_compression,
for (int i = 0; i < batch; i++) { for (int i = 0; i < batch; i++) {
for (int j = 0; j < in_feat; j++) { for (int j = 0; j < in_feat; j++) {
for (int k = 0; k < out_feat; k++) { for (int k = 0; k < out_feat; k++) {
uop_buf[uop_idx].reset_out = false;
uop_buf[uop_idx].dst_idx = i * out_feat + k; uop_buf[uop_idx].dst_idx = i * out_feat + k;
uop_buf[uop_idx].src_idx = i * in_feat + j; uop_buf[uop_idx].src_idx = i * in_feat + j;
uop_buf[uop_idx].wgt_idx = k * in_feat + j; uop_buf[uop_idx].wgt_idx = k * in_feat + j;
...@@ -409,7 +370,6 @@ VTAUop * getGEMMUops(int batch, int in_feat, int out_feat, bool uop_compression, ...@@ -409,7 +370,6 @@ VTAUop * getGEMMUops(int batch, int in_feat, int out_feat, bool uop_compression,
} }
} else { } else {
for (int i = 0; i < batch; i++) { for (int i = 0; i < batch; i++) {
uop_buf[i].reset_out = false;
uop_buf[i].dst_idx = i * out_feat; uop_buf[i].dst_idx = i * out_feat;
uop_buf[i].src_idx = i * in_feat; uop_buf[i].src_idx = i * in_feat;
uop_buf[i].wgt_idx = 0; uop_buf[i].wgt_idx = 0;
...@@ -422,7 +382,6 @@ VTAUop * getGEMMUops(int batch, int in_feat, int out_feat, bool uop_compression, ...@@ -422,7 +382,6 @@ VTAUop * getGEMMUops(int batch, int in_feat, int out_feat, bool uop_compression,
for (int i = 0; i < batch; i++) { for (int i = 0; i < batch; i++) {
for (int j = 0; j < in_feat; j++) { for (int j = 0; j < in_feat; j++) {
for (int k = 0; k < out_feat; k++) { for (int k = 0; k < out_feat; k++) {
uop_buf[uop_idx].reset_out = false;
uop_buf[uop_idx].dst_idx = i * out_feat + k; uop_buf[uop_idx].dst_idx = i * out_feat + k;
uop_buf[uop_idx].src_idx = batch * in_feat + i * in_feat + j; uop_buf[uop_idx].src_idx = batch * in_feat + i * in_feat + j;
uop_buf[uop_idx].wgt_idx = out_feat * in_feat + k * in_feat + j; uop_buf[uop_idx].wgt_idx = out_feat * in_feat + k * in_feat + j;
...@@ -432,7 +391,6 @@ VTAUop * getGEMMUops(int batch, int in_feat, int out_feat, bool uop_compression, ...@@ -432,7 +391,6 @@ VTAUop * getGEMMUops(int batch, int in_feat, int out_feat, bool uop_compression,
} }
} else { } else {
for (int i = 0; i < batch; i++) { for (int i = 0; i < batch; i++) {
uop_buf[batch+i].reset_out = false;
uop_buf[batch+i].dst_idx = i * out_feat; uop_buf[batch+i].dst_idx = i * out_feat;
uop_buf[batch+i].src_idx = batch * in_feat + i * in_feat; uop_buf[batch+i].src_idx = batch * in_feat + i * in_feat;
uop_buf[batch+i].wgt_idx = out_feat * in_feat; uop_buf[batch+i].wgt_idx = out_feat * in_feat;
...@@ -456,12 +414,10 @@ VTAUop * getMapALUUops(int vector_size, bool uop_compression) { ...@@ -456,12 +414,10 @@ VTAUop * getMapALUUops(int vector_size, bool uop_compression) {
if (!uop_compression) { if (!uop_compression) {
for (int i = 0; i < vector_size; i++) { for (int i = 0; i < vector_size; i++) {
uop_buf[i].reset_out = 0;
uop_buf[i].dst_idx = i; uop_buf[i].dst_idx = i;
uop_buf[i].src_idx = vector_size + i; uop_buf[i].src_idx = vector_size + i;
} }
} else { } else {
uop_buf[0].reset_out = 0;
uop_buf[0].dst_idx = 0; uop_buf[0].dst_idx = 0;
uop_buf[0].src_idx = vector_size; uop_buf[0].src_idx = vector_size;
} }
...@@ -511,7 +467,7 @@ void printParameters() { ...@@ -511,7 +467,7 @@ void printParameters() {
printf("VTA_INSN_GEM_2 [%d]\n", VTA_INSN_GEM_2); printf("VTA_INSN_GEM_2 [%d]\n", VTA_INSN_GEM_2);
printf("VTA_INSN_GEM_3 [%d]\n", VTA_INSN_GEM_3); printf("VTA_INSN_GEM_3 [%d]\n", VTA_INSN_GEM_3);
printf("VTA_INSN_GEM_4 [%d]\n", VTA_INSN_GEM_4); printf("VTA_INSN_GEM_4 [%d]\n", VTA_INSN_GEM_4);
printf("VTA_INSN_GEM_5 [%d-%d]\n", VTA_INSN_GEM_5_0, VTA_INSN_GEM_5_1); printf("VTA_INSN_GEM_5 [%d]\n", VTA_INSN_GEM_5);
printf("VTA_INSN_GEM_6 [%d-%d]\n", VTA_INSN_GEM_6_0, VTA_INSN_GEM_6_1); printf("VTA_INSN_GEM_6 [%d-%d]\n", VTA_INSN_GEM_6_0, VTA_INSN_GEM_6_1);
printf("VTA_INSN_GEM_7 [%d-%d]\n", VTA_INSN_GEM_7_0, VTA_INSN_GEM_7_1); printf("VTA_INSN_GEM_7 [%d-%d]\n", VTA_INSN_GEM_7_0, VTA_INSN_GEM_7_1);
printf("VTA_INSN_GEM_8 [%d-%d]\n", VTA_INSN_GEM_8_0, VTA_INSN_GEM_8_1); printf("VTA_INSN_GEM_8 [%d-%d]\n", VTA_INSN_GEM_8_0, VTA_INSN_GEM_8_1);
...@@ -521,17 +477,15 @@ void printParameters() { ...@@ -521,17 +477,15 @@ void printParameters() {
printf("VTA_INSN_GEM_C [%d-%d]\n", VTA_INSN_GEM_C_0, VTA_INSN_GEM_C_1); printf("VTA_INSN_GEM_C [%d-%d]\n", VTA_INSN_GEM_C_0, VTA_INSN_GEM_C_1);
printf("VTA_INSN_GEM_D [%d-%d]\n", VTA_INSN_GEM_D_0, VTA_INSN_GEM_D_1); printf("VTA_INSN_GEM_D [%d-%d]\n", VTA_INSN_GEM_D_0, VTA_INSN_GEM_D_1);
printf("VTA_INSN_GEM_E [%d-%d]\n", VTA_INSN_GEM_E_0, VTA_INSN_GEM_E_1); printf("VTA_INSN_GEM_E [%d-%d]\n", VTA_INSN_GEM_E_0, VTA_INSN_GEM_E_1);
printf("VTA_INSN_ALU_D [%d-%d]\n", VTA_INSN_ALU_D_0, VTA_INSN_ALU_D_1); printf("VTA_INSN_GEM_F [%d-%d]\n", VTA_INSN_GEM_F_0, VTA_INSN_GEM_F_1);
printf("VTA_INSN_ALU_E [%d]\n", VTA_INSN_ALU_E); printf("VTA_INSN_ALU_E [%d-%d]\n", VTA_INSN_ALU_E_0, VTA_INSN_ALU_E_1);
printf("VTA_INSN_ALU_F [%d-%d]\n", VTA_INSN_ALU_F_0, VTA_INSN_ALU_F_1); printf("VTA_INSN_ALU_F [%d]\n", VTA_INSN_ALU_F);
printf("VTA_UOP_GEM_0 [%d]\n", VTA_UOP_GEM_0); printf("VTA_INSN_ALU_G [%d-%d]\n", VTA_INSN_ALU_G_0, VTA_INSN_ALU_G_1);
printf("VTA_UOP_GEM_0 [%d-%d]\n", VTA_UOP_GEM_0_0, VTA_UOP_GEM_0_1);
printf("VTA_UOP_GEM_1 [%d-%d]\n", VTA_UOP_GEM_1_0, VTA_UOP_GEM_1_1); printf("VTA_UOP_GEM_1 [%d-%d]\n", VTA_UOP_GEM_1_0, VTA_UOP_GEM_1_1);
printf("VTA_UOP_GEM_2 [%d-%d]\n", VTA_UOP_GEM_2_0, VTA_UOP_GEM_2_1); printf("VTA_UOP_GEM_2 [%d-%d]\n", VTA_UOP_GEM_2_0, VTA_UOP_GEM_2_1);
printf("VTA_UOP_GEM_3 [%d-%d]\n", VTA_UOP_GEM_3_0, VTA_UOP_GEM_3_1); printf("VTA_UOP_ALU_0 [%d-%d]\n", VTA_UOP_ALU_0_0, VTA_UOP_ALU_0_1);
printf("VTA_UOP_ALU_0 [%d]\n", VTA_UOP_ALU_0);
printf("VTA_UOP_ALU_1 [%d-%d]\n", VTA_UOP_ALU_1_0, VTA_UOP_ALU_1_1); printf("VTA_UOP_ALU_1 [%d-%d]\n", VTA_UOP_ALU_1_0, VTA_UOP_ALU_1_1);
printf("VTA_UOP_ALU_2 [%d-%d]\n", VTA_UOP_ALU_2_0, VTA_UOP_ALU_2_1);
printf("VTA_UOP_ALU_3 [%d-%d]\n", VTA_UOP_ALU_3_0, VTA_UOP_ALU_3_1);
} }
void printInstruction(int num_insn, VTAGenericInsn *insns) { void printInstruction(int num_insn, VTAGenericInsn *insns) {
...@@ -601,6 +555,7 @@ void printInstruction(int num_insn, VTAGenericInsn *insns) { ...@@ -601,6 +555,7 @@ void printInstruction(int num_insn, VTAGenericInsn *insns) {
printf("\trange (%d, %d)\n", printf("\trange (%d, %d)\n",
static_cast<int>(c.gemm.uop_bgn), static_cast<int>(c.gemm.uop_bgn),
static_cast<int>(c.gemm.uop_end)); static_cast<int>(c.gemm.uop_end));
printf("\treset_out: %d\n", static_cast<int>(c.gemm.reset_reg));
printf("\touter loop - iter: %d, acc: %d, inp: %d, wgt: %d\n", printf("\touter loop - iter: %d, acc: %d, inp: %d, wgt: %d\n",
static_cast<int>(c.gemm.iter_out), static_cast<int>(c.gemm.iter_out),
static_cast<int>(c.gemm.dst_factor_out), static_cast<int>(c.gemm.dst_factor_out),
...@@ -634,6 +589,7 @@ void printInstruction(int num_insn, VTAGenericInsn *insns) { ...@@ -634,6 +589,7 @@ void printInstruction(int num_insn, VTAGenericInsn *insns) {
static_cast<int>(c.mem.pop_next_dep), static_cast<int>(c.mem.pop_next_dep),
static_cast<int>(c.mem.push_prev_dep), static_cast<int>(c.mem.push_prev_dep),
static_cast<int>(c.mem.push_next_dep)); static_cast<int>(c.mem.push_next_dep));
printf("\treset_out: %d\n", static_cast<int>(c.alu.reset_reg));
printf("\trange (%d, %d)\n", printf("\trange (%d, %d)\n",
static_cast<int>(c.alu.uop_bgn), static_cast<int>(c.alu.uop_bgn),
static_cast<int>(c.alu.uop_end)); static_cast<int>(c.alu.uop_end));
...@@ -662,8 +618,7 @@ void printMicroOp(int num_uop, VTAUop *uops) { ...@@ -662,8 +618,7 @@ void printMicroOp(int num_uop, VTAUop *uops) {
for (int i = 0; i < num_uop; i++) { for (int i = 0; i < num_uop; i++) {
// Read micro-op // Read micro-op
printf("DEBUG - UOP %u: ", i); printf("DEBUG - UOP %u: ", i);
printf("rst_out=%u, acc=%u, inp= %u, wgt=%u\n", uops[i].reset_out, uops[i].dst_idx, printf("acc=%u, inp= %u, wgt=%u\n", uops[i].dst_idx, uops[i].src_idx, uops[i].wgt_idx);
uops[i].src_idx, uops[i].wgt_idx);
} }
} }
...@@ -671,7 +626,6 @@ int alu_test(int opcode, bool use_imm, int batch, int vector_size, bool uop_comp ...@@ -671,7 +626,6 @@ int alu_test(int opcode, bool use_imm, int batch, int vector_size, bool uop_comp
// Some assertions // Some assertions
assert(batch % VTA_BATCH == 0); assert(batch % VTA_BATCH == 0);
assert(vector_size % VTA_BLOCK_OUT == 0); assert(vector_size % VTA_BLOCK_OUT == 0);
assert(!(opcode == VTA_ALU_OPCODE_SHL && !use_imm));
assert(!(opcode == VTA_ALU_OPCODE_SHR && !use_imm)); assert(!(opcode == VTA_ALU_OPCODE_SHR && !use_imm));
printf("=====================================================================================\n"); printf("=====================================================================================\n");
printf("INFO - ALU test of %s: batch=%d, vector_size=%d, uop_compression=%d\n", printf("INFO - ALU test of %s: batch=%d, vector_size=%d, uop_compression=%d\n",
...@@ -701,16 +655,9 @@ int alu_test(int opcode, bool use_imm, int batch, int vector_size, bool uop_comp ...@@ -701,16 +655,9 @@ int alu_test(int opcode, bool use_imm, int batch, int vector_size, bool uop_comp
} else if (opcode == VTA_ALU_OPCODE_ADD) { } else if (opcode == VTA_ALU_OPCODE_ADD) {
immediate[b] = static_cast<acc_T>( immediate[b] = static_cast<acc_T>(
rand_r(&globalSeed) % (1LL << (VTA_INP_WIDTH / 2)) - (1LL << (VTA_INP_WIDTH / 2 - 1))); rand_r(&globalSeed) % (1LL << (VTA_INP_WIDTH / 2)) - (1LL << (VTA_INP_WIDTH / 2 - 1)));
} else if (opcode == VTA_ALU_OPCODE_SUB) {
immediate[b] = static_cast<acc_T>(
rand_r(&globalSeed) % (1LL << (VTA_INP_WIDTH / 2)) - (1LL << (VTA_INP_WIDTH / 2 - 1)));
} else if (opcode == VTA_ALU_OPCODE_MUL) {
immediate[b] = static_cast<acc_T>(
rand_r(&globalSeed) % (1LL << (VTA_INP_WIDTH / 2)) - (1LL << (VTA_INP_WIDTH / 2 - 1)));
} else if (opcode == VTA_ALU_OPCODE_SHL) {
immediate[b] = static_cast<acc_T>(rand_r(&globalSeed) % (VTA_INP_WIDTH + 1));
} else if (opcode == VTA_ALU_OPCODE_SHR) { } else if (opcode == VTA_ALU_OPCODE_SHR) {
immediate[b] = static_cast<acc_T>(rand_r(&globalSeed) % (VTA_INP_WIDTH + 1)); immediate[b] = static_cast<acc_T>(
rand_r(&globalSeed) % VTA_ACC_WIDTH - VTA_ACC_WIDTH/2);
} }
} }
...@@ -783,18 +730,6 @@ int alu_test(int opcode, bool use_imm, int batch, int vector_size, bool uop_comp ...@@ -783,18 +730,6 @@ int alu_test(int opcode, bool use_imm, int batch, int vector_size, bool uop_comp
} else if (opcode == VTA_ALU_OPCODE_ADD) { } else if (opcode == VTA_ALU_OPCODE_ADD) {
inputs[i][j] = static_cast<acc_T>( inputs[i][j] = static_cast<acc_T>(
rand_r(&globalSeed) % (1LL << (VTA_INP_WIDTH - 1)) - (1LL << (VTA_INP_WIDTH - 2))); rand_r(&globalSeed) % (1LL << (VTA_INP_WIDTH - 1)) - (1LL << (VTA_INP_WIDTH - 2)));
} else if (opcode == VTA_ALU_OPCODE_SUB) {
inputs[i][j] = static_cast<acc_T>(
rand_r(&globalSeed) % (1LL << (VTA_INP_WIDTH - 1)) - (1LL << (VTA_INP_WIDTH - 2)));
} else if (opcode == VTA_ALU_OPCODE_MUL) {
inputs[i][j] = static_cast<acc_T>(
rand_r(&globalSeed) % (1LL << (VTA_INP_WIDTH / 2)) - (1LL << (VTA_INP_WIDTH / 2 - 1)));
} else if (opcode == VTA_ALU_OPCODE_SHL) {
inputs[i][j] = static_cast<acc_T>(
rand_r(&globalSeed) % (1LL << (VTA_INP_WIDTH - 1)) - (1LL << (VTA_INP_WIDTH - 2)));
} else if (opcode == VTA_ALU_OPCODE_SHR) {
inputs[i][j] = static_cast<acc_T>(
rand_r(&globalSeed) % (1LL << (VTA_INP_WIDTH - 1)) - (1LL << (VTA_INP_WIDTH - 2)));
} }
} }
} }
...@@ -830,22 +765,12 @@ int alu_test(int opcode, bool use_imm, int batch, int vector_size, bool uop_comp ...@@ -830,22 +765,12 @@ int alu_test(int opcode, bool use_imm, int batch, int vector_size, bool uop_comp
} else { } else {
tmp = inputs[i][j] + immediate[i / VTA_BATCH]; tmp = inputs[i][j] + immediate[i / VTA_BATCH];
} }
} else if (opcode == VTA_ALU_OPCODE_SUB) {
if (!use_imm) {
tmp = inputs[i][j] - inputs[i][j + vector_size];
} else {
tmp = inputs[i][j] - immediate[i / VTA_BATCH];
}
} else if (opcode == VTA_ALU_OPCODE_MUL) {
if (!use_imm) {
tmp = inputs[i][j] * inputs[i][j + vector_size];
} else {
tmp = inputs[i][j] * immediate[i / VTA_BATCH];
}
} else if (opcode == VTA_ALU_OPCODE_SHL) {
tmp = inputs[i][j] << immediate[i / VTA_BATCH];
} else if (opcode == VTA_ALU_OPCODE_SHR) { } else if (opcode == VTA_ALU_OPCODE_SHR) {
if (immediate[i / VTA_BATCH] >= 0) {
tmp = inputs[i][j] >> immediate[i / VTA_BATCH]; tmp = inputs[i][j] >> immediate[i / VTA_BATCH];
} else {
tmp = inputs[i][j] << (0 - immediate[i / VTA_BATCH]);
}
} }
// Set // Set
outputs_ref[i][j] = (out_T) tmp; outputs_ref[i][j] = (out_T) tmp;
...@@ -876,7 +801,7 @@ int alu_test(int opcode, bool use_imm, int batch, int vector_size, bool uop_comp ...@@ -876,7 +801,7 @@ int alu_test(int opcode, bool use_imm, int batch, int vector_size, bool uop_comp
(volatile inp_vec_T *) NULL, (volatile inp_vec_T *) NULL,
(volatile wgt_vec_T *) NULL, (volatile wgt_vec_T *) NULL,
(volatile acc_vec_T *) bias_buf, (volatile acc_vec_T *) bias_buf,
(volatile inp_vec_T *) output_buf); (volatile out_vec_T *) output_buf);
#endif #endif
// Unpack output buffer // Unpack output buffer
...@@ -1149,7 +1074,7 @@ int blocked_gemm_test(int batch, int channels, int block, bool uop_compression, ...@@ -1149,7 +1074,7 @@ int blocked_gemm_test(int batch, int channels, int block, bool uop_compression,
(volatile inp_vec_T *) input_buf, (volatile inp_vec_T *) input_buf,
(volatile wgt_vec_T *) weight_buf, (volatile wgt_vec_T *) weight_buf,
(volatile acc_vec_T *) bias_buf, (volatile acc_vec_T *) bias_buf,
(volatile inp_vec_T *) output_buf); (volatile out_vec_T *) output_buf);
#endif #endif
// Unpack output data // Unpack output data
......
...@@ -16,18 +16,20 @@ inp_dtype = "int%d" % vta.VTA_INP_WIDTH ...@@ -16,18 +16,20 @@ inp_dtype = "int%d" % vta.VTA_INP_WIDTH
wgt_dtype = "int%d" % vta.VTA_WGT_WIDTH wgt_dtype = "int%d" % vta.VTA_WGT_WIDTH
Workload = namedtuple("Conv2DWorkload", Workload = namedtuple("Conv2DWorkload",
['height', 'width', 'in_filter', 'out_filter', ['batch', 'height', 'width', 'in_filter', 'out_filter',
'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride']) 'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])
class Conv2DSchedule(object): class Conv2DSchedule(object):
def __init__(self, def __init__(self,
oc_factor, b_factor=1,
oc_factor=1,
ko_factor=1, ko_factor=1,
h_factor=1, h_factor=1,
w_factor=0, w_factor=0,
oc_nthread=0, oc_nthread=0,
h_nthread=0, h_nthread=0,
debug_sync=False): debug_sync=False):
self.b_factor = b_factor
self.oc_factor = oc_factor self.oc_factor = oc_factor
self.ko_factor = ko_factor self.ko_factor = ko_factor
self.h_factor = h_factor self.h_factor = h_factor
...@@ -35,10 +37,140 @@ class Conv2DSchedule(object): ...@@ -35,10 +37,140 @@ class Conv2DSchedule(object):
self.oc_nthread = oc_nthread self.oc_nthread = oc_nthread
self.h_nthread = h_nthread self.h_nthread = h_nthread
self.debug_sync = debug_sync self.debug_sync = debug_sync
def __str__(self):
return "{}.{}.{}.{}.{}.{}.{}".format(
self.b_factor, self.oc_factor, self.ko_factor,
self.h_factor, self.w_factor,
self.oc_nthread, self.h_nthread)
Schedule = Conv2DSchedule Schedule = Conv2DSchedule
def test_conv2d_chwv(key, batch_size, wl, plan, log_frame, profile=True): def get_insn_count(layer, sched):
b, h, w, ci, co = sched
b_factor = b
h_factor = layer.height / h
w_factor = layer.width / w
ci_factor = int(np.ceil(float(layer.in_filter) / (ci * vta.VTA_BLOCK_IN)))
co_factor = int(np.ceil(float(layer.out_filter) / (co * vta.VTA_BLOCK_OUT)))
input_xfers = b_factor * h_factor * w_factor * co_factor * ci_factor
weight_xfers = b_factor * h_factor * w_factor * co_factor * ci_factor
output_xfers = b_factor * h_factor * w_factor * co_factor
# compute instruction count
# output transfer factor: 4 (INIT, GEMM, ALU, STORE)
# offset: 5 (3 uop kernels, 1 initial dep push, 1 finish, co_factor)
insn_count = input_xfers + weight_xfers + (output_xfers * 4) + 5 + co_factor
return insn_count
def find_schedules(layer, mtOnly=False, bestOnly=False):
# Helper function to get factors
def find_factors(n):
factors = []
for i in range(1, n+1):
if n % i == 0:
factors.append(i)
return factors
# Scheduling exploration
batch_factors = find_factors(int(np.ceil(float(layer.batch) / vta.VTA_BATCH)))
height_factors = find_factors(layer.height / layer.hstride)
width_factors = find_factors(layer.width / layer.wstride)
cin_factors = find_factors(int(np.ceil(float(layer.in_filter) / vta.VTA_BLOCK_IN)))
cout_factors = find_factors(int(np.ceil(float(layer.out_filter) / vta.VTA_BLOCK_OUT)))
ht_factors = [1, 2]
cot_factors = [1, 2]
# Explore schedules
schedules = []
for b in batch_factors:
for h in height_factors:
for w in width_factors:
for ci in cin_factors:
for co in cout_factors:
# FIXME: filter because of 2D load
if w == layer.width/layer.wstride or (w != layer.width/layer.wstride and co == 1):
# FIXME: filter because of 2D load
if ci == 1:
schedules.append([b, h, w, ci, co])
# Filter the schedules that wouldn't work in the available BRAM sizes
input_elem_size_b = vta.VTA_BATCH * vta.VTA_BLOCK_IN * vta.VTA_INP_WIDTH
weight_elem_size_b = vta.VTA_BLOCK_IN * vta.VTA_BLOCK_OUT * vta.VTA_WGT_WIDTH
output_elem_size_b = vta.VTA_BATCH * vta.VTA_BLOCK_OUT * vta.VTA_OUT_WIDTH
input_brams_capacity_b = vta.VTA_INP_BUFF_SIZE * 8
weight_brams_capacity_b = vta.VTA_WGT_BUFF_SIZE * 8
output_brams_capacity_b = vta.VTA_OUT_BUFF_SIZE * 8
fil_sched = []
xfer_size = []
for sched in schedules:
b, h, w, ci, co = sched
for ht in [1, 2]:
for cot in [1, 2]:
# Make sure to filter cases where we apply threading on two axes
# or cases where the threading factors for h and co are not
# factors of h and co
if not (ht == 2 and cot == 2) and h % ht == 0 and co % cot == 0:
# If in multi-threaded mode, only allow for mt configs:
if (mtOnly and (ht == 2 or cot == 2)) or not mtOnly:
h /= ht
co /= cot
input_tile_elems = b * \
((h - 1) * layer.hstride + layer.hkernel) * \
((w - 1) * layer.wstride + layer.wkernel) * ci
weight_tile_elems = layer.hkernel * layer.wkernel * ci * co
output_tile_elems = b * h * w * co
insn_count = get_insn_count(layer, sched)
# 1. Check input capacity
# 2. Check weight capacity
# 3. Check output capacity
# 4. Check instruction capacity
# 5. Make sure that we don't write to the same acc location
# within 2 consecutive cycles
if input_tile_elems*input_elem_size_b <= input_brams_capacity_b/(cot*ht) and \
weight_tile_elems*weight_elem_size_b <= weight_brams_capacity_b and \
output_tile_elems*output_elem_size_b <= output_brams_capacity_b/(cot*ht) and \
insn_count <= vta.VTA_MAX_XFER / 16 and \
h > 2 and w > 2:
schedule = Schedule(oc_factor=co, ko_factor=ci, h_factor=h,
w_factor=w, oc_nthread=cot, h_nthread=ht)
fil_sched.append(schedule)
xfer_size.append(get_data_movementB(schedule, layer))
if bestOnly:
return [fil_sched[xfer_size.index(min(xfer_size))]]
else:
return fil_sched
def get_data_movementB(sched, layer):
# Derive data movement
input_elem_size_b = vta.VTA_BATCH * vta.VTA_BLOCK_IN * vta.VTA_INP_WIDTH
weight_elem_size_b = vta.VTA_BLOCK_IN * vta.VTA_BLOCK_OUT * vta.VTA_WGT_WIDTH
output_elem_size_b = vta.VTA_BATCH * vta.VTA_BLOCK_OUT * vta.VTA_OUT_WIDTH
b = sched.b_factor
h = sched.h_factor
w = sched.w_factor
ci = sched.ko_factor
co = sched.oc_factor
ht = sched.h_nthread
cot = sched.oc_nthread
input_tile_elems = b * \
((h - 1) * layer.hstride + layer.hkernel) * \
((w - 1) * layer.wstride + layer.wkernel) * ci
weight_tile_elems = layer.hkernel * layer.wkernel * ci
output_tile_elems = b * h * w * co
# Derive factors
b_factor = int(np.ceil(float(layer.batch) / (b * vta.VTA_BATCH)))
h_factor = (layer.height / layer.hstride) / h
w_factor = (layer.width / layer.wstride) / w
ci_factor = int(np.ceil(float(layer.in_filter) / (ci * vta.VTA_BLOCK_IN)))
co_factor = int(np.ceil(float(layer.out_filter) / (co * vta.VTA_BLOCK_OUT)))
# Derive transfers
input_xfers = b_factor * h_factor * w_factor * co_factor * ci_factor
weight_xfers = b_factor * h_factor * w_factor * co_factor * ci_factor
output_xfers = b_factor * h_factor * w_factor * co_factor
# Compute total transfer sizes
input_xfer_B = input_tile_elems * input_xfers * input_elem_size_b / 8
weight_xfer_B = weight_tile_elems * weight_xfers * weight_elem_size_b / 8
output_xfer_B = output_tile_elems * output_xfers * output_elem_size_b / 8
total_xfer_B = input_xfer_B + weight_xfer_B + output_xfer_B
return total_xfer_B
def test_conv2d_chwv(layer, key, batch_size, wl, sched, log_frame, profile=True):
assert batch_size % vta.VTA_BATCH == 0 assert batch_size % vta.VTA_BATCH == 0
assert wl.in_filter % vta.VTA_BLOCK_IN == 0 assert wl.in_filter % vta.VTA_BLOCK_IN == 0
assert wl.out_filter % vta.VTA_BLOCK_OUT == 0 assert wl.out_filter % vta.VTA_BLOCK_OUT == 0
...@@ -71,6 +203,7 @@ def test_conv2d_chwv(key, batch_size, wl, plan, log_frame, profile=True): ...@@ -71,6 +203,7 @@ def test_conv2d_chwv(key, batch_size, wl, plan, log_frame, profile=True):
res_shf = tvm.compute(res_shape, lambda *i: res_cnv(*i) >> 8, name="res_shf") res_shf = tvm.compute(res_shape, lambda *i: res_cnv(*i) >> 8, name="res_shf")
res = tvm.compute(res_shape, lambda *i: res_shf(*i).astype(inp_dtype), name="res") res = tvm.compute(res_shape, lambda *i: res_shf(*i).astype(inp_dtype), name="res")
num_ops = batch_size * fout_height * fout_width * wl.hkernel * wl.wkernel * wl.out_filter * wl.in_filter num_ops = batch_size * fout_height * fout_width * wl.hkernel * wl.wkernel * wl.out_filter * wl.in_filter
total_xfer_B = get_data_movementB(sched, wl)
def verify(s, check_correctness): def verify(s, check_correctness):
mod = tvm.build(s, [data, kernel, res], "ext_dev", target, name="conv2d") mod = tvm.build(s, [data, kernel, res], "ext_dev", target, name="conv2d")
...@@ -98,7 +231,7 @@ def test_conv2d_chwv(key, batch_size, wl, plan, log_frame, profile=True): ...@@ -98,7 +231,7 @@ def test_conv2d_chwv(key, batch_size, wl, plan, log_frame, profile=True):
data_arr = tvm.nd.array(data_packed, ctx) data_arr = tvm.nd.array(data_packed, ctx)
kernel_arr = tvm.nd.array(kernel_packed, ctx) kernel_arr = tvm.nd.array(kernel_packed, ctx)
res_arr = tvm.nd.array(res_np, ctx) res_arr = tvm.nd.array(res_np, ctx)
time_f = f.time_evaluator("conv2d", ctx, number=10) time_f = f.time_evaluator("conv2d", ctx, number=1)
cost = time_f(data_arr, kernel_arr, res_arr) cost = time_f(data_arr, kernel_arr, res_arr)
res_unpack = res_arr.asnumpy().transpose( res_unpack = res_arr.asnumpy().transpose(
(0, 4, 1, 5, 2, 3)).reshape(batch_size, wl.out_filter, fout_height, fout_width) (0, 4, 1, 5, 2, 3)).reshape(batch_size, wl.out_filter, fout_height, fout_width)
...@@ -125,10 +258,10 @@ def test_conv2d_chwv(key, batch_size, wl, plan, log_frame, profile=True): ...@@ -125,10 +258,10 @@ def test_conv2d_chwv(key, batch_size, wl, plan, log_frame, profile=True):
s[res_cnv].set_scope(vta.SCOPE_OUT) s[res_cnv].set_scope(vta.SCOPE_OUT)
s[res_shf].set_scope(vta.SCOPE_OUT) s[res_shf].set_scope(vta.SCOPE_OUT)
# tile # tile
oc_factor = (plan.oc_factor if plan.oc_factor oc_factor = (sched.oc_factor if sched.oc_factor
else wl.out_filter // vta.VTA_BLOCK_OUT) else wl.out_filter // vta.VTA_BLOCK_OUT)
h_factor = (plan.h_factor if plan.h_factor else fout_height) h_factor = (sched.h_factor if sched.h_factor else fout_height)
w_factor = (plan.w_factor if plan.w_factor else fout_width) w_factor = (sched.w_factor if sched.w_factor else fout_width)
xbo, xco, xi, xj, xbi, xci = s[res].op.axis xbo, xco, xi, xj, xbi, xci = s[res].op.axis
xco0, xco1 = s[res].split(xco, factor=oc_factor) xco0, xco1 = s[res].split(xco, factor=oc_factor)
xi0, xi1 = s[res].split(xi, factor=h_factor) xi0, xi1 = s[res].split(xi, factor=h_factor)
...@@ -137,21 +270,21 @@ def test_conv2d_chwv(key, batch_size, wl, plan, log_frame, profile=True): ...@@ -137,21 +270,21 @@ def test_conv2d_chwv(key, batch_size, wl, plan, log_frame, profile=True):
s[res_cnv].compute_at(s[res], xj0) s[res_cnv].compute_at(s[res], xj0)
s[res_shf].compute_at(s[res], xj0) s[res_shf].compute_at(s[res], xj0)
if plan.oc_nthread: if sched.oc_nthread > 1:
_, tx = s[res].split(xco0, factor=plan.oc_nthread) _, tx = s[res].split(xco0, factor=sched.oc_nthread)
s[res].reorder(tx, xbo) s[res].reorder(tx, xbo)
s[res].bind(tx, tvm.thread_axis("cthread")) s[res].bind(tx, tvm.thread_axis("cthread"))
if plan.h_nthread: if sched.h_nthread > 1:
xo, tx = s[res].split(xi0, factor=plan.h_nthread) xo, tx = s[res].split(xi0, factor=sched.h_nthread)
s[res].reorder(tx, xbo) s[res].reorder(tx, xbo)
s[res].bind(tx, tvm.thread_axis("cthread")) s[res].bind(tx, tvm.thread_axis("cthread"))
xbo, xco, xi, xj, xbi, xci = s[res_cnv].op.axis xbo, xco, xi, xj, xbi, xci = s[res_cnv].op.axis
s[res_cnv].reorder(xbo, ko, xj, dj, di, xco, xi, xbi, xci, ki) s[res_cnv].reorder(xbo, ko, xj, dj, di, xco, xi, xbi, xci, ki)
if plan.ko_factor: if sched.ko_factor:
ko0, ko1 = s[res_cnv].split(ko, factor=plan.ko_factor) ko0, ko1 = s[res_cnv].split(ko, factor=sched.ko_factor)
s[data_buf].compute_at(s[res_cnv], ko0) s[data_buf].compute_at(s[res_cnv], ko0)
s[kernel_buf].compute_at(s[res_cnv], ko0) s[kernel_buf].compute_at(s[res_cnv], ko0)
# Use VTA instructions # Use VTA instructions
...@@ -160,7 +293,7 @@ def test_conv2d_chwv(key, batch_size, wl, plan, log_frame, profile=True): ...@@ -160,7 +293,7 @@ def test_conv2d_chwv(key, batch_size, wl, plan, log_frame, profile=True):
s[res_cnv].tensorize(xbi, gemm) s[res_cnv].tensorize(xbi, gemm)
s[res_shf].pragma(s[res_shf].op.axis[0], alu) s[res_shf].pragma(s[res_shf].op.axis[0], alu)
s[res].pragma(xco1, store_out) s[res].pragma(xco1, store_out)
if plan.debug_sync: if sched.debug_sync:
s[res].pragma(xco0, "coproc_sync") s[res].pragma(xco0, "coproc_sync")
if print_ir: if print_ir:
print(tvm.lower(s, [data, kernel, res], simple_mode=True)) print(tvm.lower(s, [data, kernel, res], simple_mode=True))
...@@ -169,6 +302,7 @@ def test_conv2d_chwv(key, batch_size, wl, plan, log_frame, profile=True): ...@@ -169,6 +302,7 @@ def test_conv2d_chwv(key, batch_size, wl, plan, log_frame, profile=True):
def conv_normal(print_ir): def conv_normal(print_ir):
print("----- CONV2D End-to-End Test-------") print("----- CONV2D End-to-End Test-------")
def run_test(header, print_ir, check_correctness): def run_test(header, print_ir, check_correctness):
s = [1, sched.oc_factor, sched.ko_factor, sched.h_factor, sched.w_factor]
cost = run_schedule( cost = run_schedule(
vta.DMA_COPY, vta.DMA_COPY, vta.DMA_COPY, vta.DMA_COPY,
vta.GEMM, vta.ALU, vta.DMA_COPY, vta.GEMM, vta.ALU, vta.DMA_COPY,
...@@ -177,8 +311,27 @@ def test_conv2d_chwv(key, batch_size, wl, plan, log_frame, profile=True): ...@@ -177,8 +311,27 @@ def test_conv2d_chwv(key, batch_size, wl, plan, log_frame, profile=True):
print(header) print(header)
print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops)) print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops))
log_frame["key"].append(key) log_frame["key"].append(key)
log_frame["layer"].append(layer)
log_frame["total-data"].append(total_xfer_B)
log_frame["total-gops"].append(gops) log_frame["total-gops"].append(gops)
log_frame["total-cost"].append(cost.mean) log_frame["total-cost"].append(cost.mean)
log_frame["total-insn"].append(get_insn_count(wl, s))
log_frame["block-batch"].append(vta.VTA_BATCH)
log_frame["block-in"].append(vta.VTA_BLOCK_IN)
log_frame["block-out"].append(vta.VTA_BLOCK_OUT)
log_frame["inp-width"].append(vta.VTA_INP_WIDTH)
log_frame["wgt-width"].append(vta.VTA_WGT_WIDTH)
log_frame["uop-size"].append(vta.VTA_UOP_BUFF_SIZE)
log_frame["inp-size"].append(vta.VTA_INP_BUFF_SIZE)
log_frame["wgt-size"].append(vta.VTA_WGT_BUFF_SIZE)
log_frame["out-size"].append(vta.VTA_OUT_BUFF_SIZE)
log_frame["oc-factor"].append(sched.oc_factor)
log_frame["ic-factor"].append(sched.ko_factor)
log_frame["h-factor"].append(sched.h_factor)
log_frame["w-factor"].append(sched.w_factor)
log_frame["oc-threads"].append(sched.oc_nthread)
log_frame["h-threads"].append(sched.h_nthread)
log_frame["threaded"].append(True if sched.oc_nthread > 1 or sched.h_nthread > 1 else False)
with tvm.build_config(add_lower_pass=vta.debug_mode(0)): with tvm.build_config(add_lower_pass=vta.debug_mode(0)):
run_test("NORMAL", print_ir, True) run_test("NORMAL", print_ir, True)
...@@ -325,90 +478,63 @@ def test_conv2d_chwv(key, batch_size, wl, plan, log_frame, profile=True): ...@@ -325,90 +478,63 @@ def test_conv2d_chwv(key, batch_size, wl, plan, log_frame, profile=True):
load_wgt_unittest(False) load_wgt_unittest(False)
store_out_unittest(False) store_out_unittest(False)
# Perform profiling
profile = False
# Data set batch size
batch_size = 1
# Use multi-threading for latency hiding
multi_threaded = False
# ResNet18 workloads # ResNet18 workloads
resnet = { resnet = {
# Workloads of resnet18 on imagenet # Workloads of resnet18 on imagenet
0: Workload(224, 224, 16, 64, 7, 7, 3, 3, 2, 2), 0: Workload(1, 224, 224, 16, 64, 7, 7, 3, 3, 2, 2),
1: Workload(56, 56, 64, 64, 3, 3, 1, 1, 1, 1), 1: Workload(1, 56, 56, 64, 64, 3, 3, 1, 1, 1, 1),
2: Workload(56, 56, 64, 64, 1, 1, 0, 0, 1, 1), 2: Workload(1, 56, 56, 64, 64, 1, 1, 0, 0, 1, 1),
3: Workload(56, 56, 64, 128, 3, 3, 1, 1, 2, 2), 3: Workload(1, 56, 56, 64, 128, 3, 3, 1, 1, 2, 2),
4: Workload(56, 56, 64, 128, 1, 1, 0, 0, 2, 2), 4: Workload(1, 56, 56, 64, 128, 1, 1, 0, 0, 2, 2),
5: Workload(28, 28, 128, 128, 3, 3, 1, 1, 1, 1), 5: Workload(1, 28, 28, 128, 128, 3, 3, 1, 1, 1, 1),
6: Workload(28, 28, 128, 256, 3, 3, 1, 1, 2, 2), 6: Workload(1, 28, 28, 128, 256, 3, 3, 1, 1, 2, 2),
7: Workload(28, 28, 128, 256, 1, 1, 0, 0, 2, 2), 7: Workload(1, 28, 28, 128, 256, 1, 1, 0, 0, 2, 2),
8: Workload(14, 14, 256, 256, 3, 3, 1, 1, 1, 1), 8: Workload(1, 14, 14, 256, 256, 3, 3, 1, 1, 1, 1),
9: Workload(14, 14, 256, 512, 3, 3, 1, 1, 2, 2), 9: Workload(1, 14, 14, 256, 512, 3, 3, 1, 1, 2, 2),
10: Workload(14, 14, 256, 512, 1, 1, 0, 0, 2, 2), 10: Workload(1, 14, 14, 256, 512, 1, 1, 0, 0, 2, 2),
11: Workload(7, 7, 512, 512, 3, 3, 1, 1, 1, 1), 11: Workload(1, 7, 7, 512, 512, 3, 3, 1, 1, 1, 1),
} }
# List of simple benchmarks begin = 0
simple = [ end = len(resnet)
Workload(height=22, width=22, in_filter=256, out_filter=64,
hkernel=3, wkernel=3, hpad=1, wpad=1, hstride=1, wstride=1)
]
# Serial schedule
resnet_serial = [
[None, None],
[resnet[1], Schedule(oc_factor=2, ko_factor=1, h_factor=14, w_factor=0)],
[resnet[2], Schedule(oc_factor=4, ko_factor=4, h_factor=8, w_factor=0)],
[resnet[3], Schedule(oc_factor=4, ko_factor=1, h_factor=14, w_factor=0)],
[resnet[4], Schedule(oc_factor=8, ko_factor=1, h_factor=4, w_factor=0)],
[resnet[5], Schedule(oc_factor=8, ko_factor=1, h_factor=7, w_factor=0)],
[resnet[6], Schedule(oc_factor=8, ko_factor=1, h_factor=14, w_factor=0)],
[resnet[7], Schedule(oc_factor=16, ko_factor=1, h_factor=7, w_factor=0)],
[resnet[8], Schedule(oc_factor=8, ko_factor=1, h_factor=7, w_factor=0)],
[resnet[9], Schedule(oc_factor=8, ko_factor=1, h_factor=7, w_factor=0)],
[resnet[10], Schedule(oc_factor=16, ko_factor=1, h_factor=7, w_factor=0)],
[resnet[11], Schedule(oc_factor=8, ko_factor=1, h_factor=7, w_factor=0)],
]
# SMT schedule
resnet_smt = [
[resnet[0], Schedule(oc_factor=1, ko_factor=1, h_factor=4, w_factor=56)],
[resnet[1], Schedule(oc_factor=2, ko_factor=1, h_factor=7, h_nthread=2)],
[resnet[2], Schedule(oc_factor=4, ko_factor=2, h_factor=4, w_factor=0, h_nthread=2)],
[resnet[3], Schedule(oc_factor=4, ko_factor=1, h_factor=7, w_factor=0, h_nthread=2)],
[resnet[4], Schedule(oc_factor=4, ko_factor=1, h_factor=7, h_nthread=2)],
[resnet[5], Schedule(oc_factor=4, ko_factor=1, h_factor=7, w_factor=0, h_nthread=2)],
[resnet[6], Schedule(oc_factor=4, ko_factor=1, h_factor=7, w_factor=0, oc_nthread=2)],
[resnet[7], Schedule(oc_factor=8, ko_factor=1, h_factor=7, w_factor=0, oc_nthread=2)],
[resnet[8], Schedule(oc_factor=4, ko_factor=1, h_factor=7, w_factor=0, oc_nthread=2)],
[resnet[9], Schedule(oc_factor=4, ko_factor=1, h_factor=7, w_factor=0, oc_nthread=2)],
[resnet[10], Schedule(oc_factor=8, ko_factor=1, h_factor=7, w_factor=0, oc_nthread=2)],
[resnet[11], Schedule(oc_factor=4, ko_factor=1, h_factor=7, w_factor=0, oc_nthread=2)],
]
# Perform profiling resnet_schedules = []
profile = False for i in range(begin, end):
# Whether use SMT scheds = find_schedules(resnet[i], mtOnly=multi_threaded, bestOnly=True)
use_smt = True resnet_schedules.append([i, scheds])
# Data set batch size
batch_size = 1
resnet_schedule = resnet_smt if use_smt else resnet_serial keys = ["key", "layer", "total-data", "total-gops", "total-cost", "total-insn",
"block-batch", "block-in", "block-out", "wgt-width", "inp-width",
"uop-size", "inp-size", "wgt-size", "out-size",
"oc-factor", "ic-factor", "h-factor", "w-factor",
"oc-threads", "h-threads", "threaded"]
begin = 0 if profile:
end = len(resnet_schedule) keys += ["skip-alu-gops", "skip-alu-cost",
keys = ["key", "total-gops", "total-cost",
"skip-alu-gops", "skip-alu-cost",
"gemm-gops", "gemm-cost", "alu-gops", "alu-cost", "gemm-gops", "gemm-cost", "alu-gops", "alu-cost",
"ld-inp-cost", "ld-wgt-cost", "st-out-cost", "ld-inp-cost", "ld-wgt-cost", "st-out-cost",
"ld-inp-gbits", "ld-wgt-gbits", "st-out-gbits",] "ld-inp-gbits", "ld-wgt-gbits", "st-out-gbits"]
log_frame = { log_frame = {
k : [] for k in keys k : [] for k in keys
} }
for i, x in enumerate(resnet_schedule[begin:end]):
wl, plan = x
if not wl:
continue
key = "resnet-cfg[%d]" % i
test_conv2d_chwv(key, batch_size, wl, plan, log_frame, profile)
if profile: for x in resnet_schedules:
pd.set_option('expand_frame_repr', False) l, plans = x
log_df = pd.DataFrame() for plan in plans:
for k in keys: key = "resnet-cfg[{}-{}]".format(l, plan)
test_conv2d_chwv(l, key, batch_size, resnet[l], plan, log_frame, profile)
pd.set_option('expand_frame_repr', False)
log_df = pd.DataFrame()
for k in keys:
log_df[k] = log_frame[k] log_df[k] = log_frame[k]
print(log_df) print(log_df)
log_df.to_csv("conv2d.csv")
...@@ -6,10 +6,14 @@ from tvm.contrib import rpc, util ...@@ -6,10 +6,14 @@ from tvm.contrib import rpc, util
host = "pynq" host = "pynq"
port = 9091 port = 9091
target = "llvm -target=armv7-none-linux-gnueabihf" target = "llvm -target=armv7-none-linux-gnueabihf"
bit = "vta.bit" bit = "{}x{}x{}_{}bx{}b_{}_{}_{}_{}_100MHz_10ns.bit".format(
vta.VTA_BATCH, vta.VTA_BLOCK_IN, vta.VTA_BLOCK_OUT,
vta.VTA_INP_WIDTH, vta.VTA_WGT_WIDTH,
vta.VTA_LOG_UOP_BUFF_SIZE, vta.VTA_LOG_INP_BUFF_SIZE,
vta.VTA_LOG_WGT_BUFF_SIZE, vta.VTA_LOG_OUT_BUFF_SIZE)
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
bitstream = os.path.join(curr_path, "./", bit) bitstream = os.path.join(curr_path, "../../../../vta_bitstreams/bitstreams/", bit)
def test_program_rpc(): def test_program_rpc():
assert tvm.module.enabled("rpc") assert tvm.module.enabled("rpc")
......
...@@ -480,8 +480,8 @@ if __name__ == "__main__": ...@@ -480,8 +480,8 @@ if __name__ == "__main__":
test_padded_load() test_padded_load()
print("Load/store test") print("Load/store test")
test_save_load_out() test_save_load_out()
print("GEMM test") # print("GEMM test")
test_gemm() # test_gemm()
print("Max immediate") print("Max immediate")
test_alu(tvm.max, np.maximum, use_imm=True) test_alu(tvm.max, np.maximum, use_imm=True)
print("Max") print("Max")
...@@ -492,6 +492,8 @@ if __name__ == "__main__": ...@@ -492,6 +492,8 @@ if __name__ == "__main__":
test_alu(lambda x, y: x + y) test_alu(lambda x, y: x + y)
print("Shift right immediate") print("Shift right immediate")
test_alu(lambda x, y: x >> y, np.right_shift, use_imm=True) test_alu(lambda x, y: x >> y, np.right_shift, use_imm=True)
print("Shift left immediate")
test_alu(lambda x, y: x << y, np.left_shift, use_imm=True)
print("Relu") print("Relu")
test_relu() test_relu()
# print("Shift and scale") # print("Shift and scale")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment