Commit 6c62dac3 by Tianqi Chen

[TVM] Upgrade TVM Support

parent 3ae9e155
......@@ -59,31 +59,7 @@ In the 'config.mk' file, make sure that:
For the *Python Package Installation*, we recommend updating your `~/.bashrc` file to extend your `PYTHONPATH` with the TVM Python libraries.
```bash
export PYTHONPATH=<tvm root>/python:<tvm root>/topi/python:${PYTHONPATH}
```
#### NNVM Installation
Clone the NNVM repository from `tqchen` in the directory of your choosing:
```bash
git clone git@github.com:tqchen/nnvm.git --recursive
```
To run this example, we rely on a special branch of NNVM `qt`:
```bash
cd <nnvm root>
git checkout qt
```
Launch the compilation, this takes about a minute on two threads.
```bash
cd <nnvm root>
make -j2
```
Finally update your `~/.bashrc` file to include the NNVM python libraries in your `PYTHONPATH`:
```bash
export PYTHONPATH=<nnvm root>/python:${PYTHONPATH}
export PYTHONPATH=<tvm root>/python:<tvm root>/topi/python:<tvm root>/nnvm/python:${PYTHONPATH}
```
#### MxNet Installation
......@@ -236,7 +212,7 @@ This time again, we will run the 2D convolution testbench. But beforehand, we'll
* Runtime building on the Pynq, which needs to be run everytime the `config.json` configuration is modified. This ensures that the VTA software runtime that generates the accelerator's executable via just-in-time (JIT) compilation matches the specifications of the VTA design that is programmed on the FPGA. The build process takes about 30 seconds to complete.
```bash
python tests/python/pynq/test_program_rpc.py
python tests/python/pynq/test_program_rpc.py
```
> Tip: You can track progress of the FPGA programming and the runtime rebuilding steps by looking at the RPC server's logging messages in your Pynq `ssh` session.
......@@ -244,7 +220,7 @@ python tests/python/pynq/test_program_rpc.py
We are now ready to run the 2D convolution testbench for the ResNet-15 workload in hardware.
```bash
python tests/python/pynq/test_benchmark_conv2d.py
python tests/python/pynq/test_benchmark_conv2d.py
```
The performance metrics measured on the Pynq board will be reported for each convolutional layer.
......@@ -280,7 +256,7 @@ You’ll need to install Xilinx’ FPGA compilation toolchain, [Vivado HL WebPAC
```bash
chmod u+x Xilinx_Vivado_SDK_2017.1_0415_1_Lin64.bin
```
5. Now you can execute the binary:
5. Now you can execute the binary:
```bash
./Xilinx_Vivado_SDK_2017.1_0415_1_Lin64.bin
```
......@@ -337,7 +313,7 @@ If you just want to generate the HLS-based VTA IP cores without launching the en
make ip
```
You'll be able to view the HLS synthesis reports under `<vta root>/build/hardware/xilinx/hls/<configuration>/<block>/solution0/syn/report/<block>_csynth.rpt`
> Note: The `<configuration>` name is a string that summarizes the VTA configuration parameters specified in the `config.json`. The `<block>` name refers to the specific module in the VTA pipeline.
> Note: The `<configuration>` name is a string that summarizes the VTA configuration parameters specified in the `config.json`. The `<block>` name refers to the specific module in the VTA pipeline.
Finally to run the full hardware compilation and generate the bitstream, run:
......
......@@ -26,8 +26,8 @@ data_dir = "_data/"
url = "https://homes.cs.washington.edu/~moreau/media/vta/"
TEST_FILE = 'cat.jpg'
CATEG_FILE = 'synset.txt'
RESNET_GRAPH_FILE = 'quantize_graph.json'
RESNET_PARAMS_FILE = 'quantize_params.pkl'
RESNET_GRAPH_FILE = 'resnet18_qt8.json'
RESNET_PARAMS_FILE = 'resnet18_qt8_params.pkl'
# Create data dir
if not os.path.exists(data_dir):
os.makedirs(data_dir)
......@@ -70,7 +70,7 @@ def mark_nop(graph, conv_layer=-1, skip_conv_layer=()):
attrs = node["attrs"]
node_name = node["name"]
func_name = attrs["func_name"]
if func_name.find("quantized_conv2d") != -1:
if func_name.find("conv2d") != -1:
if conv_layer >= 0:
if counter != conv_layer:
attrs["func_name"] = "__nop"
......@@ -109,9 +109,9 @@ graph_attr.set_dtype_inputs(sym, dtype_dict)
graph = graph.apply("InferShape").apply("InferType")
dtype = "float32"
sym = vta.graph.remove_stochastic(sym)
sym = vta.graph.clean_cast(sym)
sym = vta.graph.clean_conv_fuse(sym)
if target.device_name == "vta":
sym = vta.graph.pack(sym, shape_dict, bfactor, cfactor)
......@@ -166,8 +166,10 @@ def run_e2e(graph):
# get outputs
tvm_output = m.get_output(
0,tvm.nd.empty((1000,), dtype, remote.cpu(0)))
top1 = np.argmax(tvm_output.asnumpy())
print('TVM prediction top-1:', top1, synset[top1])
top = list(reversed(np.argsort(tvm_output.asnumpy())))
for i in range(5):
print('TVM prediction top-%d: %s' % (i, synset[top[i]]))
print("t-cost=%g" % tcost.mean)
......
......@@ -71,48 +71,6 @@ def _get_shape(sym, shape_dict):
return graph_util.infer_shape(
nnvm.graph.create(sym), **shape_dict)[1][0]
def remove_stochastic(graph):
"""
Replace stochastic rounding and shift with determinstic version.
Parameters
----------
graph : Graph
The input graph
Returns
-------
replaced_graph : Graph
The final replaced graph.
"""
gidx = graph.index
node_map = {}
for nid, node in enumerate(gidx.nodes):
children = [node_map[e[0]] for e in node["inputs"]]
attrs = node.get("attrs", {})
node_name = node["name"]
op_name = node["op"]
get_clone = lambda c, o_n, n_n, a: getattr(nnvm.symbol, o_n)(
*c, name=n_n, **a)
if op_name == "null":
new_node = nnvm.symbol.Variable(node_name)
elif op_name == "stochastic_round":
new_node = children[0]
elif op_name == "noise_lshift":
new_node = nnvm.symbol.left_shift(
children[0], **attrs)
else:
new_node = get_clone(children, op_name, node_name, attrs)
node_map[nid] = new_node
assert len(graph.index.output_entries) == 1
ret = node_map[graph.index.output_entries[0][0]]
ret = nnvm.graph.create(ret)
return ret
def clean_conv_fuse(graph):
"""Cleanup the convolution's later fuse stages
......@@ -131,8 +89,8 @@ def clean_conv_fuse(graph):
if flag:
node = nnvm.symbol.clip(node, a_max=127, a_min=-127)
node = nnvm.symbol.cast(node, dtype="int8")
# Use identity as a hint to block conv2d schedules
node = nnvm.symbol.identity(node)
# Use copy as a hint to block conv2d schedules
node = nnvm.symbol.copy(node)
flag = False
return node, flag
......@@ -166,13 +124,13 @@ def clean_conv_fuse(graph):
new_entry = (
get_clone([children[0][0]], op_name, node_name, attrs),
False)
elif op_name == "quantized_conv2d":
elif op_name == "conv2d" and attrs["out_dtype"] == "int32":
data, weight = children
data = _clean_entry(data)
new_node = nnvm.sym.quantized_conv2d(
new_node = nnvm.sym.conv2d(
data[0], weight[0], name=node_name, **attrs)
new_entry = (new_node, True)
elif op_name in ("left_shift", "right_shift", "relu"):
elif op_name in ("__lshift_scalar__", "__rshift_scalar__", "relu"):
new_entry = (
get_clone([children[0][0]], op_name, node_name, attrs),
children[0][1])
......@@ -199,7 +157,6 @@ def clean_conv_fuse(graph):
ret = nnvm.graph.create(ret)
return ret
def clean_cast(graph):
"""
Move the casts to early part of graph,
......@@ -232,11 +189,11 @@ def clean_cast(graph):
elif op_name == "cast":
dtype = attrs["dtype"]
new_node, _ = _clean_cast(children[0], dtype)
elif op_name == "quantized_conv2d":
elif op_name == "conv2d" and attrs["out_dtype"] == "int32":
data, weight = children
data, _ = _clean_cast(data, "int8")
weight, _ = _clean_cast(weight, "int8")
new_node = nnvm.sym.quantized_conv2d(
new_node = nnvm.sym.conv2d(
data, weight, name=node_name, **attrs)
elif op_name == "elemwise_add":
lhs, rhs = children
......@@ -314,21 +271,21 @@ def pack(graph, shape_dict, bfactor, cfactor, start_name=None):
*children, name=node_name, **attrs)
else:
new_node = get_clone(children, op_name, node_name, attrs)
elif op_name == "quantized_conv2d":
elif op_name == "conv2d" and attrs["out_dtype"] == "int32":
if start_pack:
attrs["pack_batch"] = str(bfactor)
attrs["pack_channel"] = str(cfactor)
attrs["layout"] = "NCHW%dn%dc" % (bfactor, cfactor)
attrs["kernel_layout"] = "OIHW%do%di" % (cfactor, cfactor)
data, weight = children
weight = _pack_weight(weight, ishape[1], cfactor)
new_node = nnvm.sym.quantized_conv2d(
new_node = nnvm.sym.conv2d(
data, weight, name=node_name, **attrs)
elif counter == 1:
attrs["pack_batch"] = str(bfactor)
attrs["pack_channel"] = str(cfactor)
attrs["layout"] = "NCHW%dn%dc" % (bfactor, cfactor)
attrs["kernel_layout"] = "OIHW%do%di" % (cfactor, cfactor)
data, weight = children
data = _pack_batch_channel(data, ishape[0], bfactor, cfactor)
weight = _pack_weight(weight, ishape[1], cfactor)
new_node = nnvm.sym.quantized_conv2d(
new_node = nnvm.sym.conv2d(
data, weight, name=node_name, **attrs)
new_node = _unpack_batch_channel(new_node, oshape)
counter = counter + 1
......
......@@ -215,7 +215,7 @@ def _lower(sch, inputs, func_name, graph):
f, (tvm.container.Array, tuple, list)) else [f]
@reg.register_compute("clip", level=11)
@reg.register_compute("clip", level=15)
def compute_clip(attrs, inputs, _):
""" Clip operator.
"""
......@@ -231,11 +231,24 @@ def compute_clip(attrs, inputs, _):
x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB")
return x
# override to force partition at copy
reg.register_pattern("copy", OpPattern.INJECTIVE, level=15)
reg.register_pattern("identity", OpPattern.INJECTIVE, level=11)
def is_packed_layout(layout):
"""Check if layout is packed layout"""
if layout == "NCHW":
return False
assert "n" in layout
assert "c" in layout
return True
@reg.register_compute("quantized_conv2d", level=11)
def compute_quantized_conv2d(attrs, inputs, out):
@reg.register_alter_op_layout("conv2d", level=15)
def alter_conv2d_layout(*_):
return None
@reg.register_compute("conv2d", level=15)
def compute_conv2d(attrs, inputs, out):
""" 2D convolution algorithm.
"""
padding = attrs.get_int_tuple("padding")
......@@ -244,36 +257,30 @@ def compute_quantized_conv2d(attrs, inputs, out):
groups = attrs.get_int("groups")
channels = attrs.get_int("channels")
layout = attrs["layout"]
out_dtype = attrs['out_type']
cmp_dtype = 'int32' # compute data type
assert layout == "NCHW", "only support nchw for now"
out_dtype = attrs['out_dtype']
assert dilation == (1, 1), "not support dilate now"
assert attrs.get_bool("use_bias") is False
pack_channel = attrs.get_int("pack_channel")
if pack_channel != 0:
if is_packed_layout(layout):
assert groups == 1
return packed_conv2d(inputs[0], inputs[1],
padding, strides)
padding, strides, out_dtype=out_dtype)
if groups == 1:
out = topi.nn.conv2d(inputs[0], inputs[1], strides, padding, out_dtype=cmp_dtype)
out = topi.nn.conv2d(inputs[0], inputs[1], strides, padding, out_dtype=out_dtype)
elif groups == get_const_int(inputs[0].shape[1]) and groups == channels:
out = topi.nn.depthwise_conv2d_nchw(
inputs[0], inputs[1], strides, padding, out_dtype=cmp_dtype)
inputs[0], inputs[1], strides, padding, out_dtype=out_dtype)
else:
raise ValueError("not support arbitrary group number for now")
assert out_dtype == cmp_dtype
return out
@reg.register_schedule("quantized_conv2d", level=11)
@reg.register_schedule("conv2d", level=15)
def schedule_quantized_conv2d(attrs, outs, target):
""" 2D convolution schedule.
"""
channels = attrs.get_int("channels")
pack_channel = attrs.get_int("pack_channel")
if channels != 0 and pack_channel:
layout = attrs["layout"]
if is_packed_layout(layout):
target = tvm.target.create(target)
if target.device_name == "vta":
return schedule_packed_conv2d(outs)
......
......@@ -38,7 +38,7 @@ def test_vta_conv2d():
res_conv = vta.top.packed_conv2d(
data, kernel, padding=(wl.hpad, wl.wpad), strides=(wl.hstride, wl.wstride))
res = topi.right_shift(res_conv, 8)
res = topi.broadcast_add(res, bias)
res = topi.add(res, bias)
res = my_clip(res, 0, 127)
res = topi.cast(res, "int8")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment