Commit 4ba6bd50 by Thierry Moreau Committed by Tianqi Chen

[UTILS, DOC] Use TVM file downloading utility, conv2d tutorial (#48)

parent d1128ced
...@@ -2,12 +2,20 @@ ...@@ -2,12 +2,20 @@
Follow the first two parts of the [Installation Guide](../../../docs/how_to/install.md) to make sure that the VTA python libraries are installed, and that the RPC server is running on the Pynq FPGA dev board. Follow the first two parts of the [Installation Guide](../../../docs/how_to/install.md) to make sure that the VTA python libraries are installed, and that the RPC server is running on the Pynq FPGA dev board.
Simply run the following python script: We recommend leaving the `config.json` to its default parameterization (of course you can change the target between "sim" and "pynq").
Simply run the example program. We rely on pickle to store parameters which now only works with python2.
```bash ```bash
python imagenet_predict.py python2 imagenet_predict.py
``` ```
This will run imagenet classification using the ResNet18 architecture on a VTA design that performs 8-bit integer inference, to perform classification on a cat image `cat.jpg`. The script will first download the following files into `_data/` directory:
* `cat.jpg` which provides a test sample for the ImageNet classifier
* `quantize_graph.json` which describes the NNVM graph of the 8-bit ResNet-18
* `quantize_params.plk` which contains the network parameters
* `synset.txt` which contains the ImageNet categories
Next, it will run imagenet classification using the ResNet18 architecture on a VTA design that performs 8-bit integer inference, to perform classification on a cat image `cat.jpg`.
The script reports runtime measured on the Pynq board (in seconds), and the top-1 result category: The script reports runtime measured on the Pynq board (in seconds), and the top-1 result category:
``` ```
......
# some standard imports # some standard imports
import nnvm import nnvm
import tvm import tvm
from nnvm.compiler import graph_attr
import vta import vta
import vta.testing import vta.testing
import os import os
import numpy as np import numpy as np
from PIL import Image
import pickle import pickle
import json import json
import logging import logging
import wget
from PIL import Image
from nnvm.compiler import graph_attr
from tvm.contrib import graph_runtime, rpc, util from tvm.contrib import graph_runtime, rpc, util
from tvm.contrib.download import download
bfactor = 1 bfactor = 1
cfactor = 16 cfactor = 16
...@@ -20,15 +21,20 @@ verbose = False ...@@ -20,15 +21,20 @@ verbose = False
debug_fpga_only = False debug_fpga_only = False
# Obtain model and hardware files (they're too large to check-in) # Obtain model and hardware files (they're too large to check-in)
# Download them into _data dir
data_dir = "_data/"
url = "https://homes.cs.washington.edu/~moreau/media/vta/" url = "https://homes.cs.washington.edu/~moreau/media/vta/"
TEST_FILE = 'cat.jpg' TEST_FILE = 'cat.jpg'
CATEG_FILE = 'synset.txt' CATEG_FILE = 'synset.txt'
RESNET_GRAPH_FILE = 'quantize_graph.json' RESNET_GRAPH_FILE = 'quantize_graph.json'
RESNET_PARAMS_FILE = 'quantize_params.pkl' RESNET_PARAMS_FILE = 'quantize_params.pkl'
# Create data dir
if not os.path.exists(data_dir):
os.makedirs(data_dir)
# Download files
for file in [TEST_FILE, CATEG_FILE, RESNET_GRAPH_FILE, RESNET_PARAMS_FILE]: for file in [TEST_FILE, CATEG_FILE, RESNET_GRAPH_FILE, RESNET_PARAMS_FILE]:
if not os.path.isfile(file): if not os.path.isfile(file):
print ("Downloading {}".format(file)) download(os.path.join(url, file), os.path.join(data_dir, file))
wget.download(url+file)
if verbose: if verbose:
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
...@@ -40,8 +46,8 @@ target_host = "llvm -mtriple=armv7-none-linux-gnueabihf -mcpu=cortex-a9 -mattr=+ ...@@ -40,8 +46,8 @@ target_host = "llvm -mtriple=armv7-none-linux-gnueabihf -mcpu=cortex-a9 -mattr=+
if vta.get_env().TARGET == "sim": if vta.get_env().TARGET == "sim":
target_host = "llvm" target_host = "llvm"
synset = eval(open(os.path.join(CATEG_FILE)).read()) synset = eval(open(os.path.join(data_dir, CATEG_FILE)).read())
image = Image.open(os.path.join(TEST_FILE)).resize((224, 224)) image = Image.open(os.path.join(data_dir, TEST_FILE)).resize((224, 224))
def transform_image(image): def transform_image(image):
image = np.array(image) - np.array([123., 117., 104.]) image = np.array(image) - np.array([123., 117., 104.])
...@@ -88,9 +94,9 @@ print('x', x.shape) ...@@ -88,9 +94,9 @@ print('x', x.shape)
import nnvm.compiler import nnvm.compiler
np.random.seed(0) np.random.seed(0)
sym = nnvm.graph.load_json( sym = nnvm.graph.load_json(
open(os.path.join(RESNET_GRAPH_FILE)).read()) open(os.path.join(data_dir, RESNET_GRAPH_FILE)).read())
params = pickle.load( params = pickle.load(
open(os.path.join(RESNET_PARAMS_FILE))) open(os.path.join(data_dir, RESNET_PARAMS_FILE), 'rb'))
shape_dict = {"data": x.shape} shape_dict = {"data": x.shape}
dtype_dict = {"data": 'float32'} dtype_dict = {"data": 'float32'}
......
...@@ -2,9 +2,16 @@ ...@@ -2,9 +2,16 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import os import os
import urllib import sys
from tvm.contrib.download import download
from .environment import get_env from .environment import get_env
if sys.version_info >= (3,):
import urllib.error as urllib2
else:
import urllib2
# bitstream repo # bitstream repo
BITSTREAM_URL = "https://github.com/uwsaml/vta-distro/raw/master/bitstreams/" BITSTREAM_URL = "https://github.com/uwsaml/vta-distro/raw/master/bitstreams/"
...@@ -41,15 +48,25 @@ def download_bitstream(): ...@@ -41,15 +48,25 @@ def download_bitstream():
url = os.path.join(BITSTREAM_URL, env.TARGET) url = os.path.join(BITSTREAM_URL, env.TARGET)
url = os.path.join(url, env.HW_VER) url = os.path.join(url, env.HW_VER)
url = os.path.join(url, env.BITSTREAM) url = os.path.join(url, env.BITSTREAM)
# Check that the bitstream is accessible from the server
if urllib.urlopen(url).getcode() == 404: try:
# Raise error - the solution when this happens it to build your own bitstream and add it download(url, bit)
# to your VTA_CACHE_PATH except urllib2.HTTPError as err:
raise RuntimeError( if err.code == 404:
"Error: {} is not available. It appears that this configuration has not been built." raise RuntimeError(
.format(url)) # Raise error - the solution when this happens it to build your
else: # own bitstream and add it to your $VTA_CACHE_PATH
urllib.urlretrieve(url, bit) "{} is not available. It appears that this configuration \
success = True bistream has not been cached. Please compile your own bitstream (see hardware \
compilation guide to get Xilinx toolchains setup) and add it to your \
$VTA_CACHE_PATH. Alternatively edit your config.json back to its default \
settings. You can see the list of available bitstreams under {}"
.format(url, BITSTREAM_URL))
else:
raise RuntimeError(
# This could happen when trying to access the URL behind a proxy
"Something went wrong when trying to access {}. Check your \
internet connection or proxy settings."
.format(url))
return success return success
...@@ -15,23 +15,34 @@ def run(run_func): ...@@ -15,23 +15,34 @@ def run(run_func):
""" """
env = get_env() env = get_env()
# Run on local sim rpc if necessary if env.TARGET == "sim":
local_rpc = int(os.environ.get("VTA_LOCAL_SIM_RPC", "0"))
if local_rpc: # Talk to local RPC if necessary to debug RPC server.
env.TARGET = "sim" # Compile vta on your host with make at the root.
remote = rpc.connect("localhost", local_rpc) # Make sure TARGET is set to "sim" in the config.json file.
run_func(env, remote) # Then launch the RPC server on the host machine
else: # with ./apps/pynq_rpc/start_rpc_server.sh
# run on simulator # Set your VTA_LOCAL_SIM_RPC environment variable to
if simulator.enabled(): # the port it's listening to, e.g. 9090
env.TARGET = "sim" local_rpc = int(os.environ.get("VTA_LOCAL_SIM_RPC", "0"))
if local_rpc:
remote = rpc.connect("localhost", local_rpc)
run_func(env, remote)
else:
# Make sure simulation library exists
# If this fails, build vta on host (make)
# with TARGET="sim" in the json.config file.
assert simulator.enabled()
run_func(env, rpc.LocalSession()) run_func(env, rpc.LocalSession())
# Run on PYNQ if env variable exists elif env.TARGET == "pynq":
host = os.environ.get("VTA_PYNQ_RPC_HOST", None)
if host: # Run on PYNQ if env variable exists
env.TARGET = "pynq" host = os.environ.get("VTA_PYNQ_RPC_HOST", None)
port = os.environ.get("VTA_PYNQ_RPC_PORT", "9091") port = int(os.environ.get("VTA_PYNQ_RPC_PORT", None))
port = int(port) if host and port:
remote = rpc.connect(host, port) remote = rpc.connect(host, port)
run_func(env, remote) run_func(env, remote)
else:
raise RuntimeError(
"Please set the VTA_PYNQ_RPC_HOST and VTA_PYNQ_RPC_PORT environment variables")
...@@ -18,7 +18,8 @@ def test_gemm(): ...@@ -18,7 +18,8 @@ def test_gemm():
channel // env.BLOCK_OUT, channel // env.BLOCK_OUT,
env.BATCH, env.BATCH,
env.BLOCK_OUT) env.BLOCK_OUT)
num_ops = channel * channel * batch_size # To compute number of ops, use a x2 factor for FMA
num_ops = 2 * channel * channel * batch_size
ko = tvm.reduce_axis((0, channel // env.BLOCK_IN), name='ko') ko = tvm.reduce_axis((0, channel // env.BLOCK_IN), name='ko')
ki = tvm.reduce_axis((0, env.BLOCK_IN), name='ki') ki = tvm.reduce_axis((0, env.BLOCK_IN), name='ki')
...@@ -157,14 +158,14 @@ def test_gemm(): ...@@ -157,14 +158,14 @@ def test_gemm():
def gemm_normal(print_ir): def gemm_normal(print_ir):
mock = env.mock mock = env.mock
print("----- GEMM GFLOPS End-to-End Test-------") print("----- GEMM GOPS End-to-End Test-------")
def run_test(header, print_ir, check_correctness): def run_test(header, print_ir, check_correctness):
cost = run_schedule( cost = run_schedule(
env.dma_copy, env.dma_copy, env.gemm, env.alu, env.dma_copy, env.dma_copy, env.dma_copy, env.gemm, env.alu, env.dma_copy,
print_ir, check_correctness) print_ir, check_correctness)
gops = (num_ops / cost.mean) / float(10 ** 9) gops = (num_ops / cost.mean) / float(10 ** 9)
print(header) print(header)
print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops)) print("\tTime cost = %g sec/op, %g GOPS" % (cost.mean, gops))
with vta.build_config(): with vta.build_config():
run_test("NORMAL", print_ir, True) run_test("NORMAL", print_ir, True)
...@@ -177,7 +178,7 @@ def test_gemm(): ...@@ -177,7 +178,7 @@ def test_gemm():
print_ir, False) print_ir, False)
gops = (num_ops / cost.mean) / float(10 ** 9) gops = (num_ops / cost.mean) / float(10 ** 9)
print(header) print(header)
print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops)) print("\tTime cost = %g sec/op, %g GOPS" % (cost.mean, gops))
with vta.build_config(): with vta.build_config():
run_test("NORMAL", print_ir) run_test("NORMAL", print_ir)
...@@ -190,7 +191,7 @@ def test_gemm(): ...@@ -190,7 +191,7 @@ def test_gemm():
print_ir, False) print_ir, False)
gops = (num_ops / cost.mean) / float(10 ** 9) gops = (num_ops / cost.mean) / float(10 ** 9)
print(header) print(header)
print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops)) print("\tTime cost = %g sec/op, %g GOPS" % (cost.mean, gops))
with vta.build_config(): with vta.build_config():
run_test("NORMAL", print_ir) run_test("NORMAL", print_ir)
print("") print("")
...@@ -204,7 +205,7 @@ def test_gemm(): ...@@ -204,7 +205,7 @@ def test_gemm():
gops = (num_ops / cost.mean) / float(10 ** 9) gops = (num_ops / cost.mean) / float(10 ** 9)
bandwith = (batch_size * channel * env.INP_WIDTH / cost.mean) / float(10 ** 9) bandwith = (batch_size * channel * env.INP_WIDTH / cost.mean) / float(10 ** 9)
print(header) print(header)
print("\tTime cost = %g sec/op, %g GFLOPS, bandwidth=%g Gbits" % ( print("\tTime cost = %g sec/op, %g GOPS, bandwidth=%g Gbits" % (
cost.mean, gops, bandwith)) cost.mean, gops, bandwith))
with vta.build_config(): with vta.build_config():
run_test("NORMAL", print_ir) run_test("NORMAL", print_ir)
...@@ -219,7 +220,7 @@ def test_gemm(): ...@@ -219,7 +220,7 @@ def test_gemm():
gops = (num_ops / cost.mean) / float(10 ** 9) gops = (num_ops / cost.mean) / float(10 ** 9)
bandwith = (channel * channel * env.WGT_WIDTH / cost.mean) / float(10 ** 9) bandwith = (channel * channel * env.WGT_WIDTH / cost.mean) / float(10 ** 9)
print(header) print(header)
print("\tTime cost = %g sec/op, %g GFLOPS, bandwidth=%g Gbits" % ( print("\tTime cost = %g sec/op, %g GOPS, bandwidth=%g Gbits" % (
cost.mean, gops, bandwith)) cost.mean, gops, bandwith))
with vta.build_config(): with vta.build_config():
run_test("NORMAL", print_ir) run_test("NORMAL", print_ir)
...@@ -235,7 +236,7 @@ def test_gemm(): ...@@ -235,7 +236,7 @@ def test_gemm():
gops = (num_ops / cost.mean) / float(10 ** 9) gops = (num_ops / cost.mean) / float(10 ** 9)
bandwith = (batch_size * channel * env.OUT_WIDTH / cost.mean) / float(10 ** 9) bandwith = (batch_size * channel * env.OUT_WIDTH / cost.mean) / float(10 ** 9)
print(header) print(header)
print("\tTime cost = %g sec/op, %g GFLOPS, bandwidth=%g Gbits" % ( print("\tTime cost = %g sec/op, %g GOPS, bandwidth=%g Gbits" % (
cost.mean, gops, bandwith)) cost.mean, gops, bandwith))
with vta.build_config(): with vta.build_config():
run_test("NORMAL", print_ir) run_test("NORMAL", print_ir)
......
...@@ -42,6 +42,7 @@ def test_vta_conv2d(): ...@@ -42,6 +42,7 @@ def test_vta_conv2d():
res = my_clip(res, 0, 127) res = my_clip(res, 0, 127)
res = topi.cast(res, "int8") res = topi.cast(res, "int8")
# To compute number of ops, use a x2 factor for FMA
num_ops = 2 * batch_size * fout_height * fout_width * wl.hkernel * wl.wkernel * wl.out_filter * wl.in_filter num_ops = 2 * batch_size * fout_height * fout_width * wl.hkernel * wl.wkernel * wl.out_filter * wl.in_filter
a_shape = (batch_size, wl.in_filter, wl.height, wl.width) a_shape = (batch_size, wl.in_filter, wl.height, wl.width)
...@@ -118,7 +119,7 @@ def test_vta_conv2d(): ...@@ -118,7 +119,7 @@ def test_vta_conv2d():
print(vta.lower(s, [data, kernel, bias, res], simple_mode=True)) print(vta.lower(s, [data, kernel, bias, res], simple_mode=True))
cost = verify(s, True) cost = verify(s, True)
gops = (num_ops / cost.mean) / float(10 ** 9) gops = (num_ops / cost.mean) / float(10 ** 9)
print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops)) print("\tTime cost = %g sec/op, %g GOPS" % (cost.mean, gops))
conv_normal(False) conv_normal(False)
......
...@@ -91,6 +91,7 @@ elif env.TARGET == "sim": ...@@ -91,6 +91,7 @@ elif env.TARGET == "sim":
# #
# .. image:: https://raw.githubusercontent.com/uwsaml/web-data/master/vta/tutorial/tensor_core.png # .. image:: https://raw.githubusercontent.com/uwsaml/web-data/master/vta/tutorial/tensor_core.png
# :align: center # :align: center
# :width: 480px
# #
# The dimensions of that matrix-matrix multiplication are specified in # The dimensions of that matrix-matrix multiplication are specified in
# the :code:`config.json` configuration file. # the :code:`config.json` configuration file.
...@@ -109,6 +110,7 @@ elif env.TARGET == "sim": ...@@ -109,6 +110,7 @@ elif env.TARGET == "sim":
# #
# .. image:: https://raw.githubusercontent.com/uwsaml/web-data/master/vta/tutorial/data_tiling.png # .. image:: https://raw.githubusercontent.com/uwsaml/web-data/master/vta/tutorial/data_tiling.png
# :align: center # :align: center
# :width: 480px
# #
# We first define the variables :code:`m`, :code:`n`, :code:`o` to represent # We first define the variables :code:`m`, :code:`n`, :code:`o` to represent
# the shape of the matrix multiplication. These variables are multiplicative # the shape of the matrix multiplication. These variables are multiplicative
...@@ -119,7 +121,6 @@ elif env.TARGET == "sim": ...@@ -119,7 +121,6 @@ elif env.TARGET == "sim":
# 1 implies that our compute building block is vector-matrix multiply). # 1 implies that our compute building block is vector-matrix multiply).
# #
###################################################################### ######################################################################
# .. note:: # .. note::
# #
......
...@@ -66,7 +66,7 @@ elif env.TARGET == "sim": ...@@ -66,7 +66,7 @@ elif env.TARGET == "sim":
# :code:`BATCH`, :code:`BLOCK_IN`, and :code:`BLOCK_OUT` respectively. # :code:`BATCH`, :code:`BLOCK_IN`, and :code:`BLOCK_OUT` respectively.
# #
# We've added extra operators to the matrix multiplication that apply # We've added extra operators to the matrix multiplication that apply
# shifting and clipping to the output in order to mimic the a fixed-point # shifting and clipping to the output in order to mimic a fixed-point
# matrix multiplication followed by a rectified linear activation. # matrix multiplication followed by a rectified linear activation.
# We describe the TVM dataflow graph of the fully connected layer below: # We describe the TVM dataflow graph of the fully connected layer below:
# #
...@@ -152,7 +152,7 @@ res = tvm.compute(output_shape, ...@@ -152,7 +152,7 @@ res = tvm.compute(output_shape,
# Those include: # Those include:
# #
# - Computation blocking # - Computation blocking
# - Computation lowering to VTA hardware intrinsics # - Lowering to VTA hardware intrinsics
# Create TVM schedule # Create TVM schedule
...@@ -161,8 +161,8 @@ s = tvm.create_schedule(res.op) ...@@ -161,8 +161,8 @@ s = tvm.create_schedule(res.op)
print(tvm.lower(s, [data, weight, res], simple_mode=True)) print(tvm.lower(s, [data, weight, res], simple_mode=True))
###################################################################### ######################################################################
# Tiling the Computation # Blocking the Computation
# ~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~
# The matrix multiplication is by default too large for activations or weights # The matrix multiplication is by default too large for activations or weights
# to fit on VTA's on-chip buffers all at once. # to fit on VTA's on-chip buffers all at once.
# We block the (1, 1024) by (1024, 1024) matrix multiplication into # We block the (1, 1024) by (1024, 1024) matrix multiplication into
...@@ -180,8 +180,7 @@ print(tvm.lower(s, [data, weight, res], simple_mode=True)) ...@@ -180,8 +180,7 @@ print(tvm.lower(s, [data, weight, res], simple_mode=True))
# #
# .. image:: https://raw.githubusercontent.com/uwsaml/web-data/master/vta/tutorial/blocking.png # .. image:: https://raw.githubusercontent.com/uwsaml/web-data/master/vta/tutorial/blocking.png
# :align: center # :align: center
# :height: 367px # :width: 480px
# :width: 387px
# #
# .. note:: # .. note::
# #
...@@ -236,7 +235,7 @@ s[res_shr].compute_at(s[res], oc_out) ...@@ -236,7 +235,7 @@ s[res_shr].compute_at(s[res], oc_out)
s[res_max].compute_at(s[res], oc_out) s[res_max].compute_at(s[res], oc_out)
s[res_min].compute_at(s[res], oc_out) s[res_min].compute_at(s[res], oc_out)
# Apply additional loop split along input channel axis # Apply additional loop split along reduction axis (input channel)
b_inn, oc_inn, b_tns, oc_tns = s[res_gemm].op.axis b_inn, oc_inn, b_tns, oc_tns = s[res_gemm].op.axis
ic_out, ic_inn = s[res_gemm].split(ic, i_block) ic_out, ic_inn = s[res_gemm].split(ic, i_block)
...@@ -273,6 +272,8 @@ s[data_buf].pragma(s[data_buf].op.axis[0], env.dma_copy) ...@@ -273,6 +272,8 @@ s[data_buf].pragma(s[data_buf].op.axis[0], env.dma_copy)
s[weight_buf].pragma(s[weight_buf].op.axis[0], env.dma_copy) s[weight_buf].pragma(s[weight_buf].op.axis[0], env.dma_copy)
# Use DMA copy pragma on SRAM->DRAM operation # Use DMA copy pragma on SRAM->DRAM operation
# (this implies that these copies should be performed along b_inn,
# or result axis 2)
s[res].pragma(s[res].op.axis[2], env.dma_copy) s[res].pragma(s[res].op.axis[2], env.dma_copy)
###################################################################### ######################################################################
...@@ -313,21 +314,21 @@ f = remote.load_module("gemm.o") ...@@ -313,21 +314,21 @@ f = remote.load_module("gemm.o")
# Get the remote device context # Get the remote device context
ctx = remote.ext_dev(0) ctx = remote.ext_dev(0)
# Initialize the A and B arrays randomly in the int range of (-128, 128] # Initialize the data and weight arrays randomly in the int range of (-128, 128]
data = np.random.randint( data_np = np.random.randint(
-128, 128, size=(batch_size, in_channels)).astype(data.dtype) -128, 128, size=(batch_size, in_channels)).astype(data.dtype)
weight = np.random.randint( weight_np = np.random.randint(
-128, 128, size=(out_channels, in_channels)).astype(weight.dtype) -128, 128, size=(out_channels, in_channels)).astype(weight.dtype)
# Apply packing to the A and B arrays from a 2D to a 4D packed layout # Apply packing to the data and weight arrays from a 2D to a 4D packed layout
data_packed = data.reshape(batch_size // env.BATCH, data_packed = data_np.reshape(batch_size // env.BATCH,
env.BATCH, env.BATCH,
in_channels // env.BLOCK_IN, in_channels // env.BLOCK_IN,
env.BLOCK_IN).transpose((0, 2, 1, 3)) env.BLOCK_IN).transpose((0, 2, 1, 3))
weight_packed = weight.reshape(out_channels // env.BLOCK_OUT, weight_packed = weight_np.reshape(out_channels // env.BLOCK_OUT,
env.BLOCK_OUT, env.BLOCK_OUT,
in_channels // env.BLOCK_IN, in_channels // env.BLOCK_IN,
env.BLOCK_IN).transpose((0, 2, 1, 3)) env.BLOCK_IN).transpose((0, 2, 1, 3))
# Format the input/output arrays with tvm.nd.array to the DLPack standard # Format the input/output arrays with tvm.nd.array to the DLPack standard
data_nd = tvm.nd.array(data_packed, ctx) data_nd = tvm.nd.array(data_packed, ctx)
...@@ -338,8 +339,8 @@ res_nd = tvm.nd.array(np.zeros(output_shape).astype(res.dtype), ctx) ...@@ -338,8 +339,8 @@ res_nd = tvm.nd.array(np.zeros(output_shape).astype(res.dtype), ctx)
f(data_nd, weight_nd, res_nd) f(data_nd, weight_nd, res_nd)
# Verify against numpy implementation # Verify against numpy implementation
res_ref = np.dot(data.astype(env.acc_dtype), res_ref = np.dot(data_np.astype(env.acc_dtype),
weight.T.astype(env.acc_dtype)) weight_np.T.astype(env.acc_dtype))
res_ref = res_ref >> env.INP_WIDTH res_ref = res_ref >> env.INP_WIDTH
res_ref = np.clip(res_ref, 0, inp_max) res_ref = np.clip(res_ref, 0, inp_max)
res_ref = res_ref.astype(res.dtype) res_ref = res_ref.astype(res.dtype)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment