Commit 5c5806ba by Tianqi Chen

[INFRASTRUCTURE] Migrate to json based config. Move gemm test to integration. (#28)

* Migrate to json based config. Move gemm test to integration.

* temp checkin

* checkin  example json
parent 666f32d6
ROOTDIR = $(CURDIR)
ifndef config
ifneq ("$(wildcard ./config.mk)", "")
config = config.mk
else
config = make/config.mk
endif
endif
include $(config)
export LDFLAGS = -pthread -lm
export CFLAGS = -std=c++11 -Wall -O2 -Iinclude -fPIC
ifdef NNVM_PATH
CFLAGS += -I$(NNVM_PATH)/include
else
NNVM_PATH = $(ROOTDIR)/nnvm
CFLAGS += -I$(NNVM_PATH)/include
endif
ifdef TVM_PATH
CFLAGS += -I$(TVM_PATH)/include -I$(TVM_PATH)/dlpack/include -I$(TVM_PATH)/HalideIR/src
else
TVM_PATH = $(NNVM_PATH)/tvm
CFLAGS += -I$(TVM_PATH)/include -I$(TVM_PATH)/dlpack/include -I$(TVM_PATH)/HalideIR/src
endif
ifdef DMLC_CORE_PATH
CFLAGS += -I$(DMLC_CORE_PATH)/include
else
CFLAGS += -I$(NNVM_PATH)/dmlc-core/include
endif
ifneq ($(ADD_CFLAGS), NONE)
CFLAGS += $(ADD_CFLAGS)
endif
ifneq ($(ADD_LDFLAGS), NONE)
LDFLAGS += $(ADD_LDFLAGS)
endif
VTA_CONFIG = python make/vta_config.py
CFLAGS += `${VTA_CONFIG} --cflags`
LDFLAGS += `${VTA_CONFIG} --ldflags`
VTA_TARGET := $(shell ${VTA_CONFIG} --target)
UNAME_S := $(shell uname -s)
......@@ -53,29 +21,30 @@ else
NO_WHOLE_ARCH= --no-whole-archive
endif
VTA_LIB_SRC = $(wildcard src/*.cc src/tvm/*.cc)
ifeq ($(VTA_TARGET), pynq)
VTA_LIB_SRC = $(wildcard src/*.cc)
ifeq (${VTA_TARGET}, pynq)
VTA_LIB_SRC += $(wildcard src/pynq/*.cc)
LDFLAGS += -L/usr/lib -lsds_lib
LDFLAGS += -L/opt/python3.6/lib/python3.6/site-packages/pynq/drivers/
LDFLAGS += -L/opt/python3.6/lib/python3.6/site-packages/pynq/lib/
LDFLAGS += -l:libdma.so
endif
ifeq ($(VTA_TARGET), sim)
ifeq (${VTA_TARGET}, sim)
VTA_LIB_SRC += $(wildcard src/sim/*.cc)
endif
VTA_LIB_OBJ = $(patsubst src/%.cc, build/%.o, $(VTA_LIB_SRC))
all: lib/libvta.so
all: lib/libvta.so lib/libvta.so.json
build/%.o: src/%.cc
@mkdir -p $(@D)
$(CXX) $(CFLAGS) -MM -MT build/$*.o $< >build/$*.d
$(CXX) -c $(CFLAGS) -c $< -o $@
lib/libvta.so.json: lib/libvta.so
@mkdir -p $(@D)
${VTA_CONFIG} --cfg-json > $@
lib/libvta.so: $(VTA_LIB_OBJ)
@mkdir -p $(@D)
$(CXX) $(CFLAGS) -shared -o $@ $(filter %.o, $^) $(LDFLAGS)
......
# VTA Configuration
Each VTA runtime/hardware configuration is specified by config.json file.
You can copy the config.json to project root and modify the configuration
before you type make.
The config is going to affect the behavior of python package as well as
the hardware runtime build.
\ No newline at end of file
{
"TARGET" : "pynq",
"LOG_INP_WIDTH" : 3,
"LOG_WGT_WIDTH" : 3,
"LOG_ACC_WIDTH" : 5,
"LOG_OUT_WIDTH" : 3,
"LOG_BATCH" : 0,
"LOG_BLOCK_IN" : 4,
"LOG_BLOCK_OUT" : 4,
"LOG_UOP_BUFF_SIZE" : 15,
"LOG_INP_BUFF_SIZE" : 15,
"LOG_WGT_BUFF_SIZE" : 15,
"LOG_ACC_BUFF_SIZE" : 17
}
#-------------------------------------------------------------------------------
# Template configuration for compiling VTA runtime.
#
# If you want to change the configuration, please use the following
# steps. Assume you are on the root directory of nnvm. First copy the this
# file so that any local changes will be ignored by git
#
# $ cp make/config.mk .
#
# Next modify the according entries, and then compile by
#
# $ make
#
# or build in parallel with 8 threads
#
# $ make -j8
#-------------------------------------------------------------------------------
#---------------------
# choice of compiler
#--------------------
# the additional link flags you want to add
ADD_LDFLAGS=
# the additional compile flags you want to add
ADD_CFLAGS=
# the hardware target, can be [sim, pynq]
VTA_TARGET = pynq
#---------------------
# VTA hardware parameters
#--------------------
# Log of input/activation width in bits (default 3 -> 8 bits)
VTA_LOG_INP_WIDTH = 3
# Log of kernel weight width in bits (default 3 -> 8 bits)
VTA_LOG_WGT_WIDTH = 3
# Log of accum width in bits (default 5 -> 32 bits)
VTA_LOG_ACC_WIDTH = 5
# Log of tensor batch size (A in (A,B)x(B,C) matrix multiplication)
VTA_LOG_BATCH = 0
# Log of tensor inner block size (B in (A,B)x(B,C) matrix multiplication)
VTA_LOG_BLOCK_IN = 4
# Log of tensor outer block size (C in (A,B)x(B,C) matrix multiplication)
VTA_LOG_BLOCK_OUT = 4
# Log of uop buffer size in Bytes
VTA_LOG_UOP_BUFF_SIZE = 15
# Log of inp buffer size in Bytes
VTA_LOG_INP_BUFF_SIZE = 15
# Log of wgt buffer size in Bytes
VTA_LOG_WGT_BUFF_SIZE = 15
# Log of acc buffer size in Bytes
VTA_LOG_ACC_BUFF_SIZE = 17
#---------------------
# Derived VTA hardware parameters
#--------------------
# Input width in bits
VTA_INP_WIDTH = $(shell echo "$$(( 1 << $(VTA_LOG_INP_WIDTH) ))" )
# Weight width in bits
VTA_WGT_WIDTH = $(shell echo "$$(( 1 << $(VTA_LOG_WGT_WIDTH) ))" )
# Log of output width in bits
VTA_LOG_OUT_WIDTH = $(VTA_LOG_INP_WIDTH)
# Output width in bits
VTA_OUT_WIDTH = $(shell echo "$$(( 1 << $(VTA_LOG_OUT_WIDTH) ))" )
# Tensor batch size
VTA_BATCH = $(shell echo "$$(( 1 << $(VTA_LOG_BATCH) ))" )
# Tensor outer block size
VTA_IN_BLOCK = $(shell echo "$$(( 1 << $(VTA_LOG_BLOCK_IN) ))" )
# Tensor inner block size
VTA_OUT_BLOCK = $(shell echo "$$(( 1 << $(VTA_LOG_BLOCK_OUT) ))" )
# Uop buffer size in Bytes
VTA_UOP_BUFF_SIZE = $(shell echo "$$(( 1 << $(VTA_LOG_UOP_BUFF_SIZE) ))" )
# Inp buffer size in Bytes
VTA_INP_BUFF_SIZE = $(shell echo "$$(( 1 << $(VTA_LOG_INP_BUFF_SIZE) ))" )
# Wgt buffer size in Bytes
VTA_WGT_BUFF_SIZE = $(shell echo "$$(( 1 << $(VTA_LOG_WGT_BUFF_SIZE) ))" )
# Acc buffer size in Bytes
VTA_ACC_BUFF_SIZE = $(shell echo "$$(( 1 << $(VTA_LOG_ACC_BUFF_SIZE) ))" )
# Log of out buffer size in Bytes
VTA_LOG_OUT_BUFF_SIZE = \
$(shell echo "$$(( $(VTA_LOG_ACC_BUFF_SIZE) + $(VTA_LOG_OUT_WIDTH) - $(VTA_LOG_ACC_WIDTH) ))" )
# Out buffer size in Bytes
VTA_OUT_BUFF_SIZE = $(shell echo "$$(( 1 << $(VTA_LOG_OUT_BUFF_SIZE) ))" )
# Update ADD_CFLAGS
ADD_CFLAGS +=
-DVTA_TARGET=$(VTA_TARGET)\
-DVTA_LOG_WGT_WIDTH=$(VTA_LOG_WGT_WIDTH) -DVTA_LOG_INP_WIDTH=$(VTA_LOG_INP_WIDTH) \
-DVTA_LOG_ACC_WIDTH=$(VTA_LOG_ACC_WIDTH) -DVTA_LOG_OUT_WIDTH=$(VTA_LOG_OUT_WIDTH) \
-DVTA_LOG_BATCH=$(VTA_LOG_BATCH) \
-DVTA_LOG_BLOCK_IN=$(VTA_LOG_BLOCK_IN) -DVTA_LOG_BLOCK_OUT=$(VTA_LOG_BLOCK_OUT) \
-DVTA_LOG_UOP_BUFF_SIZE=$(VTA_LOG_UOP_BUFF_SIZE) -DVTA_LOG_INP_BUFF_SIZE=$(VTA_LOG_INP_BUFF_SIZE) \
-DVTA_LOG_WGT_BUFF_SIZE=$(VTA_LOG_WGT_BUFF_SIZE) -DVTA_LOG_ACC_BUFF_SIZE=$(VTA_LOG_ACC_BUFF_SIZE) \
-DVTA_LOG_OUT_BUFF_SIZE=$(VTA_LOG_OUT_BUFF_SIZE)
"""VTA config tool"""
import os
import sys
import json
import argparse
def get_pkg_config(cfg):
"""Get the pkg config object."""
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
proj_root = os.path.abspath(os.path.join(curr_path, "../"))
pkg_config_py = os.path.join(proj_root, "python/vta/pkg_config.py")
libpkg = {"__file__": pkg_config_py}
exec(compile(open(pkg_config_py, "rb").read(), pkg_config_py, "exec"), libpkg, libpkg)
PkgConfig = libpkg["PkgConfig"]
return PkgConfig(cfg, proj_root)
def main():
"""Main funciton"""
parser = argparse.ArgumentParser()
parser.add_argument("--cflags", action="store_true",
help="print the cflags")
parser.add_argument("--update", action="store_true",
help="Print out the json option.")
parser.add_argument("--ldflags", action="store_true",
help="print the cflags")
parser.add_argument("--cfg-json", action="store_true",
help="print all the config json")
parser.add_argument("--target", action="store_true",
help="print the target")
args = parser.parse_args()
if len(sys.argv) == 1:
parser.print_help()
return
curr_path = os.path.dirname(
os.path.abspath(os.path.expanduser(__file__)))
proj_root = os.path.abspath(os.path.join(curr_path, "../"))
path_list = [
os.path.join(proj_root, "config.json"),
os.path.join(proj_root, "make/config.json")
]
ok_path_list = [p for p in path_list if os.path.exists(p)]
if not ok_path_list:
raise RuntimeError("Cannot find config in %s" % str(path_list))
cfg = json.load(open(ok_path_list[0]))
cfg["LOG_OUT_WIDTH"] = cfg["LOG_INP_WIDTH"]
pkg = get_pkg_config(cfg)
if args.target:
print(pkg.target)
if args.cflags:
print(" ".join(pkg.cflags))
if args.ldflags:
print(" ".join(pkg.ldflags))
if args.cfg_json:
print(pkg.cfg_json)
if __name__ == "__main__":
main()
"""TVM-based VTA Compiler Toolchain"""
from __future__ import absolute_import as _abs
import sys
from .environment import get_env, Environment
from . import arm_conv2d, vta_conv2d
from .build_module import build_config, lower, build
from .rpc_client import reconfig_runtime, program_fpga
try:
# allow optional import in config mode.
from . import arm_conv2d, vta_conv2d
from .build_module import build_config, lower, build
from .rpc_client import reconfig_runtime, program_fpga
from . import graph
except (ImportError, RuntimeError):
pass
......@@ -3,10 +3,11 @@
from __future__ import absolute_import as _abs
import os
import glob
import json
import copy
import tvm
from . import intrin
from .pkg_config import PkgConfig
class DevContext(object):
......@@ -61,45 +62,6 @@ class DevContext(object):
return 1 if self.DEBUG_NO_SYNC else qid
class PkgConfig(object):
"""Simple package config tool for VTA.
This is used to provide runtime specific configurations.
"""
def __init__(self, env):
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
proj_root = os.path.abspath(os.path.join(curr_path, "../../"))
# include path
self.include_path = [
"-I%s/include" % proj_root,
"-I%s/nnvm/tvm/include" % proj_root,
"-I%s/nnvm/tvm/dlpack/include" % proj_root,
"-I%s/nnvm/dmlc-core/include" % proj_root
]
# List of source files that can be used to build standalone library.
self.lib_source = []
self.lib_source += glob.glob("%s/src/*.cc" % proj_root)
self.lib_source += glob.glob("%s/src/%s/*.cc" % (proj_root, env.TARGET))
# macro keys
self.macro_defs = []
for key in env.cfg_keys:
self.macro_defs.append("-DVTA_%s=%s" % (key, str(getattr(env, key))))
if env.TARGET == "pynq":
self.ldflags = [
"-L/usr/lib",
"-lsds_lib",
"-L/opt/python3.6/lib/python3.6/site-packages/pynq/drivers/",
"-L/opt/python3.6/lib/python3.6/site-packages/pynq/lib/",
"-l:libdma.so"]
else:
self.ldflags = []
@property
def cflags(self):
return self.include_path + self.macro_defs
class Environment(object):
"""Hareware configuration object.
......@@ -123,19 +85,6 @@ class Environment(object):
env = vta.get_env()
"""
current = None
cfg_keys = [
"TARGET",
"LOG_INP_WIDTH",
"LOG_WGT_WIDTH",
"LOG_ACC_WIDTH",
"LOG_BATCH",
"LOG_BLOCK_IN",
"LOG_BLOCK_OUT",
"LOG_UOP_BUFF_SIZE",
"LOG_INP_BUFF_SIZE",
"LOG_WGT_BUFF_SIZE",
"LOG_ACC_BUFF_SIZE",
]
# constants
MAX_XFER = 1 << 22
# debug flags
......@@ -152,10 +101,9 @@ class Environment(object):
def __init__(self, cfg):
# Log of input/activation width in bits
self.__dict__.update(cfg)
for key in self.cfg_keys:
for key in PkgConfig.cfg_keys:
if key not in cfg:
raise ValueError("Expect key %s in cfg" % key)
self.LOG_OUT_WIDTH = self.LOG_INP_WIDTH
self.LOG_OUT_BUFF_SIZE = (
self.LOG_ACC_BUFF_SIZE +
self.LOG_OUT_WIDTH -
......@@ -195,7 +143,12 @@ class Environment(object):
self.mock_mode = False
self._mock_env = None
self._dev_ctx = None
self._pkg_config = None
def pkg_config(self):
"""PkgConfig instance"""
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
proj_root = os.path.abspath(os.path.join(curr_path, "../../"))
return PkgConfig(self.__dict__, proj_root)
@property
def dev(self):
......@@ -205,13 +158,6 @@ class Environment(object):
return self._dev_ctx
@property
def pkg_config(self):
"""PkgConfig instance"""
if self._pkg_config is None:
self._pkg_config = PkgConfig(self)
return self._pkg_config
@property
def mock(self):
"""A mock version of the Environment
......@@ -327,30 +273,19 @@ def coproc_dep_pop(op):
def _init_env():
"""Iniitalize the default global env"""
python_vta_dir = os.path.dirname(__file__)
filename = os.path.join(python_vta_dir, '../../config.mk')
keys = set()
for k in Environment.cfg_keys:
keys.add("VTA_" + k)
if not os.path.isfile(filename):
curr_path = os.path.dirname(
os.path.abspath(os.path.expanduser(__file__)))
proj_root = os.path.abspath(os.path.join(curr_path, "../../"))
path_list = [
os.path.join(curr_path, "config.json"),
os.path.join(proj_root, "config.json"),
os.path.join(proj_root, "make/config.json")
]
path_list = [p for p in path_list if os.path.exists(p)]
if not path_list:
raise RuntimeError(
"Error: {} not found.make sure you have config.mk in your vta root"
"Error: {} not found.make sure you have config.json in your vta root"
.format(filename))
cfg = {}
with open(filename) as f:
for line in f:
for k in keys:
if k +" =" in line:
val = line.split("=")[1].strip()
if k.startswith("VTA_"):
k = k[4:]
try:
cfg[k] = int(val)
except ValueError:
cfg[k] = val
return Environment(cfg)
return Environment(json.load(open(path_list[0])))
Environment.current = _init_env()
......@@ -8,11 +8,13 @@ import logging
import argparse
import os
import ctypes
import json
import tvm
from tvm._ffi.base import c_str
from tvm.contrib import rpc, cc
from ..environment import get_env
from ..pkg_config import PkgConfig
@tvm.register_func("tvm.contrib.rpc.server.start", override=True)
......@@ -21,8 +23,9 @@ def server_start():
# pylint: disable=unused-variable
curr_path = os.path.dirname(
os.path.abspath(os.path.expanduser(__file__)))
dll_path = os.path.abspath(
os.path.join(curr_path, "../../../lib/libvta.so"))
proj_root = os.path.abspath(os.path.join(curr_path, "../../.."))
dll_path = os.path.abspath(os.path.join(proj_root, "lib/libvta.so"))
cfg_path = os.path.abspath(os.path.join(proj_root, "lib/libvta.so.json"))
runtime_dll = []
_load_module = tvm.get_global_func("tvm.contrib.rpc.server.load_module")
......@@ -56,26 +59,36 @@ def server_start():
runtime_dll.pop()
@tvm.register_func("tvm.contrib.vta.reconfig_runtime", override=True)
def reconfig_runtime(cflags):
def reconfig_runtime(cfg_json):
"""Rebuild and reload runtime with new configuration.
Parameters
----------
cfg_json : str
JSON string used for configurations.
JSON string used for configurations.
"""
if runtime_dll:
raise RuntimeError("Can only reconfig in the beginning of session...")
cflags = cflags.split()
env = get_env()
cfg = json.loads(cfg_json)
cfg["TARGET"] = env.TARGET
pkg = PkgConfig(cfg, proj_root)
# check if the configuration is already the same
if os.path.isfile(cfg_path):
old_cfg = json.load(open(cfg_path))
if pkg.same_config(old_cfg):
logging.info("Skip reconfiguration because runtime config is the same")
return
cflags += ["-O2", "-std=c++11"]
cflags += env.pkg_config.include_path
ldflags = env.pkg_config.ldflags
cflags += pkg.cflags
ldflags = pkg.ldflags
lib_name = dll_path
source = env.pkg_config.lib_source
logging.info("Rebuild runtime: output=%s, cflags=%s, source=%s, ldflags=%s",
dll_path, str(cflags), str(source), str(ldflags))
cc.create_shared(lib_name, source, cflags + ldflags)
with open(cfg_path, "w") as outputfile:
json.dump(pkg.cfg_json, outputfile)
def main():
......
"""VTA Package configuration module
This module is dependency free and can be used to configure package.
"""
from __future__ import absolute_import as _abs
import json
import glob
class PkgConfig(object):
"""Simple package config tool for VTA.
This is used to provide runtime specific configurations.
Parameters
----------
cfg : dict
The config dictionary
proj_root : str
Path to the project root
"""
cfg_keys = [
"TARGET",
"LOG_INP_WIDTH",
"LOG_WGT_WIDTH",
"LOG_ACC_WIDTH",
"LOG_OUT_WIDTH",
"LOG_BATCH",
"LOG_BLOCK_IN",
"LOG_BLOCK_OUT",
"LOG_UOP_BUFF_SIZE",
"LOG_INP_BUFF_SIZE",
"LOG_WGT_BUFF_SIZE",
"LOG_ACC_BUFF_SIZE",
]
def __init__(self, cfg, proj_root):
# include path
self.include_path = [
"-I%s/include" % proj_root,
"-I%s/nnvm/tvm/include" % proj_root,
"-I%s/nnvm/tvm/dlpack/include" % proj_root,
"-I%s/nnvm/dmlc-core/include" % proj_root
]
# List of source files that can be used to build standalone library.
self.lib_source = []
self.lib_source += glob.glob("%s/src/*.cc" % proj_root)
self.lib_source += glob.glob("%s/src/%s/*.cc" % (proj_root, cfg["TARGET"]))
# macro keys
self.macro_defs = []
self.cfg_dict = {}
for key in self.cfg_keys:
self.macro_defs.append("-DVTA_%s=%s" % (key, str(cfg[key])))
self.cfg_dict[key] = cfg[key]
self.target = cfg["TARGET"]
if self.target == "pynq":
self.ldflags = [
"-L/usr/lib",
"-lsds_lib",
"-L/opt/python3.6/lib/python3.6/site-packages/pynq/drivers/",
"-L/opt/python3.6/lib/python3.6/site-packages/pynq/lib/",
"-l:libdma.so"]
else:
self.ldflags = []
@property
def cflags(self):
return self.include_path + self.macro_defs
@property
def cfg_json(self):
return json.dumps(self.cfg_dict, indent=2)
def same_config(self, cfg):
"""Compare if cfg is same as current config.
Parameters
----------
cfg : the configuration
The configuration
Returns
-------
equal : bool
Whether the configuration is the same.
"""
for k, v in self.cfg_dict.items():
if k not in cfg:
return False
if cfg[k] != v:
return False
return True
......@@ -12,23 +12,8 @@ def reconfig_runtime(remote):
The TVM RPC session
"""
env = get_env()
keys = ["VTA_LOG_WGT_WIDTH",
"VTA_LOG_INP_WIDTH",
"VTA_LOG_ACC_WIDTH",
"VTA_LOG_OUT_WIDTH",
"VTA_LOG_BATCH",
"VTA_LOG_BLOCK_IN",
"VTA_LOG_BLOCK_OUT",
"VTA_LOG_UOP_BUFF_SIZE",
"VTA_LOG_INP_BUFF_SIZE",
"VTA_LOG_WGT_BUFF_SIZE",
"VTA_LOG_ACC_BUFF_SIZE",
"VTA_LOG_OUT_BUFF_SIZE"]
cflags = []
for k in keys:
cflags += ["-D%s=%s" % (k, str(getattr(env, k[4:])))]
freconfig = remote.get_function("tvm.contrib.vta.reconfig_runtime")
freconfig(" ".join(cflags))
freconfig(env.pkg_config().cfg_json)
def program_fpga(remote, bitstream):
......
......@@ -109,9 +109,12 @@ class DRAM {
* \return The true virtual address;
*/
void* GetAddr(uint64_t phy_addr) {
CHECK_NE(phy_addr, 0)
<< "trying to get address that is nullptr";
std::lock_guard<std::mutex> lock(mutex_);
uint64_t loc = (phy_addr >> kPageBits) - 1;
CHECK_LT(loc, ptable_.size());
CHECK_LT(loc, ptable_.size())
<< "phy_addr=" << phy_addr;
Page* p = ptable_[loc];
CHECK(p != nullptr);
size_t offset = (loc - p->ptable_begin) << kPageBits;
......@@ -173,7 +176,7 @@ class DRAM {
private:
// The bits in page table
static constexpr vta_phy_addr_t kPageBits = 16;
static constexpr vta_phy_addr_t kPageBits = 12;
// page size, also the maximum allocable size 16 K
static constexpr vta_phy_addr_t kPageSize = 1 << kPageBits;
/*! \brief A page in the DRAM */
......@@ -388,6 +391,7 @@ class Device {
}
void RunStore(const VTAMemInsn* op) {
if (op->x_size == 0) return;
if (op->memory_type == VTA_MEM_ID_ACC ||
op->memory_type == VTA_MEM_ID_UOP) {
prof_->out_store_nbytes += (
......
import os
import tvm
import vta
import numpy as np
import time
from tvm.contrib import rpc, util
from tvm.contrib import util
import vta.testing
host = "pynq"
port = 9091
target = "llvm -target=armv7-none-linux-gnueabihf"
def test_gemm():
def run_gemm_packed(env, remote, batch_size, channel, block):
data_shape = (batch_size // env.BATCH,
channel // env.BLOCK_IN,
env.BATCH,
env.BLOCK_IN)
weight_shape = (channel // env.BLOCK_OUT,
channel // env.BLOCK_IN,
env.BLOCK_OUT,
env.BLOCK_IN)
res_shape = (batch_size // env.BATCH,
channel // env.BLOCK_OUT,
env.BATCH,
env.BLOCK_OUT)
num_ops = channel * channel * batch_size
def test_gemm_packed(batch_size, channel, block):
env = vta.get_env()
data_shape = (batch_size // env.BATCH,
channel // env.BLOCK_IN,
env.BATCH,
env.BLOCK_IN)
weight_shape = (channel // env.BLOCK_OUT,
channel // env.BLOCK_IN,
env.BLOCK_OUT,
env.BLOCK_IN)
res_shape = (batch_size // env.BATCH,
channel // env.BLOCK_OUT,
env.BATCH,
env.BLOCK_OUT)
num_ops = channel * channel * batch_size
ko = tvm.reduce_axis((0, channel // env.BLOCK_IN), name='ko')
ki = tvm.reduce_axis((0, env.BLOCK_IN), name='ki')
ko = tvm.reduce_axis((0, channel // env.BLOCK_IN), name='ko')
ki = tvm.reduce_axis((0, env.BLOCK_IN), name='ki')
data = tvm.placeholder(data_shape,
name="data",
dtype=env.inp_dtype)
weight = tvm.placeholder(weight_shape,
name="weight",
dtype=env.wgt_dtype)
data_buf = tvm.compute(data_shape,
lambda *i: data(*i),
"data_buf")
weight_buf = tvm.compute(weight_shape,
lambda *i: weight(*i),
"weight_buf")
res_gem = tvm.compute(res_shape,
lambda bo, co, bi, ci: tvm.sum(
data_buf[bo, ko, bi, ki].astype(env.acc_dtype) *
weight_buf[co, ko, ci, ki].astype(env.acc_dtype),
axis=[ko, ki]),
name="res_gem")
res_shf = tvm.compute(res_shape,
lambda *i: res_gem(*i)>>8,
name="res_shf")
res_max = tvm.compute(res_shape,
lambda *i: tvm.max(res_shf(*i), 0),
"res_max") #relu
res_min = tvm.compute(res_shape,
lambda *i: tvm.min(res_max(*i), (1<<(env.INP_WIDTH-1))-1),
"res_min") #relu
res = tvm.compute(res_shape,
lambda *i: res_min(*i).astype(env.inp_dtype),
name="res")
data = tvm.placeholder(data_shape,
name="data",
dtype=env.inp_dtype)
weight = tvm.placeholder(weight_shape,
name="weight",
dtype=env.wgt_dtype)
data_buf = tvm.compute(data_shape,
lambda *i: data(*i),
"data_buf")
weight_buf = tvm.compute(weight_shape,
lambda *i: weight(*i),
"weight_buf")
res_gem = tvm.compute(res_shape,
lambda bo, co, bi, ci: tvm.sum(
data_buf[bo, ko, bi, ki].astype(env.acc_dtype) *
weight_buf[co, ko, ci, ki].astype(env.acc_dtype),
axis=[ko, ki]),
name="res_gem")
res_shf = tvm.compute(res_shape,
lambda *i: res_gem(*i)>>8,
name="res_shf")
res_max = tvm.compute(res_shape,
lambda *i: tvm.max(res_shf(*i), 0),
"res_max") #relu
res_min = tvm.compute(res_shape,
lambda *i: tvm.min(res_max(*i), (1<<(env.INP_WIDTH-1))-1),
"res_min") #relu
res = tvm.compute(res_shape,
lambda *i: res_min(*i).astype(env.inp_dtype),
name="res")
def verify(s, check_correctness=True):
mod = vta.build(s, [data, weight, res],
"ext_dev", env.target_host, name="gemm")
temp = util.tempdir()
mod.save(temp.relpath("gemm.o"))
remote.upload(temp.relpath("gemm.o"))
f = remote.load_module("gemm.o")
# verify
ctx = remote.ext_dev(0)
# Data in original format
data_orig = np.random.randint(
-128, 128, size=(batch_size, channel)).astype(data.dtype)
weight_orig = np.random.randint(
-128, 128, size=(channel, channel)).astype(weight.dtype)
data_packed = data_orig.reshape(
batch_size // env.BATCH, env.BATCH,
channel // env.BLOCK_IN, env.BLOCK_IN).transpose((0, 2, 1, 3))
weight_packed = weight_orig.reshape(
channel // env.BLOCK_OUT, env.BLOCK_OUT,
channel // env.BLOCK_IN, env.BLOCK_IN).transpose((0, 2, 1, 3))
res_np = np.zeros(res_shape).astype(res.dtype)
data_arr = tvm.nd.array(data_packed, ctx)
weight_arr = tvm.nd.array(weight_packed, ctx)
res_arr = tvm.nd.array(res_np, ctx)
res_ref = np.zeros(res_shape).astype(env.acc_dtype)
for b in range(batch_size // env.BATCH):
for i in range(channel // env.BLOCK_OUT):
for j in range(channel // env.BLOCK_IN):
res_ref[b,i,:] += np.dot(data_packed[b,j,:].astype(env.acc_dtype),
weight_packed[i,j].T.astype(env.acc_dtype))
res_ref = np.right_shift(res_ref, 8)
res_ref = np.clip(res_ref, 0, (1<<(env.INP_WIDTH-1))-1).astype(res.dtype)
time_f = f.time_evaluator("gemm", ctx, number=20)
cost = time_f(data_arr, weight_arr, res_arr)
res_unpack = res_arr.asnumpy().reshape(batch_size // env.BATCH,
channel // env.BLOCK_OUT,
env.BATCH,
env.BLOCK_OUT)
if check_correctness:
np.testing.assert_allclose(res_unpack, res_ref)
return cost
def verify(s, check_correctness=True):
mod = vta.build(s, [data, weight, res], "ext_dev", target, name="gemm")
temp = util.tempdir()
remote = rpc.connect(host, port)
mod.save(temp.relpath("gemm.o"))
remote.upload(temp.relpath("gemm.o"))
f = remote.load_module("gemm.o")
# verify
ctx = remote.ext_dev(0)
# Data in original format
data_orig = np.random.randint(
-128, 128, size=(batch_size, channel)).astype(data.dtype)
weight_orig = np.random.randint(
-128, 128, size=(channel, channel)).astype(weight.dtype)
data_packed = data_orig.reshape(
batch_size // env.BATCH, env.BATCH,
channel // env.BLOCK_IN, env.BLOCK_IN).transpose((0, 2, 1, 3))
weight_packed = weight_orig.reshape(
channel // env.BLOCK_OUT, env.BLOCK_OUT,
channel // env.BLOCK_IN, env.BLOCK_IN).transpose((0, 2, 1, 3))
res_np = np.zeros(res_shape).astype(res.dtype)
data_arr = tvm.nd.array(data_packed, ctx)
weight_arr = tvm.nd.array(weight_packed, ctx)
res_arr = tvm.nd.array(res_np, ctx)
res_ref = np.zeros(res_shape).astype(env.acc_dtype)
for b in range(batch_size // env.BATCH):
for i in range(channel // env.BLOCK_OUT):
for j in range(channel // env.BLOCK_IN):
res_ref[b,i,:] += np.dot(data_packed[b,j,:].astype(env.acc_dtype),
weight_packed[i,j].T.astype(env.acc_dtype))
res_ref = np.right_shift(res_ref, 8)
res_ref = np.clip(res_ref, 0, (1<<(env.INP_WIDTH-1))-1).astype(res.dtype)
time_f = f.time_evaluator("gemm", ctx, number=20)
cost = time_f(data_arr, weight_arr, res_arr)
res_unpack = res_arr.asnumpy().reshape(batch_size // env.BATCH,
channel // env.BLOCK_OUT,
env.BATCH,
env.BLOCK_OUT)
if check_correctness:
np.testing.assert_allclose(res_unpack, res_ref)
return cost
def run_schedule(load_inp,
load_wgt,
gemm,
alu,
store_out,
print_ir,
check_correctness):
s = tvm.create_schedule(res.op)
s[data_buf].set_scope(env.inp_scope)
s[weight_buf].set_scope(env.wgt_scope)
s[res_gem].set_scope(env.acc_scope)
s[res_shf].set_scope(env.acc_scope)
s[res_min].set_scope(env.acc_scope)
s[res_max].set_scope(env.acc_scope)
def run_schedule(load_inp,
load_wgt,
gemm,
alu,
store_out,
print_ir,
check_correctness):
s = tvm.create_schedule(res.op)
s[data_buf].set_scope(env.inp_scope)
s[weight_buf].set_scope(env.wgt_scope)
s[res_gem].set_scope(env.acc_scope)
s[res_shf].set_scope(env.acc_scope)
s[res_min].set_scope(env.acc_scope)
s[res_max].set_scope(env.acc_scope)
if block:
bblock = block // env.BATCH
iblock = block // env.BLOCK_IN
oblock = block // env.BLOCK_OUT
xbo, xco, xbi, xci = s[res].op.axis
xb1, xco1, xb2, xco2 = s[res].tile(xbo, xco, bblock, oblock)
store_pt = xb2
if block:
bblock = block // env.BATCH
iblock = block // env.BLOCK_IN
oblock = block // env.BLOCK_OUT
xbo, xco, xbi, xci = s[res].op.axis
xb1, xco1, xb2, xco2 = s[res].tile(xbo, xco, bblock, oblock)
store_pt = xb2
s[res_gem].compute_at(s[res], xco1)
s[res_shf].compute_at(s[res], xco1)
s[res_min].compute_at(s[res], xco1)
s[res_max].compute_at(s[res], xco1)
s[res_gem].compute_at(s[res], xco1)
s[res_shf].compute_at(s[res], xco1)
s[res_min].compute_at(s[res], xco1)
s[res_max].compute_at(s[res], xco1)
xbo, xco, xbi, xci = s[res_gem].op.axis
# Compute one line at a time
ko1, ko2 = s[res_gem].split(ko, iblock)
s[res_gem].reorder(ko1, ko2, xbo, xco, xbi, xci, ki)
s[data_buf].compute_at(s[res_gem], ko1)
s[weight_buf].compute_at(s[res_gem], ko1)
# Use VTA instructions
s[data_buf].pragma(s[data_buf].op.axis[0], load_inp)
s[weight_buf].pragma(s[weight_buf].op.axis[0], load_wgt)
s[res_gem].tensorize(xbi, gemm)
s[res_shf].pragma(s[res_shf].op.axis[0], alu)
s[res_min].pragma(s[res_min].op.axis[0], alu)
s[res_max].pragma(s[res_max].op.axis[0], alu)
s[res].pragma(store_pt, store_out)
else:
xbo, xco, xbi, xci = s[res_gem].op.axis
s[res_gem].reorder(ko, xbo, xco, xbi, xci, ki)
# Use VTA instructions
s[data_buf].pragma(s[data_buf].op.axis[0], load_inp)
s[weight_buf].pragma(s[weight_buf].op.axis[0], load_wgt)
s[res_gem].tensorize(xbi, gemm)
s[res_shf].pragma(s[res_shf].op.axis[0], alu)
s[res_min].pragma(s[res_min].op.axis[0], alu)
s[res_max].pragma(s[res_max].op.axis[0], alu)
s[res].pragma(s[res].op.axis[0], store_out)
xbo, xco, xbi, xci = s[res_gem].op.axis
# Compute one line at a time
ko1, ko2 = s[res_gem].split(ko, iblock)
s[res_gem].reorder(ko1, ko2, xbo, xco, xbi, xci, ki)
s[data_buf].compute_at(s[res_gem], ko1)
s[weight_buf].compute_at(s[res_gem], ko1)
# Use VTA instructions
s[data_buf].pragma(s[data_buf].op.axis[0], load_inp)
s[weight_buf].pragma(s[weight_buf].op.axis[0], load_wgt)
s[res_gem].tensorize(xbi, gemm)
s[res_shf].pragma(s[res_shf].op.axis[0], alu)
s[res_min].pragma(s[res_min].op.axis[0], alu)
s[res_max].pragma(s[res_max].op.axis[0], alu)
s[res].pragma(store_pt, store_out)
else:
xbo, xco, xbi, xci = s[res_gem].op.axis
s[res_gem].reorder(ko, xbo, xco, xbi, xci, ki)
# Use VTA instructions
s[data_buf].pragma(s[data_buf].op.axis[0], load_inp)
s[weight_buf].pragma(s[weight_buf].op.axis[0], load_wgt)
s[res_gem].tensorize(xbi, gemm)
s[res_shf].pragma(s[res_shf].op.axis[0], alu)
s[res_min].pragma(s[res_min].op.axis[0], alu)
s[res_max].pragma(s[res_max].op.axis[0], alu)
s[res].pragma(s[res].op.axis[0], store_out)
if print_ir:
print(tvm.lower(s, [data, weight, res], simple_mode=True))
return verify(s, check_correctness)
if print_ir:
print(tvm.lower(s, [data, weight, res], simple_mode=True))
return verify(s, check_correctness)
def gemm_normal(print_ir):
mock = env.mock
print("----- GEMM GFLOPS End-to-End Test-------")
def run_test(header, print_ir, check_correctness):
cost = run_schedule(
env.dma_copy, env.dma_copy, env.gemm, env.alu, env.dma_copy,
print_ir, check_correctness)
gops = (num_ops / cost.mean) / float(10 ** 9)
print(header)
print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops))
with vta.build_config():
run_test("NORMAL", print_ir, True)
def gemm_normal(print_ir):
mock = env.mock
print("----- GEMM GFLOPS End-to-End Test-------")
def run_test(header, print_ir, check_correctness):
cost = run_schedule(
env.dma_copy, env.dma_copy, env.gemm, env.alu, env.dma_copy,
print_ir, check_correctness)
gops = (num_ops / cost.mean) / float(10 ** 9)
print(header)
print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops))
with vta.build_config():
run_test("NORMAL", print_ir, True)
print("")
def gevm_unittest(print_ir):
mock = env.mock
print("----- GEMM Unit Test-------")
def run_test(header, print_ir):
cost = run_schedule(
mock.dma_copy, mock.dma_copy, env.gemm, mock.alu, mock.dma_copy,
print_ir, False)
gops = (num_ops / cost.mean) / float(10 ** 9)
print(header)
print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops))
with vta.build_config():
run_test("NORMAL", print_ir)
def gevm_unittest(print_ir):
mock = env.mock
print("----- GEMM Unit Test-------")
def run_test(header, print_ir):
cost = run_schedule(
mock.dma_copy, mock.dma_copy, env.gemm, mock.alu, mock.dma_copy,
print_ir, False)
gops = (num_ops / cost.mean) / float(10 ** 9)
print(header)
print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops))
with vta.build_config():
run_test("NORMAL", print_ir)
print("")
def alu_unittest(print_ir):
mock = env.mock
print("----- ALU Unit Test-------")
def run_test(header, print_ir):
cost = run_schedule(
mock.dma_copy, mock.dma_copy, mock.gemm, env.alu, mock.dma_copy,
print_ir, False)
gops = (num_ops / cost.mean) / float(10 ** 9)
print(header)
print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops))
with vta.build_config():
run_test("NORMAL", print_ir)
print("")
def alu_unittest(print_ir):
mock = env.mock
print("----- ALU Unit Test-------")
def run_test(header, print_ir):
cost = run_schedule(
mock.dma_copy, mock.dma_copy, mock.gemm, env.alu, mock.dma_copy,
print_ir, False)
gops = (num_ops / cost.mean) / float(10 ** 9)
print(header)
print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops))
with vta.build_config():
run_test("NORMAL", print_ir)
print("")
def load_inp_unittest(print_ir):
mock = env.mock
print("----- LoadInp Unit Test-------")
def run_test(header, print_ir):
cost = run_schedule(
env.dma_copy, mock.dma_copy, mock.gemm, mock.alu, mock.dma_copy, print_ir, False)
gops = (num_ops / cost.mean) / float(10 ** 9)
bandwith = (batch_size * channel * env.INP_WIDTH / cost.mean) / float(10 ** 9)
print(header)
print("\tTime cost = %g sec/op, %g GFLOPS, bandwidth=%g Gbits" % (
cost.mean, gops, bandwith))
with vta.build_config():
run_test("NORMAL", print_ir)
print("")
def load_inp_unittest(print_ir):
mock = env.mock
print("----- LoadInp Unit Test-------")
def run_test(header, print_ir):
cost = run_schedule(
env.dma_copy, mock.dma_copy, mock.gemm, mock.alu, mock.dma_copy, print_ir, False)
gops = (num_ops / cost.mean) / float(10 ** 9)
bandwith = (batch_size * channel * env.INP_WIDTH / cost.mean) / float(10 ** 9)
print(header)
print("\tTime cost = %g sec/op, %g GFLOPS, bandwidth=%g Gbits" % (
cost.mean, gops, bandwith))
with vta.build_config():
run_test("NORMAL", print_ir)
print("")
def load_wgt_unittest(print_ir):
mock = env.mock
print("----- LoadWgt Unit Test-------")
def run_test(header, print_ir):
cost = run_schedule(
mock.dma_copy, env.dma_copy, mock.gemm, mock.alu, mock.dma_copy, print_ir, False)
gops = (num_ops / cost.mean) / float(10 ** 9)
bandwith = (channel * channel * env.WGT_WIDTH / cost.mean) / float(10 ** 9)
print(header)
print("\tTime cost = %g sec/op, %g GFLOPS, bandwidth=%g Gbits" % (
cost.mean, gops, bandwith))
with vta.build_config():
run_test("NORMAL", print_ir)
print("")
def load_wgt_unittest(print_ir):
mock = env.mock
print("----- LoadWgt Unit Test-------")
def run_test(header, print_ir):
cost = run_schedule(
mock.dma_copy, env.dma_copy, mock.gemm, mock.alu, mock.dma_copy, print_ir, False)
gops = (num_ops / cost.mean) / float(10 ** 9)
bandwith = (channel * channel * env.WGT_WIDTH / cost.mean) / float(10 ** 9)
print(header)
print("\tTime cost = %g sec/op, %g GFLOPS, bandwidth=%g Gbits" % (
cost.mean, gops, bandwith))
with vta.build_config():
run_test("NORMAL", print_ir)
print("")
def store_out_unittest(print_ir):
mock = env.mock
print("----- StoreOut Unit Test-------")
def run_test(header, print_ir):
cost = run_schedule(
mock.dma_copy, mock.dma_copy, mock.gemm, mock.alu, env.dma_copy,
print_ir, False)
gops = (num_ops / cost.mean) / float(10 ** 9)
bandwith = (batch_size * channel * env.OUT_WIDTH / cost.mean) / float(10 ** 9)
print(header)
print("\tTime cost = %g sec/op, %g GFLOPS, bandwidth=%g Gbits" % (
cost.mean, gops, bandwith))
with vta.build_config():
run_test("NORMAL", print_ir)
print("")
def store_out_unittest(print_ir):
mock = env.mock
print("----- StoreOut Unit Test-------")
def run_test(header, print_ir):
cost = run_schedule(
mock.dma_copy, mock.dma_copy, mock.gemm, mock.alu, env.dma_copy,
print_ir, False)
gops = (num_ops / cost.mean) / float(10 ** 9)
bandwith = (batch_size * channel * env.OUT_WIDTH / cost.mean) / float(10 ** 9)
print(header)
print("\tTime cost = %g sec/op, %g GFLOPS, bandwidth=%g Gbits" % (
cost.mean, gops, bandwith))
with vta.build_config():
run_test("NORMAL", print_ir)
print("")
gemm_normal(False)
gevm_unittest(False)
alu_unittest(False)
# FIXME: report time that is too short
# load_inp_unittest(False)
# load_wgt_unittest(False)
# store_out_unittest(False)
gemm_normal(False)
gevm_unittest(False)
alu_unittest(False)
print("========GEMM 128=========")
test_gemm_packed(128, 128, 128)
def _run(env, remote):
print("========GEMM 128=========")
run_gemm_packed(env, remote, 128, 128, 128)
# FIXME: hanging run
# print("========GEMM 1024========")
# test_gemm_packed(1024, 1024, 128)
vta.testing.run(_run)
if __name__ == "__main__":
test_gemm()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment