Commit bc410130 by Tianqi Chen

[TOPI] Fix the CPU op perf (#56)

parent ffe1badd
......@@ -63,7 +63,6 @@ doc:
clean:
$(RM) -rf build lib bin *~ */*~ */*/*~ */*/*/*~ */*.o */*/*.o */*/*/*.o
$(RM) -rf cat.jpg quantize_graph.json quantize_params.pkl synset.txt
-include build/*.d
......
......@@ -2,3 +2,4 @@ doxygen
modules
tutorials
_build
gen_modules
\ No newline at end of file
# some standard imports
import nnvm
import tvm
import vta
import vta.testing
import os
import numpy as np
import pickle
import json
import logging
from PIL import Image
from nnvm.compiler import graph_attr
from tvm.contrib import graph_runtime, rpc, util
from tvm.contrib.download import download
bfactor = 1
cfactor = 16
verbose = False
# only run fpga component, mark non-conv ops as nop
debug_fpga_only = False
# Obtain model files (they're too large to check-in)
# Download them into _data dir
data_dir = "_data/"
url = "https://homes.cs.washington.edu/~moreau/media/vta/"
TEST_FILE = 'cat.jpg'
CATEG_FILE = 'synset.txt'
RESNET_GRAPH_FILE = 'resnet18_qt8.json'
RESNET_PARAMS_FILE = 'resnet18_qt8.params'
# Create data dir
if not os.path.exists(data_dir):
os.makedirs(data_dir)
# Download files
for file in [TEST_FILE, CATEG_FILE, RESNET_GRAPH_FILE, RESNET_PARAMS_FILE]:
if not os.path.isfile(file):
download(os.path.join(url, file), os.path.join(data_dir, file))
if verbose:
logging.basicConfig(level=logging.DEBUG)
# Change to -device=vtacpu to run cpu only inference.
target = tvm.target.create("llvm -device=vta")
target_host = "llvm -mtriple=armv7-none-linux-gnueabihf -mcpu=cortex-a9 -mattr=+neon"
if vta.get_env().TARGET == "sim":
target_host = "llvm"
synset = eval(open(os.path.join(data_dir, CATEG_FILE)).read())
image = Image.open(os.path.join(data_dir, TEST_FILE)).resize((224, 224))
def transform_image(image):
image = np.array(image) - np.array([123., 117., 104.])
image /= np.array([58.395, 57.12, 57.375])
image = image.transpose((2, 0, 1))
image = image[np.newaxis, :]
return image
def mark_nop(graph, conv_layer=-1, skip_conv_layer=()):
"""Helper function to mark certain op as nop
Useful to debug performance issues.
"""
jgraph = json.loads(graph.json())
counter = 0
for nid, node in enumerate(jgraph["nodes"]):
op_name = node["op"]
if op_name != "tvm_op":
continue
attrs = node["attrs"]
node_name = node["name"]
func_name = attrs["func_name"]
if func_name.find("conv2d") != -1:
if conv_layer >= 0:
if counter != conv_layer:
attrs["func_name"] = "__nop"
if counter in skip_conv_layer:
attrs["func_name"] = "__nop"
counter += 1
else:
if conv_layer >= 0:
attrs["func_name"] = "__nop"
attrs["func_name"] = "__nop"
if attrs["func_name"] != "__nop":
print("Run function %s"% func_name)
graph = nnvm.graph.load_json(json.dumps(jgraph))
return graph
x = transform_image(image)
print('x', x.shape)
######################################################################
# now compile the graph
import nnvm.compiler
np.random.seed(0)
sym = nnvm.graph.load_json(
open(os.path.join(data_dir, RESNET_GRAPH_FILE)).read())
params = nnvm.compiler.load_param_dict(
open(os.path.join(data_dir, RESNET_PARAMS_FILE), 'rb').read())
shape_dict = {"data": x.shape}
dtype_dict = {"data": 'float32'}
shape_dict.update({k: v.shape for k, v in params.items()})
dtype_dict.update({k: str(v.dtype) for k, v in params.items()})
graph = nnvm.graph.create(sym)
graph_attr.set_shape_inputs(sym, shape_dict)
graph_attr.set_dtype_inputs(sym, dtype_dict)
graph = graph.apply("InferShape").apply("InferType")
dtype = "float32"
sym = vta.graph.clean_cast(sym)
sym = vta.graph.clean_conv_fuse(sym)
if target.device_name == "vta":
sym = vta.graph.pack(sym, shape_dict, bfactor, cfactor)
with nnvm.compiler.build_config(opt_level=3):
if target.device_name != "vta":
graph, lib, params = nnvm.compiler.build(
sym, target_host, shape_dict, dtype_dict,
params=params)
else:
with vta.build_config():
graph, lib, params = nnvm.compiler.build(
sym, target, shape_dict, dtype_dict,
params=params, target_host=target_host)
assert tvm.module.enabled("rpc")
temp = util.tempdir()
lib.save(temp.relpath("graphlib.o"))
if vta.get_env().TARGET == "sim":
remote = rpc.LocalSession()
print("local session")
else:
host = os.environ.get("VTA_PYNQ_RPC_HOST", None)
assert host
port = os.environ.get("VTA_PYNQ_RPC_PORT", "9091")
port = int(port)
remote = rpc.connect(host, port)
# Program FPGA, and build runtime if necessary
# Overwrite bitstream with a path to your own if you built it yourself
vta.reconfig_runtime(remote)
vta.program_fpga(remote, bitstream=None)
remote.upload(temp.relpath("graphlib.o"))
lib = remote.load_module("graphlib.o")
ctx = remote.ext_dev(0) if target.device_name == "vta" else remote.cpu(0)
print("Build complete...")
def run_e2e(graph):
"""Running end to end example
"""
if debug_fpga_only:
graph = mark_nop(graph, skip_conv_layer=(0,))
m = graph_runtime.create(graph, lib, ctx)
# set inputs
m.set_input('data', tvm.nd.array(x.astype("float32")))
m.set_input(**params)
# execute
timer = m.module.time_evaluator("run", ctx, number=10)
tcost = timer()
# get outputs
tvm_output = m.get_output(
0,tvm.nd.empty((1000,), dtype, remote.cpu(0)))
top = list(reversed(np.argsort(tvm_output.asnumpy())))
for i in range(5):
print('TVM prediction top-%d: %s' % (i, synset[top[i]]))
print("t-cost=%g" % tcost.mean)
def run_layer(old_graph, layer_begin, layer_end):
"""Run a certain layer."""
for layer_id in range(layer_begin, layer_end):
print("run resnet[%d]..."% (layer_id))
graph = mark_nop(old_graph, layer_id)
m = graph_runtime.create(graph, lib, ctx)
# set inputs
m.set_input('data', tvm.nd.array(x.astype("float32")))
m.set_input(**params)
# execute
timer = m.module.time_evaluator("run", ctx, number=1)
tcost = timer()
print("resnet[%d]: %g\n"% (layer_id, tcost.mean))
run_e2e(graph)
......@@ -5,13 +5,9 @@ Used for CPU conv2d
"""
from __future__ import absolute_import as _abs
import tvm
from topi import tag
from topi.nn.conv2d import conv2d, _get_schedule
from topi.nn.conv2d import SpatialPack, Im2ColPack, Workload
from topi.nn.conv2d import _SCH_TO_DECL_FUNC
from topi.nn.conv2d import _get_workload
from topi.nn.util import infer_pad, infer_stride
from topi.rasp import conv2d as _rasp_conv2d
from topi import generic
_WORKLOADS = [
......@@ -52,284 +48,8 @@ def _schedule_conv2d(wkl):
sch = _SCHEDULES[idx]
return sch
conv2d.register(["vtacpu", "vta"], _rasp_conv2d._declaration_conv2d)
@conv2d.register(["vtacpu", "vta"])
def _declaration_conv2d(data, kernel, stride, padding, layout, out_dtype):
assert layout == 'NCHW', "only support NCHW convolution on vtacpu"
assert data.shape[0].value == 1, "only support batch size=1 convolution on vtacpu"
wkl = _get_workload(data, kernel, stride, padding, out_dtype)
sch = _get_schedule(wkl)
return _SCH_TO_DECL_FUNC[type(sch)](data, kernel, stride, padding, out_dtype)
def _schedule_spatial_conv2d(s, data, data_pad, data_vec,
kernel, kernel_vec,
conv_out, output, last):
# no stride and padding info here
padding = infer_pad(data, data_pad)
if data_pad is None:
stride = infer_stride(data, kernel, output)
else:
stride = infer_stride(data_pad, kernel, output)
wkl = _get_workload(data, kernel, stride, padding, output.dtype)
sch = _get_schedule(wkl)
H, W = wkl.height, wkl.width
CI, CO = wkl.in_filter, wkl.out_filter
HK, WK = wkl.hkernel, wkl.wkernel
HPAD, WPAD = wkl.hpad, wkl.wpad
HSTR, WSTR = wkl.hstride, wkl.wstride
HCAT, WCAT = HK-1, WK-1
DOPAD = (HPAD != 0 and WPAD != 0)
VH = sch.vh
VW = sch.vw
VC = sch.vc
UNROLL = sch.unroll
A, B, C = data, kernel, last
A0, A1 = data_pad, data_vec
B0 = kernel_vec
C0, C1 = conv_out, output
CC = s.cache_write(C0, "global")
_, co, oh, ow, vh, vw, vc = s[C0].op.axis
if UNROLL:
s[C0].unroll(vw)
s[C0].vectorize(vc)
s[CC].compute_at(s[C0], ow)
_, co, oh, ow, vh, vw, vc = s[CC].op.axis
ci, dh, dw = s[CC].op.reduce_axis
s[CC].reorder(ci, dh, vh, dw, vw, vc)
if UNROLL:
s[CC].unroll(vw)
s[CC].vectorize(vc)
##### Schedule A
if DOPAD:
s[A0].compute_inline()
_, h, _, _, _, _ = s[A1].op.axis
if sch.ba == 1:
oaxis = h
paxis = h
else:
oh, ih = s[A1].split(h, sch.ba)
oaxis = oh
paxis = ih
s[A1].parallel(paxis)
s[A1].pragma(oaxis, "parallel_launch_point")
s[A1].pragma(paxis, "parallel_stride_pattern")
s[A1].pragma(oaxis, "parallel_barrier_when_finish")
##### Schedule B
co, _, _, _, _ = s[B0].op.axis
if sch.bc == 1:
oaxis = co
paxis = co
else:
oco, ico = s[B0].split(co, sch.bc)
oaxis = oco
paxis = ico
s[B0].parallel(paxis)
s[B0].pragma(oaxis, "parallel_launch_point")
s[B0].pragma(paxis, "parallel_stride_pattern")
s[B0].pragma(oaxis, "parallel_barrier_when_finish")
##### Schedule C
n, co, h, w = s[C].op.axis
co, vc = s[C].split(co, VC)
oh, ow, vh, vw = s[C].tile(h, w, VH, VW)
s[C].reorder(n, co, oh, ow, vh, vw, vc)
if C != C1:
s[C1].compute_inline()
s[C0].compute_at(s[C], ow)
if sch.bc == 1:
oaxis = co
paxis = co
else:
oco, ico = s[C].split(co, sch.bc)
oaxis = oco
paxis = ico
s[C].parallel(paxis)
s[C].pragma(oaxis, "parallel_launch_point")
s[C].pragma(paxis, "parallel_stride_pattern")
s[C].pragma(oaxis, "parallel_barrier_when_finish")
return s
def _schedule_im2col_conv2d(s, data, data_pad, data_col, data_vec,
kernel, kernel_vec,
conv_out, output, last):
# no stride and padding info here
padding = infer_pad(data, data_pad)
if data_pad is None:
stride = infer_stride(data, kernel, output)
else:
stride = infer_stride(data_pad, kernel, output)
wkl = _get_workload(data, kernel, stride, padding, output.dtype)
sch = _get_schedule(wkl)
H, W = wkl.height, wkl.width
CI = wkl.in_filter
CO = wkl.out_filter
HK, WK = wkl.hkernel, wkl.wkernel
HPAD, WPAD = wkl.hpad, wkl.wpad
HSTR, WSTR = wkl.hstride, wkl.wstride
HCAT, WCAT = HK-1, WK-1
DOPAD = (HPAD != 0 and WPAD != 0)
P = sch.vp
Q = sch.vq
UNROLL = sch.unroll
A, B, C = data, kernel, last
A0, A1, A2 = data_pad, data_col, data_vec
B0 = kernel_vec
C0, C1 = conv_out, output
CC = s.cache_write(C0, "global")
AA = s.cache_read(A2, "global", [CC])
BB = s.cache_read(B0, "global", [CC])
##### Schedule CC
_, co, im, vim, vco = s[C0].op.axis
s[C0].unroll(vim)
s[C0].vectorize(vco)
s[CC].compute_at(s[C0], im)
_, co, im, vim, vco = s[CC].op.axis
ci, hk, wk = s[CC].op.reduce_axis
s[CC].reorder(ci, hk, wk, vim, vco)
s[CC].unroll(vim)
s[CC].vectorize(vco)
# s[CC].unroll(ccr)
### Schedule C
_, co, h, w = s[C].op.axis
im = s[C].fuse(h, w)
im, vim = s[C].split(im, P)
co, vco = s[C].split(co, Q)
s[C].reorder(co, im, vim, vco)
if sch.bc == 1:
oaxis = co
paxis = co
else:
oco, ico = s[C].split(co, sch.bc)
oaxis = oco
paxis = ico
s[C].parallel(paxis)
s[C].pragma(oaxis, "parallel_launch_point")
s[C].pragma(paxis, "parallel_stride_pattern")
s[C].pragma(oaxis, "parallel_barrier_when_finish")
if C1 != C:
s[C1].compute_inline()
s[C0].compute_at(s[C], paxis)
##### Schedule A
if DOPAD:
s[A0].compute_inline()
s[A1].compute_inline()
s[AA].compute_at(s[CC], wk)
s[AA].unroll(AA.op.axis[4])
_, im, _, _, _, _ = s[A2].op.axis
if sch.ba == 1:
oaxis = im
paxis = im
else:
oim, iim = s[A2].split(im, sch.ba)
oaxis = oim
paxis = iim
s[A2].parallel(paxis)
s[A2].pragma(oaxis, "parallel_launch_point")
s[A2].pragma(paxis, "parallel_stride_pattern")
s[A2].pragma(oaxis, "parallel_barrier_when_finish")
##### Schedule B
s[BB].compute_at(s[CC], wk)
s[BB].vectorize(BB.op.axis[4])
co, _, _, _, _ = s[B0].op.axis
if sch.bc == 1:
oaxis = co
paxis = co
else:
oco, ico = s[B0].split(co, sch.bc)
oaxis = oco
paxis = ico
s[B0].parallel(paxis)
s[B0].pragma(oaxis, "parallel_launch_point")
s[B0].pragma(paxis, "parallel_stride_pattern")
s[B0].pragma(oaxis, "parallel_barrier_when_finish")
return s
@generic.schedule_conv2d_nchw.register(["vtacpu", "vta"])
def schedule_conv2d(outs):
"""Create schedule for tensors"""
s = tvm.create_schedule([x.op for x in outs])
def traverse(op):
"""Traverse operators from computation graph"""
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(op.tag):
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors:
traverse(tensor.op)
if 'spatial_conv_output' in op.tag:
output = op.output(0)
conv_out = op.input_tensors[0]
kernel_vec = conv_out.op.input_tensors[1]
kernel = kernel_vec.op.input_tensors[0]
data_vec = conv_out.op.input_tensors[0]
data = data_vec.op.input_tensors[0]
data_pad = None
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
data_pad = data
data = data_pad.op.input_tensors[0]
_schedule_spatial_conv2d(s, data, data_pad, data_vec,
kernel, kernel_vec,
conv_out, output, outs[0])
if 'im2col_conv_output' in op.tag:
output = op.output(0)
conv_out = op.input_tensors[0]
kernel_vec = conv_out.op.input_tensors[1]
kernel = kernel_vec.op.input_tensors[0]
data_vec = conv_out.op.input_tensors[0]
data_col = data_vec.op.input_tensors[0]
data = data_col.op.input_tensors[0]
data_pad = None
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
data_pad = data
data = data_pad.op.input_tensors[0]
_schedule_im2col_conv2d(s, data, data_pad, data_col, data_vec,
kernel, kernel_vec,
conv_out, output, outs[0])
traverse(outs[0].op)
return s
generic.schedule_conv2d_nchw.register(
["vtacpu", "vta"],
_rasp_conv2d.schedule_conv2d_nchw)
......@@ -192,7 +192,7 @@ def _build(funcs, target, target_host):
tvm_t = tvm.target.create(target)
if tvm_t.device_name == "vta":
return tvm.build(funcs, target="ext_dev", target_host=target_host)
elif tvm_t.device_name == "rasp" or tvm_t.device_name == "vta-cpu":
elif tvm_t.device_name == "rasp" or tvm_t.device_name == "vtacpu":
return tvm.build(funcs, target=target_host)
return tvm.build(funcs, target=target)
......
......@@ -20,6 +20,114 @@ def my_clip(x, a_min, a_max):
x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB")
return x
def test_cpu_conv2d():
def run_cpu_conv2d(env, remote, key, batch_size, wl, profile=True):
data_shape = (batch_size, wl.in_filter, wl.height, wl.width)
kernel_shape = (wl.out_filter, wl.in_filter, wl.hkernel, wl.wkernel)
fout_height = (wl.height + 2 * wl.hpad - wl.hkernel) // wl.hstride + 1
fout_width = (wl.width + 2 * wl.wpad - wl.wkernel) // wl.wstride + 1
data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype)
kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype)
res_conv = topi.nn.conv2d(
data, kernel, padding=(wl.hpad, wl.wpad),
strides=(wl.hstride, wl.wstride),
out_dtype="int32")
res = topi.right_shift(res_conv, 8)
res = my_clip(res, 0, 127)
res = topi.cast(res, "int8")
# To compute number of ops, use a x2 factor for FMA
num_ops = 2 * batch_size * fout_height * fout_width * wl.hkernel * wl.wkernel * wl.out_filter * wl.in_filter
a_shape = (batch_size, wl.in_filter, wl.height, wl.width)
w_shape = (wl.out_filter, wl.in_filter, wl.hkernel, wl.wkernel)
stride = (wl.hstride, wl.wstride)
data_dtype = data.dtype
kernel_dtype = kernel.dtype
acc_dtype = env.acc_dtype
assert wl.hpad == wl.wpad
padding = wl.hpad
@memoize("vta.tests.test_benchmark_topi.conv2d.cpu.verify_nhwc")
def get_ref_data():
a_np = (np.random.uniform(size=a_shape) * 4).astype(data_dtype)
w_np = (np.random.uniform(size=w_shape) * 4).astype(kernel_dtype)
a_np = np.abs(a_np)
w_np = np.abs(w_np)
b_np = topi.testing.conv2d_nchw_python(
a_np.astype(acc_dtype), w_np.astype(acc_dtype), stride, padding).astype(acc_dtype)
return a_np, w_np, b_np
def verify(s, check_correctness):
mod = tvm.build(s, [data, kernel, res],
"llvm -device=vtacpu",
env.target_host,
name="conv2d")
temp = util.tempdir()
mod.save(temp.relpath("conv2d.o"))
remote.upload(temp.relpath("conv2d.o"))
f = remote.load_module("conv2d.o")
# verify
ctx = remote.cpu(0)
# Data in original format
data_orig, kernel_orig, res_ref = get_ref_data()
res_shape = topi.util.get_const_tuple(res.shape)
res_np = np.zeros(res_shape).astype(res.dtype)
data_arr = tvm.nd.array(data_orig, ctx)
kernel_arr = tvm.nd.array(kernel_orig, ctx)
res_arr = tvm.nd.array(res_np, ctx)
time_f = f.time_evaluator("conv2d", ctx, number=5)
cost = time_f(data_arr, kernel_arr, res_arr)
res_unpack = res_arr.asnumpy()
if check_correctness:
assert wl.hpad == wl.wpad
stride = (wl.hstride, wl.wstride)
padding = wl.hpad
res_ref = res_ref >> 8
res_ref = np.clip(res_ref, 0, 127).astype("int8")
np.testing.assert_allclose(res_unpack, res_ref)
return cost
def conv_normal(print_ir):
print("----- CONV2D CPU End-to-End Test-------")
s = topi.generic.schedule_conv2d_nchw([res])
if print_ir:
print(tvm.lower(s, [data, kernel, res], simple_mode=True))
cost = verify(s, True)
gops = (num_ops / cost.mean) / float(10 ** 9)
print("\tTime cost = %g sec/op, %g GOPS" % (cost.mean, gops))
conv_normal(False)
def _run(env, remote):
# ResNet18 workloads
resnet = {
# Workloads of resnet18 on imagenet
0: Workload(1, 224, 224, 16, 64, 7, 7, 3, 3, 2, 2),
1: Workload(1, 56, 56, 64, 64, 3, 3, 1, 1, 1, 1),
2: Workload(1, 56, 56, 64, 64, 1, 1, 0, 0, 1, 1),
3: Workload(1, 56, 56, 64, 128, 3, 3, 1, 1, 2, 2),
4: Workload(1, 56, 56, 64, 128, 1, 1, 0, 0, 2, 2),
5: Workload(1, 28, 28, 128, 128, 3, 3, 1, 1, 1, 1),
6: Workload(1, 28, 28, 128, 256, 3, 3, 1, 1, 2, 2),
7: Workload(1, 28, 28, 128, 256, 1, 1, 0, 0, 2, 2),
8: Workload(1, 14, 14, 256, 256, 3, 3, 1, 1, 1, 1),
9: Workload(1, 14, 14, 256, 512, 3, 3, 1, 1, 2, 2),
10: Workload(1, 14, 14, 256, 512, 1, 1, 0, 0, 2, 2),
11: Workload(1, 7, 7, 512, 512, 3, 3, 1, 1, 1, 1),
}
batch_size = 1
for i in range(1, len(resnet)):
wl = resnet[i]
key = "resnet-cfg[%d]" % i
print("key=%s" % key)
print(wl)
with tvm.target.create("llvm -device=vtacpu"):
run_cpu_conv2d(env, remote, key, batch_size, wl)
vta.testing.run(_run)
def test_vta_conv2d():
def run_vta_conv2d(env, remote, key, batch_size, wl, profile=True):
......@@ -54,7 +162,7 @@ def test_vta_conv2d():
assert wl.hpad == wl.wpad
padding = wl.hpad
@memoize("vta.tests.test_benchmark_topi.conv2d,verify_nhwc")
@memoize("vta.tests.test_benchmark_topi.conv2d.verify_nhwc")
def get_ref_data():
a_np = (np.random.uniform(size=a_shape) * 4).astype(data_dtype)
w_np = (np.random.uniform(size=w_shape) * 4).astype(kernel_dtype)
......@@ -153,4 +261,6 @@ def test_vta_conv2d():
if __name__ == "__main__":
test_cpu_conv2d()
exit(0)
test_vta_conv2d()
......@@ -116,8 +116,8 @@ def generate_graph(graph_fn, params_fn, device="vta"):
with nnvm.compiler.build_config(opt_level=3):
if target.device_name != "vta":
graph, lib, params = nnvm.compiler.build(
sym, target_host, shape_dict, dtype_dict,
params=params)
sym, target, shape_dict, dtype_dict,
params=params, target_host=target_host)
else:
with vta.build_config():
graph, lib, params = nnvm.compiler.build(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment