Commit 5c5806ba by Tianqi Chen

[INFRASTRUCTURE] Migrate to json based config. Move gemm test to integration. (#28)

* Migrate to json based config. Move gemm test to integration.

* temp checkin

* checkin  example json
parent 666f32d6
ROOTDIR = $(CURDIR) ROOTDIR = $(CURDIR)
ifndef config
ifneq ("$(wildcard ./config.mk)", "")
config = config.mk
else
config = make/config.mk
endif
endif
include $(config)
export LDFLAGS = -pthread -lm export LDFLAGS = -pthread -lm
export CFLAGS = -std=c++11 -Wall -O2 -Iinclude -fPIC export CFLAGS = -std=c++11 -Wall -O2 -Iinclude -fPIC
ifdef NNVM_PATH VTA_CONFIG = python make/vta_config.py
CFLAGS += -I$(NNVM_PATH)/include CFLAGS += `${VTA_CONFIG} --cflags`
else LDFLAGS += `${VTA_CONFIG} --ldflags`
NNVM_PATH = $(ROOTDIR)/nnvm VTA_TARGET := $(shell ${VTA_CONFIG} --target)
CFLAGS += -I$(NNVM_PATH)/include
endif
ifdef TVM_PATH
CFLAGS += -I$(TVM_PATH)/include -I$(TVM_PATH)/dlpack/include -I$(TVM_PATH)/HalideIR/src
else
TVM_PATH = $(NNVM_PATH)/tvm
CFLAGS += -I$(TVM_PATH)/include -I$(TVM_PATH)/dlpack/include -I$(TVM_PATH)/HalideIR/src
endif
ifdef DMLC_CORE_PATH
CFLAGS += -I$(DMLC_CORE_PATH)/include
else
CFLAGS += -I$(NNVM_PATH)/dmlc-core/include
endif
ifneq ($(ADD_CFLAGS), NONE)
CFLAGS += $(ADD_CFLAGS)
endif
ifneq ($(ADD_LDFLAGS), NONE)
LDFLAGS += $(ADD_LDFLAGS)
endif
UNAME_S := $(shell uname -s) UNAME_S := $(shell uname -s)
...@@ -53,29 +21,30 @@ else ...@@ -53,29 +21,30 @@ else
NO_WHOLE_ARCH= --no-whole-archive NO_WHOLE_ARCH= --no-whole-archive
endif endif
VTA_LIB_SRC = $(wildcard src/*.cc src/tvm/*.cc)
ifeq ($(VTA_TARGET), pynq) VTA_LIB_SRC = $(wildcard src/*.cc)
ifeq (${VTA_TARGET}, pynq)
VTA_LIB_SRC += $(wildcard src/pynq/*.cc) VTA_LIB_SRC += $(wildcard src/pynq/*.cc)
LDFLAGS += -L/usr/lib -lsds_lib
LDFLAGS += -L/opt/python3.6/lib/python3.6/site-packages/pynq/drivers/
LDFLAGS += -L/opt/python3.6/lib/python3.6/site-packages/pynq/lib/
LDFLAGS += -l:libdma.so
endif endif
ifeq ($(VTA_TARGET), sim) ifeq (${VTA_TARGET}, sim)
VTA_LIB_SRC += $(wildcard src/sim/*.cc) VTA_LIB_SRC += $(wildcard src/sim/*.cc)
endif endif
VTA_LIB_OBJ = $(patsubst src/%.cc, build/%.o, $(VTA_LIB_SRC)) VTA_LIB_OBJ = $(patsubst src/%.cc, build/%.o, $(VTA_LIB_SRC))
all: lib/libvta.so all: lib/libvta.so lib/libvta.so.json
build/%.o: src/%.cc build/%.o: src/%.cc
@mkdir -p $(@D) @mkdir -p $(@D)
$(CXX) $(CFLAGS) -MM -MT build/$*.o $< >build/$*.d $(CXX) $(CFLAGS) -MM -MT build/$*.o $< >build/$*.d
$(CXX) -c $(CFLAGS) -c $< -o $@ $(CXX) -c $(CFLAGS) -c $< -o $@
lib/libvta.so.json: lib/libvta.so
@mkdir -p $(@D)
${VTA_CONFIG} --cfg-json > $@
lib/libvta.so: $(VTA_LIB_OBJ) lib/libvta.so: $(VTA_LIB_OBJ)
@mkdir -p $(@D) @mkdir -p $(@D)
$(CXX) $(CFLAGS) -shared -o $@ $(filter %.o, $^) $(LDFLAGS) $(CXX) $(CFLAGS) -shared -o $@ $(filter %.o, $^) $(LDFLAGS)
......
# VTA Configuration
Each VTA runtime/hardware configuration is specified by config.json file.
You can copy the config.json to project root and modify the configuration
before you type make.
The config is going to affect the behavior of python package as well as
the hardware runtime build.
\ No newline at end of file
{
"TARGET" : "pynq",
"LOG_INP_WIDTH" : 3,
"LOG_WGT_WIDTH" : 3,
"LOG_ACC_WIDTH" : 5,
"LOG_OUT_WIDTH" : 3,
"LOG_BATCH" : 0,
"LOG_BLOCK_IN" : 4,
"LOG_BLOCK_OUT" : 4,
"LOG_UOP_BUFF_SIZE" : 15,
"LOG_INP_BUFF_SIZE" : 15,
"LOG_WGT_BUFF_SIZE" : 15,
"LOG_ACC_BUFF_SIZE" : 17
}
#-------------------------------------------------------------------------------
# Template configuration for compiling VTA runtime.
#
# If you want to change the configuration, please use the following
# steps. Assume you are on the root directory of nnvm. First copy the this
# file so that any local changes will be ignored by git
#
# $ cp make/config.mk .
#
# Next modify the according entries, and then compile by
#
# $ make
#
# or build in parallel with 8 threads
#
# $ make -j8
#-------------------------------------------------------------------------------
#---------------------
# choice of compiler
#--------------------
# the additional link flags you want to add
ADD_LDFLAGS=
# the additional compile flags you want to add
ADD_CFLAGS=
# the hardware target, can be [sim, pynq]
VTA_TARGET = pynq
#---------------------
# VTA hardware parameters
#--------------------
# Log of input/activation width in bits (default 3 -> 8 bits)
VTA_LOG_INP_WIDTH = 3
# Log of kernel weight width in bits (default 3 -> 8 bits)
VTA_LOG_WGT_WIDTH = 3
# Log of accum width in bits (default 5 -> 32 bits)
VTA_LOG_ACC_WIDTH = 5
# Log of tensor batch size (A in (A,B)x(B,C) matrix multiplication)
VTA_LOG_BATCH = 0
# Log of tensor inner block size (B in (A,B)x(B,C) matrix multiplication)
VTA_LOG_BLOCK_IN = 4
# Log of tensor outer block size (C in (A,B)x(B,C) matrix multiplication)
VTA_LOG_BLOCK_OUT = 4
# Log of uop buffer size in Bytes
VTA_LOG_UOP_BUFF_SIZE = 15
# Log of inp buffer size in Bytes
VTA_LOG_INP_BUFF_SIZE = 15
# Log of wgt buffer size in Bytes
VTA_LOG_WGT_BUFF_SIZE = 15
# Log of acc buffer size in Bytes
VTA_LOG_ACC_BUFF_SIZE = 17
#---------------------
# Derived VTA hardware parameters
#--------------------
# Input width in bits
VTA_INP_WIDTH = $(shell echo "$$(( 1 << $(VTA_LOG_INP_WIDTH) ))" )
# Weight width in bits
VTA_WGT_WIDTH = $(shell echo "$$(( 1 << $(VTA_LOG_WGT_WIDTH) ))" )
# Log of output width in bits
VTA_LOG_OUT_WIDTH = $(VTA_LOG_INP_WIDTH)
# Output width in bits
VTA_OUT_WIDTH = $(shell echo "$$(( 1 << $(VTA_LOG_OUT_WIDTH) ))" )
# Tensor batch size
VTA_BATCH = $(shell echo "$$(( 1 << $(VTA_LOG_BATCH) ))" )
# Tensor outer block size
VTA_IN_BLOCK = $(shell echo "$$(( 1 << $(VTA_LOG_BLOCK_IN) ))" )
# Tensor inner block size
VTA_OUT_BLOCK = $(shell echo "$$(( 1 << $(VTA_LOG_BLOCK_OUT) ))" )
# Uop buffer size in Bytes
VTA_UOP_BUFF_SIZE = $(shell echo "$$(( 1 << $(VTA_LOG_UOP_BUFF_SIZE) ))" )
# Inp buffer size in Bytes
VTA_INP_BUFF_SIZE = $(shell echo "$$(( 1 << $(VTA_LOG_INP_BUFF_SIZE) ))" )
# Wgt buffer size in Bytes
VTA_WGT_BUFF_SIZE = $(shell echo "$$(( 1 << $(VTA_LOG_WGT_BUFF_SIZE) ))" )
# Acc buffer size in Bytes
VTA_ACC_BUFF_SIZE = $(shell echo "$$(( 1 << $(VTA_LOG_ACC_BUFF_SIZE) ))" )
# Log of out buffer size in Bytes
VTA_LOG_OUT_BUFF_SIZE = \
$(shell echo "$$(( $(VTA_LOG_ACC_BUFF_SIZE) + $(VTA_LOG_OUT_WIDTH) - $(VTA_LOG_ACC_WIDTH) ))" )
# Out buffer size in Bytes
VTA_OUT_BUFF_SIZE = $(shell echo "$$(( 1 << $(VTA_LOG_OUT_BUFF_SIZE) ))" )
# Update ADD_CFLAGS
ADD_CFLAGS +=
-DVTA_TARGET=$(VTA_TARGET)\
-DVTA_LOG_WGT_WIDTH=$(VTA_LOG_WGT_WIDTH) -DVTA_LOG_INP_WIDTH=$(VTA_LOG_INP_WIDTH) \
-DVTA_LOG_ACC_WIDTH=$(VTA_LOG_ACC_WIDTH) -DVTA_LOG_OUT_WIDTH=$(VTA_LOG_OUT_WIDTH) \
-DVTA_LOG_BATCH=$(VTA_LOG_BATCH) \
-DVTA_LOG_BLOCK_IN=$(VTA_LOG_BLOCK_IN) -DVTA_LOG_BLOCK_OUT=$(VTA_LOG_BLOCK_OUT) \
-DVTA_LOG_UOP_BUFF_SIZE=$(VTA_LOG_UOP_BUFF_SIZE) -DVTA_LOG_INP_BUFF_SIZE=$(VTA_LOG_INP_BUFF_SIZE) \
-DVTA_LOG_WGT_BUFF_SIZE=$(VTA_LOG_WGT_BUFF_SIZE) -DVTA_LOG_ACC_BUFF_SIZE=$(VTA_LOG_ACC_BUFF_SIZE) \
-DVTA_LOG_OUT_BUFF_SIZE=$(VTA_LOG_OUT_BUFF_SIZE)
"""VTA config tool"""
import os
import sys
import json
import argparse
def get_pkg_config(cfg):
"""Get the pkg config object."""
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
proj_root = os.path.abspath(os.path.join(curr_path, "../"))
pkg_config_py = os.path.join(proj_root, "python/vta/pkg_config.py")
libpkg = {"__file__": pkg_config_py}
exec(compile(open(pkg_config_py, "rb").read(), pkg_config_py, "exec"), libpkg, libpkg)
PkgConfig = libpkg["PkgConfig"]
return PkgConfig(cfg, proj_root)
def main():
"""Main funciton"""
parser = argparse.ArgumentParser()
parser.add_argument("--cflags", action="store_true",
help="print the cflags")
parser.add_argument("--update", action="store_true",
help="Print out the json option.")
parser.add_argument("--ldflags", action="store_true",
help="print the cflags")
parser.add_argument("--cfg-json", action="store_true",
help="print all the config json")
parser.add_argument("--target", action="store_true",
help="print the target")
args = parser.parse_args()
if len(sys.argv) == 1:
parser.print_help()
return
curr_path = os.path.dirname(
os.path.abspath(os.path.expanduser(__file__)))
proj_root = os.path.abspath(os.path.join(curr_path, "../"))
path_list = [
os.path.join(proj_root, "config.json"),
os.path.join(proj_root, "make/config.json")
]
ok_path_list = [p for p in path_list if os.path.exists(p)]
if not ok_path_list:
raise RuntimeError("Cannot find config in %s" % str(path_list))
cfg = json.load(open(ok_path_list[0]))
cfg["LOG_OUT_WIDTH"] = cfg["LOG_INP_WIDTH"]
pkg = get_pkg_config(cfg)
if args.target:
print(pkg.target)
if args.cflags:
print(" ".join(pkg.cflags))
if args.ldflags:
print(" ".join(pkg.ldflags))
if args.cfg_json:
print(pkg.cfg_json)
if __name__ == "__main__":
main()
"""TVM-based VTA Compiler Toolchain""" """TVM-based VTA Compiler Toolchain"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import sys
from .environment import get_env, Environment from .environment import get_env, Environment
from . import arm_conv2d, vta_conv2d
from .build_module import build_config, lower, build
from .rpc_client import reconfig_runtime, program_fpga
try: try:
# allow optional import in config mode.
from . import arm_conv2d, vta_conv2d
from .build_module import build_config, lower, build
from .rpc_client import reconfig_runtime, program_fpga
from . import graph from . import graph
except (ImportError, RuntimeError): except (ImportError, RuntimeError):
pass pass
...@@ -3,10 +3,11 @@ ...@@ -3,10 +3,11 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import os import os
import glob import json
import copy import copy
import tvm import tvm
from . import intrin from . import intrin
from .pkg_config import PkgConfig
class DevContext(object): class DevContext(object):
...@@ -61,45 +62,6 @@ class DevContext(object): ...@@ -61,45 +62,6 @@ class DevContext(object):
return 1 if self.DEBUG_NO_SYNC else qid return 1 if self.DEBUG_NO_SYNC else qid
class PkgConfig(object):
"""Simple package config tool for VTA.
This is used to provide runtime specific configurations.
"""
def __init__(self, env):
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
proj_root = os.path.abspath(os.path.join(curr_path, "../../"))
# include path
self.include_path = [
"-I%s/include" % proj_root,
"-I%s/nnvm/tvm/include" % proj_root,
"-I%s/nnvm/tvm/dlpack/include" % proj_root,
"-I%s/nnvm/dmlc-core/include" % proj_root
]
# List of source files that can be used to build standalone library.
self.lib_source = []
self.lib_source += glob.glob("%s/src/*.cc" % proj_root)
self.lib_source += glob.glob("%s/src/%s/*.cc" % (proj_root, env.TARGET))
# macro keys
self.macro_defs = []
for key in env.cfg_keys:
self.macro_defs.append("-DVTA_%s=%s" % (key, str(getattr(env, key))))
if env.TARGET == "pynq":
self.ldflags = [
"-L/usr/lib",
"-lsds_lib",
"-L/opt/python3.6/lib/python3.6/site-packages/pynq/drivers/",
"-L/opt/python3.6/lib/python3.6/site-packages/pynq/lib/",
"-l:libdma.so"]
else:
self.ldflags = []
@property
def cflags(self):
return self.include_path + self.macro_defs
class Environment(object): class Environment(object):
"""Hareware configuration object. """Hareware configuration object.
...@@ -123,19 +85,6 @@ class Environment(object): ...@@ -123,19 +85,6 @@ class Environment(object):
env = vta.get_env() env = vta.get_env()
""" """
current = None current = None
cfg_keys = [
"TARGET",
"LOG_INP_WIDTH",
"LOG_WGT_WIDTH",
"LOG_ACC_WIDTH",
"LOG_BATCH",
"LOG_BLOCK_IN",
"LOG_BLOCK_OUT",
"LOG_UOP_BUFF_SIZE",
"LOG_INP_BUFF_SIZE",
"LOG_WGT_BUFF_SIZE",
"LOG_ACC_BUFF_SIZE",
]
# constants # constants
MAX_XFER = 1 << 22 MAX_XFER = 1 << 22
# debug flags # debug flags
...@@ -152,10 +101,9 @@ class Environment(object): ...@@ -152,10 +101,9 @@ class Environment(object):
def __init__(self, cfg): def __init__(self, cfg):
# Log of input/activation width in bits # Log of input/activation width in bits
self.__dict__.update(cfg) self.__dict__.update(cfg)
for key in self.cfg_keys: for key in PkgConfig.cfg_keys:
if key not in cfg: if key not in cfg:
raise ValueError("Expect key %s in cfg" % key) raise ValueError("Expect key %s in cfg" % key)
self.LOG_OUT_WIDTH = self.LOG_INP_WIDTH
self.LOG_OUT_BUFF_SIZE = ( self.LOG_OUT_BUFF_SIZE = (
self.LOG_ACC_BUFF_SIZE + self.LOG_ACC_BUFF_SIZE +
self.LOG_OUT_WIDTH - self.LOG_OUT_WIDTH -
...@@ -195,7 +143,12 @@ class Environment(object): ...@@ -195,7 +143,12 @@ class Environment(object):
self.mock_mode = False self.mock_mode = False
self._mock_env = None self._mock_env = None
self._dev_ctx = None self._dev_ctx = None
self._pkg_config = None
def pkg_config(self):
"""PkgConfig instance"""
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
proj_root = os.path.abspath(os.path.join(curr_path, "../../"))
return PkgConfig(self.__dict__, proj_root)
@property @property
def dev(self): def dev(self):
...@@ -205,13 +158,6 @@ class Environment(object): ...@@ -205,13 +158,6 @@ class Environment(object):
return self._dev_ctx return self._dev_ctx
@property @property
def pkg_config(self):
"""PkgConfig instance"""
if self._pkg_config is None:
self._pkg_config = PkgConfig(self)
return self._pkg_config
@property
def mock(self): def mock(self):
"""A mock version of the Environment """A mock version of the Environment
...@@ -327,30 +273,19 @@ def coproc_dep_pop(op): ...@@ -327,30 +273,19 @@ def coproc_dep_pop(op):
def _init_env(): def _init_env():
"""Iniitalize the default global env""" """Iniitalize the default global env"""
python_vta_dir = os.path.dirname(__file__) curr_path = os.path.dirname(
filename = os.path.join(python_vta_dir, '../../config.mk') os.path.abspath(os.path.expanduser(__file__)))
keys = set() proj_root = os.path.abspath(os.path.join(curr_path, "../../"))
path_list = [
for k in Environment.cfg_keys: os.path.join(curr_path, "config.json"),
keys.add("VTA_" + k) os.path.join(proj_root, "config.json"),
os.path.join(proj_root, "make/config.json")
if not os.path.isfile(filename): ]
path_list = [p for p in path_list if os.path.exists(p)]
if not path_list:
raise RuntimeError( raise RuntimeError(
"Error: {} not found.make sure you have config.mk in your vta root" "Error: {} not found.make sure you have config.json in your vta root"
.format(filename)) .format(filename))
return Environment(json.load(open(path_list[0])))
cfg = {}
with open(filename) as f:
for line in f:
for k in keys:
if k +" =" in line:
val = line.split("=")[1].strip()
if k.startswith("VTA_"):
k = k[4:]
try:
cfg[k] = int(val)
except ValueError:
cfg[k] = val
return Environment(cfg)
Environment.current = _init_env() Environment.current = _init_env()
...@@ -8,11 +8,13 @@ import logging ...@@ -8,11 +8,13 @@ import logging
import argparse import argparse
import os import os
import ctypes import ctypes
import json
import tvm import tvm
from tvm._ffi.base import c_str from tvm._ffi.base import c_str
from tvm.contrib import rpc, cc from tvm.contrib import rpc, cc
from ..environment import get_env from ..environment import get_env
from ..pkg_config import PkgConfig
@tvm.register_func("tvm.contrib.rpc.server.start", override=True) @tvm.register_func("tvm.contrib.rpc.server.start", override=True)
...@@ -21,8 +23,9 @@ def server_start(): ...@@ -21,8 +23,9 @@ def server_start():
# pylint: disable=unused-variable # pylint: disable=unused-variable
curr_path = os.path.dirname( curr_path = os.path.dirname(
os.path.abspath(os.path.expanduser(__file__))) os.path.abspath(os.path.expanduser(__file__)))
dll_path = os.path.abspath( proj_root = os.path.abspath(os.path.join(curr_path, "../../.."))
os.path.join(curr_path, "../../../lib/libvta.so")) dll_path = os.path.abspath(os.path.join(proj_root, "lib/libvta.so"))
cfg_path = os.path.abspath(os.path.join(proj_root, "lib/libvta.so.json"))
runtime_dll = [] runtime_dll = []
_load_module = tvm.get_global_func("tvm.contrib.rpc.server.load_module") _load_module = tvm.get_global_func("tvm.contrib.rpc.server.load_module")
...@@ -56,26 +59,36 @@ def server_start(): ...@@ -56,26 +59,36 @@ def server_start():
runtime_dll.pop() runtime_dll.pop()
@tvm.register_func("tvm.contrib.vta.reconfig_runtime", override=True) @tvm.register_func("tvm.contrib.vta.reconfig_runtime", override=True)
def reconfig_runtime(cflags): def reconfig_runtime(cfg_json):
"""Rebuild and reload runtime with new configuration. """Rebuild and reload runtime with new configuration.
Parameters Parameters
---------- ----------
cfg_json : str cfg_json : str
JSON string used for configurations. JSON string used for configurations.
""" """
if runtime_dll: if runtime_dll:
raise RuntimeError("Can only reconfig in the beginning of session...") raise RuntimeError("Can only reconfig in the beginning of session...")
cflags = cflags.split()
env = get_env() env = get_env()
cfg = json.loads(cfg_json)
cfg["TARGET"] = env.TARGET
pkg = PkgConfig(cfg, proj_root)
# check if the configuration is already the same
if os.path.isfile(cfg_path):
old_cfg = json.load(open(cfg_path))
if pkg.same_config(old_cfg):
logging.info("Skip reconfiguration because runtime config is the same")
return
cflags += ["-O2", "-std=c++11"] cflags += ["-O2", "-std=c++11"]
cflags += env.pkg_config.include_path cflags += pkg.cflags
ldflags = env.pkg_config.ldflags ldflags = pkg.ldflags
lib_name = dll_path lib_name = dll_path
source = env.pkg_config.lib_source source = env.pkg_config.lib_source
logging.info("Rebuild runtime: output=%s, cflags=%s, source=%s, ldflags=%s", logging.info("Rebuild runtime: output=%s, cflags=%s, source=%s, ldflags=%s",
dll_path, str(cflags), str(source), str(ldflags)) dll_path, str(cflags), str(source), str(ldflags))
cc.create_shared(lib_name, source, cflags + ldflags) cc.create_shared(lib_name, source, cflags + ldflags)
with open(cfg_path, "w") as outputfile:
json.dump(pkg.cfg_json, outputfile)
def main(): def main():
......
"""VTA Package configuration module
This module is dependency free and can be used to configure package.
"""
from __future__ import absolute_import as _abs
import json
import glob
class PkgConfig(object):
"""Simple package config tool for VTA.
This is used to provide runtime specific configurations.
Parameters
----------
cfg : dict
The config dictionary
proj_root : str
Path to the project root
"""
cfg_keys = [
"TARGET",
"LOG_INP_WIDTH",
"LOG_WGT_WIDTH",
"LOG_ACC_WIDTH",
"LOG_OUT_WIDTH",
"LOG_BATCH",
"LOG_BLOCK_IN",
"LOG_BLOCK_OUT",
"LOG_UOP_BUFF_SIZE",
"LOG_INP_BUFF_SIZE",
"LOG_WGT_BUFF_SIZE",
"LOG_ACC_BUFF_SIZE",
]
def __init__(self, cfg, proj_root):
# include path
self.include_path = [
"-I%s/include" % proj_root,
"-I%s/nnvm/tvm/include" % proj_root,
"-I%s/nnvm/tvm/dlpack/include" % proj_root,
"-I%s/nnvm/dmlc-core/include" % proj_root
]
# List of source files that can be used to build standalone library.
self.lib_source = []
self.lib_source += glob.glob("%s/src/*.cc" % proj_root)
self.lib_source += glob.glob("%s/src/%s/*.cc" % (proj_root, cfg["TARGET"]))
# macro keys
self.macro_defs = []
self.cfg_dict = {}
for key in self.cfg_keys:
self.macro_defs.append("-DVTA_%s=%s" % (key, str(cfg[key])))
self.cfg_dict[key] = cfg[key]
self.target = cfg["TARGET"]
if self.target == "pynq":
self.ldflags = [
"-L/usr/lib",
"-lsds_lib",
"-L/opt/python3.6/lib/python3.6/site-packages/pynq/drivers/",
"-L/opt/python3.6/lib/python3.6/site-packages/pynq/lib/",
"-l:libdma.so"]
else:
self.ldflags = []
@property
def cflags(self):
return self.include_path + self.macro_defs
@property
def cfg_json(self):
return json.dumps(self.cfg_dict, indent=2)
def same_config(self, cfg):
"""Compare if cfg is same as current config.
Parameters
----------
cfg : the configuration
The configuration
Returns
-------
equal : bool
Whether the configuration is the same.
"""
for k, v in self.cfg_dict.items():
if k not in cfg:
return False
if cfg[k] != v:
return False
return True
...@@ -12,23 +12,8 @@ def reconfig_runtime(remote): ...@@ -12,23 +12,8 @@ def reconfig_runtime(remote):
The TVM RPC session The TVM RPC session
""" """
env = get_env() env = get_env()
keys = ["VTA_LOG_WGT_WIDTH",
"VTA_LOG_INP_WIDTH",
"VTA_LOG_ACC_WIDTH",
"VTA_LOG_OUT_WIDTH",
"VTA_LOG_BATCH",
"VTA_LOG_BLOCK_IN",
"VTA_LOG_BLOCK_OUT",
"VTA_LOG_UOP_BUFF_SIZE",
"VTA_LOG_INP_BUFF_SIZE",
"VTA_LOG_WGT_BUFF_SIZE",
"VTA_LOG_ACC_BUFF_SIZE",
"VTA_LOG_OUT_BUFF_SIZE"]
cflags = []
for k in keys:
cflags += ["-D%s=%s" % (k, str(getattr(env, k[4:])))]
freconfig = remote.get_function("tvm.contrib.vta.reconfig_runtime") freconfig = remote.get_function("tvm.contrib.vta.reconfig_runtime")
freconfig(" ".join(cflags)) freconfig(env.pkg_config().cfg_json)
def program_fpga(remote, bitstream): def program_fpga(remote, bitstream):
......
...@@ -109,9 +109,12 @@ class DRAM { ...@@ -109,9 +109,12 @@ class DRAM {
* \return The true virtual address; * \return The true virtual address;
*/ */
void* GetAddr(uint64_t phy_addr) { void* GetAddr(uint64_t phy_addr) {
CHECK_NE(phy_addr, 0)
<< "trying to get address that is nullptr";
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
uint64_t loc = (phy_addr >> kPageBits) - 1; uint64_t loc = (phy_addr >> kPageBits) - 1;
CHECK_LT(loc, ptable_.size()); CHECK_LT(loc, ptable_.size())
<< "phy_addr=" << phy_addr;
Page* p = ptable_[loc]; Page* p = ptable_[loc];
CHECK(p != nullptr); CHECK(p != nullptr);
size_t offset = (loc - p->ptable_begin) << kPageBits; size_t offset = (loc - p->ptable_begin) << kPageBits;
...@@ -173,7 +176,7 @@ class DRAM { ...@@ -173,7 +176,7 @@ class DRAM {
private: private:
// The bits in page table // The bits in page table
static constexpr vta_phy_addr_t kPageBits = 16; static constexpr vta_phy_addr_t kPageBits = 12;
// page size, also the maximum allocable size 16 K // page size, also the maximum allocable size 16 K
static constexpr vta_phy_addr_t kPageSize = 1 << kPageBits; static constexpr vta_phy_addr_t kPageSize = 1 << kPageBits;
/*! \brief A page in the DRAM */ /*! \brief A page in the DRAM */
...@@ -388,6 +391,7 @@ class Device { ...@@ -388,6 +391,7 @@ class Device {
} }
void RunStore(const VTAMemInsn* op) { void RunStore(const VTAMemInsn* op) {
if (op->x_size == 0) return;
if (op->memory_type == VTA_MEM_ID_ACC || if (op->memory_type == VTA_MEM_ID_ACC ||
op->memory_type == VTA_MEM_ID_UOP) { op->memory_type == VTA_MEM_ID_UOP) {
prof_->out_store_nbytes += ( prof_->out_store_nbytes += (
......
import os
import tvm import tvm
import vta
import numpy as np import numpy as np
import time from tvm.contrib import util
from tvm.contrib import rpc, util import vta.testing
host = "pynq"
port = 9091
target = "llvm -target=armv7-none-linux-gnueabihf"
def test_gemm():
def run_gemm_packed(env, remote, batch_size, channel, block):
data_shape = (batch_size // env.BATCH,
channel // env.BLOCK_IN,
env.BATCH,
env.BLOCK_IN)
weight_shape = (channel // env.BLOCK_OUT,
channel // env.BLOCK_IN,
env.BLOCK_OUT,
env.BLOCK_IN)
res_shape = (batch_size // env.BATCH,
channel // env.BLOCK_OUT,
env.BATCH,
env.BLOCK_OUT)
num_ops = channel * channel * batch_size
def test_gemm_packed(batch_size, channel, block): ko = tvm.reduce_axis((0, channel // env.BLOCK_IN), name='ko')
env = vta.get_env() ki = tvm.reduce_axis((0, env.BLOCK_IN), name='ki')
data_shape = (batch_size // env.BATCH,
channel // env.BLOCK_IN,
env.BATCH,
env.BLOCK_IN)
weight_shape = (channel // env.BLOCK_OUT,
channel // env.BLOCK_IN,
env.BLOCK_OUT,
env.BLOCK_IN)
res_shape = (batch_size // env.BATCH,
channel // env.BLOCK_OUT,
env.BATCH,
env.BLOCK_OUT)
num_ops = channel * channel * batch_size
ko = tvm.reduce_axis((0, channel // env.BLOCK_IN), name='ko') data = tvm.placeholder(data_shape,
ki = tvm.reduce_axis((0, env.BLOCK_IN), name='ki') name="data",
dtype=env.inp_dtype)
weight = tvm.placeholder(weight_shape,
name="weight",
dtype=env.wgt_dtype)
data_buf = tvm.compute(data_shape,
lambda *i: data(*i),
"data_buf")
weight_buf = tvm.compute(weight_shape,
lambda *i: weight(*i),
"weight_buf")
res_gem = tvm.compute(res_shape,
lambda bo, co, bi, ci: tvm.sum(
data_buf[bo, ko, bi, ki].astype(env.acc_dtype) *
weight_buf[co, ko, ci, ki].astype(env.acc_dtype),
axis=[ko, ki]),
name="res_gem")
res_shf = tvm.compute(res_shape,
lambda *i: res_gem(*i)>>8,
name="res_shf")
res_max = tvm.compute(res_shape,
lambda *i: tvm.max(res_shf(*i), 0),
"res_max") #relu
res_min = tvm.compute(res_shape,
lambda *i: tvm.min(res_max(*i), (1<<(env.INP_WIDTH-1))-1),
"res_min") #relu
res = tvm.compute(res_shape,
lambda *i: res_min(*i).astype(env.inp_dtype),
name="res")
data = tvm.placeholder(data_shape, def verify(s, check_correctness=True):
name="data", mod = vta.build(s, [data, weight, res],
dtype=env.inp_dtype) "ext_dev", env.target_host, name="gemm")
weight = tvm.placeholder(weight_shape, temp = util.tempdir()
name="weight", mod.save(temp.relpath("gemm.o"))
dtype=env.wgt_dtype) remote.upload(temp.relpath("gemm.o"))
data_buf = tvm.compute(data_shape, f = remote.load_module("gemm.o")
lambda *i: data(*i), # verify
"data_buf") ctx = remote.ext_dev(0)
weight_buf = tvm.compute(weight_shape, # Data in original format
lambda *i: weight(*i), data_orig = np.random.randint(
"weight_buf") -128, 128, size=(batch_size, channel)).astype(data.dtype)
res_gem = tvm.compute(res_shape, weight_orig = np.random.randint(
lambda bo, co, bi, ci: tvm.sum( -128, 128, size=(channel, channel)).astype(weight.dtype)
data_buf[bo, ko, bi, ki].astype(env.acc_dtype) * data_packed = data_orig.reshape(
weight_buf[co, ko, ci, ki].astype(env.acc_dtype), batch_size // env.BATCH, env.BATCH,
axis=[ko, ki]), channel // env.BLOCK_IN, env.BLOCK_IN).transpose((0, 2, 1, 3))
name="res_gem") weight_packed = weight_orig.reshape(
res_shf = tvm.compute(res_shape, channel // env.BLOCK_OUT, env.BLOCK_OUT,
lambda *i: res_gem(*i)>>8, channel // env.BLOCK_IN, env.BLOCK_IN).transpose((0, 2, 1, 3))
name="res_shf") res_np = np.zeros(res_shape).astype(res.dtype)
res_max = tvm.compute(res_shape, data_arr = tvm.nd.array(data_packed, ctx)
lambda *i: tvm.max(res_shf(*i), 0), weight_arr = tvm.nd.array(weight_packed, ctx)
"res_max") #relu res_arr = tvm.nd.array(res_np, ctx)
res_min = tvm.compute(res_shape, res_ref = np.zeros(res_shape).astype(env.acc_dtype)
lambda *i: tvm.min(res_max(*i), (1<<(env.INP_WIDTH-1))-1), for b in range(batch_size // env.BATCH):
"res_min") #relu for i in range(channel // env.BLOCK_OUT):
res = tvm.compute(res_shape, for j in range(channel // env.BLOCK_IN):
lambda *i: res_min(*i).astype(env.inp_dtype), res_ref[b,i,:] += np.dot(data_packed[b,j,:].astype(env.acc_dtype),
name="res") weight_packed[i,j].T.astype(env.acc_dtype))
res_ref = np.right_shift(res_ref, 8)
res_ref = np.clip(res_ref, 0, (1<<(env.INP_WIDTH-1))-1).astype(res.dtype)
time_f = f.time_evaluator("gemm", ctx, number=20)
cost = time_f(data_arr, weight_arr, res_arr)
res_unpack = res_arr.asnumpy().reshape(batch_size // env.BATCH,
channel // env.BLOCK_OUT,
env.BATCH,
env.BLOCK_OUT)
if check_correctness:
np.testing.assert_allclose(res_unpack, res_ref)
return cost
def verify(s, check_correctness=True): def run_schedule(load_inp,
mod = vta.build(s, [data, weight, res], "ext_dev", target, name="gemm") load_wgt,
temp = util.tempdir() gemm,
remote = rpc.connect(host, port) alu,
mod.save(temp.relpath("gemm.o")) store_out,
remote.upload(temp.relpath("gemm.o")) print_ir,
f = remote.load_module("gemm.o") check_correctness):
# verify s = tvm.create_schedule(res.op)
ctx = remote.ext_dev(0) s[data_buf].set_scope(env.inp_scope)
# Data in original format s[weight_buf].set_scope(env.wgt_scope)
data_orig = np.random.randint( s[res_gem].set_scope(env.acc_scope)
-128, 128, size=(batch_size, channel)).astype(data.dtype) s[res_shf].set_scope(env.acc_scope)
weight_orig = np.random.randint( s[res_min].set_scope(env.acc_scope)
-128, 128, size=(channel, channel)).astype(weight.dtype) s[res_max].set_scope(env.acc_scope)
data_packed = data_orig.reshape(
batch_size // env.BATCH, env.BATCH,
channel // env.BLOCK_IN, env.BLOCK_IN).transpose((0, 2, 1, 3))
weight_packed = weight_orig.reshape(
channel // env.BLOCK_OUT, env.BLOCK_OUT,
channel // env.BLOCK_IN, env.BLOCK_IN).transpose((0, 2, 1, 3))
res_np = np.zeros(res_shape).astype(res.dtype)
data_arr = tvm.nd.array(data_packed, ctx)
weight_arr = tvm.nd.array(weight_packed, ctx)
res_arr = tvm.nd.array(res_np, ctx)
res_ref = np.zeros(res_shape).astype(env.acc_dtype)
for b in range(batch_size // env.BATCH):
for i in range(channel // env.BLOCK_OUT):
for j in range(channel // env.BLOCK_IN):
res_ref[b,i,:] += np.dot(data_packed[b,j,:].astype(env.acc_dtype),
weight_packed[i,j].T.astype(env.acc_dtype))
res_ref = np.right_shift(res_ref, 8)
res_ref = np.clip(res_ref, 0, (1<<(env.INP_WIDTH-1))-1).astype(res.dtype)
time_f = f.time_evaluator("gemm", ctx, number=20)
cost = time_f(data_arr, weight_arr, res_arr)
res_unpack = res_arr.asnumpy().reshape(batch_size // env.BATCH,
channel // env.BLOCK_OUT,
env.BATCH,
env.BLOCK_OUT)
if check_correctness:
np.testing.assert_allclose(res_unpack, res_ref)
return cost
def run_schedule(load_inp, if block:
load_wgt, bblock = block // env.BATCH
gemm, iblock = block // env.BLOCK_IN
alu, oblock = block // env.BLOCK_OUT
store_out, xbo, xco, xbi, xci = s[res].op.axis
print_ir, xb1, xco1, xb2, xco2 = s[res].tile(xbo, xco, bblock, oblock)
check_correctness): store_pt = xb2
s = tvm.create_schedule(res.op)
s[data_buf].set_scope(env.inp_scope)
s[weight_buf].set_scope(env.wgt_scope)
s[res_gem].set_scope(env.acc_scope)
s[res_shf].set_scope(env.acc_scope)
s[res_min].set_scope(env.acc_scope)
s[res_max].set_scope(env.acc_scope)
if block: s[res_gem].compute_at(s[res], xco1)
bblock = block // env.BATCH s[res_shf].compute_at(s[res], xco1)
iblock = block // env.BLOCK_IN s[res_min].compute_at(s[res], xco1)
oblock = block // env.BLOCK_OUT s[res_max].compute_at(s[res], xco1)
xbo, xco, xbi, xci = s[res].op.axis
xb1, xco1, xb2, xco2 = s[res].tile(xbo, xco, bblock, oblock)
store_pt = xb2
s[res_gem].compute_at(s[res], xco1) xbo, xco, xbi, xci = s[res_gem].op.axis
s[res_shf].compute_at(s[res], xco1) # Compute one line at a time
s[res_min].compute_at(s[res], xco1) ko1, ko2 = s[res_gem].split(ko, iblock)
s[res_max].compute_at(s[res], xco1) s[res_gem].reorder(ko1, ko2, xbo, xco, xbi, xci, ki)
s[data_buf].compute_at(s[res_gem], ko1)
s[weight_buf].compute_at(s[res_gem], ko1)
# Use VTA instructions
s[data_buf].pragma(s[data_buf].op.axis[0], load_inp)
s[weight_buf].pragma(s[weight_buf].op.axis[0], load_wgt)
s[res_gem].tensorize(xbi, gemm)
s[res_shf].pragma(s[res_shf].op.axis[0], alu)
s[res_min].pragma(s[res_min].op.axis[0], alu)
s[res_max].pragma(s[res_max].op.axis[0], alu)
s[res].pragma(store_pt, store_out)
else:
xbo, xco, xbi, xci = s[res_gem].op.axis
s[res_gem].reorder(ko, xbo, xco, xbi, xci, ki)
# Use VTA instructions
s[data_buf].pragma(s[data_buf].op.axis[0], load_inp)
s[weight_buf].pragma(s[weight_buf].op.axis[0], load_wgt)
s[res_gem].tensorize(xbi, gemm)
s[res_shf].pragma(s[res_shf].op.axis[0], alu)
s[res_min].pragma(s[res_min].op.axis[0], alu)
s[res_max].pragma(s[res_max].op.axis[0], alu)
s[res].pragma(s[res].op.axis[0], store_out)
xbo, xco, xbi, xci = s[res_gem].op.axis
# Compute one line at a time
ko1, ko2 = s[res_gem].split(ko, iblock)
s[res_gem].reorder(ko1, ko2, xbo, xco, xbi, xci, ki)
s[data_buf].compute_at(s[res_gem], ko1)
s[weight_buf].compute_at(s[res_gem], ko1)
# Use VTA instructions
s[data_buf].pragma(s[data_buf].op.axis[0], load_inp)
s[weight_buf].pragma(s[weight_buf].op.axis[0], load_wgt)
s[res_gem].tensorize(xbi, gemm)
s[res_shf].pragma(s[res_shf].op.axis[0], alu)
s[res_min].pragma(s[res_min].op.axis[0], alu)
s[res_max].pragma(s[res_max].op.axis[0], alu)
s[res].pragma(store_pt, store_out)
else:
xbo, xco, xbi, xci = s[res_gem].op.axis
s[res_gem].reorder(ko, xbo, xco, xbi, xci, ki)
# Use VTA instructions
s[data_buf].pragma(s[data_buf].op.axis[0], load_inp)
s[weight_buf].pragma(s[weight_buf].op.axis[0], load_wgt)
s[res_gem].tensorize(xbi, gemm)
s[res_shf].pragma(s[res_shf].op.axis[0], alu)
s[res_min].pragma(s[res_min].op.axis[0], alu)
s[res_max].pragma(s[res_max].op.axis[0], alu)
s[res].pragma(s[res].op.axis[0], store_out)
if print_ir: if print_ir:
print(tvm.lower(s, [data, weight, res], simple_mode=True)) print(tvm.lower(s, [data, weight, res], simple_mode=True))
return verify(s, check_correctness) return verify(s, check_correctness)
def gemm_normal(print_ir): def gemm_normal(print_ir):
mock = env.mock mock = env.mock
print("----- GEMM GFLOPS End-to-End Test-------") print("----- GEMM GFLOPS End-to-End Test-------")
def run_test(header, print_ir, check_correctness): def run_test(header, print_ir, check_correctness):
cost = run_schedule( cost = run_schedule(
env.dma_copy, env.dma_copy, env.gemm, env.alu, env.dma_copy, env.dma_copy, env.dma_copy, env.gemm, env.alu, env.dma_copy,
print_ir, check_correctness) print_ir, check_correctness)
gops = (num_ops / cost.mean) / float(10 ** 9) gops = (num_ops / cost.mean) / float(10 ** 9)
print(header) print(header)
print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops)) print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops))
with vta.build_config(): with vta.build_config():
run_test("NORMAL", print_ir, True) run_test("NORMAL", print_ir, True)
print("") def gevm_unittest(print_ir):
mock = env.mock
print("----- GEMM Unit Test-------")
def run_test(header, print_ir):
cost = run_schedule(
mock.dma_copy, mock.dma_copy, env.gemm, mock.alu, mock.dma_copy,
print_ir, False)
gops = (num_ops / cost.mean) / float(10 ** 9)
print(header)
print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops))
with vta.build_config():
run_test("NORMAL", print_ir)
def gevm_unittest(print_ir): def alu_unittest(print_ir):
mock = env.mock mock = env.mock
print("----- GEMM Unit Test-------") print("----- ALU Unit Test-------")
def run_test(header, print_ir): def run_test(header, print_ir):
cost = run_schedule( cost = run_schedule(
mock.dma_copy, mock.dma_copy, env.gemm, mock.alu, mock.dma_copy, mock.dma_copy, mock.dma_copy, mock.gemm, env.alu, mock.dma_copy,
print_ir, False) print_ir, False)
gops = (num_ops / cost.mean) / float(10 ** 9) gops = (num_ops / cost.mean) / float(10 ** 9)
print(header) print(header)
print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops)) print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops))
with vta.build_config(): with vta.build_config():
run_test("NORMAL", print_ir) run_test("NORMAL", print_ir)
print("") print("")
def alu_unittest(print_ir): def load_inp_unittest(print_ir):
mock = env.mock mock = env.mock
print("----- ALU Unit Test-------") print("----- LoadInp Unit Test-------")
def run_test(header, print_ir): def run_test(header, print_ir):
cost = run_schedule( cost = run_schedule(
mock.dma_copy, mock.dma_copy, mock.gemm, env.alu, mock.dma_copy, env.dma_copy, mock.dma_copy, mock.gemm, mock.alu, mock.dma_copy, print_ir, False)
print_ir, False) gops = (num_ops / cost.mean) / float(10 ** 9)
gops = (num_ops / cost.mean) / float(10 ** 9) bandwith = (batch_size * channel * env.INP_WIDTH / cost.mean) / float(10 ** 9)
print(header) print(header)
print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops)) print("\tTime cost = %g sec/op, %g GFLOPS, bandwidth=%g Gbits" % (
with vta.build_config(): cost.mean, gops, bandwith))
run_test("NORMAL", print_ir) with vta.build_config():
print("") run_test("NORMAL", print_ir)
print("")
def load_inp_unittest(print_ir): def load_wgt_unittest(print_ir):
mock = env.mock mock = env.mock
print("----- LoadInp Unit Test-------") print("----- LoadWgt Unit Test-------")
def run_test(header, print_ir): def run_test(header, print_ir):
cost = run_schedule( cost = run_schedule(
env.dma_copy, mock.dma_copy, mock.gemm, mock.alu, mock.dma_copy, print_ir, False) mock.dma_copy, env.dma_copy, mock.gemm, mock.alu, mock.dma_copy, print_ir, False)
gops = (num_ops / cost.mean) / float(10 ** 9) gops = (num_ops / cost.mean) / float(10 ** 9)
bandwith = (batch_size * channel * env.INP_WIDTH / cost.mean) / float(10 ** 9) bandwith = (channel * channel * env.WGT_WIDTH / cost.mean) / float(10 ** 9)
print(header) print(header)
print("\tTime cost = %g sec/op, %g GFLOPS, bandwidth=%g Gbits" % ( print("\tTime cost = %g sec/op, %g GFLOPS, bandwidth=%g Gbits" % (
cost.mean, gops, bandwith)) cost.mean, gops, bandwith))
with vta.build_config(): with vta.build_config():
run_test("NORMAL", print_ir) run_test("NORMAL", print_ir)
print("") print("")
def load_wgt_unittest(print_ir): def store_out_unittest(print_ir):
mock = env.mock mock = env.mock
print("----- LoadWgt Unit Test-------") print("----- StoreOut Unit Test-------")
def run_test(header, print_ir): def run_test(header, print_ir):
cost = run_schedule( cost = run_schedule(
mock.dma_copy, env.dma_copy, mock.gemm, mock.alu, mock.dma_copy, print_ir, False) mock.dma_copy, mock.dma_copy, mock.gemm, mock.alu, env.dma_copy,
gops = (num_ops / cost.mean) / float(10 ** 9) print_ir, False)
bandwith = (channel * channel * env.WGT_WIDTH / cost.mean) / float(10 ** 9) gops = (num_ops / cost.mean) / float(10 ** 9)
print(header) bandwith = (batch_size * channel * env.OUT_WIDTH / cost.mean) / float(10 ** 9)
print("\tTime cost = %g sec/op, %g GFLOPS, bandwidth=%g Gbits" % ( print(header)
cost.mean, gops, bandwith)) print("\tTime cost = %g sec/op, %g GFLOPS, bandwidth=%g Gbits" % (
with vta.build_config(): cost.mean, gops, bandwith))
run_test("NORMAL", print_ir) with vta.build_config():
print("") run_test("NORMAL", print_ir)
print("")
def store_out_unittest(print_ir):
mock = env.mock
print("----- StoreOut Unit Test-------")
def run_test(header, print_ir):
cost = run_schedule(
mock.dma_copy, mock.dma_copy, mock.gemm, mock.alu, env.dma_copy,
print_ir, False)
gops = (num_ops / cost.mean) / float(10 ** 9)
bandwith = (batch_size * channel * env.OUT_WIDTH / cost.mean) / float(10 ** 9)
print(header)
print("\tTime cost = %g sec/op, %g GFLOPS, bandwidth=%g Gbits" % (
cost.mean, gops, bandwith))
with vta.build_config():
run_test("NORMAL", print_ir)
print("")
gemm_normal(False)
gevm_unittest(False)
alu_unittest(False)
# FIXME: report time that is too short
# load_inp_unittest(False)
# load_wgt_unittest(False)
# store_out_unittest(False)
gemm_normal(False)
gevm_unittest(False)
alu_unittest(False)
print("========GEMM 128=========") def _run(env, remote):
test_gemm_packed(128, 128, 128) print("========GEMM 128=========")
run_gemm_packed(env, remote, 128, 128, 128)
# FIXME: hanging run vta.testing.run(_run)
# print("========GEMM 1024========")
# test_gemm_packed(1024, 1024, 128) if __name__ == "__main__":
test_gemm()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment