Commit 32076df8 by Lianmin Zheng Committed by Tianqi Chen

[AUTOTVM] TOPI integration for ARM CPU (#1487)

parent b625b992
......@@ -188,3 +188,6 @@ build*
# Jetbrain
.idea
# tmp file
.nfs*
# Performance Benchmark
## Results
See results on wiki page https://github.com/dmlc/tvm/wiki/Benchmark
## How to Reproduce
### ARM CPU
We use RPC infrastructure in TVM to make device management easy. So you need to use it for reproducing benchmark results.
1. Start an RPC Tracker on the host machine
```bash
python3 -m tvm.exec.rpc_tracker
```
2. Register devices to the tracker
* For Linux device
* Build tvm runtime on your device [Help](https://docs.tvm.ai/tutorials/nnvm/deploy_model_on_rasp.html#build-tvm-runtime-on-device)
* Register your device to tracker by
```bash
python3 -m tvm.exec.rpc_sever --tracker=[HOST_IP]:9190 --key=[DEVICE_KEY]
```
replace `[HOST_IP]` with the IP address of the host machine, `[DEVICE_KEY]` with the name of device.
E.g. Here is an example command for RK3399,
`python3 -m tvm.exec.rpc_sever --tracker=10.77.1.123:9190 --key=rk3399`, where 10.77.1.123 is the IP address of the tracker.
* For Android device
* Build and install tvm RPC apk on your device [Help](https://github.com/dmlc/tvm/tree/master/apps/android_rpc).
Make sure you can pass the android rpc test. Then you have alreadly known how to register.
3. Verify the device registration
We can query all registered devices by
```bash
python3 -m tvm.exec.query_rpc_tracker
```
You should be able to find your devices in `Queue Status`. Make sure the registration is correct before going ahead.
For our test environment, one sample output can be
```bash
Queue Status
------------------------------
key free pending
------------------------------
mate10pro 1 0
p20pro 2 0
pixel2 2 0
rk3399 2 0
rasp3b 8 0
```
4. Run benchmark
We did auto-tuning for Huawei P20/Mate10 Pro, Google Pixel2, Raspberry Pi3 and Firefly-RK3399,
and release pre-tuned parameters in [this repo](https://github.com/uwsaml/tvm-distro).
During compilation, TVM will download these operator parameters automatically.
```bash
python3 arm_cpu_imagenet_bench.py --device rasp3b --rpc-key rasp3b
python3 arm_cpu_imagenet_bench.py --device rk3399 --rpc-key rk3399
python3 arm_cpu_imagenet_bench.py --device pixel2 --rpc-key pixel2
python3 arm_cpu_imagenet_bench.py --device p20pro --rpc-key p20pro
python3 arm_cpu_imagenet_bench.py --device mate10pro --rpc-key mate10pro
```
If your device has a same SoC of the above device, you can reuse these parameters
(e.g. use `llvm -device=arm_cpu -mode=rk3399 -target=aarch64-linux-gnu` as target).
Otherwise, you need to tune for your own device, please follow this
[tutorial](https://docs.tvm.ai/tutorials/autotvm/tune_nnvm_arm.html).
"""Benchmark script for performance on ARM CPU.
see README.md for the usage and results of this script.
"""
import argparse
import time
import numpy as np
import nnvm.testing
import nnvm.compiler
import tvm
from tvm import autotvm
from tvm.contrib.util import tempdir
import tvm.contrib.graph_runtime as runtime
def get_network(name, batch_size):
"""Get the symbol definition and random weight of a network"""
input_shape = (batch_size, 3, 224, 224)
output_shape = (batch_size, 1000)
if name == 'resnet-18':
net, params = nnvm.testing.resnet.get_workload(num_layers=18,
batch_size=batch_size, image_shape=(3, 224, 224))
elif name == 'mobilenet':
net, params = nnvm.testing.mobilenet.get_workload(batch_size=batch_size)
elif name == 'squeezenet v1.1':
net, params = nnvm.testing.squeezenet.get_workload(batch_size=batch_size,
version='1.1')
elif name == 'vgg-16':
net, params = nnvm.testing.vgg.get_workload(batch_size=batch_size, num_layers=16)
else:
raise RuntimeError("Unsupported network: " + name)
return net, params, input_shape, output_shape
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--network", type=str, choices=['resnet-18', 'mobilenet', 'squeezenet v1.1', 'vgg-16'])
parser.add_argument("--device", type=str, required=True, choices=['rk3399', 'mate10', 'mate10pro', 'p20', 'p20pro',
'pixel2', 'rasp3b', 'pynq'])
parser.add_argument("--host", type=str, default='localhost')
parser.add_argument("--port", type=int, default=9190)
parser.add_argument("--rpc-key", type=str, required=True)
parser.add_argument("--number", type=int, default=6)
args = parser.parse_args()
dtype = 'float32'
if args.network is None:
networks = ['squeezenet v1.1', 'mobilenet', 'resnet-18', 'vgg-16']
else:
networks = [args.network]
target = tvm.target.arm_cpu(model=args.device)
# connect to remote device
tracker = tvm.rpc.connect_tracker(args.host, args.port)
remote = tracker.request(args.rpc_key)
print("--------------------------------------------------")
print("%-20s %-20s" % ("Network Name", "Mean Inference Time (std dev)"))
print("--------------------------------------------------")
for network in networks:
net, params, input_shape, output_shape = get_network(network, batch_size=1)
with nnvm.compiler.build_config(opt_level=2, add_pass=['AlterOpLayout']):
graph, lib, params = nnvm.compiler.build(
net, target=target, shape={'data': input_shape}, params=params, dtype=dtype)
tmp = tempdir()
if 'android' in str(target):
from tvm.contrib import ndk
filename = "%s.so" % network
lib.export_library(tmp.relpath(filename), ndk.create_shared)
else:
filename = "%s.tar" % network
lib.export_library(tmp.relpath(filename))
# upload library and params
ctx = remote.context(str(target), 0)
remote.upload(tmp.relpath(filename))
rparams = {k: tvm.nd.array(v, ctx) for k, v in params.items()}
rlib = remote.load_module(filename)
module = runtime.create(graph, rlib, ctx)
data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype))
module.set_input('data', data_tvm)
module.set_input(**rparams)
# evaluate
ftimer = module.module.time_evaluator("run", ctx, number=args.number, repeat=3)
prof_res = np.array(ftimer().results) * 1000 # multiply 1000 for converting to millisecond
print("%-20s %-19s (%s)" % (network, "%.2f ms" % np.mean(prof_res), "%.2f ms" % np.std(prof_res)))
""" Benchmark script for performance on Raspberry Pi. For example, run the file with:
`python rasp_imagenet_bench.py --model='modbilenet' --host='rasp0' --port=9090`. For
more details about how to set up the inference environment on Raspberry Pi, Please
refer to NNVM Tutorial: Deploy the Pretrained Model on Raspberry Pi """
import time
import argparse
import numpy as np
import tvm
import nnvm.compiler
import nnvm.testing
from tvm.contrib import util, rpc
from tvm.contrib import graph_runtime as runtime
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, required=True, choices=['resnet', 'mobilenet'],
help="The model type.")
parser.add_argument('--host', type=str, required=True, help="The host address of your Raspberry Pi.")
parser.add_argument('--port', type=int, required=True, help="The port number of your Raspberry Pi.")
parser.add_argument('--opt-level', type=int, default=1, help="Level of optimization.")
parser.add_argument('--num-iter', type=int, default=50, help="Number of iteration during benchmark.")
args = parser.parse_args()
opt_level = args.opt_level
num_iter = args.num_iter
batch_size = 1
num_classes = 1000
image_shape = (3, 224, 224)
data_shape = (batch_size,) + image_shape
out_shape = (batch_size, num_classes)
if args.model == 'resnet':
net, params = nnvm.testing.resnet.get_workload(
batch_size=1, image_shape=image_shape)
elif args.model == 'mobilenet':
net, params = nnvm.testing.mobilenet.get_workload(
batch_size=1, image_shape=image_shape)
else:
raise ValueError('no benchmark prepared for {}.'.format(args.model))
with nnvm.compiler.build_config(opt_level=opt_level):
graph, lib, params = nnvm.compiler.build(
net, tvm.target.rasp(), shape={"data": data_shape}, params=params)
tmp = util.tempdir()
lib_fname = tmp.relpath('net.o')
lib.save(lib_fname)
remote = rpc.connect(args.host, args.port)
remote.upload(lib_fname)
ctx = remote.cpu(0)
rlib = remote.load_module('net.o')
rparams = {k: tvm.nd.array(v, ctx) for k, v in params.items()}
module = runtime.create(graph, rlib, ctx)
module.set_input('data', tvm.nd.array(np.random.uniform(size=(data_shape)).astype("float32")))
module.set_input(**rparams)
module.run()
out = module.get_output(0, tvm.nd.empty(out_shape, ctx=ctx))
out.asnumpy()
print('benchmark args: {}'.format(args))
ftimer = module.module.time_evaluator("run", ctx, num_iter)
for i in range(3):
prof_res = ftimer()
print(prof_res)
# sleep for avoiding cpu overheat
time.sleep(45)
if __name__ == '__main__':
main()
......@@ -44,6 +44,9 @@ tvm.autotvm.tuner
.. automodule:: tvm.autotvm.tuner.callback
:members:
.. automodule:: tvm.autotvm.tuner.graph_tuning
:members:
tvm.autotvm.task
~~~~~~~~~~~~~~~~
.. automodule:: tvm.autotvm.task
......@@ -55,6 +58,15 @@ tvm.autotvm.task
.. automodule:: tvm.autotvm.task.space
:members:
.. automodule:: tvm.autotvm.task.dispatcher
:members:
.. automodule:: tvm.autotvm.task.topi_integration
:members:
.. automodule:: tvm.autotvm.task.nnvm_integration
:members:
tvm.autotvm.record
~~~~~~~~~~~~~~~~~~
.. automodule:: tvm.autotvm.record
......
......@@ -60,6 +60,8 @@ The configuration of tvm can be modified by `config.cmake`.
- Edit ``build/config.cmake`` to customize the compilation options
- On macOS, for some versions of XCode, you need to add ``-lc++abi`` in the LDFLAGS or you'll get link errors.
- Change ``set(USE_CUDA OFF)`` to ``set(USE_CUDA ON)`` to enable CUDA backend. So do other backends and libraries
(OpenCL, RCOM, METAL, VULKAN, ...).
- TVM optionally depends on LLVM. LLVM is required for CPU codegen that needs LLVM.
......@@ -84,7 +86,7 @@ The configuration of tvm can be modified by `config.cmake`.
cmake ..
make -j4
If everything goes well, we can go to :ref:`python-package-installation`_
If everything goes well, we can go to :ref:`python-package-installation`
Building on Windows
~~~~~~~~~~~~~~~~~~~
......
......@@ -172,6 +172,77 @@ struct Conv2DParam : public dmlc::Parameter<Conv2DParam> {
static const constexpr int kBias = 2;
};
struct WinogradWeightTransformParam : public dmlc::Parameter<WinogradWeightTransformParam> {
int tile_size;
DMLC_DECLARE_PARAMETER(WinogradWeightTransformParam) {
DMLC_DECLARE_FIELD(tile_size)
.describe("Tile size of winograd. E.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)");
}
static const constexpr int kWeight = 0;
};
struct WinogradConv2DParam : public dmlc::Parameter<WinogradConv2DParam> {
int channels;
TShape kernel_size;
TShape strides;
TShape padding;
TShape dilation;
int groups;
std::string layout;
std::string kernel_layout;
std::string out_layout;
int out_dtype;
bool use_bias;
int tile_size;
DMLC_DECLARE_PARAMETER(WinogradConv2DParam) {
DMLC_DECLARE_FIELD(channels)
.describe("The dimensionality of the output space"
"i.e. the number of output channels in the convolution.");
DMLC_DECLARE_FIELD(kernel_size)
.describe("Specifies the dimensions of the convolution window.");
DMLC_DECLARE_FIELD(strides).set_default(TShape({1, 1}))
.describe("Specifies the strides of the convolution.");
DMLC_DECLARE_FIELD(padding).set_default(TShape({0, 0}))
.describe("If padding is non-zero, then the input is implicitly zero-padded"
"on both sides for padding number of points");
DMLC_DECLARE_FIELD(dilation).set_default(TShape({1, 1}))
.describe("Specifies the dilation rate to use for dilated convolution.");
DMLC_DECLARE_FIELD(groups).set_default(1)
.describe("Controls the connections between inputs and outputs."
"At groups=1, all inputs are convolved to all outputs."
"At groups=2, the operation becomes equivalent to having two convolution"
"layers side by side, each seeing half the input channels, and producing"
"half the output channels, and both subsequently concatenated.");
DMLC_DECLARE_FIELD(layout).set_default("NCHW")
.describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Convolution is applied on the 'H' and"
"'W' dimensions.");
DMLC_DECLARE_FIELD(out_layout).set_default("__undef__")
.describe("Dimension ordering of output. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Default to be same as input layout.");
DMLC_DECLARE_FIELD(kernel_layout).set_default("OIHW")
.describe("Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc."
"'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width"
"dimensions respectively.");
DMLC_DECLARE_DTYPE_FIELD(out_dtype)
.add_enum("same", -1)
.set_default(-1)
.describe("Output data type, set to explicit type under mixed precision setting");
DMLC_DECLARE_FIELD(use_bias).set_default(true)
.describe("Whether the layer uses a bias vector.");
DMLC_DECLARE_FIELD(tile_size)
.describe("Tile size of winograd. E.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)");
}
// constants
static const constexpr int kData = 0;
static const constexpr int kWeight = 1;
static const constexpr int kBias = 2;
};
struct Conv2DTransposeParam : public dmlc::Parameter<Conv2DTransposeParam> {
int channels;
......
......@@ -6,6 +6,7 @@ import logging
import tvm
from tvm.contrib import graph_runtime
from tvm import autotvm
from . import graph_attr, graph_util
from .. import graph as _graph
from .. import symbol as sym
......@@ -238,67 +239,74 @@ def build(graph, target=None, shape=None, dtype="float32",
raise ValueError("Target is not set in env or passed as argument.")
target = tvm.target.create(target)
shape = shape if shape else {}
if not isinstance(shape, dict):
raise TypeError("require shape to be dict")
for value in shape.values():
if not all(isinstance(x, int) for x in value):
raise TypeError("shape value must be int iterator")
cfg = BuildConfig.current
graph = graph if isinstance(graph, _graph.Graph) else _graph.create(graph)
shape, dtype = _update_shape_dtype(shape, dtype, params)
# correct layout if necessary
layout = layout if layout else {}
graph = graph_attr.set_layout_inputs(graph, layout)
graph = graph.apply("CorrectLayout")
index = graph.index
layouts = graph.json_attr("layout")
layout = {x : layouts[index.entry_id(x)] for x in index.input_names}
# Initial pass do shape type inference
ishape, _ = graph_util.infer_shape(graph, **shape)
shape.update(zip(graph.index.input_names, ishape))
if not isinstance(dtype, str):
idtype, _ = graph_util.infer_dtype(graph, **dtype)
dtype.update(zip(graph.index.input_names, idtype))
# Initialize all variables specified in _all_var_init
init_var = {}
if _all_var_init:
init_var = initialize_variables(shape, dtype)
# Apply optimization
with target:
graph = optimize(graph, shape, dtype, layout)
# Clear extra params without nodes.
_remove_noref_params(params, graph)
# Precompute prune
if params and cfg.pass_enabled("PrecomputePrune"):
graph, params = precompute_prune(graph, params)
shape, dtype = _update_shape_dtype(shape, dtype, params)
# Operator Fusion and generation
graph = graph_attr.set_shape_inputs(graph, shape)
graph = graph.apply("InferShape")
graph = graph_attr.set_dtype_inputs(graph, dtype)
graph._set_json_attr("target", str(target), "str")
if target_host is not None:
graph._set_json_attr("target_host", str(target_host), "str")
if cfg.pass_enabled("OpFusion"):
graph._set_json_attr("opt_level", 1, "int")
# if not inside an autotvm config dispatch context, load pre-tuned parameters from TopHub
if autotvm.task.DispatchContext.current is None:
tophub_context = autotvm.tophub.context(target)
else:
graph._set_json_attr("opt_level", 0, "int")
graph = graph.apply("InferShape").apply("InferType")
with target:
graph = graph.apply("GraphFusePartition").apply("GraphFuseCompile")
libmod = graph_attr._move_out_module(graph, "module")
# Write variable initial values into params
if init_var:
if params is None:
params = {}
params.update(init_var)
return graph, libmod, params
tophub_context = autotvm.util.EmptyContext()
with tophub_context:
shape = shape if shape else {}
if not isinstance(shape, dict):
raise TypeError("require shape to be dict")
for value in shape.values():
if not all(isinstance(x, int) for x in value):
raise TypeError("shape value must be int iterator")
cfg = BuildConfig.current
graph = graph if isinstance(graph, _graph.Graph) else _graph.create(graph)
shape, dtype = _update_shape_dtype(shape, dtype, params)
# correct layout if necessary
layout = layout if layout else {}
graph = graph_attr.set_layout_inputs(graph, layout)
graph = graph.apply("CorrectLayout")
index = graph.index
layouts = graph.json_attr("layout")
layout = {x: layouts[index.entry_id(x)] for x in index.input_names}
# Initial pass do shape type inference
ishape, _ = graph_util.infer_shape(graph, **shape)
shape.update(zip(graph.index.input_names, ishape))
if not isinstance(dtype, str):
idtype, _ = graph_util.infer_dtype(graph, **dtype)
dtype.update(zip(graph.index.input_names, idtype))
# Initialize all variables specified in _all_var_init
init_var = {}
if _all_var_init:
init_var = initialize_variables(shape, dtype)
# Apply optimization
with target:
graph = optimize(graph, shape, dtype, layout)
# Clear extra params without nodes.
_remove_noref_params(params, graph)
# Precompute prune
if params and cfg.pass_enabled("PrecomputePrune"):
graph, params = precompute_prune(graph, params)
shape, dtype = _update_shape_dtype(shape, dtype, params)
# Operator Fusion and generation
graph = graph_attr.set_shape_inputs(graph, shape)
graph = graph.apply("InferShape")
graph = graph_attr.set_dtype_inputs(graph, dtype)
graph._set_json_attr("target", str(target), "str")
if target_host is not None:
graph._set_json_attr("target_host", str(target_host), "str")
if cfg.pass_enabled("OpFusion"):
graph._set_json_attr("opt_level", 1, "int")
else:
graph._set_json_attr("opt_level", 0, "int")
graph = graph.apply("InferShape").apply("InferType")
with target:
graph = graph.apply("GraphFusePartition").apply("GraphFuseCompile")
libmod = graph_attr._move_out_module(graph, "module")
# Write variable initial values into params
if init_var:
if params is None:
params = {}
params.update(init_var)
return graph, libmod, params
def _remove_noref_params(params, graph):
""" Helper to clear non referenced params
......
......@@ -89,7 +89,7 @@ def compute_conv2d(attrs, inputs, _):
layout = attrs["layout"]
kernel_layout = attrs["kernel_layout"]
out_dtype = attrs["out_dtype"]
out_dtype = None if out_dtype == "same" else out_dtype
out_dtype = inputs[0].dtype if out_dtype == "same" else out_dtype
assert layout == "NCHW" or layout == "NHWC"
(dilation_h, dilation_w) = dilation
if dilation_h < 1 or dilation_w < 1:
......@@ -196,6 +196,53 @@ def schedule_contrib_conv2d_NCHWc(attrs, outs, target):
reg.register_pattern("_contrib_conv2d_NCHWc", OpPattern.OUT_ELEMWISE_FUSABLE)
@reg.register_compute("_contrib_conv2d_winograd_weight_transform")
def compute_contrib_conv2d_winograd_weight_transform(attrs, inputs, _):
return topi.nn.conv2d_winograd_weight_transform(inputs[0], attrs.get_int('tile_size'))
@reg.register_schedule("_contrib_conv2d_winograd_weight_transform")
def schedule_contrib_conv2d_winograd_weight_transform(attrs, outs, target):
with tvm.target.create(target):
return topi.generic.schedule_conv2d_winograd_weight_transform(outs)
reg.register_pattern("_contrib_conv2d_winograd_weight_transform", OpPattern.OUT_ELEMWISE_FUSABLE)
@reg.register_compute("_contrib_conv2d_winograd_without_weight_transform")
def compute_contrib_conv2d_winograd_without_weight_transform(attrs, inputs, _):
"""Compute definition of conv2d NCHWc"""
padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides")
dilation = attrs.get_int_tuple("dilation")
groups = attrs.get_int("groups")
layout = attrs.get_string("layout")
out_dtype = attrs.get_string("out_dtype")
tile_size = attrs.get_int("tile_size")
out_dtype = inputs[0].dtype if out_dtype == "same" else out_dtype
assert dilation == (1, 1), "Do not support dilate now"
assert groups == 1, "Do not supoort arbitrary group number"
# pylint: disable=assignment-from-no-return
out = topi.nn.conv2d_winograd_without_weight_transform(
inputs[0], inputs[1], strides, padding, layout, out_dtype,
tile_size)
if attrs.get_bool("use_bias"):
bias = inputs[2]
bias = topi.expand_dims(bias, axis=1, num_newaxis=2)
out = topi.add(out, bias)
return out
@reg.register_schedule("_contrib_conv2d_winograd_without_weight_transform")
def schedule_contrib_conv2d_winograd_without_weight_transform(attrs, outs, target):
with tvm.target.create(target):
return topi.generic.schedule_conv2d_winograd_without_weight_transform(outs)
reg.register_pattern("_contrib_conv2d_winograd_without_weight_transform",
OpPattern.OUT_ELEMWISE_FUSABLE)
# conv2d_transpose
@reg.register_compute("conv2d_transpose")
def compute_conv2d_transpose(attrs, inputs, _):
......
......@@ -130,11 +130,110 @@ inline bool Conv2DInferShape(const nnvm::NodeAttrs& attrs,
return true;
}
inline bool WinogradConv2DInferShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape>* in_shape,
std::vector<TShape>* out_shape) {
static const Layout kNCHW("NCHW");
static const Layout kOIHW("OIHW");
const WinogradConv2DParam& param = nnvm::get<WinogradConv2DParam>(attrs.parsed);
const Layout in_layout(param.layout);
const Layout kernel_layout(param.kernel_layout);
CHECK(in_layout.convertible(kNCHW))
<< "Conv only support input layouts that are convertible from NCHW."
<< " But got " << in_layout;
CHECK(kernel_layout.convertible(kOIHW))
<< "Conv only support kernel layouts that are convertible from OIHW."
<< " But got "<< kernel_layout;
Layout out_layout(param.out_layout);
if (!out_layout.defined()) out_layout = in_layout;
CHECK(out_layout.convertible(kNCHW))
<< "Conv only support output layouts that are convertible from NCHW."
<< " But got " << out_layout;
if (param.use_bias) {
CHECK_EQ(in_shape->size(), 3U) << "Input:[data, weight, bias]";
} else {
CHECK_EQ(in_shape->size(), 2U) << "Input:[data, weight]";
}
CHECK_EQ(out_shape->size(), 1U);
TShape dshape = in_shape->at(0);
if (dshape.ndim() == 0) return false;
dshape = ConvertLayout(dshape, in_layout, kNCHW);
CHECK_EQ(dshape.ndim(), 4U) << "Input data should be 4D";
CHECK_EQ(param.kernel_size.ndim(), 2U);
CHECK_EQ(param.strides.ndim(), 2U)
<< "incorrect stride size: " << param.strides;
CHECK_EQ(param.dilation.ndim(), 2U)
<< "incorrect dilate size: " << param.dilation;
CHECK_EQ(dshape[1] % param.groups, 0U)
<< "input channels must divide group size";
CHECK_EQ(param.channels % param.groups, 0U)
<< "output channels must divide group size";
// NOTE: Do not check weight shape here!
// Different backend requires different layout to compute
// the batch gemm stage in winograd efficiently, but we want to
// make this NNVM symbol work for all backends.
// So we accept all weight shapes, and assume the TOPI developers
// can handle this correctly in alter_op_layout.
if (param.use_bias) {
static const Layout default_bias_layout("C");
TShape bias_shape({param.channels});
auto oc_block = out_layout.subsizeof('C');
if (oc_block > 0) {
size_t split_axis = (out_layout.indexof('C') < out_layout.indexof('c')) ? 1 : 0;
bias_shape = ConvertLayout(bias_shape, default_bias_layout,
default_bias_layout.split('C', split_axis, oc_block));
}
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, WinogradConv2DParam::kBias, bias_shape);
}
// dilation
dim_t dilated_ksize_y = 1 + (param.kernel_size[0] - 1) * param.dilation[0];
dim_t dilated_ksize_x = 1 + (param.kernel_size[1] - 1) * param.dilation[1];
TShape oshape({dshape[0], param.channels, 0, 0});
if (dshape[2] != 0) {
oshape[2] = (dshape[2] + param.padding[0] * 2 - dilated_ksize_y) / param.strides[0] + 1;
}
if (dshape[3] != 0) {
oshape[3] = (dshape[3] + param.padding[1] * 2 - dilated_ksize_x) / param.strides[1] + 1;
}
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, ConvertLayout(oshape, kNCHW, out_layout));
// Perform incomplete shape inference. Fill in the missing values in data shape.
// 1) We can always fill in the batch_size.
// 2) We can back-calculate the input height/width if the corresponding stride is 1.
oshape = ConvertLayout((*out_shape)[0], out_layout, kNCHW);
dshape[0] = oshape[0];
if (oshape[2] && param.strides[0] == 1) {
dshape[2] = oshape[2] + dilated_ksize_y - 1 - 2 * param.padding[0];
}
if (oshape[3] && param.strides[1] == 1) {
dshape[3] = oshape[3] + dilated_ksize_x - 1 - 2 * param.padding[1];
}
NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, WinogradConv2DParam::kData,
ConvertLayout(dshape, kNCHW, in_layout));
// Check whether the kernel sizes are valid
if (dshape[2] != 0) {
CHECK_LE(dilated_ksize_y, dshape[2] + 2 * param.padding[0])
<< "kernel size exceed input";
}
if (dshape[3] != 0) {
CHECK_LE(dilated_ksize_x, dshape[3] + 2 * param.padding[1])
<< "kernel size exceed input";
}
return true;
}
template <typename PARAM>
inline bool Conv2DInferType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_type,
std::vector<int>* out_type) {
const Conv2DParam& param = nnvm::get<Conv2DParam>(attrs.parsed);
const PARAM& param = nnvm::get<PARAM>(attrs.parsed);
if (param.use_bias) {
CHECK_EQ(in_type->size(), 3U) << "Input:[data, weight, bias]";
} else {
......@@ -154,11 +253,12 @@ inline bool Conv2DInferType(const nnvm::NodeAttrs& attrs,
}
template<typename PARAM>
inline bool Conv2DCorrectLayout(const NodeAttrs& attrs,
std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) {
const Conv2DParam& param = nnvm::get<Conv2DParam>(attrs.parsed);
const PARAM& param = nnvm::get<PARAM>(attrs.parsed);
const Layout in_layout(param.layout);
Layout out_layout(param.out_layout);
......@@ -213,8 +313,8 @@ a bias vector is created and added to the outputs.
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<Conv2DParam>)
.set_attr<FListInputNames>("FListInputNames", UseBiasListInputNames<Conv2DParam>)
.set_attr<FInferShape>("FInferShape", Conv2DInferShape)
.set_attr<FInferType>("FInferType", Conv2DInferType)
.set_attr<FCorrectLayout>("FCorrectLayout", Conv2DCorrectLayout)
.set_attr<FInferType>("FInferType", Conv2DInferType<Conv2DParam>)
.set_attr<FCorrectLayout>("FCorrectLayout", Conv2DCorrectLayout<Conv2DParam>)
.set_num_outputs(1)
.set_num_inputs(UseBiasNumInputs<Conv2DParam>)
.set_support_level(2)
......@@ -238,12 +338,81 @@ NNVM_REGISTER_OP(_contrib_conv2d_NCHWc)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<Conv2DParam>)
.set_attr<FListInputNames>("FListInputNames", UseBiasListInputNames<Conv2DParam>)
.set_attr<FInferShape>("FInferShape", Conv2DInferShape)
.set_attr<FInferType>("FInferType", Conv2DInferType)
.set_attr<FCorrectLayout>("FCorrectLayout", Conv2DCorrectLayout)
.set_attr<FInferType>("FInferType", Conv2DInferType<Conv2DParam>)
.set_attr<FCorrectLayout>("FCorrectLayout", Conv2DCorrectLayout<Conv2DParam>)
.set_num_outputs(1)
.set_num_inputs(UseBiasNumInputs<Conv2DParam>)
.set_support_level(2);
NNVM_REGISTER_OP(_contrib_conv2d_winograd_weight_transform)
.describe(R"code(Weight transformation of winograd fast convolution algorithm.
Separate this into another nnvm symbol in order to enable Precompute Pass to compute the
weight transformation in advance.
- **weight**: (channels, in_channels, kernel_size[0], kernel_size[1])
)code" NNVM_ADD_FILELINE)
.add_argument("weight", "4D Tensor", "Weight tensor.")
.add_arguments(WinogradWeightTransformParam::__FIELDS__())
.set_attr_parser(ParamParser<WinogradWeightTransformParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<WinogradWeightTransformParam>)
.set_attr<FInferShape>("FInferShape", [](const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape) {
const auto& param = nnvm::get<WinogradWeightTransformParam>(attrs.parsed);
const TShape &wshape = (*in_shape)[0];
CHECK_EQ(wshape.ndim(), 4) << "Weight should be a 4 dimensional tensor";
TShape oshape({param.tile_size + wshape[2] - 1,
param.tile_size + wshape[3] - 1,
wshape[0],
wshape[1]});
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape);
return true;
})
.set_attr<FCorrectLayout>("FCorrectLayot", [](const NodeAttrs& attrs,
std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) {
Layout layout("OIHW");
NNVM_ASSIGN_LAYOUT(*ilayouts, 0, layout);
NNVM_ASSIGN_LAYOUT(*olayouts, 0, layout);
return true;
})
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_num_outputs(1)
.set_num_inputs(1)
.set_support_level(5);
DMLC_REGISTER_PARAMETER(WinogradWeightTransformParam);
NNVM_REGISTER_OP(_contrib_conv2d_winograd_without_weight_transform)
.describe(R"code(Compute conv2d with winograd algorithm.
- **data**: Input is 4D array of shape (batch_size, in_channels, height, width)
- **weight**: Any shape
We do not check shape for this input tensor.
- **bias**: (channels,)
- **out**: Output is 4D array of shape (batch_size, channels, out_height, out_width)
)code" NNVM_ADD_FILELINE)
.add_argument("data", "4D Tensor", "Input data.")
.add_argument("weight", "Tensor", "Transformed weight tensor.")
.add_argument("bias", "1D Tensor", "Bias parameter.")
.add_arguments(WinogradConv2DParam::__FIELDS__())
.set_attr_parser(ParamParser<WinogradConv2DParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<WinogradConv2DParam>)
.set_attr<FListInputNames>("FListInputNames", UseBiasListInputNames<WinogradConv2DParam>)
.set_attr<FInferShape>("FInferShape", WinogradConv2DInferShape)
.set_attr<FInferType>("FInferType", Conv2DInferType<WinogradConv2DParam>)
.set_attr<FCorrectLayout>("FCorrectLayout", Conv2DCorrectLayout<WinogradConv2DParam>)
.set_num_outputs(1)
.set_num_inputs(UseBiasNumInputs<WinogradConv2DParam>)
.set_support_level(5);
DMLC_REGISTER_PARAMETER(WinogradConv2DParam);
NNVM_REGISTER_OP(_conv2d_grad)
.describe(R"code(2D convolution grad.
......
......@@ -18,9 +18,12 @@ from . import record
from . import task
from . import tuner
from . import util
from . import env
from . import tophub
# some shortcuts
from .measure import measure_option, MeasureInput, MeasureResult, MeasureErrorNo
from .measure import measure_option, MeasureInput, MeasureResult, MeasureErrorNo, use_rpc
from .tuner import callback
from .task import template, get_config, create, ConfigSpace, ConfigEntity
from .record import ApplyHistoryBest as apply_history_best
from .task import template, get_config, create, ConfigSpace, ConfigEntity, \
ApplyHistoryBest as apply_history_best
from .env import GLOBAL_SCOPE
......@@ -8,5 +8,6 @@ class AutotvmGlobalScope(object):
AutotvmGlobalScope.current = self
self.cuda_target_arch = None
self.in_tuning = False
GLOBAL_SCOPE = AutotvmGlobalScope()
"""Distributed executor infrastructure to scale up the tuning"""
from .measure import MeasureInput, MeasureResult, MeasureErrorNo
from .measure import create_measure_batch, measure_option
from .measure_methods import request_remote
from .measure import MeasureInput, MeasureResult, MeasureErrorNo, measure_option
from .measure_methods import request_remote, create_measure_batch, use_rpc
from .local_executor import LocalExecutor
from .executor import Future, Executor
......@@ -8,7 +8,10 @@ try:
except ImportError:
from Queue import Empty
import psutil
try:
import psutil
except ImportError:
psutil = None
from . import executor
......@@ -106,22 +109,28 @@ class LocalFutureNoFork(executor.Future):
class LocalExecutor(executor.Executor):
"""Local executor that runs workers on the same machine with multiprocessing."""
def __init__(self, timeout=None):
self.timeout = timeout or executor.Executor.DEFAULT_TIMEOUT
"""Local executor that runs workers on the same machine with multiprocessing.
def submit(self, func, *args, **kwargs):
"""
Parameters
----------
timeout: float, optional
timeout of a job. If time is out. A TimeoutError will be returned (not raised)
do_fork: bool, optional
For some runtime systems that do not support fork after initialization
(e.g. cuda runtime, cudnn). Set this to False if you have used these runtime
before submitting jobs.
"""
def __init__(self, timeout=None, do_fork=True):
self.timeout = timeout or executor.Executor.DEFAULT_TIMEOUT
self.do_fork = do_fork
Note
----------
By default, the executor will fork a new process for a new job
But some runtime does not support fork (e.g. cuda runtime, cudnn).
In this circumstance, you should set 'fork_new_process' to False in kwargs
"""
fork_new_process = kwargs.pop('fork_new_process', True)
if self.do_fork:
if not psutil:
raise RuntimeError("Python package psutil is missing. "
"please try `pip install psutil`")
if not fork_new_process:
def submit(self, func, *args, **kwargs):
if not self.do_fork:
return LocalFutureNoFork(func(*args, **kwargs))
queue = Queue(1)
......
......@@ -9,15 +9,12 @@ import multiprocessing
import pickle
import json
import time
import os
from collections import OrderedDict
import numpy as np
from .. import build, lower, target as _target
from . import task
from .task import DispatchContext, ConfigEntity
from .task import ConfigEntity, ApplyHistoryBest
from .measure import MeasureInput, MeasureResult
AUTOTVM_LOG_VERSION = 0.1
......@@ -120,8 +117,8 @@ def decode(row, protocol='json'):
tgt = _target.create(str(tgt))
def clean_json_to_python(x):
"""1. convert all list in x to tuple (hashable)
2. convert unicode to str for python2
"""1. Convert all list in x to tuple (hashable)
2. Convert unicode to str for python2
"""
if isinstance(x, list):
return tuple([clean_json_to_python(a) for a in x])
......@@ -151,6 +148,7 @@ def decode(row, protocol='json'):
else:
raise RuntimeError("Invalid log protocol: " + protocol)
def load_from_file(filename):
"""Generator: load records from file.
This is a generator that yields the records.
......@@ -168,105 +166,6 @@ def load_from_file(filename):
yield decode(row)
class ApplyHistoryBest(DispatchContext):
"""
Apply the history best config
Parameters
----------
records : str or iterator of (MeasureInput, MeasureResult)
Collection of tuning records.
If is str, then it should be the filename of a records log file.
Each row of this file is an encoded record pair.
Otherwise, it is an iterator.
default: ConfigEntity, optional
The default config to return when no history records
"""
def __init__(self, records, default=None):
super(ApplyHistoryBest, self).__init__()
self.best_by_targetkey = {}
self.best_by_model = {}
self._default = default
self.load(records)
def load(self, records):
"""Load records to this dispatch context
Parameters
----------
records : str or iterator of (MeasureInput, MeasureResult)
Collection of tuning records.
If is str, then it should be the filename of a records log file.
Each row of this file is an encoded record pair.
Otherwise, it is an iterator.
"""
if isinstance(records, str):
records = load_from_file(records)
if not records:
return
best_by_targetkey = self.best_by_targetkey
best_by_model = self.best_by_model
counter = 0
for inp, res in records:
counter += 1
if res.error_no != 0:
continue
# use target keys in tvm target system as key to build best map
for k in inp.target.keys:
key = (k, inp.task.workload)
if key not in best_by_targetkey:
best_by_targetkey[key] = (inp, res)
else:
_, other_res = best_by_targetkey[key]
if np.mean(other_res.costs) > np.mean(res.costs):
best_by_targetkey[key] = (inp, res)
# use model as key to build best map
for opt in inp.target.options:
if opt.startswith("-model"):
model = opt[7:]
key = (model, inp.task.workload)
if key not in best_by_model:
best_by_model[key] = (inp, res)
else:
_, other_res = best_by_model[key]
if np.mean(other_res.costs) > np.mean(res.costs):
best_by_model[key] = (inp, res)
break
logging.info("Finish loading %d records", counter)
def query(self, target, workload):
if target is None:
raise RuntimeError("Need a target context to find the history best. "
"Hint: If your target is llvm, use `with tvm.target.create('llvm'):`"
" above the dispatcher call. So does other target. ")
# first try matching by model
for opt in target.options:
if opt.startswith("-model"):
model = opt[7:]
key = (model, workload)
if key in self.best_by_model:
return self.best_by_model[key][0].config
# then try matching by target key
for k in target.keys:
key = (k, workload)
if key in self.best_by_targetkey:
return self.best_by_targetkey[key][0].config
if self._default:
return self._default
raise RuntimeError(
"Cannot find config for target=%s, workload=%s" % (target, workload))
def split_workload(in_file, clean=True):
"""Split a log file into separate files, each of which contains only a single workload
This function can also delete duplicated records in log file
......@@ -326,7 +225,7 @@ def pick_best(in_file, out_file):
----------
in_file: str
The filename of input
out_file:
out_file: str or file
The filename of output
"""
best_context = ApplyHistoryBest(load_from_file(in_file))
......@@ -338,31 +237,13 @@ def pick_best(in_file, out_file):
for v in best_context.best_by_targetkey.values():
best_set.add(measure_str_key(v[0]))
logging.info("Extract %d best records from the log file", len(best_set))
logging.info("Extract %d best records from the %s", len(best_set), in_file)
fout = open(out_file, 'w') if isinstance(out_file, str) else out_file
fout = open(out_file, 'w')
for inp, res in load_from_file(in_file):
if measure_str_key(inp) in best_set:
fout.write(encode(inp, res) + "\n")
def load_op_param(rootpath=os.path.join(os.path.expanduser('~'), ".tvm", "op_params")):
"""Load pre-tuned parameters of operators.
This function will load all "*.log" file under root path and select best configs.
Parameters
----------
rootpath: str
The root path of stored parameters
"""
best_context = ApplyHistoryBest([])
for dirpath, _, filenames in os.walk(rootpath):
for filename in filenames:
if os.path.splitext(filename)[1] == '.log':
best_context.load(os.path.join(dirpath, filename))
assert not DispatchContext.current, "Cannot load pre-tuned parameters inside a dispatch context"
DispatchContext.current = best_context
best_set.remove(measure_str_key(inp))
"""
Usage:
......
......@@ -9,4 +9,7 @@ of typical tasks of interest.
from .task import Task, create, register, template, get_config, args_to_workload
from .space import ConfigSpace, ConfigEntity
from .code_hash import attach_code_hash, attach_code_hash_to_arg
from .dispatcher import DispatchContext, ApplyConfig, dispatcher
from .dispatcher import DispatchContext, ApplyConfig, ApplyHistoryBest, dispatcher
from .topi_integration import register_topi_compute, register_topi_schedule
from .nnvm_integration import extract_from_graph
......@@ -12,7 +12,10 @@ of the DispatchContext base class.
"""
from __future__ import absolute_import as _abs
import logging
from decorator import decorate
import numpy as np
from tvm import target as _target
......@@ -52,25 +55,6 @@ class DispatchContext(object):
DispatchContext.current = self._old_ctx
class ApplyConfig(DispatchContext):
"""Apply a specific config entity during query.
Parameters
----------
config : ConfigSpace or ConfigEntity
The specific configuration we care about.
"""
def __init__(self, config):
super(ApplyConfig, self).__init__()
self._config = config
self.workload = None
def query(self, target, workload):
"""Override query"""
self.workload = workload
return self._config
def dispatcher(fworkload):
"""Wrap a workload dispatcher function.
......@@ -137,3 +121,124 @@ def dispatcher(fworkload):
fdecorate = decorate(fworkload, dispatch_func)
fdecorate.register = register
return fdecorate
class ApplyConfig(DispatchContext):
"""Apply a specific config entity during query.
Parameters
----------
config : ConfigSpace or ConfigEntity
The specific configuration we care about.
"""
def __init__(self, config):
super(ApplyConfig, self).__init__()
self._config = config
self.workload = None
def query(self, target, workload):
"""Override query"""
self.workload = workload
return self._config
class ApplyHistoryBest(DispatchContext):
"""
Apply the history best config
Parameters
----------
records : str or iterator of (MeasureInput, MeasureResult)
Collection of tuning records.
If is str, then it should be the filename of a records log file.
Each row of this file is an encoded record pair.
Otherwise, it is an iterator.
default: ConfigEntity, optional
The default config to return when no history records
"""
def __init__(self, records, default=None):
super(ApplyHistoryBest, self).__init__()
self.best_by_targetkey = {}
self.best_by_model = {}
self._default = default
if records:
self.load(records)
def load(self, records):
"""Load records to this dispatch context
Parameters
----------
records : str or iterator of (MeasureInput, MeasureResult)
Collection of tuning records.
If is str, then it should be the filename of a records log file.
Each row of this file is an encoded record pair.
Otherwise, it is an iterator.
"""
from ..record import load_from_file
if isinstance(records, str):
records = load_from_file(records)
if not records:
return
best_by_targetkey = self.best_by_targetkey
best_by_model = self.best_by_model
counter = 0
for inp, res in records:
counter += 1
if res.error_no != 0:
continue
# use target keys in tvm target system as key to build best map
for k in inp.target.keys:
key = (k, inp.task.workload)
if key not in best_by_targetkey:
best_by_targetkey[key] = (inp, res)
else:
_, other_res = best_by_targetkey[key]
if np.mean(other_res.costs) > np.mean(res.costs):
best_by_targetkey[key] = (inp, res)
# use model as key to build best map
for opt in inp.target.options:
if opt.startswith("-model"):
model = opt[7:]
key = (model, inp.task.workload)
if key not in best_by_model:
best_by_model[key] = (inp, res)
else:
_, other_res = best_by_model[key]
if np.mean(other_res.costs) > np.mean(res.costs):
best_by_model[key] = (inp, res)
break
logging.debug("Finish loading %d records", counter)
def query(self, target, workload):
if target is None:
raise RuntimeError("Need a target context to find the history best. "
"Hint: If your target is llvm, use `with tvm.target.create('llvm'):`"
" above the dispatcher call. So does other target. ")
# first try matching by model
for opt in target.options:
if opt.startswith("-model"):
model = opt[7:]
key = (model, workload)
if key in self.best_by_model:
return self.best_by_model[key][0].config
# then try matching by target key
for k in target.keys:
key = (k, workload)
if key in self.best_by_targetkey:
return self.best_by_targetkey[key][0].config
if self._default:
return self._default
raise RuntimeError(
"Cannot find config for target=%s, workload=%s" % (target, workload))
# pylint: disable=unused-variable,invalid-name
"""
Decorator and utilities for the integration with TOPI and NNVM
"""
import warnings
from ... import tensor, placeholder, target as _target
from ..util import get_const_tuple
from .task import create, register
def serialize_args(args):
"""serialize arguments of a topi function to a hashable tuple.
Parameters
----------
args: list of hashable or Tensor
"""
ret = []
for t in args:
if isinstance(t, tensor.Tensor):
ret.append(('TENSOR', get_const_tuple(t.shape), t.dtype))
else:
ret.append(t)
return tuple(ret)
def deserialize_args(args):
"""The inverse function of :code:`serialize_args`.
Parameters
----------
args: list of hashable or Tensor
"""
ret = []
for t in args:
if isinstance(t, tuple) and t[0] == 'TENSOR':
ret.append(placeholder(shape=t[1], dtype=t[2]))
else:
ret.append(t)
return ret
# Task extractor for nnvm graph
class TaskExtractEnv:
"""Global environment for extracting tuning tasks from nnvm graph"""
current = None
def __init__(self):
import topi
import nnvm
self.symbol2topi = {
nnvm.sym.conv2d: [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw]
}
self.topi_to_task = {
topi.nn.conv2d: "topi_nn_conv2d",
topi.nn.depthwise_conv2d_nchw: "topi_nn_depthwise_conv2d_nchw",
}
self._register_dummy()
self._register_topi_task()
self.task_collection = []
def _register_dummy(self):
"""Register dummy function to track the topi function call"""
for func in self.topi_to_task:
def _local_scope(local_func):
"""build a scope to holds the function"""
@local_func.register("dummy", )
def _dummy_func(*args, **kwargs):
assert not kwargs, "Do not support extracting tuning tasks when" \
"kwargs is used in TOPI function call." \
"Please modify it to use only positional args."
if (self.topi_to_task[local_func], serialize_args(args)) \
not in self.task_collection:
self.task_collection.append((self.topi_to_task[local_func],
serialize_args(args)))
with _target.create("opencl"):
return local_func(*args)
_local_scope(func)
def _register_topi_task(self):
"""register tuning wrapper for topi function"""
import topi
# Tuning wrapper for topi functions
@register("topi_nn_conv2d")
def _topi_nn_conv2d(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
args = deserialize_args(args)
A, W = args[:2]
layout = args[-2]
assert layout == 'NCHW', "only support NCHW currently"
C = topi.nn.conv2d(*args, **kwargs)
s = topi.generic.schedule_conv2d_nchw([C])
return s, [A, W, C]
@register("topi_nn_depthwise_conv2d_nchw")
def _topi_nn_depthwise_conv2d_nchw(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
args = deserialize_args(args)
A, W = args[:2]
C = topi.nn.depthwise_conv2d_nchw(*args, **kwargs)
s = topi.generic.schedule_depthwise_conv2d_nchw([C])
return s, [A, W, C]
def reset(self):
"""Reset task collections"""
self.task_collection = []
def get_tasks(self):
"""Get collected tasks"""
return self.task_collection
@staticmethod
def get():
"""Get the single instance of TaskExtractEnv"""
if not TaskExtractEnv.current:
TaskExtractEnv.current = TaskExtractEnv()
return TaskExtractEnv.current
def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):
""" Extract tuning tasks from a nnvm graph.
This function collects tunning tasks by building the graph
with a "dummy" target and tracing all the calls to topi.
Parameters
----------
graph : Graph
The graph to tune
shape : dict of str to tuple, optional
The input shape to the graph
dtype : str or dict of str to str
The input types to the graph
target: tvm.target.Target
The compilation target
symbols : Array of nnvm.symbol
Array of nnvm symbols
target_host: tvm.target.Target
The host compilation target
Returns
-------
task: Array of autotvm.task.Task
collected tasks
"""
import nnvm.compiler
env = TaskExtractEnv.get()
topi_funcs = []
for sym_name in symbols:
if sym_name in env.symbol2topi:
topi_funcs.extend(env.symbol2topi[sym_name])
else:
warnings.warn("Symbol %s is not tunable, ignored" % sym_name)
# run compiler to collect all TOPI calls during compilation
env.reset()
dummy_target = _target.create("opencl -device=dummy")
nnvm.compiler.build(graph, target=dummy_target, shape=shape, dtype=dtype)
tasks = []
for task_name, args in env.get_tasks():
tasks.append(create(task_name, args,
target=target, target_host=target_host,
template_key='direct'))
return tasks
......@@ -21,6 +21,11 @@ from tvm.autotvm.util import get_const_int
Axis = namedtuple('Axis', ['space', 'index'])
try:
_long = long
except NameError:
_long = int
class InstantiationError(ValueError):
"""Actively detected error in instantiating a template with a config,
......@@ -103,7 +108,7 @@ class VirtualAxis(TransformSpace):
VirtualAxis.name_ct += 1
self.name = name
if isinstance(var, int):
if isinstance(var, (int, _long)):
self.length = var
elif isinstance(var, schedule.IterVar):
self.name = var.var.name
......@@ -114,7 +119,7 @@ class VirtualAxis(TransformSpace):
elif isinstance(var, VirtualAxis):
self.length = var.length
else:
raise RuntimeError("Invalid type of axis")
raise RuntimeError("Invalid type of axis: " + str(type(var)))
@staticmethod
def get_num_output(var, name=None):
......
......@@ -362,7 +362,7 @@ def compute_flop(sch):
exp = body[0]
ret += num_element * _count_flop(exp)
ret += traverse([sch[t].op for t in op.input_tensors])
ret += traverse([t.op for t in op.input_tensors])
elif isinstance(op, tensor.PlaceholderOp):
pass
......@@ -382,5 +382,4 @@ def compute_flop(sch):
raise RuntimeError("Cannot find float number operation in this operator. "
"Please use `cfg.add_flop` to manually set "
"FLOP for this operator")
return ret
# pylint: disable=unused-variable,invalid-name
"""
Decorators for registering tunable templates to TOPI.
These decorators can make your simple implementation be able to use different configurations
for different workloads.
Here we directly use all arguments to the TOPI call as "workload", so make sure all the arguments
(except tvm.Tensor) in you calls are hashable. For tvm.Tensor, we will serialize it to a hashable
tuple.
See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage.
"""
from ... import _api_internal, tensor
from ..util import get_func_name
from .task import args_to_workload, dispatcher
# A table that records all registered dispatcher for all targets
_REGISTED_DISPATHCER = {
}
def register_topi_compute(topi_compute, target_keys, template_keys, func=None):
"""Register a tunable template for a topi compute function.
After the registration. This topi compute will become a configuration dispatcher. It uses
all its argument as workload and dispatches configurations according to the input workload.
It also stores this "workload" to its final ComputeOp, which can be used to reconstruct
"workload" in the following topi_schedule call.
Parameters
----------
topi_compute: GenericFunc
The topi compute function that will be overloaded
target_keys: str or list of str
The compilation target. The same as the argument of GenericFunc.register.
template_keys: str or list of str
The template key.
We might have several strategies for a single operator (e.g. direct, im2col, winograd).
The template key is used to identity the algorithm strategy.
Every operator must have a "direct" template, which is used by default.
func: None or callable
If it is None, return a decorator.
If is callable, decorate this function.
Returns
-------
decorator: callable
A decorator
Examples
--------
See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage.
"""
fname = get_func_name(topi_compute)
def _decorator(f):
targets = [target_keys] if isinstance(target_keys, str) else target_keys
for target_key in targets:
if target_key not in _REGISTED_DISPATHCER:
_REGISTED_DISPATHCER[target_key] = {}
if topi_compute not in _REGISTED_DISPATHCER:
@topi_compute.register(target_key)
@dispatcher
def config_dispatcher(*args, **kwargs):
"""override topi call as a config dispatcher"""
assert not kwargs, "Do not support kwargs in template function call"
return (fname, ) + args_to_workload(args)
_REGISTED_DISPATHCER[target_key][topi_compute] = config_dispatcher
config_dispatcher = _REGISTED_DISPATHCER[target_key][topi_compute]
@config_dispatcher.register(template_keys)
def template_call(cfg, *args, **kwargs):
"""call the topi func and attach workload to compute node"""
assert not kwargs, "Do not support kwargs in template function call"
if f == topi_compute.fdefault:
node = f(*args, **kwargs)
else:
node = f(cfg, *args, **kwargs)
# attach workload to return op
op = node.op
attrs = {}
for k, v in node.op.attrs.items():
attrs[k] = v
attrs['workload'] = (fname, ) + args_to_workload(args)
if isinstance(op, tensor.ComputeOp):
op = _api_internal._ComputeOp(
op.name, op.tag, attrs, op.axis, op.body)
elif isinstance(op, tensor.ExternOp):
op = _api_internal._ExternOp(
op.name, op.tag, attrs,
op.inputs, op.input_placeholders,
op.output_placeholders, op.body)
else:
raise RuntimeError("Unsupported op type: " + str(type(op)))
if isinstance(node, tensor.Tensor):
return op.output(0)
return [op.output(i) for i in range(len(node))]
return f
if func:
_decorator(func)
return _decorator
def register_topi_schedule(topi_schedule, target_keys, template_keys, func=None):
"""Register a tunable template for a topi schedule function.
After the registration. This topi schedule will become a configuration dispatcher. It dispatches
configurations according to the input workload.
Note that this function will try to find "workload" from all the ComputeOp in the input.
You can attach "workload" to your compute op by using :any:`register_topi_compute`.
Parameters
----------
topi_schedule: GenericFunc
The topi schedule function that will be overloaded
target_keys: str or list of str
The compilation target
template_keys: str or list of str
The template key.
We might have several strategies for a single operator (e.g. direct, im2col, winograd).
The template key is used to identity the algorithm strategy.
Every operator must have a "direct" template, which is used by default.
func: None or callable
If it is None, return a decorator.
If is callable, decorate this function.
Returns
-------
decorator: callable
A decorator
Examples
--------
See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage.
"""
def _decorator(f):
targets = [target_keys] if isinstance(target_keys, str) else target_keys
for target_key in targets:
if target_key not in _REGISTED_DISPATHCER:
_REGISTED_DISPATHCER[target_key] = {}
if topi_schedule not in _REGISTED_DISPATHCER[target_key]:
@topi_schedule.register(target_key)
@dispatcher
def config_dispatcher(outs):
"""override topi call as a workload dispatcher"""
def traverse(tensors):
"""traverse all ops to find attached workload"""
for t in tensors:
op = t.op
if 'workload' in op.attrs:
return op.attrs['workload']
wkl = traverse(op.input_tensors)
if wkl:
return wkl
return None
outs = [outs] if isinstance(outs, tensor.Tensor) else outs
workload = traverse(outs)
if workload is None:
raise RuntimeError("Cannot find workload in attribute of this schedule")
return args_to_workload(workload)
_REGISTED_DISPATHCER[target_key][topi_schedule] = config_dispatcher
config_dispatcher = _REGISTED_DISPATHCER[target_key][topi_schedule]
@config_dispatcher.register(template_keys)
def template_call(cfg, outs):
"""call the schedule func"""
if f == topi_schedule.fdefault:
return f(outs)
return f(cfg, outs)
return f
if func:
_decorator(func)
return _decorator
"""
TopHub: Tensor Operator Hub
To get the best performance, we typically need auto-tuning for the specific devices.
TVM releases pre-tuned parameters in TopHub for some common networks and hardware targets.
TVM will download these parameters for you when you create the target for the first time.
"""
import logging
import os
import json
from .task import ApplyHistoryBest
from .. import target as _target
from ..contrib.util import tempdir
from ..contrib.download import download
AUTOTVM_TOPHUB_ROOT_PATH = os.path.join(os.path.expanduser('~'), ".tvm", "tophub")
def _alias(name):
"""convert alias for some packages"""
table = {
'vtacpu': 'vta',
}
return table.get(name, name)
def context(target, extra_files=None):
"""Return the dispatch context with pre-tuned parameters.
The corresponding downloaded *.log files under tophub root path will be loaded.
Users can also add their own files in argument `extra_files`.
Parameters
----------
target: Target
The compilation target
extra_files: list of str, optional
Extra log files to load
"""
rootpath = AUTOTVM_TOPHUB_ROOT_PATH
best_context = ApplyHistoryBest([])
if isinstance(target, str):
target = _target.create(target)
big_target = str(target).split()[0]
if os.path.isfile(os.path.join(rootpath, big_target + ".log")):
best_context.load(os.path.join(rootpath, big_target + ".log"))
for opt in target.options:
if opt.startswith("-device"):
model = _alias(opt[8:])
if os.path.isfile(os.path.join(rootpath, model) + ".log"):
best_context.load(os.path.join(rootpath, model) + ".log")
if extra_files:
for filename in extra_files:
best_context.load(filename)
return best_context
def download_package(backend):
"""Download pre-tuned parameters of operators for a backend
Parameters
----------
backend: str
The name of package
"""
rootpath = AUTOTVM_TOPHUB_ROOT_PATH
if not os.path.isdir(rootpath):
# make directory
splits = os.path.split(rootpath)
for j in range(1, len(splits)+1):
path = os.path.join(*splits[:j])
if not os.path.isdir(path):
os.mkdir(path)
backend = _alias(backend)
logging.info("Download pre-tuned parameters for %s", backend)
download("https://raw.githubusercontent.com/uwsaml/tvm-distro/master/tophub/%s.log" % backend,
os.path.join(rootpath, backend + ".log"), True, verbose=0)
def check_package(backend):
"""Check whether have pre-tuned parameters of the certain target.
If not, will download it.
Parameters
----------
backend: str
The name of package
"""
backend = _alias(backend)
if os.path.isfile(os.path.join(AUTOTVM_TOPHUB_ROOT_PATH, backend + ".log")):
return
download_package(backend)
def list_packages():
"""List all available pre-tuned op parameters for targets
Returns
-------
ret: List
All available packets
"""
path = tempdir()
filename = path.relpath("info.json")
logging.info("Download meta info for pre-tuned parameters")
download("https://raw.githubusercontent.com/uwsaml/tvm-distro/master/tophub/info.json",
filename, True, verbose=0)
with open(filename, "r") as fin:
text = "".join(fin.readlines())
info = json.loads(text)
keys = list(info.keys())
keys.sort()
return [(k, info[k]) for k in keys]
# pylint: disable=consider-using-enumerate,invalid-name
"""Namespace of callback utilities of AutoTVM"""
import sys
import time
import numpy as np
from .. import record
def log_to_file(file_out, protocol='json'):
"""Log the tuning records into file.
The rows of the log are stored in the format of autotvm.record.encode.
......@@ -21,7 +24,6 @@ def log_to_file(file_out, protocol='json'):
callback : callable
Callback function to do the logging.
"""
def _callback(_, inputs, results):
"""Callback implementation"""
if isinstance(file_out, str):
......@@ -34,55 +36,21 @@ def log_to_file(file_out, protocol='json'):
return _callback
def save_tuner_state(prefix, save_every_sample=100):
"""Save the state of tuner
def log_to_database(db):
"""Save the tuning records to a database object.
Parameters
----------
prefix : srt
prefix of the filename to store state
save_every_sample: int
save the state every x samples
Returns
-------
callback : function
Callback function to do the auto saving.
db: Database
The database
"""
def _callback(tuner, inputs, results):
for _, __ in zip(inputs, results):
try:
ct = len(tuner.visited)
except AttributeError:
ct = 0
if ct % save_every_sample == 0:
tuner.save_state(prefix + "_%d.state" % ct)
return _callback
def log_to_redis(host="localhost", port=6379, dbn=11):
"""Record the tuning record to a redis DB.
Parameters
----------
host: str, optional
Host address of redis db
port: int, optional
Port of redis db
dbn: int, optional
which redis db to use, default 11
"""
# import here so only depend on redis when necessary
import redis
red = redis.StrictRedis(host=host, port=port, db=dbn)
def _callback(_, inputs, results):
"""Callback implementation"""
for inp, result in zip(inputs, results):
red.set(inp, result)
db.save(inp, result)
return _callback
class Monitor(object):
"""A monitor to collect statistic during tuning"""
def __init__(self):
......@@ -110,3 +78,47 @@ class Monitor(object):
def trial_timestamps(self):
"""get wall clock time stamp of all trials"""
return np.array(self.timestamps)
def progress_bar(total, prefix=''):
"""Display progress bar for tuning
Parameters
----------
total: int
The total number of trials
prefix: str
The prefix of output message
"""
class _Context:
"""Context to store local variables"""
def __init__(self):
self.best_flops = 0
self.cur_flops = 0
self.ct = 0
self.total = total
def __del__(self):
sys.stdout.write(' Done.\n')
ctx = _Context()
tic = time.time()
def _callback(tuner, inputs, results):
ctx.ct += len(inputs)
flops = 0
for inp, res in zip(inputs, results):
if res.error_no == 0:
flops = inp.task.flop / np.mean(res.costs)
ctx.cur_flops = flops
ctx.best_flops = tuner.best_flops
sys.stdout.write('\r%s Current/Best: %7.2f/%7.2f GFLOPS | Progress: (%d/%d) '
'| %.2f s' %
(prefix, ctx.cur_flops/1e9, ctx.best_flops/1e9, ctx.ct, ctx.total,
time.time() - tic))
sys.stdout.flush()
return _callback
......@@ -117,3 +117,6 @@ class GATuner(Tuner):
def has_next(self):
return len(self.visited) - (len(self.genes) - self.trial_pt) < len(self.space)
def load_history(self, data_set):
pass
......@@ -25,6 +25,9 @@ class GridSearchTuner(Tuner):
def has_next(self):
return self.counter < len(self.task.config_space)
def load_history(self, data_set):
pass
def __getstate__(self):
return {"counter": self.counter}
......@@ -56,6 +59,9 @@ class RandomTuner(Tuner):
def has_next(self):
return len(self.visited) < len(self.task.config_space)
def load_history(self, data_set):
pass
def __getstate__(self):
return {"visited": self.counter}
......
......@@ -242,7 +242,7 @@ class ModelBasedTuner(Tuner):
self.ys.append(flops)
else:
self.xs.append(index)
self.ys.append(0)
self.ys.append(0.0)
# if we have enough new training samples
if len(self.xs) >= self.plan_size * (self.train_ct + 1) \
......
......@@ -26,11 +26,11 @@ class SimulatedAnnealingOptimizer(ModelOptimizer):
If is an Array, then perform linear cooling from temp[0] to temp[1]
early_stop: int, optional
Stop iteration if the optimal set do not change in `early_stop` rounds
verbose: int, optional
Print log every `verbose` iterations
log_interval: int, optional
Print log every `log_interval` iterations
"""
def __init__(self, task, n_iter=500, temp=(1, 0), persistent=True, parallel_size=128,
early_stop=50, verbose=50):
early_stop=50, log_interval=50):
super(SimulatedAnnealingOptimizer, self).__init__()
self.task = task
......@@ -41,12 +41,13 @@ class SimulatedAnnealingOptimizer(ModelOptimizer):
self.persistent = persistent
self.parallel_size = min(parallel_size, len(self.task.config_space))
self.early_stop = early_stop or 1e9
self.verbose = verbose
self.log_interval = log_interval
self.points = None
def find_maximums(self, model, num, exclusive):
tic = time.time()
temp, n_iter, early_stop, verbose = self.temp, self.n_iter, self.early_stop, self.verbose
temp, n_iter, early_stop, log_interval = \
self.temp, self.n_iter, self.early_stop, self.log_interval
if self.persistent and self.points is not None:
points = self.points
......@@ -100,19 +101,18 @@ class SimulatedAnnealingOptimizer(ModelOptimizer):
k += 1
t -= cool
if verbose >= 1 and k % verbose == 0:
if log_interval and k % log_interval == 0:
t_str = "%.2f" % t
logging.info("SA iter: %d\tlast_update: %d\tmax-0: %.2f\tmax-1: %.2f\ttemp: %s\t"
"elapsed: %.2f",
k, k_last_modify, heap_items[0][0],
np.max([v for v, _ in heap_items]), t_str,
time.time() - tic)
logging.debug("SA iter: %d\tlast_update: %d\tmax-0: %.2f\tmax-1: %.2f\ttemp: %s\t"
"elapsed: %.2f",
k, k_last_modify, heap_items[0][0],
np.max([v for v, _ in heap_items]), t_str,
time.time() - tic)
heap_items.sort(key=lambda item: -item[0])
if verbose:
logging.info("SA iter: %d\tlast_update: %d\tmax-0: %.2f\tmax-1: %.2f\telapsed: %.2f",
k, k_last_modify, heap_items[-1][0], heap_items[0][0], time.time() - tic)
logging.info("SA Maximums: %s", heap_items)
logging.debug("SA iter: %d\tlast_update: %d\tmax-0: %.2f\tmax-1: %.2f\telapsed: %.2f",
k, k_last_modify, heap_items[-1][0], heap_items[0][0], time.time() - tic)
logging.debug("SA Maximums: %s", heap_items)
if self.persistent:
self.points = points
......
......@@ -7,6 +7,7 @@ import numpy as np
from ..measure import MeasureInput
from ..measure import create_measure_batch
from ..env import GLOBAL_SCOPE
class Tuner(object):
"""Base class for tuners
......@@ -64,7 +65,7 @@ class Tuner(object):
"""
pass
def tune(self, n_trial, measure_option, early_stop=None, verbose=1, callbacks=()):
def tune(self, n_trial, measure_option, early_stopping=None, callbacks=()):
"""Begin tuning
Parameters
......@@ -74,11 +75,8 @@ class Tuner(object):
measure_option: dict
The options for how to measure generated code.
You should use the return value ot autotvm.measure_option for this argument.
early_stop: int
early_stopping: int
Early stop the tuning when not finding better configs in this number of trials
verbose: int
0: silent mode, no output
1: print every measurement result
callbacks: List of callable
A list of callback functions. The signature of callback function is
(Tuner, List of MeasureInput, List of MeasureResult)
......@@ -87,8 +85,9 @@ class Tuner(object):
"""
measure_batch = create_measure_batch(self.task, measure_option)
parallel_num = getattr(measure_batch, 'parallel_num', 1)
early_stop = early_stop or 1e9
early_stopping = early_stopping or 1e9
GLOBAL_SCOPE.in_tuning = True
i = 0
while i < n_trial:
if not self.has_next():
......@@ -99,23 +98,22 @@ class Tuner(object):
inputs = [MeasureInput(self.task.target, self.task, config) for config in configs]
results = measure_batch(inputs)
# print info
if verbose >= 1:
for k, (inp, res) in enumerate(zip(inputs, results)):
config = inp.config
if res.error_no == 0:
flops = inp.task.flop / np.mean(res.costs)
else:
flops = 0
if flops > self.best_flops:
self.best_flops = flops
self.best_config = config
self.best_measure_pair = (inp, res)
self.best_iter = i + k
logging.info("No: %d\tGFLOPS: %.2f/%.2f\tresult: %s\t%s",
i + k + 1, flops / 1e9, self.best_flops / 1e9,
res, config)
# keep best config
for k, (inp, res) in enumerate(zip(inputs, results)):
config = inp.config
if res.error_no == 0:
flops = inp.task.flop / np.mean(res.costs)
else:
flops = 0
if flops > self.best_flops:
self.best_flops = flops
self.best_config = config
self.best_measure_pair = (inp, res)
self.best_iter = i + k
logging.debug("No: %d\tGFLOPS: %.2f/%.2f\tresult: %s\t%s",
i + k + 1, flops / 1e9, self.best_flops / 1e9,
res, config)
i += len(results)
......@@ -124,10 +122,12 @@ class Tuner(object):
for callback in callbacks:
callback(self, inputs, results)
if i > self.best_iter + early_stop:
logging.info("Early stopped. Best iter: %d.", self.best_iter)
if i > self.best_iter + early_stopping:
logging.debug("Early stopped. Best iter: %d.", self.best_iter)
break
GLOBAL_SCOPE.in_tuning = False
del measure_batch
def reset(self):
......
......@@ -42,10 +42,10 @@ class XGBoostCostModel(CostModel):
The cost model predicts relative rank score.
num_threads: int, optional
The number of threads.
verbose: int, optional
If is not none, the cost model will print training log every `verbose` iterations.
log_interval: int, optional
If is not none, the cost model will print training log every `log_interval` iterations.
"""
def __init__(self, task, feature_type, loss_type, num_threads=None, verbose=20):
def __init__(self, task, feature_type, loss_type, num_threads=None, log_interval=25):
super(XGBoostCostModel, self).__init__()
if xgb is None:
......@@ -60,7 +60,7 @@ class XGBoostCostModel(CostModel):
self.fea_type = feature_type
self.loss_type = loss_type
self.num_threads = num_threads
self.verbose = verbose
self.log_interval = log_interval
if loss_type == 'reg':
self.xgb_params = {
......@@ -139,7 +139,8 @@ class XGBoostCostModel(CostModel):
x_train = self._get_feature(xs)
y_train = np.array(ys)
y_train = y_train / np.max(y_train)
y_max = np.max(y_train)
y_train = y_train / max(y_max, 1e-8)
valid_index = y_train > 1e-6
index = np.random.permutation(len(x_train))
......@@ -160,19 +161,20 @@ class XGBoostCostModel(CostModel):
fevals=[
xgb_average_recalln_curve_score(plan_size),
],
verbose_eval=self.verbose)])
verbose_eval=self.log_interval)])
logging.info("train: %.2f\tobs: %d\terror: %d\tn_cache: %d",
time.time() - tic, len(xs),
len(xs) - np.sum(valid_index),
self.feature_cache.size(self.fea_type))
logging.debug("XGB train: %.2f\tobs: %d\terror: %d\tn_cache: %d",
time.time() - tic, len(xs),
len(xs) - np.sum(valid_index),
self.feature_cache.size(self.fea_type))
def fit_log(self, records, plan_size):
tic = time.time()
self._reset_pool()
args = list(records)
logging.info("Load %d entries from history log file", len(args))
logging.debug("XGB load %d entries from history log file", len(args))
if self.fea_type == 'itervar':
feature_extract_func = _extract_itervar_feature_log
elif self.fea_type == 'knob':
......@@ -187,7 +189,8 @@ class XGBoostCostModel(CostModel):
x_train = xs
y_train = ys
y_train /= np.max(y_train)
y_max = np.max(y_train)
y_train = y_train / max(y_max, 1e-8)
index = np.random.permutation(len(x_train))
dtrain = xgb.DMatrix(x_train[index], y_train[index])
......@@ -203,9 +206,9 @@ class XGBoostCostModel(CostModel):
fevals=[
xgb_average_recalln_curve_score(plan_size),
],
verbose_eval=self.verbose)])
verbose_eval=self.log_interval)])
logging.info("train: %.2f\tobs: %d", time.time() - tic, len(xs))
logging.debug("XGB train: %.2f\tobs: %d", time.time() - tic, len(xs))
def predict(self, xs, output_margin=False):
feas = self._get_feature(xs)
......@@ -232,7 +235,7 @@ class XGBoostCostModel(CostModel):
def clone_new(self):
return XGBoostCostModel(self.task, self.fea_type, self.loss_type,
self.num_threads, self.verbose)
self.num_threads, self.log_interval)
def _get_feature(self, indexes):
"""get features for indexes, run extraction if we do not have cache for them"""
......@@ -282,7 +285,7 @@ def _extract_itervar_feature_log(arg):
if res.error_no == 0:
y = inp.task.flop / np.mean(res.costs)
else:
y = 0
y = 0.0
return x, y
def _extract_knob_feature_index(index):
......@@ -301,7 +304,7 @@ def _extract_knob_feature_log(arg):
inp.task.instantiate(config)
y = inp.task.flop / np.mean(res.costs)
else:
y = 0
y = 0.0
return x, y
def _extract_curve_feature_index(index):
......@@ -325,12 +328,11 @@ def _extract_curve_feature_log(arg):
if res.error_no == 0:
y = inp.task.flop / np.mean(res.costs)
else:
y = 0
y = 0.0
return x, y
def custom_callback(stopping_rounds, metric, fevals, evals=(), log_file=None,
save_file="xgb_checkpoint", save_every=None,
maximize=False, verbose_eval=True):
"""callback function for xgboost to support multiple custom evaluation functions"""
from xgboost.core import EarlyStopException
......@@ -400,18 +402,12 @@ def custom_callback(stopping_rounds, metric, fevals, evals=(), log_file=None,
continue
infos.append("%s: %.6f" % (item[0], item[1]))
if not isinstance(verbose_eval, bool) and i % verbose_eval == 0:
logging.info("\t".join(infos))
if not isinstance(verbose_eval, bool) and verbose_eval and i % verbose_eval == 0:
logging.debug("\t".join(infos))
if log_file:
with open(log_file, "a") as fout:
fout.write("\t".join(infos) + '\n')
##### save model #####
if save_every and i % save_every == 0:
filename = save_file + ".%05d.bst" % i
logging.info("save model to %s ...", filename)
bst.save_model(filename)
##### choose score and do early stopping #####
score = None
for item in eval_res:
......@@ -439,7 +435,7 @@ def custom_callback(stopping_rounds, metric, fevals, evals=(), log_file=None,
elif env.iteration - best_iteration >= stopping_rounds:
best_msg = state['best_msg']
if verbose_eval and env.rank == 0:
logging.info("Stopping. Best iteration: %s ", best_msg)
logging.debug("XGB stopped. Best iteration: %s ", best_msg)
raise EarlyStopException(best_iteration)
return callback
......
......@@ -40,16 +40,21 @@ class XGBTuner(ModelBasedTuner):
If is not None, the tuner will first select
top-(plan_size * diversity_filter_ratio) candidates according to the cost model
and then pick batch_size of them according to the diversity metric.
log_interval: int, optional
The verbose level.
If is 0, output nothing.
Otherwise, output debug information every `verbose` iterations.
"""
def __init__(self, task, plan_size=32,
feature_type='itervar', loss_type='rank', num_threads=None,
optimizer='sa', diversity_filter_ratio=None):
optimizer='sa', diversity_filter_ratio=None, log_interval=50):
cost_model = XGBoostCostModel(task,
feature_type=feature_type,
loss_type=loss_type,
num_threads=num_threads)
num_threads=num_threads,
log_interval=log_interval // 2)
if optimizer == 'sa':
optimizer = SimulatedAnnealingOptimizer(task)
optimizer = SimulatedAnnealingOptimizer(task, log_interval=log_interval)
else:
assert isinstance(optimizer, ModelOptimizer), "Optimizer must be " \
"a supported name string" \
......
......@@ -8,6 +8,16 @@ import numpy as np
from .. import expr, ir_pass
class EmptyContext(object):
"""An empty context"""
def __enter__(self):
pass
def __exit__(self, exc_type, exc_val, exc_tb):
pass
def get_rank(values):
"""get rank of items
......
......@@ -6,7 +6,7 @@ import os
import sys
import time
def download(url, path, overwrite=False, size_compare=False):
def download(url, path, overwrite=False, size_compare=False, verbose=1):
"""Downloads the file from the internet.
Set the input options correctly to overwrite or do the size comparison
......@@ -23,9 +23,10 @@ def download(url, path, overwrite=False, size_compare=False):
size_compare : bool, optional
Whether to do size compare to check downloaded file.
"""
import requests
verbose: int, optional
Verbose level
"""
if sys.version_info >= (3,):
import urllib.request as urllib2
else:
......@@ -33,6 +34,7 @@ def download(url, path, overwrite=False, size_compare=False):
if os.path.isfile(path) and not overwrite:
if size_compare:
import requests
file_size = os.path.getsize(path)
res_head = requests.head(url)
res_get = requests.get(url, stream=True)
......@@ -45,7 +47,9 @@ def download(url, path, overwrite=False, size_compare=False):
return
print('File {} exists, skip.'.format(path))
return
print('Downloading from url {} to {}'.format(url, path))
if verbose >= 1:
print('Downloading from url {} to {}'.format(url, path))
# Stateful start time
start_time = time.time()
......
......@@ -142,3 +142,35 @@ def which(exec_name):
if os.path.isfile(full_path) and os.access(full_path, os.X_OK):
return full_path
return None
def get_lower_ir(s):
"""Get lower ir code of a schedule.
This is useful for debug, since you don't have to find all inputs/outputs
for a schedule in a fused subgraph.
Parameters
----------
s: Schedule
Returns
-------
ir: str
The lower ir
"""
from .. import tensor
from ..build_module import lower
outputs = s.outputs
inputs = []
def find_all(op):
if isinstance(op, tensor.PlaceholderOp):
inputs.append(op.output(0))
else:
for x in op.input_tensors:
find_all(x.op)
for out in outputs:
find_all(out)
return lower(s, inputs, simple_mode=True)
# pylint: disable=invalid-name
"""Pick best log entries from a large file and store them to a small file"""
import argparse
import os
import logging
import warnings
from .. import autotvm
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--act", type=str, choices=['pick-best'],
help="The action")
parser.add_argument("--i", type=str, help="The input file or directory")
parser.add_argument("--o", type=str, help="The output file")
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
if args.act == 'pick-best':
if os.path.isfile(args.i):
args.o = args.o or args.i + ".best.log"
autotvm.record.pick_best(args.i, args.o)
elif os.path.isdir(args.i):
args.o = args.o or "best.log"
tmp_filename = args.o + ".tmp"
with open(tmp_filename, 'w') as tmp_fout:
for filename in os.listdir(args.i):
if filename.endswith(".log"):
try:
autotvm.record.pick_best(filename, tmp_fout)
except Exception: # pylint: disable=broad-except
warnings.warn("Ignore invalid file %s" % filename)
logging.info("Run final filter...")
autotvm.record.pick_best(tmp_filename, args.o)
os.remove(tmp_filename)
logging.info("Output to %s ...", args.o)
else:
raise ValueError("Invalid input file: " + args.i)
else:
raise ValueError("Invalid action " + args.act)
......@@ -40,20 +40,21 @@ if __name__ == "__main__":
help='The port of the PRC')
parser.add_argument('--port-end', type=int, default=9199,
help='The end search port of the PRC')
parser.add_argument('--tracker', type=str,
help="The address of RPC tracker in host:port format. "
"e.g. (10.77.1.234:9190)")
parser.add_argument('--key', type=str, default="",
help="RPC key used to identify the connection type.")
parser.add_argument('--load-library', type=str, default="",
help="The key used to identify the device type in tracker.")
parser.add_argument('--silent', action='store_true',
help="Whether run in silent mode.")
parser.add_argument('--load-library', type=str,
help="Additional library to load")
parser.add_argument('--tracker', type=str, default="",
help="Report to RPC tracker")
parser.add_argument('--no-fork', dest='fork', action='store_false',
help="Use spawn mode to avoid fork. This option \
is able to avoid potential fork problems with Metal, OpenCL \
and ROCM compilers.")
parser.add_argument('--custom-addr', type=str,
help="Custom IP Address to Report to RPC Tracker")
parser.add_argument('--silent', action='store_true',
help="Whether run in silent mode.")
parser.set_defaults(fork=True)
args = parser.parse_args()
......
......@@ -6,13 +6,12 @@ import logging
import argparse
import multiprocessing
import sys
from ..rpc.tracker import Tracker
from .. import rpc
def main(args):
"""Main funciton"""
tracker = Tracker(args.host, port=args.port, port_end=args.port_end,
silent=args.silent)
tracker = rpc.Tracker(args.host, port=args.port, port_end=args.port_end,
silent=args.silent)
tracker.proc.join()
......
# pylint: disable=invalid-name
"""Download pre-tuned parameters of ops"""
import argparse
import logging
from ..autotvm.tophub import list_packages, download_package
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--download", type=str, nargs='+',
help="Target to download. Use 'all' to download for all targets")
parser.add_argument("-l", "--list", action='store_true', help="List available packages")
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
if args.list:
info = list_packages()
print("\n%-20s %-20s" % ("Target", "Size"))
print("-" * 41)
for target, info in info:
print("%-20s %-20s" % (target, "%.2f MB" % (info['size']/1000000)))
if args.download:
info = list_packages()
all_targets = [x[0] for x in info]
if 'all' in args.download:
targets = all_targets
else:
targets = args.download
for t in targets:
if t not in all_targets:
print("Warning : cannot find tuned parameters of " + t + ". (ignored)")
download_package(t)
......@@ -10,4 +10,6 @@ upload and run remote RPC server, get the result back to verify correctness.
"""
from .server import Server
from .tracker import Tracker
from .proxy import Proxy
from .client import RPCSession, LocalSession, TrackerSession, connect, connect_tracker
......@@ -225,18 +225,24 @@ class TrackerSession(object):
res += item["key"] + "\n"
res += "----------------------------\n"
res += "\n"
res += "Queue Status\n"
res += "----------------------------\n"
res += "key\tfree\tpending\n"
res += "----------------------------\n"
# compute max length of device key
queue_info = data['queue_info']
keys = list(queue_info.keys())
if keys:
keys.sort()
max_key_len = max([len(k) for k in keys])
for k in keys:
res += ("%%-%d" % max_key_len + "s\t%d\t%g\n") % \
(k, queue_info[k]["free"], queue_info[k]["pending"])
else:
max_key_len = 0
res += "Queue Status\n"
res += "----------------------------\n"
res += ("%%-%ds" % max_key_len + "\tfree\tpending\n") % 'key'
res += "----------------------------\n"
for k in keys:
res += ("%%-%ds" % max_key_len + "\t%d\t%g\n") % \
(k, queue_info[k]["free"], queue_info[k]["pending"])
res += "----------------------------\n"
return res
......
......@@ -460,6 +460,10 @@ class Proxy(object):
timeout_server : float, optional
Timeout of server until it sees a matching connection.
tracker_addr: Tuple (str, int) , optional
The address of RPC Tracker in tuple (host, ip) format.
If is not None, the server will register itself to the tracker.
index_page : str, optional
Path to an index page that can be used to display at proxy index.
......
......@@ -20,6 +20,7 @@ import multiprocessing
import subprocess
import time
import sys
import signal
from .._ffi.function import register_func
from .._ffi.base import py_str
......@@ -257,7 +258,7 @@ def _popen(cmd):
class Server(object):
"""Start RPC server on a seperate process.
"""Start RPC server on a separate process.
This is a simple python implementation based on multi-processing.
It is also possible to implement a similar C based sever with
......@@ -284,14 +285,21 @@ class Server(object):
This is recommended to switch on if we want to do local RPC demonstration
for GPU devices to avoid fork safety issues.
silent: bool, optional
Whether run this server in silent mode.
tracker_addr: Tuple (str, int) , optional
The address of RPC Tracker in tuple(host, ip) format.
If is not None, the server will register itself to the tracker.
key : str, optional
The key used to identify the server in Proxy connection.
The key used to identify the device type in tracker.
load_library : str, optional
List of additional libraries to be loaded during execution.
custom_addr: str, optional
Custom IP Address to Report to RPC Tracker
silent: bool, optional
Whether run this server in silent mode.
"""
def __init__(self,
host,
......@@ -299,11 +307,11 @@ class Server(object):
port_end=9199,
is_proxy=False,
use_popen=False,
silent=False,
tracker_addr=None,
key="",
load_library=None,
custom_addr=None):
custom_addr=None,
silent=False):
try:
if base._ServerLoop is None:
raise RuntimeError("Please compile with USE_RPC=1")
......@@ -313,6 +321,7 @@ class Server(object):
self.port = port
self.libs = []
self.custom_addr = custom_addr
self.use_popen = use_popen
self.logger = logging.getLogger("RPCServer")
if silent:
......@@ -334,10 +343,7 @@ class Server(object):
if silent:
cmd += ["--silent"]
self.proc = multiprocessing.Process(
target=subprocess.check_call, args=(cmd,))
self.proc.deamon = True
self.proc.start()
self.proc = subprocess.Popen(cmd, preexec_fn=os.setsid)
time.sleep(0.5)
elif not is_proxy:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
......@@ -371,9 +377,14 @@ class Server(object):
def terminate(self):
"""Terminate the server process"""
if self.proc:
self.proc.terminate()
self.proc = None
if self.use_popen:
if self.proc:
os.killpg(self.proc.pid, signal.SIGTERM)
self.proc = None
else:
if self.proc:
self.proc.terminate()
self.proc = None
def __del__(self):
self.terminate()
......@@ -40,6 +40,8 @@ We can also use other specific function in this module to create specific target
"""
from __future__ import absolute_import
import warnings
from ._ffi.base import _LIB_NAME
from ._ffi.node import NodeBase, register_node
from . import _api_internal
......@@ -51,7 +53,6 @@ except ImportError as err_msg:
if _LIB_NAME != "libtvm_runtime.so":
raise err_msg
def _merge_opts(opts, new_opts):
"""Helper function to merge options"""
if isinstance(new_opts, str):
......@@ -72,7 +73,7 @@ class Target(NodeBase):
Do not use class constructor, you can create target using the following functions
- :any:`tvm.target.create` create target from string
- :any:`tvm.target.rasp` create raspberry pi target
- :any:`tvm.target.arm_cpu` create arm_cpu target
- :any:`tvm.target.cuda` create CUDA target
- :any:`tvm.target.rocm` create ROCM target
- :any:`tvm.target.mali` create Mali target
......@@ -374,22 +375,6 @@ def rocm(options=None):
return _api_internal._TargetCreate("rocm", *options)
def rasp(options=None):
"""Returns a rasp target.
Parameters
----------
options : str or list of str
Additional options
"""
opts = ["-device=rasp",
"-mtriple=armv7l-none-linux-gnueabihf",
"-mcpu=cortex-a53",
"-mattr=+neon"]
opts = _merge_opts(opts, options)
return _api_internal._TargetCreate("llvm", *opts)
def mali(options=None):
"""Returns a ARM Mali GPU target.
......@@ -428,6 +413,52 @@ def opengl(options=None):
return _api_internal._TargetCreate("opengl", *options)
def arm_cpu(model='unknown', options=None):
"""Returns a ARM CPU target.
This function will also download pre-tuned op parameters when there is none.
Parameters
----------
model: str
SoC name or phone name of the arm board.
options : str or list of str
Additional options
"""
from . import autotvm
trans_table = {
"pixel2": ["-model=snapdragon835", "-target=arm64-linux-android"],
"mate10": ["-model=kirin970", "-target=arm64-linux-android"],
"mate10pro": ["-model=kirin970", "-target=arm64-linux-android"],
"p20": ["-model=kirin970", "-target=arm64-linux-android"],
"p20pro": ["-model=kirin970", "-target=arm64-linux-android"],
"rasp3b": ["-model=bcm2837", "-target=armv7l-linux-gnueabihf"],
"rk3399": ["-model=rk3399", "-target=aarch64-linux-gnu"],
"pynq": ["-model=pynq", "-target=armv7a-linux-eabi"],
}
pre_defined_opt = trans_table.get(model, ["-model=%s" % model])
# download pre-tuned parameters for arm_cpu if there is not any.
autotvm.tophub.check_package('arm_cpu')
opts = ["-device=arm_cpu"] + pre_defined_opt
opts = _merge_opts(opts, options)
return _api_internal._TargetCreate("llvm", *opts)
def rasp(options=None):
"""Return a Raspberry 3b target.
Parameters
----------
options : str or list of str
Additional options
"""
warnings.warn('tvm.target.rasp() is going to be deprecated. '
'Please use tvm.target.arm_cpu("rasp3b")')
return arm_cpu('rasp3b', options)
def create(target_str):
"""Get a target given target string.
......
......@@ -261,15 +261,6 @@ Target metal(const std::vector<std::string>& options) {
return CreateTarget("metal", options);
}
Target rasp(const std::vector<std::string>& options) {
return CreateTarget("llvm", MergeOptions(options, {
"-device=rasp",
"-mtriple=armv7l-none-linux-gnueabihf",
"-mcpu=cortex-a53",
"-mattr=+neon"
}));
}
Target mali(const std::vector<std::string>& options) {
return CreateTarget("opencl", MergeOptions(options, {
"-device=mali"
......@@ -731,11 +722,6 @@ TVM_REGISTER_API("_GetCurrentTarget")
TVM_REGISTER_API("_EnterTargetScope")
.set_body([](TVMArgs args, TVMRetValue* ret) {
Target target = args[0];
auto current = Target::current_target();
if (current.defined() && target->str() != current->str()) {
LOG(WARNING) << "Overriding target " << current->str()
<< " with new target scope " << target->str();
}
Target::EnterTargetScope(target);
});
......
......@@ -13,7 +13,6 @@ Module OpenCLModuleCreate(
std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap,
std::string source) {
LOG(WARNING) << "OpenCL runtime not enabled, return a source module...";
return codegen::DeviceSourceModuleCreate(data, fmt, fmap, "opencl");
}
......
......@@ -108,19 +108,19 @@ def test_task_tuner_without_measurement():
"""test task and tuner without measurement"""
task, target = get_sample_task()
def measure_batch(inputs):
def custom_measure(input_pack, build_func, build_args, number, repeat,
ref_input, ref_output):
from tvm.autotvm import MeasureResult
results = []
for inp in inputs:
for inp in input_pack:
tic = time.time()
# do nothing
time.sleep(0.001)
results.append(MeasureResult([time.time() - tic], 0,
time.time() - tic, time.time()))
return results
measure_option = autotvm.measure_option(mode='custom',
custom_measure_batch=measure_batch)
measure_option = autotvm.measure_option(custom_measure)
logging.info("%s", task.config_space)
......@@ -128,6 +128,7 @@ def test_task_tuner_without_measurement():
for tuner_class in [autotvm.tuner.RandomTuner, autotvm.tuner.GridSearchTuner]:
tuner = tuner_class(task)
tuner.tune(n_trial=10, measure_option=measure_option)
assert tuner.best_flops > 1
def test_tuning_with_measure():
def check(target, target_host):
......@@ -140,7 +141,7 @@ def test_tuning_with_measure():
task, target = get_sample_task(target, target_host)
logging.info("%s", task.config_space)
measure_option = autotvm.measure_option(mode='local',
measure_option = autotvm.measure_option('local',
timeout=4,
number=2)
......@@ -152,7 +153,8 @@ def test_tuning_with_measure():
if __name__ == "__main__":
# only print log when invoked from main
logging.basicConfig(level=logging.INFO)
logging.basicConfig(level=logging.DEBUG)
test_task_tuner_without_measurement()
test_tuning_with_measure()
......@@ -47,7 +47,7 @@ def test_db_filter():
batch_size = 2
measure_option = autotvm.measure_option(mode='local-nofork', timeout=2)
measure_option = autotvm.measure_option('local', do_fork=False, timeout=2)
measure_batch = autotvm.measure.create_measure_batch(task, measure_option)
ct = 0
......@@ -72,7 +72,7 @@ def test_db_filter():
db.flush()
# First setting, memoize one input at a time, check that each is saved and replayed
measure_option = autotvm.measure_option(mode='local-nofork', timeout=2, replay_db=db)
measure_option = autotvm.measure_option('local', do_fork=False, timeout=2, replay_db=db)
measure_batch = autotvm.measure.create_measure_batch(task, measure_option)
for i in range(len(all_inputs)+1):
......@@ -160,9 +160,10 @@ def test_db_save_replay():
if not ctx.exist:
logging.warning("Skip this test because there is no supported device for test")
measure_option = autotvm.measure_option(mode='local-nofork',
measure_option = autotvm.measure_option('local',
do_fork=False,
timeout=2,
replay_db=_db, save_to_replay_db=True)
replay_db=_db)
measure_batch = autotvm.measure.create_measure_batch(task, measure_option)
batch_size = 2
......@@ -182,6 +183,8 @@ def test_db_save_replay():
results = measure_batch(inputs)
all_results += results
ct += 1
callback = autotvm.callback.log_to_database(_db)
callback(None, all_inputs, all_results)
assert len(_db.db.keys()) == batch_size * TRIAL_LIMIT, \
"%d vs %d" % (len(_db.db.keys()), batch_size * TRIAL_LIMIT)
......@@ -207,7 +210,7 @@ def test_check_hashmismatch():
if not ctx.exist:
logging.warning("Skip this test because there is no supported device for test")
measure_option = autotvm.measure_option(mode='local-nofork')
measure_option = autotvm.measure_option('local', do_fork=False)
measure_batch = autotvm.measure.create_measure_batch(task, measure_option)
inputs = list()
......
......@@ -84,7 +84,7 @@ def test_feature_shape():
targets = [
tvm.target.cuda(),
tvm.target.mali(),
tvm.target.rasp(),
tvm.target.arm_cpu(),
]
for target in targets:
......
......@@ -28,7 +28,7 @@ def test_target_dispatch():
with tvm.target.create("cuda"):
assert mygeneric(1) == 3
with tvm.target.rasp():
with tvm.target.arm_cpu():
assert mygeneric(1) == 11
with tvm.target.create("metal"):
......
......@@ -2,6 +2,9 @@
export PYTHONPATH=python:nnvm/python:vta/python:topi/python
rm -rf python/tvm/*.pyc python/tvm/*/*.pyc python/tvm/*/*/*.pyc python/tvm/*/*/*/*.pyc
rm -rf ~/.tvm
echo "Running unittest..."
python -m nose -v vta/tests/python/unittest || exit -1
python3 -m nose -v vta/tests/python/unittest || exit -1
......
......@@ -24,7 +24,7 @@ from .broadcast import *
from . import nn
from . import x86
from . import cuda
from . import rasp
from . import arm_cpu
from . import mali
from . import intel_graphics
from . import opengl
......
"""Schedule for ARM CPU"""
from . import conv2d
from . import depthwise_conv2d
from . import bitserial_conv2d
......@@ -43,7 +43,7 @@ _QUANTIZED_SCHEDULES_NCHW = [
SpatialPackNCHW(1, 1, 8, 1, 16),
]
@_get_schedule.register("rasp")
@_get_schedule.register("arm_cpu")
def _get_schedule_bitserial_conv2d(wkl, layout):
if wkl not in _WORKLOADS:
raise ValueError("no schedule for such workload: {}".format(wkl))
......@@ -55,7 +55,7 @@ def _get_schedule_bitserial_conv2d(wkl, layout):
return sch
@bitserial_conv2d.register("rasp")
@bitserial_conv2d.register("arm_cpu")
def _declaration_bitserial_conv2d(data, kernel, stride, padding, activation_bits, weight_bits,
layout='NCHW', pack_dtype=None, out_dtype=None, dorefa=False):
if out_dtype is None:
......@@ -323,7 +323,7 @@ def _schedule_spatial_conv2d_nhwc(s, data, data_q, data_pad, data_vec,
s = s.normalize()
return s
@generic.schedule_bitserial_conv2d_nhwc.register(["rasp"])
@generic.schedule_bitserial_conv2d_nhwc.register(["arm_cpu"])
def schedule_bitserial_conv2d_nhwc(outs):
"""Raspverry pi schedule for bitserial conv2d"""
s = tvm.create_schedule([x.op for x in outs])
......
# pylint: disable=invalid-name,unused-variable
"""Depthwise convolution schedule for ARM CPU"""
import tvm
from tvm import autotvm
from ..generic import schedule_depthwise_conv2d_nchw
from ..nn import depthwise_conv2d_nchw
from ..util import traverse_inline
# register original implementation of depthwise_conv2d_nchw since we don't need to change this part
autotvm.task.register_topi_compute(depthwise_conv2d_nchw, 'arm_cpu', 'direct',
depthwise_conv2d_nchw.fdefault)
# register customized schedule for arm cpu.
@autotvm.task.register_topi_schedule(schedule_depthwise_conv2d_nchw, 'arm_cpu', 'direct')
def schedule_depthwise_conv2d_nchw_(cfg, outs):
"""Schedule depthwise conv2d"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
def _schedule(cfg, s, data, data_pad, kernel, output):
A, B, C = data, kernel, output
s[data_pad].compute_inline()
# define tile
n, c, h, w = s[output].op.axis
cfg.define_split('tile_c', c, num_outputs=2)
cfg.define_split('tile_h', h, num_outputs=2)
cfg.define_split('tile_w', w, num_outputs=2)
# park data to vector form [n, c, h, w] -> [n, C, h, w, VC]
A0 = s.cache_read(data_pad, "global", C)
_, c, h, w = s[A0].op.axis
c, vc = cfg['tile_c'].apply(s, A0, c)
s[A0].reorder(c, h, w, vc)
A1 = s.cache_write(A0, 'global')
s[A0].compute_inline()
# park kernel to vector form [co, ci, kh, kw] -> [CO, ci, kh, kw, VC]
B0 = s.cache_read(B, "global", C)
c, m, h, w = s[B0].op.axis
c, vc, = cfg['tile_c'].apply(s, B0, c)
s[B0].reorder(c, m, h, w, vc)
B1 = s.cache_write(B0, 'global')
s[B0].compute_inline()
_, c, h, w = s[C].op.axis
c, vc, = cfg['tile_c'].apply(s, C, c)
s[C].reorder(c, h, w, vc)
# depthwise conv
C0 = s.cache_write(C, 'global')
_, c, h, w, vc = s[C0].op.axis
dh, dw = s[C0].op.reduce_axis
oh, ih = cfg['tile_h'].apply(s, C0, h)
ow, iw = cfg['tile_w'].apply(s, C0, w)
s[C0].reorder(c, oh, ow, dh, dw, ih, iw, vc)
s[A1].compute_at(s[C0], oh)
# try unroll and vectorization
cfg.define_annotate('ann', [ih, iw, vc], policy='try_unroll_vec')
cfg['ann'].apply(s, C0, [ih, iw, vc],
axis_lens=[cfg['tile_h'].size[-1],
cfg['tile_w'].size[-1],
cfg['tile_c'].size[-1]],
max_unroll=16,
cfg=cfg)
# mark parallel
n, c, h, w = s[C].op.axis
s[C].parallel(c)
n, c, h, w, vc = s[C0].op.axis
s[C0].parallel(c)
c, m, h, w, vc = s[B1].op.axis
s[B1].parallel(c)
return s
def _callback(op):
if op.tag == 'depthwise_conv2d_nchw':
output = op.output(0)
kernel = op.input_tensors[1]
data = op.input_tensors[0]
data_pad = None
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
data_pad = data
data = data_pad.op.input_tensors[0]
_schedule(cfg, s, data, data_pad, kernel, output)
traverse_inline(s, outs[0].op, _callback)
return s
......@@ -92,6 +92,54 @@ def schedule_conv2d_NCHWc(num_filter, kernel_size, strides,
@tvm.target.generic_func
def schedule_conv2d_winograd_weight_transform(outs):
"""Schedule for weight transformation of winograd
Parameters
----------
outs: Array of Tensor
The computation graph description of this operator
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
# Typically this is computed in nnvm PreCompute pass
# so we make a schedule here for cpu llvm
s = tvm.create_schedule([x.op for x in outs])
output = outs[0]
_, G = s[output].op.input_tensors
s[G].compute_inline()
eps, nu, co, ci = s[output].op.axis
r_kh, r_kw = s[output].op.reduce_axis
s[output].reorder(co, ci, r_kh, r_kw, eps, nu)
for axis in [r_kh, r_kw, eps, nu]:
s[output].unroll(axis)
s[output].parallel(co)
return s
@tvm.target.generic_func
def schedule_conv2d_winograd_without_weight_transform(outs):
"""Schedule for winograd without weight transformation
Parameters
----------
outs: Array of Tensor
The computation graph description of this operator
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_conv2d_transpose_nchw(outs):
"""Schedule for conv2d_transpose_nchw
......
# pylint: disable=redefined-builtin, wildcard-import
"""Raspberry pi specific declaration and schedules."""
from __future__ import absolute_import as _abs
from .conv2d import schedule_conv2d_nchw
from .depthwise_conv2d import schedule_depthwise_conv2d_nchw
from .bitserial_conv2d import schedule_bitserial_conv2d_nhwc
# pylint: disable=invalid-name,unused-variable, unused-argument
"""Schedule for depthwise_conv2d with auto fusion"""
from __future__ import absolute_import as _abs
from collections import namedtuple
import tvm
from .. import tag
from ..nn.util import infer_pad, infer_stride, get_pad_tuple
from .. import generic
_Workload = namedtuple('Workload',
['in_dtype', 'out_dtype', 'height', 'width', 'channel', 'multiplier',
'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])
_Schedule = namedtuple('Schedule', ['vh', 'vw', 'vc', 'bc', 'unroll'])
# workloads of depthwise conv mobile net on imagenet
_WORKLOADS = [
_Workload('float32', 'float32', 112, 112, 32, 1, 3, 3, 1, 1, 1, 1),
_Workload('float32', 'float32', 112, 112, 64, 1, 3, 3, 1, 1, 2, 2),
_Workload('float32', 'float32', 56, 56, 128, 1, 3, 3, 1, 1, 1, 1),
_Workload('float32', 'float32', 56, 56, 128, 1, 3, 3, 1, 1, 2, 2),
_Workload('float32', 'float32', 28, 28, 256, 1, 3, 3, 1, 1, 1, 1),
_Workload('float32', 'float32', 28, 28, 256, 1, 3, 3, 1, 1, 2, 2),
_Workload('float32', 'float32', 14, 14, 512, 1, 3, 3, 1, 1, 1, 1),
_Workload('float32', 'float32', 14, 14, 512, 1, 3, 3, 1, 1, 2, 2),
_Workload('float32', 'float32', 7, 7, 1024, 1, 3, 3, 1, 1, 1, 1),
_Workload('int16', 'int32', 112, 112, 32, 1, 3, 3, 1, 1, 1, 1),
_Workload('int16', 'int32', 112, 112, 64, 1, 3, 3, 1, 1, 2, 2),
_Workload('int16', 'int32', 56, 56, 128, 1, 3, 3, 1, 1, 1, 1),
_Workload('int16', 'int32', 56, 56, 128, 1, 3, 3, 1, 1, 2, 2),
_Workload('int16', 'int32', 28, 28, 256, 1, 3, 3, 1, 1, 1, 1),
_Workload('int16', 'int32', 28, 28, 256, 1, 3, 3, 1, 1, 2, 2),
_Workload('int16', 'int32', 14, 14, 512, 1, 3, 3, 1, 1, 1, 1),
_Workload('int16', 'int32', 14, 14, 512, 1, 3, 3, 1, 1, 2, 2),
_Workload('int16', 'int32', 7, 7, 1024, 1, 3, 3, 1, 1, 1, 1),
]
_SCHEDULES = [
_Schedule(2, 1, 4, 1, True),
_Schedule(2, 4, 4, 2, True),
_Schedule(2, 1, 4, 2, False),
_Schedule(2, 4, 4, 1, True),
_Schedule(4, 1, 4, 8, True),
_Schedule(1, 1, 4, 2, True),
_Schedule(1, 1, 8, 8, True),
_Schedule(1, 1, 4, 1, False),
_Schedule(1, 1, 4, 4, False),
_Schedule(2, 4, 4, 2, False),
_Schedule(2, 7, 4, 1, True),
_Schedule(2, 4, 4, 4, False),
_Schedule(2, 2, 4, 4, False),
_Schedule(2, 2, 8, 4, False),
_Schedule(2, 2, 4, 4, True),
_Schedule(2, 2, 8, 4, False),
_Schedule(1, 2, 8, 4, True),
_Schedule(1, 1, 4, 8, True),
]
def _get_workload(data, kernel, stride, padding, out_dtype):
_, C, IH, IW = [x.value for x in data.shape]
_, MT, KH, KW = [x.value for x in kernel.shape]
HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel)
if isinstance(stride, (tuple, list)):
HSTR, WSTR = stride
else:
HSTR, WSTR = stride, stride
return _Workload(data.dtype, out_dtype, IH, IW, C, MT, KH, KW, HPAD, WPAD, HSTR, WSTR)
def _schedule(s, data, data_pad, kernel, output, last):
padding = infer_pad(data, data_pad)
if data_pad is None:
stride = infer_stride(data, kernel, output)
else:
stride = infer_stride(data_pad, kernel, output)
wkl = _get_workload(data, kernel, stride, padding, output.dtype)
if wkl not in _WORKLOADS:
return s
# use specified schedule
sch = _SCHEDULES[_WORKLOADS.index(wkl)]
H, W = wkl.height, wkl.width
CN = wkl.channel
MT = wkl.multiplier
HK, WK = wkl.hkernel, wkl.wkernel
HPAD, WPAD = wkl.hpad, wkl.wpad
HSTR, WSTR = wkl.hstride, wkl.wstride
VH, VW = sch.vh, sch.vw
BC = sch.bc
VC = sch.vc
TH = H + 2*HPAD
TW = W + 2*WPAD
OH = (H + 2*HPAD - HK) / HSTR + 1
OW = (W + 2*WPAD - WK) / WSTR + 1
A, B, C = data, kernel, output
A0 = data_pad
A1 = s.cache_read(A0, "global", C)
_, c, h, w = s[A1].op.axis
c, vc = s[A1].split(c, VC)
s[A1].reorder(c, h, w, vc)
A2 = s.cache_write(A1, 'global')
s[A0].compute_inline()
s[A1].compute_inline()
B0 = s.cache_read(B, "global", C)
c, m, h, w = s[B0].op.axis
c, vc = s[B0].split(c, VC)
s[B0].reorder(c, m, h, w, vc)
B1 = s.cache_write(B0, 'global')
s[B0].compute_inline()
_, c, h, w = s[C].op.axis
c, vc = s[C].split(c, VC)
s[C].reorder(c, h, w, vc)
C0 = s.cache_write(C, 'global')
_, c, h, w, vc = s[C0].op.axis
dh, dw = s[C0].op.reduce_axis
oh, ow, ih, iw = s[C0].tile(h, w, VH, VW)
s[C0].reorder(c, oh, ow, dh, dw, ih, iw, vc)
if sch.unroll:
s[C0].unroll(iw)
s[C0].vectorize(vc)
# # s[C0].compute_at(s[C0], ow)
launch, c, _, _ = s[C].op.axis
s[C].pragma(launch, "parallel_launch_point")
s[C].parallel(c)
s[C].pragma(c, "parallel_stride_pattern")
s[C].pragma(c, "parallel_barrier_when_finish")
s[C0].compute_at(s[C], launch)
_, c, h, w, vc = s[C0].op.axis
s[C0].parallel(c)
s[C0].pragma(c, "parallel_stride_pattern")
s[C0].pragma(c, "parallel_barrier_when_finish")
s[A2].compute_at(s[C0], oh)
# parallel(s[A2], s[A2].op.axis[1], BC)
# # s[B0].compute_at(s[C0], ow)
s[B1].compute_at(s[C], launch)
c, m, h, w, vc = s[B1].op.axis
s[B1].parallel(c)
s[B1].pragma(c, "parallel_stride_pattern")
s[B1].pragma(c, "parallel_barrier_when_finish")
return s
@generic.schedule_depthwise_conv2d_nchw.register(["cpu", "rasp"])
def schedule_depthwise_conv2d_nchw(outs):
"""Schedule for depthwise_conv2d nchw forward.
Parameters
----------
outs: Array of Tensor
The computation graph description of depthwise_conv2d
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for depthwise_conv2d nchw.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
def traverse(op):
"""Internal travserse function"""
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(op.tag):
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors:
traverse(tensor.op)
# schedule depthwise_conv2d
if op.tag == 'depthwise_conv2d_nchw':
output = op.output(0)
kernel = op.input_tensors[1]
if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
s[kernel].compute_inline()
data = op.input_tensors[0]
data_pad = None
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
data_pad = data
data = data_pad.op.input_tensors[0]
_schedule(s, data, data_pad, kernel, output, outs[0])
traverse(outs[0].op)
return s
# pylint: disable=invalid-name
"""Common topi utilities"""
from __future__ import absolute_import as _abs
import tvm
from . import tag
def traverse_inline(s, op, callback):
"""Traverse computation graph and do auto inline
Parameters
----------
s: schedule
The schedule
op: Operation
The final output operator.
callback: callable
The callback function on each op
"""
if tag.is_injective(op.tag):
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors:
traverse_inline(s, tensor.op, callback)
callback(op)
def prod(x):
"""Get the product of every items in the tuple.
......@@ -151,3 +174,33 @@ def unravel_index(idx, shape):
idx = idx // shape[i]
indices = indices[::-1]
return indices
def const_matrix(matrix, name="const_matrix"):
"""convert a const numpy 2-dimensional matrix to tvm tensor
Parameters
----------
matrix: numpy.ndarray
Const input array
name: str, optional
The name of output op
Returns
-------
tensor: Tensor
The created tensor
"""
row, col = matrix.shape
dtype = str(matrix.dtype)
def select_array(i, j):
now = tvm.const(0.0, dtype)
for ii in range(row):
for jj in range(col):
now = tvm.select(tvm.all(i % row == ii, j % col == jj),
tvm.const(matrix[ii][jj], dtype),
now)
return now
return tvm.compute(matrix.shape, select_array, name=name)
......@@ -22,7 +22,7 @@ def verify_bitserial_conv2d_nhwc(batch, in_size, in_channel, num_filter, kernel,
input_type='uint32'
out_dtype='int32'
with tvm.target.rasp():
with tvm.target.arm_cpu('rasp3b'):
A = tvm.placeholder((batch, in_height, in_width, in_channel), dtype=input_type, name='A')
W = tvm.placeholder((kernel, kernel, in_channel, num_filter), dtype=input_type, name='W')
B = topi.nn.bitserial_conv2d(A, W, stride, padding, activation_bits, weight_bits, out_dtype=out_dtype,
......
......@@ -2,6 +2,7 @@
import os
import numpy as np
import tvm
from tvm import autotvm
import topi
import topi.testing
from tvm.contrib.pickle_memoize import memoize
......@@ -11,10 +12,10 @@ from topi.util import get_const_tuple
def verify_conv2d(batch, in_size, in_channel, num_filter, kernel, stride, padding):
in_height = in_width = in_size
with tvm.target.rasp():
with tvm.target.arm_cpu():
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W')
B = topi.nn.conv2d(A, W, stride, padding)
B = topi.nn.conv2d(A, W, (stride, stride), (padding, padding), 'NCHW', 'float32')
s = topi.generic.schedule_conv2d_nchw([B])
a_shape = get_const_tuple(A.shape)
......@@ -39,7 +40,8 @@ def verify_conv2d(batch, in_size, in_channel, num_filter, kernel, stride, paddin
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
def test_conv2d():
verify_conv2d(1, 56, 64, 64, 3, 1, 1)
with autotvm.tophub.context(tvm.target.arm_cpu('rasp3b')):
verify_conv2d(1, 56, 64, 64, 3, 1, 1)
if __name__ == "__main__":
test_conv2d()
......@@ -8,6 +8,27 @@ NVIDIA GPU. By running auto-tuner on this template, we can outperform the
vendor provided library CuDNN in many cases.
"""
######################################################################
# Install dependencies
# ----------------------------------------
# To use autotvm package in tvm, we need to install some extra dependencies.
# (change "3" to "2" if you use python2):
#
# .. code-block:: bash
#
# pip3 install --user psutil xgboost
#
# To make tvm run faster in tuning, it is recommended to use cython
# as FFI of tvm. In the root directory of tvm, execute
# (change "3" to "2" if you use python2):
#
# .. code-block:: bash
#
# pip3 install --user cython
# sudo make cython3
#
# Now return to python code. Import packages.
import logging
import sys
import numpy as np
......@@ -133,7 +154,7 @@ def conv2d_no_batching(N, H, W, CI, CO, KH, KW, stride, padding):
# for this template
# logging config (for printing tuning log to screen)
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
logging.basicConfig(level=logging.DEBUG, stream=sys.stdout)
# the last layer in resnet
N, H, W, CO, CI, KH, KW, strides, padding = 1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1)
......@@ -144,12 +165,12 @@ print(task.config_space)
# use local gpu, measure 5 times for every config to reduce variance
# run 8 parallel threads for compilation
measure_option = autotvm.measure_option(mode='local',
number=10,
measure_option = autotvm.measure_option('local',
number=5,
parallel_num=8,
timeout=20)
# begin tuning, log records to file `conv2d.tsv`
# begin tuning, log records to file `conv2d.log`
tuner = autotvm.tuner.XGBTuner(task)
tuner.tune(n_trial=20,
measure_option=measure_option,
......@@ -186,6 +207,6 @@ np.testing.assert_allclose(c_np, c_tvm.asnumpy(), rtol=1e-2)
# Evaluate running time. Here we choose a large repeat number (200) to reduce the noise
# and the overhead of kernel launch. You can also use nvprof to validate the result.
evaluator = func.time_evaluator(func.entry_name, ctx, number=200)
print('Time cost of this operator: %f' % evaluator(a_tvm, w_tvm, c_tvm).mean)
......@@ -12,6 +12,27 @@ In this tutorial, you can learn how to perform these two steps in tvm.
The whole workflow is illustrated by a matrix multiplication example.
"""
######################################################################
# Install dependencies
# ----------------------------------------
# To use autotvm package in tvm, we need to install some extra dependencies.
# (change "3" to "2" if you use python2):
#
# .. code-block:: bash
#
# pip3 install --user psutil xgboost
#
# To make tvm run faster in tuning, it is recommended to use cython
# as FFI of tvm. In the root directory of tvm, execute
# (change "3" to "2" if you use python2):
#
# .. code-block:: bash
#
# pip3 install --user cython
# sudo make cython3
#
# Now return to python code. Import packages.
import logging
import sys
......@@ -247,10 +268,10 @@ print(task.config_space)
# used to get the best config later.
# logging config (for printing tuning log to screen)
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
logging.basicConfig(level=logging.DEBUG, stream=sys.stdout)
# use local cpu, measure 5 times for every config to reduce variance
measure_option = autotvm.measure_option(mode='local',
measure_option = autotvm.measure_option('local',
number=5)
# begin tuning, log records to file `matmul.log`
......
.. _tutorial-nnvm:
Compile Deep Learning Models
----------------------------
......@@ -6,14 +6,10 @@ Deploy the Pretrained Model on ARM Mali GPU
**Author**: `Lianmin Zheng <https://lmzheng.net/>`_, `Ziheng Jiang <https://ziheng.org/>`_
This is an example of using NNVM to compile a ResNet model and
deploy it on Firefly-RK3399 with ARM Mali GPU. We will use the
deploy it on Firefly-RK3399 with ARM Mali GPU. We will use the
Mali-T860 MP4 GPU on this board to accelerate the inference.
This tutorial is based on the tutorial for deploying on Raspberry Pi by `Ziheng Jiang <https://ziheng.org/>`_.
Great thanks to the original author, I only do several lines of modification.
To begin with, we import nnvm (for compilation) and TVM (for deployment).
"""
import tvm
import nnvm.compiler
import nnvm.testing
......@@ -24,92 +20,65 @@ from tvm.contrib import util, graph_runtime as runtime
# Build TVM Runtime on Device
# ---------------------------
#
# There're some prerequisites: we need build tvm runtime and set up
# a RPC server on remote device.
#
# To get started, clone tvm repo from github. It is important to clone
# the submodules along, with --recursive option (Assuming you are in
# your home directory):
#
# .. code-block:: bash
#
# git clone --recursive https://github.com/dmlc/tvm
# The first step is to build tvm runtime on the remote device.
#
# .. note::
#
# Usually device has limited resources and we only need to build
# runtime. The idea is we will use TVM compiler on the local server
# to compile and upload the compiled program to the device and run
# the device function remotely.
# All instructions in both this section and next section should be
# executed on the target device, e.g. Rk3399. And we assume it
# has Linux running.
#
# Since we do compilation on local machine, the remote device is only used
# for running the generated code. We only need to build tvm runtime on
# the remote device. Make sure you have opencl driver in your board.
# You can refer to `tutorial <https://gist.github.com/mli/585aed2cec0b5178b1a510f9f236afa2>`_
# to setup OS and opencl driver for rk3399.
#
# .. code-block:: bash
#
# make runtime
#
# After success of buildind runtime, we need set environment varibles
# in :code:`~/.bashrc` file of yourself account or :code:`/etc/profile`
# of system enviroment variables. Assuming your TVM directory is in
# :code:`~/tvm` and set environment variables below your account.
# .. code-block:: bash
#
# .. code-block:: bash
#
# vi ~/.bashrc
# git clone --recursive https://github.com/dmlc/tvm
# cd tvm
# cp cmake/config.cmake .
# sed -i "s/USE_OPENCL OFF/USE_OPENCL ON/" config.cmake
# make runtime -j4
#
# We need edit :code:`~/.bashrc` using :code:`vi ~/.bashrc` and add
# lines below (Assuming your TVM directory is in :code:`~/tvm`):
# After building runtime successfully, we need to set environment varibles
# in :code:`~/.bashrc` file. We can edit :code:`~/.bashrc`
# using :code:`vi ~/.bashrc` and add the line below (Assuming your TVM
# directory is in :code:`~/tvm`):
#
# .. code-block:: bash
# .. code-block:: bash
#
# export TVM_HOME=~/tvm
# export PATH=$PATH:$TVM_HOME/lib
# export PYTHONPATH=$PYTHONPATH:$TVM_HOME/python
# export PYTHONPATH=$PYTHONPATH:~/tvm/python
#
# To enable updated :code:`~/.bashrc`, execute :code:`source ~/.bashrc`.
# To update the environment variables, execute :code:`source ~/.bashrc`.
######################################################################
# Set Up RPC Server on Device
# ---------------------------
# To set up a TVM RPC server on the your ARM device (our remote device),
# we have prepared a one-line script so you only need to run this
# command after following the installation guide to install TVM on
# your device:
# To start an RPC server, run the following command on your remote device
# (Which is RK3399 in our example).
#
# .. code-block:: bash
#
# python -m tvm.exec.rpc_server --host 0.0.0.0 --port=9090
#
# After executing command above, if you see these lines below, it's
# successful to start RPC server on your device.
# If you see the line below, it means the RPC server started
# successfully on your device.
#
# .. code-block:: bash
#
# Loading runtime library /home/YOURNAME/code/tvm/lib/libtvm_runtime.so... exec only
# INFO:root:RPCServer: bind to 0.0.0.0:9090
#
######################################################################
# For demonstration, we simply start an RPC server on the same machine,
# if :code:`use_mali` is False. If you have set up the remote
# environment, please change the three lines below: change the
# :code:`use_mali` to True, also change the :code:`host` and :code:`port`
# with your device's host address and port number.
use_mali = False
host = '10.42.0.96'
port = 9090
if not use_mali:
# run server locally
host = 'localhost'
port = 9095
server = rpc.Server(host=host, port=port, use_popen=True)
######################################################################
# Prepare the Pretrained Model
# ----------------------------
# Back to the host machine, firstly, we need to download a MXNet Gluon
# ResNet model from model zoo, which is pretrained on ImageNet. You
# can found more details about this part at `Compile MXNet Models`
# Prepare the Pre-trained Model
# -----------------------------
# Back to the host machine, which should have a full TVM installed (with LLVM).
#
# We will use pre-trained model from
# `MXNet Gluon model zoo <https://mxnet.incubator.apache.org/api/python/gluon/model_zoo.html>`_.
# You can found more details about this part at tutorial :ref:`tutorial-from-mxnet`.
from mxnet.gluon.model_zoo.vision import get_model
from mxnet.gluon.utils import download
......@@ -135,7 +104,6 @@ def transform_image(image):
x = transform_image(image)
######################################################################
# synset is used to transform the label from number of ImageNet class to
# the word human can understand.
......@@ -143,6 +111,7 @@ synset_url = ''.join(['https://gist.githubusercontent.com/zhreshold/',
'4d0b62f3d01426887599d4f7ede23ee5/raw/',
'596b27d23537e5a1b5751d2b0481ef172f58b539/',
'imagenet1000_clsid_to_human.txt'])
synset_name = 'synset.txt'
download(synset_url, synset_name)
with open(synset_name) as f:
......@@ -176,21 +145,29 @@ out_shape = (batch_size, num_classes)
# triplet for host ARM device by setting the parameter :code:`target_host`.
######################################################################
# If we run the example locally for demonstration, we can simply set
# it as :code:`llvm`. If to run it on the ARM device, you need to specify
# its instruction set. Here is the option I use for my Firefly-RK3399.
# If we run the example on our x86 server for demonstration, we can simply
# set it as :code:`llvm`. If running it on the RK3399, we need to
# specify its instruction set. Set :code:`local_demo` to False if you
# want to run this tutorial with a real device.
if use_mali:
target_host = "llvm -target=aarch64-linux-gnu -mattr=+neon"
target = tvm.target.mali()
else:
local_demo = True
if local_demo:
target_host = "llvm"
target = tvm.target.cuda()
target = "llvm"
else:
# Here is the setting for my rk3399 board
# If you don't use rk3399, you can query your target triple by
# execute `gcc -v` on your board.
target_host = "llvm -target=aarch64-linux-gnu"
# set target as `tvm.target.mali` instead of 'opencl' to enable
# optimization for mali
target = tvm.target.mali()
# set target as `tvm.target.mali` instead of 'opencl' to enable
# target-specified optimization
graph, lib, params = nnvm.compiler.build(net, target=target,
shape={"data": data_shape}, params=params, target_host=target_host)
with nnvm.compiler.build_config(opt_level=2):
graph, lib, params = nnvm.compiler.build(net, target=target,
shape={"data": data_shape}, params=params, target_host=target_host)
# After `nnvm.compiler.build`, you will get three return values: graph,
# library and the new parameter, since we do some optimization that will
......@@ -207,14 +184,20 @@ lib.export_library(lib_fname)
# With RPC, you can deploy the model remotely from your host machine
# to the remote device.
# connect the server
remote = rpc.connect(host, port)
# obtain an RPC session from remote device.
if local_demo:
remote = rpc.LocalSession()
else:
# The following is my environment, change this to the IP address of your target device
host = '10.77.1.145'
port = 9090
remote = rpc.connect(host, port)
# upload the library to remote device and load it
remote.upload(lib_fname)
rlib = remote.load_module('net.tar')
ctx = remote.cl(0) if use_mali else remote.gpu(0)
ctx = remote.cpu(0) if local_demo else remote.cl(0)
# upload the parameter
rparams = {k: tvm.nd.array(v, ctx) for k, v in params.items()}
......@@ -231,7 +214,3 @@ out = module.get_output(0, tvm.nd.empty(out_shape, ctx=ctx))
# get top1 result
top1 = np.argmax(out.asnumpy())
print('TVM prediction top-1: {}'.format(synset[top1]))
if not use_mali:
# terminate the local server
server.terminate()
......@@ -7,9 +7,8 @@ Deploy the Pretrained Model on Raspberry Pi
This is an example of using NNVM to compile a ResNet model and deploy
it on raspberry pi.
To begin with, we import nnvm(for compilation) and TVM(for deployment).
"""
import tvm
import nnvm.compiler
import nnvm.testing
......@@ -17,102 +16,73 @@ from tvm import rpc
from tvm.contrib import util, graph_runtime as runtime
######################################################################
# .. _build-tvm-runtime-on-device:
#
# Build TVM Runtime on Device
# ---------------------------
#
# There're some prerequisites: we need build tvm runtime and set up
# a RPC server on remote device.
#
# To get started, clone tvm repo from github. It is important to clone
# the submodules along, with --recursive option (Assuming you are in
# your home directory):
#
# .. code-block:: bash
#
# git clone --recursive https://github.com/dmlc/tvm
# The first step is to build tvm runtime on the remote device.
#
# .. note::
#
# Usually device has limited resources and we only need to build
# runtime. The idea is we will use TVM compiler on the local server
# to compile and upload the compiled program to the device and run
# the device function remotely.
# All instructions in both this section and next section should be
# executed on the target device, e.g. Raspberry Pi. And we assume it
# has Linux running.
#
# Since we do compilation on local machine, the remote device is only used
# for running the generated code. We only need to build tvm runtime on
# the remote device.
#
# .. code-block:: bash
#
# make runtime
#
# After success of buildind runtime, we need set environment varibles
# in :code:`~/.bashrc` file of yourself account or :code:`/etc/profile`
# of system enviroment variables. Assuming your TVM directory is in
# :code:`~/tvm` and set environment variables below your account.
# .. code-block:: bash
#
# .. code-block:: bash
#
# vi ~/.bashrc
# git clone --recursive https://github.com/dmlc/tvm
# cd tvm
# make runtime -j4
#
# We need edit :code:`~/.bashrc` using :code:`vi ~/.bashrc` and add
# lines below (Assuming your TVM directory is in :code:`~/tvm`):
# After building runtime successfully, we need to set environment varibles
# in :code:`~/.bashrc` file. We can edit :code:`~/.bashrc`
# using :code:`vi ~/.bashrc` and add the line below (Assuming your TVM
# directory is in :code:`~/tvm`):
#
# .. code-block:: bash
# .. code-block:: bash
#
# export TVM_HOME=~/tvm
# export PATH=$PATH:$TVM_HOME/lib
# export PYTHONPATH=$PYTHONPATH:$TVM_HOME/python
# export PYTHONPATH=$PYTHONPATH:~/tvm/python
#
# To enable updated :code:`~/.bashrc`, execute :code:`source ~/.bashrc`.
# To update the environment variables, execute :code:`source ~/.bashrc`.
######################################################################
# Set Up RPC Server on Device
# ---------------------------
# To set up a TVM RPC server on the Raspberry Pi (our remote device),
# we have prepared a one-line script so you only need to run this
# command after following the installation guide to install TVM on
# your device:
# To start an RPC server, run the following command on your remote device
# (Which is Raspberry Pi in our example).
#
# .. code-block:: bash
#
# python -m tvm.exec.rpc_server --host 0.0.0.0 --port=9090
#
# After executing command above, if you see these lines below, it's
# successful to start RPC server on your device.
# If you see the line below, it means the RPC server started
# successfully on your device.
#
# .. code-block:: bash
#
# Loading runtime library /home/YOURNAME/code/tvm/lib/libtvm_runtime.so... exec only
# INFO:root:RPCServer: bind to 0.0.0.0:9090
######################################################################
# For demonstration, we simply start an RPC server on the same machine,
# if :code:`use_rasp` is False. If you have set up the remote
# environment, please change the three lines below: change the
# :code:`use_rasp` to True, also change the :code:`host` and :code:`port`
# with your device's host address and port number.
use_rasp = False
host = 'rasp0'
port = 9090
if not use_rasp:
# run server locally
host = 'localhost'
port = 9091
server = rpc.Server(host=host, port=port, use_popen=True)
#
######################################################################
# Prepare the Pretrained Model
# ----------------------------
# Back to the host machine, firstly, we need to download a MXNet Gluon
# ResNet model from model zoo, which is pretrained on ImageNet. You
# can found more details about this part at `Compile MXNet Models`
# Prepare the Pre-trained Model
# -----------------------------
# Back to the host machine, which should have a full TVM installed (with LLVM).
#
# We will use pre-trained model from
# `MXNet Gluon model zoo <https://mxnet.incubator.apache.org/api/python/gluon/model_zoo.html>`_.
# You can found more details about this part at tutorial :ref:`tutorial-from-mxnet`.
from mxnet.gluon.model_zoo.vision import get_model
from mxnet.gluon.utils import download
from PIL import Image
import numpy as np
# only one line to get the model
# one line to get the model
block = get_model('resnet18_v1', pretrained=True)
######################################################################
......@@ -131,7 +101,6 @@ def transform_image(image):
x = transform_image(image)
######################################################################
# synset is used to transform the label from number of ImageNet class to
# the word human can understand.
......@@ -173,29 +142,32 @@ out_shape = (batch_size, num_classes)
# will lead to very different performance.
######################################################################
# If we run the example locally for demonstration, we can simply set
# it as :code:`llvm`. If to run it on the Raspberry Pi, you need to
# specify its instruction set. Here is the option I use for my Raspberry
# Pi, which has been proved as a good compilation configuration.
# If we run the example on our x86 server for demonstration, we can simply
# set it as :code:`llvm`. If running it on the Raspberry Pi, we need to
# specify its instruction set. Set :code:`local_demo` to False if you want
# to run this tutorial with a real device.
if use_rasp:
target = tvm.target.rasp()
else:
local_demo = True
if local_demo:
target = tvm.target.create('llvm')
else:
target = tvm.target.arm_cpu('rasp3b')
# The above line is a simple form of
# target = tvm.target.create('llvm -devcie=arm_cpu -target=armv7l-linux-gnueabihf')
graph, lib, params = nnvm.compiler.build(
net, target, shape={"data": data_shape}, params=params)
with nnvm.compiler.build_config(opt_level=2, add_pass=['AlterOpLayout']):
graph, lib, params = nnvm.compiler.build(
net, target, shape={"data": data_shape}, params=params)
# After `nnvm.compiler.build`, you will get three return values: graph,
# library and the new parameter, since we do some optimization that will
# change the parameters but keep the result of model as the same.
# Save the library at local temporary directory.
tmp = util.tempdir()
lib_fname = tmp.relpath('net.o')
lib.save(lib_fname)
lib_fname = tmp.relpath('net.tar')
lib.export_library(lib_fname)
######################################################################
# Deploy the Model Remotely by RPC
......@@ -203,15 +175,21 @@ lib.save(lib_fname)
# With RPC, you can deploy the model remotely from your host machine
# to the remote device.
# connect the server
remote = rpc.connect(host, port)
# obtain an RPC session from remote device.
if local_demo:
remote = rpc.LocalSession()
else:
# The following is my environment, change this to the IP address of your target device
host = '10.77.1.162'
port = 9090
remote = rpc.connect(host, port)
# upload the library to remote device and load it
remote.upload(lib_fname)
rlib = remote.load_module('net.o')
rlib = remote.load_module('net.tar')
# upload the parameter (this may take a while)
ctx = remote.cpu(0)
# upload the parameter
rparams = {k: tvm.nd.array(v, ctx) for k, v in params.items()}
# create the remote runtime module
......@@ -227,7 +205,3 @@ out = module.get_output(0, tvm.nd.empty(out_shape, ctx=ctx))
# get top1 result
top1 = np.argmax(out.asnumpy())
print('TVM prediction top-1: {}'.format(synset[top1]))
if not use_rasp:
# terminate the local server
server.terminate()
"""
.. _tutorial-from-mxnet:
Compile MXNet Models
====================
**Author**: `Joshua Z. Zhang <https://zhreshold.github.io/>`_
......
"""
Compile GPU Inference
=====================
**Author**: `Yuwei Hu <https://huyuwei.github.io/>`_
This is an example of using NNVM to compile MobileNet/ResNet model and deploy its inference on GPU.
To begin with, we import nnvm(for compilation) and TVM(for deployment).
"""
import tvm
import numpy as np
from tvm.contrib import nvcc, graph_runtime
import nnvm.compiler
import nnvm.testing
######################################################################
# Register the NVCC Compiler Option
# ---------------------------------
# NNVM optimizes the graph and relies on TVM to generate fast GPU code.
# To get the maximum performance, we need to enable nvcc's compiler hook.
# This usually gives better performance than nvrtc mode.
@tvm.register_func("tvm_callback_cuda_compile", override=True)
def tvm_callback_cuda_compile(code):
ptx = nvcc.compile_cuda(code, target="ptx")
return ptx
######################################################################
# Prepare the Benchmark
# ---------------------
# We construct a standard imagenet inference benchmark.
# NNVM needs two things to compile a deep learning model:
#
# - net: the graph representation of the computation
# - params: a dictionary of str to parameters
#
# We use nnvm's testing utility to produce the model description and random parameters
# so that the example does not depend on a specific front-end framework.
#
# .. note::
#
# In a typical workflow, we can get this pair from :any:`nnvm.frontend`
#
target = "cuda"
ctx = tvm.gpu(0)
batch_size = 1
num_classes = 1000
image_shape = (3, 224, 224)
data_shape = (batch_size,) + image_shape
out_shape = (batch_size, num_classes)
# To use ResNet to do inference, run the following instead
#net, params = nnvm.testing.resnet.get_workload(
# batch_size=1, image_shape=image_shape)
net, params = nnvm.testing.mobilenet.get_workload(
batch_size=1, image_shape=image_shape)
######################################################################
# Compile the Graph
# -----------------
# To compile the graph, we call the build function with the graph
# configuration and parameters.
# When parameters are provided, NNVM will pre-compute certain part of the graph if possible (e.g. simplify batch normalization to scale shift),
# and return the updated parameters.
graph, lib, params = nnvm.compiler.build(
net, target, shape={"data": data_shape}, params=params)
######################################################################
# Run the Compiled Module
# -----------------------
#
# To deploy the module, we call :any:`tvm.contrib.graph_runtime.create` passing in the graph, the lib, and context.
# Thanks to TVM, we can deploy the compiled module to many platforms and languages.
# The deployment module is designed to contain minimum dependencies.
# This example runs on the same machine.
#
# Note that the code below no longer depends on NNVM, and only relies TVM's runtime to run(deploy).
data = np.random.uniform(-1, 1, size=data_shape).astype("float32")
module = graph_runtime.create(graph, lib, ctx)
# set input
module.set_input(**params)
module.set_input("data", data)
# run
module.run()
# get output
out = module.get_output(0, tvm.nd.empty(out_shape))
# convert to numpy
out.asnumpy()
......@@ -6,9 +6,8 @@ Quick Start Tutorial for Compiling Deep Learning Models
**Author**: `Yao Wang <https://github.com/kevinthesun>`_
This example shows how to build a neural network with NNVM python frontend and
generate runtime library for Nvidia GPU and Raspberry Pi with TVM.
To run this notebook, you need to install tvm and nnvm.
Notice that you need to build tvm with cuda and llvm.
generate runtime library for Nvidia GPU with TVM.
Notice that you need to build TVM with cuda and llvm enabled.
"""
######################################################################
......@@ -22,10 +21,13 @@ Notice that you need to build tvm with cuda and llvm.
#
# In this tutorial, we'll choose cuda and llvm as target backends.
# To begin with, let's import NNVM and TVM.
import tvm
import numpy as np
import nnvm.compiler
import nnvm.testing
import tvm
from tvm.contrib import graph_runtime
######################################################################
# Define Neural Network in NNVM
......@@ -33,7 +35,8 @@ import nnvm.testing
# First, let's define a neural network with nnvm python frontend.
# For simplicity, we'll use pre-defined resnet-18 network in NNVM.
# Parameters are initialized with Xavier initializer.
# NNVM also supports other model formats such as MXNet, CoreML and ONNX.
# NNVM also supports other model formats such as MXNet, CoreML, ONNX and
# Tensorflow.
#
# In this tutorial, we assume we will do inference on our device
# and the batch size is set to be 1. Input images are RGB color
......@@ -46,7 +49,8 @@ image_shape = (3, 224, 224)
data_shape = (batch_size,) + image_shape
out_shape = (batch_size, num_class)
net, params = nnvm.testing.resnet.get_workload(batch_size=batch_size, image_shape=image_shape)
net, params = nnvm.testing.resnet.get_workload(layers=18,
batch_size=batch_size, image_shape=image_shape)
print(net.debug_str())
######################################################################
......@@ -54,10 +58,8 @@ print(net.debug_str())
# -----------
# Next step is to compile the model using the NNVM/TVM pipeline.
# Users can specify the optimization level of the compilation.
# Currently this value can be 0 to 2, which corresponds to
# "SimplifyInference", "OpFusion" and "PrecomputePrune" respectively.
# In this example we set optimization level to be 0
# and use Raspberry Pi as compile target.
# Currently this value can be 0 to 3. The optimization passes include
# operator fusion, pre-computation, layout transformation and so on.
#
# :any:`nnvm.compiler.build` returns three components: the execution graph in
# json format, the TVM module library of compiled functions specifically
......@@ -68,24 +70,50 @@ print(net.debug_str())
#
# We'll first compile for Nvidia GPU. Behind the scene, `nnvm.compiler.build`
# first does a number of graph-level optimizations, e.g. pruning, fusing, etc.,
# then registers the operators (i.e. the nodes of the optmized graphs) to
# then registers the operators (i.e. the nodes of the optimized graphs) to
# TVM implementations to generate a `tvm.module`.
# To generate the module library, TVM will first transfer the HLO IR into the lower
# intrinsic IR of the specified target backend, which is CUDA in this example.
# Then the machine code will be generated as the module library.
# To generate the module library, TVM will first transfer the High level IR
# into the lower intrinsic IR of the specified target backend, which is CUDA
# in this example. Then the machine code will be generated as the module library.
opt_level = 0
opt_level = 3
target = tvm.target.cuda()
with nnvm.compiler.build_config(opt_level=opt_level):
graph, lib, params = nnvm.compiler.build(
net, target, shape={"data": data_shape}, params=params)
#####################################################################
# Run the generate library
# ------------------------
# Now we can create graph runtime and run the module on Nvidia GPU.
# create random input
ctx = tvm.gpu()
data = np.random.uniform(-1, 1, size=data_shape).astype("float32")
# create module
module = graph_runtime.create(graph, lib, ctx)
# set input and parameters
module.set_input("data", data)
module.set_input(**params)
# run
module.run()
# get output
out = module.get_output(0, tvm.nd.empty(out_shape))
# convert to numpy
out.asnumpy()
# Print first 10 elements of output
print(out.asnumpy().flatten()[0:10])
######################################################################
# Save Compiled Module
# ----------------------------
# After compilation, we can save the graph, lib and params into separate files
# and deploy them to Nvidia GPU.
# Save and Load Compiled Module
# -----------------------------
# We can also save the graph, lib and parameters into files and load them
# back in development environment.
####################################################
# save the graph, lib and params into separate files
from tvm.contrib import util
temp = util.tempdir()
......@@ -97,95 +125,17 @@ with open(temp.relpath("deploy_param.params"), "wb") as fo:
fo.write(nnvm.compiler.save_param_dict(params))
print(temp.listdir())
######################################################################
# Deploy locally to Nvidia GPU
# ------------------------------
# Now we can load the module back.
####################################################
import numpy as np
from tvm.contrib import graph_runtime
loaded_lib = tvm.module.load(path_lib)
# load the module back.
loaded_json = open(temp.relpath("deploy_graph.json")).read()
loaded_lib = tvm.module.load(path_lib)
loaded_params = bytearray(open(temp.relpath("deploy_param.params"), "rb").read())
input_data = tvm.nd.array(np.random.uniform(size=data_shape).astype("float32"))
module = graph_runtime.create(loaded_json, loaded_lib, tvm.gpu(0))
module.load_params(loaded_params)
input_data = tvm.nd.array(np.random.uniform(size=data_shape).astype("float32"))
module.run(data=input_data)
out = module.get_output(0, out=tvm.nd.empty(out_shape))
# Print first 10 elements of output
print(out.asnumpy()[0][0:10])
######################################################################
# Compile and Deploy the Model to Raspberry Pi Remotely with RPC
# --------------------------------------------------------------
# Following the steps above, we can also compile the model for Raspberry Pi.
# TVM provides rpc module to help with remote deploying.
#
# For demonstration, we simply start an RPC server on the same machine,
# if :code:`use_rasp` is False. If you have set up the remote
# environment, please change the three lines below: change the
# :code:`use_rasp` to True, also change the host and port with your
# device's host address and port number.
# If we run the example locally for demonstration, we can simply set the
# compilation target as `llvm`.
# To run it on the Raspberry Pi, you need to specify its instruction set.
# `llvm -target=armv7l-none-linux-gnueabihf -mcpu=cortex-a53 -mattr=+neon`
# is the recommended compilation configuration, thanks to Ziheng's work.
from tvm import rpc
use_rasp = False
host = 'rasp0'
port = 9090
if not use_rasp:
# run server locally
host = 'localhost'
port = 9099
server = rpc.Server(host=host, port=port, use_popen=True)
# compile and save model library
if use_rasp:
target = "llvm -target=armv7l-none-linux-gnueabihf -mcpu=cortex-a53 -mattr=+neon"
else:
target = "llvm"
# use `with tvm.target.rasp` for some target-specified optimization
with tvm.target.rasp():
graph, lib, params = nnvm.compiler.build(
net, target, shape={"data": data_shape}, params=params)
temp = util.tempdir()
path_lib = temp.relpath("deploy_lib_rasp.o")
lib.save(path_lib)
# connect the server
remote = rpc.connect(host, port)
# upload the library to remote device and load it
remote.upload(path_lib)
rlib = remote.load_module('deploy_lib_rasp.o')
ctx = remote.cpu(0)
# upload the parameter
rparams = {k: tvm.nd.array(v, ctx) for k, v in params.items()}
# create the remote runtime module
module = graph_runtime.create(graph, rlib, ctx)
# set parameter
module.set_input(**rparams)
# set input data
input_data = np.random.uniform(size=data_shape)
module.set_input('data', tvm.nd.array(input_data.astype('float32')))
# run
module.run()
out = module.get_output(0, out=tvm.nd.empty(out_shape, ctx=ctx))
# Print first 10 elements of output
print(out.asnumpy()[0][0:10])
out = module.get_output(0, out=tvm.nd.empty(out_shape))
if not use_rasp:
# terminate the local server
server.terminate()
# pylint: disable=invalid-name,unused-variable,invalid-name
"""Conv2D schedule ported from RASP
"""Reuse conv2d schedule from ARM CPU"""
Used for CPU conv2d
"""
from __future__ import absolute_import as _abs
import tvm
from topi.nn.conv2d import conv2d, _get_schedule
from topi.nn.conv2d import SpatialPack, Im2ColPack, Workload
from topi.rasp import conv2d as _rasp_conv2d
from topi.nn import conv2d, conv2d_alter_layout
from topi import generic
_WORKLOADS = [
Workload('float32', 'float32', 224, 224, 3, 64, 7, 7, 3, 3, 2, 2),
Workload('int8', 'int32', 224, 224, 3, 64, 7, 7, 3, 3, 2, 2),
Workload('int8', 'int32', 56, 56, 64, 64, 3, 3, 1, 1, 1, 1),
Workload('int8', 'int32', 56, 56, 64, 64, 1, 1, 0, 0, 1, 1),
Workload('int8', 'int32', 56, 56, 64, 128, 3, 3, 1, 1, 2, 2),
Workload('int8', 'int32', 56, 56, 64, 128, 1, 1, 0, 0, 2, 2),
Workload('int8', 'int32', 28, 28, 128, 128, 3, 3, 1, 1, 1, 1),
Workload('int8', 'int32', 28, 28, 128, 256, 3, 3, 1, 1, 2, 2),
Workload('int8', 'int32', 28, 28, 128, 256, 1, 1, 0, 0, 2, 2),
Workload('int8', 'int32', 14, 14, 256, 256, 3, 3, 1, 1, 1, 1),
Workload('int8', 'int32', 14, 14, 256, 512, 3, 3, 1, 1, 2, 2),
Workload('int8', 'int32', 14, 14, 256, 512, 1, 1, 0, 0, 2, 2),
Workload('int8', 'int32', 7, 7, 512, 512, 3, 3, 1, 1, 1, 1),
]
_SCHEDULES = [
# float32 imagenet
SpatialPack(1, 8, 4, 1, 4, True),
SpatialPack(1, 8, 4, 1, 4, True),
SpatialPack(1, 7, 4, 2, 4, True),
SpatialPack(1, 4, 8, 4, 1, True),
SpatialPack(1, 4, 4, 1, 16, False),
SpatialPack(1, 4, 8, 4, 8, False),
SpatialPack(1, 7, 4, 3, 8, True),
SpatialPack(1, 2, 8, 1, 8, True),
SpatialPack(2, 1, 16, 1, 4, True),
SpatialPack(1, 7, 4, 1, 1, True),
Im2ColPack(7, 4, 1, 16, True),
Im2ColPack(7, 4, 1, 8, False),
Im2ColPack(7, 4, 1, 16, False),
]
@conv2d.register(["vtacpu", "vta"])
def compute(*args, **kwargs):
with tvm.target.arm_cpu("vtacpu"):
return conv2d(*args, **kwargs)
@_get_schedule.register(["vtacpu", "vta"])
def _schedule_conv2d(wkl):
if wkl not in _WORKLOADS:
raise ValueError("no schedule for such workload: {}".format(wkl))
idx = _WORKLOADS.index(wkl)
sch = _SCHEDULES[idx]
return sch
@generic.schedule_conv2d_nchw.register(["vtacpu", "vta"])
def schedule(*args, **kwargs):
with tvm.target.arm_cpu("vtacpu"):
return generic.schedule_conv2d_nchw(*args, **kwargs)
conv2d.register(["vtacpu", "vta"], _rasp_conv2d._declaration_conv2d)
generic.schedule_conv2d_nchw.register(
["vtacpu", "vta"],
_rasp_conv2d.schedule_conv2d_nchw)
@conv2d_alter_layout.register(["vtacpu", "vta"])
def alter(*args, **kwargs):
with tvm.target.arm_cpu("vtacpu"):
return conv2d_alter_layout(*args, **kwargs)
......@@ -244,8 +244,11 @@ def is_packed_layout(layout):
return False
@reg.register_alter_op_layout("conv2d", level=15)
def alter_conv2d_layout(*_):
return None
def alter_conv2d_layout(attrs, inputs, out):
layout = attrs['layout']
if is_packed_layout(layout):
return None
return _nn.alter_conv2d_layout(attrs, inputs, out)
@reg.register_compute("conv2d", level=15)
......@@ -368,7 +371,6 @@ def schedule_packed_conv2d(outs):
oshape = topi.util.get_const_tuple(output.shape)
s = tvm.create_schedule(output.op)
# setup pad
if pad_data is not None:
cdata = pad_data
......@@ -394,7 +396,6 @@ def schedule_packed_conv2d(outs):
h_factor = (plan.h_factor if plan.h_factor else oshape[2])
w_factor = (plan.w_factor if plan.w_factor else oshape[3])
x_bo, x_co, x_i, x_j, x_bi, x_ci = s[output].op.axis
x_co0, x_co1 = s[output].split(x_co, factor=oc_factor)
x_i0, x_i1 = s[output].split(x_i, factor=h_factor)
......@@ -459,6 +460,7 @@ class Conv2DSchedule(object):
self.oc_nthread = oc_nthread
self.h_nthread = h_nthread
self.debug_sync = debug_sync
def __str__(self):
return "{}.{}.{}.{}.{}.{}.{}".format(
self.b_factor, self.oc_factor, self.ic_factor,
......@@ -483,7 +485,6 @@ RESNET = {
11: Workload(1, 7, 7, 512, 512, 3, 3, 1, 1, 1, 1),
}
_WL2PLAN = {}
for idx in RESNET:
scheds = find_schedules(RESNET[idx], vt_only=True, best_only=True)[0]
_WL2PLAN[RESNET[idx]] = scheds
"""Testing if we can generate code in topi style"""
import tvm
from tvm import autotvm
from tvm.contrib import util
from tvm.contrib.pickle_memoize import memoize
import topi
......@@ -62,8 +63,7 @@ def test_cpu_conv2d():
def verify(s, check_correctness):
mod = tvm.build(s, [data, kernel, res],
"llvm -device=vtacpu",
env.target_host,
target_host=env.target_host,
name="conv2d")
temp = util.tempdir()
mod.save(temp.relpath("conv2d.o"))
......@@ -126,7 +126,11 @@ def test_cpu_conv2d():
print(wl)
with tvm.target.create("llvm -device=vtacpu"):
run_cpu_conv2d(env, remote, key, batch_size, wl)
vta.testing.run(_run)
# load pre-tuned operator parameters for ARM CPU
autotvm.tophub.check_package('vta')
with autotvm.tophub.context('llvm -device=vtacpu'):
vta.testing.run(_run)
def test_vta_conv2d():
......@@ -172,7 +176,6 @@ def test_vta_conv2d():
a_np.astype(acc_dtype), w_np.astype(acc_dtype), stride, padding).astype(acc_dtype)
return a_np, w_np, b_np
def verify(s, check_correctness):
mod = vta.build(s, [data, kernel, bias, res], "ext_dev",
env.target_host, name="conv2d")
......
......@@ -8,7 +8,6 @@ onto the VTA accelerator design to perform ImageNet classification tasks.
"""
######################################################################
# Import Libraries
# ----------------
......@@ -17,26 +16,21 @@ onto the VTA accelerator design to perform ImageNet classification tasks.
from __future__ import absolute_import, print_function
import os
import sys
import nnvm
import nnvm.compiler
import tvm
import vta
import vta.testing
import time
from io import BytesIO
import numpy as np
import json
import requests
import time
from matplotlib import pyplot as plt
from PIL import Image
from nnvm.compiler import graph_attr
from tvm import rpc
import tvm
from tvm import rpc, autotvm
from tvm.contrib import graph_runtime, util
from tvm.contrib.download import download
from vta.testing import simulator
from io import BytesIO
from matplotlib import pyplot as plt
from PIL import Image
import nnvm.compiler
import vta
import vta.testing
# Load VTA parameters from the vta/config/vta_config.json file
env = vta.get_env()
......@@ -76,7 +70,6 @@ def classify(m, image):
# Takes in a path to a graph file, params file, and device target
# Returns the NNVM graph object, a compiled library object, and the params dict
def generate_graph(graph_fn, params_fn, device="vta"):
# Measure build start time
build_start = time.time()
......@@ -100,12 +93,6 @@ def generate_graph(graph_fn, params_fn, device="vta"):
shape_dict.update({k: v.shape for k, v in params.items()})
dtype_dict.update({k: str(v.dtype) for k, v in params.items()})
# Create NNVM graph
graph = nnvm.graph.create(sym)
graph_attr.set_shape_inputs(sym, shape_dict)
graph_attr.set_dtype_inputs(sym, dtype_dict)
graph = graph.apply("InferShape").apply("InferType")
# Apply NNVM graph optimization passes
sym = vta.graph.clean_cast(sym)
sym = vta.graph.clean_conv_fuse(sym)
......@@ -166,6 +153,9 @@ for file in [categ_fn, graph_fn, params_fn]:
# Read in ImageNet Categories
synset = eval(open(os.path.join(data_dir, categ_fn)).read())
# Download pre-tuned op parameters of conv2d for ARM CPU used in VTA
autotvm.tophub.check_package('vta')
######################################################################
# Setup the Pynq Board's RPC Server
......@@ -182,7 +172,6 @@ port = int(os.environ.get("VTA_PYNQ_RPC_PORT", "9091"))
# We configure both the bitstream and the runtime system on the Pynq
# to match the VTA configuration specified by the vta_config.json file.
if env.TARGET == "pynq":
# Make sure that TVM was compiled with RPC=1
assert tvm.module.enabled("rpc")
remote = rpc.connect(host, port)
......@@ -209,8 +198,8 @@ elif env.TARGET == "sim":
# ------------------------
# Build the ResNet graph runtime, and configure the parameters.
# Set ``device=cpu`` to run inference on the CPU,
# or ``device=vtacpu`` to run inference on the FPGA.
# Set ``device=vtacpu`` to run inference on the CPU
# or ``device=vta`` to run inference on the FPGA.
device = "vta"
# Device context
......@@ -225,7 +214,6 @@ m = graph_runtime.create(graph, lib, ctx)
# Set the parameters
m.set_input(**params)
######################################################################
# Run ResNet-18 inference on a sample image
# -----------------------------------------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment