Commit a57b5493 by Lianmin Zheng Committed by Tianqi Chen

[AUTOTVM][TOPI] Use tunable templates for GPU (CUDA/OpenCL/ROCm/Mali) (#1638)

parent d87c94d4
...@@ -6,8 +6,35 @@ See results on wiki page https://github.com/dmlc/tvm/wiki/Benchmark ...@@ -6,8 +6,35 @@ See results on wiki page https://github.com/dmlc/tvm/wiki/Benchmark
## How to Reproduce ## How to Reproduce
### ARM CPU To obtain the best performance, we always do auto-tuning for the specific devices and get
We use RPC infrastructure in TVM to make device management easy. So you need to use it for reproducing benchmark results. the parameters for used kernels. To enable easy reproduction of our results, we release
pre-tuned parameters for popular networks on some common devices.
TVM will download related tuning cache files during compilation.
If you don't have the following listed devices, you can still run these scripts.
You can pick the one that is most similar to your device as argument.
In general, the performance should also be good.
It is recommended that you run tuning by yourself if you have your customized network or devices.
Please follow the tutorial for
[NVIDIA GPU](https://docs.tvm.ai/tutorials/autotvm/tune_nnvm_cuda.html),
[ARM CPU](https://docs.tvm.ai/tutorials/autotvm/tune_nnvm_arm.html),
[Mobile GPU](https://docs.tvm.ai/tutorials/autotvm/tune_nnvm_mobile_gpu.html).
### NVIDIA GPU
Build TVM with LLVM and CUDA enabled. [Help](https://docs.tvm.ai/install/from_source.html)
```bash
python3 gpu_imagenet_bench.py --model 1080ti
python3 gpu_imagenet_bench.py --model titanx
```
### ARM CPU & Mali GPU
For embedded deivces, we use RPC infrastructure in TVM to make the management easy.
So you need to use it for reproducing benchmark results.
0. Build TVM with LLVM enabled. [Help](https://docs.tvm.ai/install/from_source.html)
1. Start an RPC Tracker on the host machine 1. Start an RPC Tracker on the host machine
```bash ```bash
...@@ -50,24 +77,22 @@ python3 -m tvm.exec.rpc_tracker ...@@ -50,24 +77,22 @@ python3 -m tvm.exec.rpc_tracker
rasp3b 8 8 0 rasp3b 8 8 0
``` ```
4. Run benchmark 4. Run benchmark
We did auto-tuning for Huawei P20/Mate10 Pro, Google Pixel2, Raspberry Pi3 and Firefly-RK3399,
and release pre-tuned parameters in [this repo](https://github.com/uwsaml/tvm-distro).
During compilation, TVM will download these operator parameters automatically.
```bash ```bash
python3 arm_cpu_imagenet_bench.py --device rasp3b --rpc-key rasp3b # ARM CPU
python3 arm_cpu_imagenet_bench.py --device rk3399 --rpc-key rk3399 python3 arm_cpu_imagenet_bench.py --model rasp3b --rpc-key rasp3b
python3 arm_cpu_imagenet_bench.py --device pixel2 --rpc-key pixel2 python3 arm_cpu_imagenet_bench.py --model rk3399 --rpc-key rk3399
python3 arm_cpu_imagenet_bench.py --device p20pro --rpc-key p20pro python3 arm_cpu_imagenet_bench.py --model pixel2 --rpc-key pixel2
python3 arm_cpu_imagenet_bench.py --device mate10pro --rpc-key mate10pro python3 arm_cpu_imagenet_bench.py --model p20pro --rpc-key p20pro
``` python3 arm_cpu_imagenet_bench.py --model mate10pro --rpc-key mate10pro
If your device has a same or similar SoC of the above devices, you can reuse these parameters. # Mali GPU
For example, if your SoC is similar to rasp3b, use python3 mobile_gpu_imagenet_bench.py --model rk3399 --rpc-key rk3399
```bash
python3 arm_cpu_imagenet_bench.py --device rasp3b --rpc-key your_custom_key
``` ```
For other devices, to get the best performance, it is recommended that you tune your network by yourself.
Please follow this [tutorial](https://docs.tvm.ai/tutorials/autotvm/tune_nnvm_arm.html).
### AMD GPU
Build TVM with LLVM and ROCm enabled. [Help](https://docs.tvm.ai/install/from_source.html)
```bash
python3 gpu_imagenet_bench.py --model gfx900 --target rocm
```
"""Benchmark script for ARM CPU. """Benchmark script for ImageNet models on ARM CPU.
see README.md for the usage and results of this script. see README.md for the usage and results of this script.
""" """
import argparse import argparse
...@@ -14,13 +14,60 @@ import nnvm.testing ...@@ -14,13 +14,60 @@ import nnvm.testing
from util import get_network, print_progress from util import get_network, print_progress
def evaluate_network(network, target, target_host, number):
# connect to remote device
tracker = tvm.rpc.connect_tracker(args.host, args.port)
remote = tracker.request(args.rpc_key)
print_progress(network)
net, params, input_shape, output_shape = get_network(network, batch_size=1)
print_progress("%-20s building..." % network)
with nnvm.compiler.build_config(opt_level=3):
graph, lib, params = nnvm.compiler.build(
net, target=target, target_host=target_host,
shape={'data': input_shape}, params=params, dtype=dtype)
tmp = tempdir()
if 'android' in str(target):
from tvm.contrib import ndk
filename = "%s.so" % network
lib.export_library(tmp.relpath(filename), ndk.create_shared)
else:
filename = "%s.tar" % network
lib.export_library(tmp.relpath(filename))
# upload library and params
print_progress("%-20s uploading..." % network)
ctx = remote.context(str(target), 0)
remote.upload(tmp.relpath(filename))
rparams = {k: tvm.nd.array(v, ctx) for k, v in params.items()}
rlib = remote.load_module(filename)
module = runtime.create(graph, rlib, ctx)
data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype))
module.set_input('data', data_tvm)
module.set_input(**rparams)
del rparams
# evaluate
print_progress("%-20s evaluating..." % network)
ftimer = module.module.time_evaluator("run", ctx, number=args.number, repeat=3)
prof_res = np.array(ftimer().results) * 1000 # multiply 1000 for converting to millisecond
print("%-20s %-19s (%s)" % (network, "%.2f ms" % np.mean(prof_res), "%.2f ms" % np.std(prof_res)))
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--network", type=str, choices= parser.add_argument("--network", type=str, choices=
['resnet-18', 'resnet-34', 'vgg-16', 'mobilenet', 'squeezenet v1.1', ]) ['resnet-18', 'resnet-34', 'vgg-16',
parser.add_argument("--device", type=str, required=True, choices= 'mobilenet', 'mobilenet_v2', 'squeezenet v1.0', 'squeezenet v1.1'])
parser.add_argument("--model", type=str, choices=
['rk3399', 'mate10', 'mate10pro', 'p20', 'p20pro', ['rk3399', 'mate10', 'mate10pro', 'p20', 'p20pro',
'pixel2', 'rasp3b', 'pynq']) 'pixel2', 'rasp3b', 'pynq'], default='rk3399',
help="The model of the test device. If your device is not listed in "
"the choices list, pick the most similar one as argument.")
parser.add_argument("--host", type=str, default='localhost') parser.add_argument("--host", type=str, default='localhost')
parser.add_argument("--port", type=int, default=9190) parser.add_argument("--port", type=int, default=9190)
parser.add_argument("--rpc-key", type=str, required=True) parser.add_argument("--rpc-key", type=str, required=True)
...@@ -34,47 +81,12 @@ if __name__ == "__main__": ...@@ -34,47 +81,12 @@ if __name__ == "__main__":
else: else:
networks = [args.network] networks = [args.network]
target = tvm.target.arm_cpu(model=args.device) target = tvm.target.arm_cpu(model=args.model)
target_host = None
# connect to remote device
tracker = tvm.rpc.connect_tracker(args.host, args.port)
remote = tracker.request(args.rpc_key)
print("--------------------------------------------------") print("--------------------------------------------------")
print("%-20s %-20s" % ("Network Name", "Mean Inference Time (std dev)")) print("%-20s %-20s" % ("Network Name", "Mean Inference Time (std dev)"))
print("--------------------------------------------------") print("--------------------------------------------------")
for network in networks: for network in networks:
print_progress(network) evaluate_network(network, target, target_host, args.number)
net, params, input_shape, output_shape = get_network(network, batch_size=1)
print_progress("%-20s building..." % network)
with nnvm.compiler.build_config(opt_level=2, add_pass=['AlterOpLayout']):
graph, lib, params = nnvm.compiler.build(
net, target=target, shape={'data': input_shape}, params=params, dtype=dtype)
tmp = tempdir()
if 'android' in str(target):
from tvm.contrib import ndk
filename = "%s.so" % network
lib.export_library(tmp.relpath(filename), ndk.create_shared)
else:
filename = "%s.tar" % network
lib.export_library(tmp.relpath(filename))
# upload library and params
print_progress("%-20s uploading..." % network)
ctx = remote.context(str(target), 0)
remote.upload(tmp.relpath(filename))
rparams = {k: tvm.nd.array(v, ctx) for k, v in params.items()}
rlib = remote.load_module(filename)
module = runtime.create(graph, rlib, ctx)
data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype))
module.set_input('data', data_tvm)
module.set_input(**rparams)
# evaluate
print_progress("%-20s evaluating..." % network)
ftimer = module.module.time_evaluator("run", ctx, number=args.number, repeat=3)
prof_res = np.array(ftimer().results) * 1000 # multiply 1000 for converting to millisecond
print("%-20s %-19s (%s)" % (network, "%.2f ms" % np.mean(prof_res), "%.2f ms" % np.std(prof_res)))
""" Benchmark script for performance on GPUs. """Benchmark script for ImageNet models on GPU.
see README.md for the usage and results of this script.
For example, run the file with:
`python gpu_imagenet_bench.py --model=mobilenet --target=cuda`.
For more details about how to set up the inference environment on GPUs,
please refer to NNVM Tutorial: ImageNet Inference on the GPU
""" """
import time
import argparse import argparse
import numpy as np import numpy as np
import tvm import tvm
from tvm.contrib.util import tempdir
import tvm.contrib.graph_runtime as runtime
import nnvm.compiler import nnvm.compiler
import nnvm.testing import nnvm.testing
from tvm.contrib import util, nvcc
from tvm.contrib import graph_runtime as runtime
@tvm.register_func from util import get_network
def tvm_callback_cuda_compile(code):
ptx = nvcc.compile_cuda(code, target="ptx")
return ptx
def main(): if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, required=True, parser.add_argument("--network", type=str, choices=
choices=['resnet', 'mobilenet'], ['resnet-18', 'resnet-34', 'resnet-50', 'vgg-16', 'vgg-19',
help="The model type.") 'inception_v3', 'mobilenet', 'mobilenet_v2', 'densenet-121'])
parser.add_argument('--target', type=str, required=True, parser.add_argument("--model", type=str,
choices=['cuda', 'rocm', 'opencl', 'metal', 'nvptx'], choices=['1080ti', 'titanx', 'gfx900'], default='1080ti',
help="Compilation target.") help="The model of the test device. If your device is not listed in "
parser.add_argument('--opt-level', type=int, default=1, help="Level of optimization.") "the choices list, pick the most similar one as argument.")
parser.add_argument('--num-iter', type=int, default=1000, help="Number of iteration during benchmark.") parser.add_argument("--number", type=int, default=500)
parser.add_argument('--repeat', type=int, default=1, help="Number of repeative times.") parser.add_argument("--target", type=str,
choices=['cuda', 'opencl', 'rocm', 'nvptx', 'metal'], default='cuda',
help="The tvm compilation target")
args = parser.parse_args() args = parser.parse_args()
opt_level = args.opt_level
num_iter = args.num_iter
ctx = tvm.context(args.target, 0)
batch_size = 1
num_classes = 1000
image_shape = (3, 224, 224)
data_shape = (batch_size,) + image_shape dtype = 'float32'
out_shape = (batch_size, num_classes)
if args.model == 'resnet':
net, params = nnvm.testing.resnet.get_workload(
batch_size=1, image_shape=image_shape)
elif args.model == 'mobilenet':
net, params = nnvm.testing.mobilenet.get_workload(
batch_size=1, image_shape=image_shape)
else:
raise ValueError('no benchmark prepared for {}.'.format(args.model))
if args.target == "cuda": if args.network is None:
unroll = 1400 networks = ['resnet-50', 'mobilenet', 'vgg-19', 'inception_v3']
else: else:
unroll = 128 networks = [args.network]
with nnvm.compiler.build_config(opt_level=opt_level):
with tvm.build_config(auto_unroll_max_step=unroll,
unroll_explicit=(args.target != "cuda")):
graph, lib, params = nnvm.compiler.build(
net, args.target, shape={"data": data_shape}, params=params)
data = np.random.uniform(-1, 1, size=data_shape).astype("float32") target = tvm.target.create('%s -model=%s' % (args.target, args.model))
module = runtime.create(graph, lib, ctx)
module.set_input(**params)
module.set_input("data", data)
module.run()
out = module.get_output(0, tvm.nd.empty(out_shape))
out.asnumpy()
print('benchmark args: {}'.format(args)) print("--------------------------------------------------")
ftimer = module.module.time_evaluator("run", ctx, num_iter) print("%-20s %-20s" % ("Network Name", "Mean Inference Time (std dev)"))
for i in range(args.repeat): print("--------------------------------------------------")
prof_res = ftimer() for network in networks:
print(prof_res) net, params, input_shape, output_shape = get_network(network, batch_size=1)
# sleep for avoiding device overheat
if i + 1 != args.repeat:
time.sleep(45)
if __name__ == '__main__': with nnvm.compiler.build_config(opt_level=3):
main() graph, lib, params = nnvm.compiler.build(
net, target=target, shape={'data': input_shape}, params=params, dtype=dtype)
# create runtime
ctx = tvm.context(str(target), 0)
module = runtime.create(graph, lib, ctx)
data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype))
module.set_input('data', data_tvm)
module.set_input(**params)
# evaluate
ftimer = module.module.time_evaluator("run", ctx, number=args.number, repeat=3)
prof_res = np.array(ftimer().results) * 1000 # multiply 1000 for converting to millisecond
print("%-20s %-19s (%s)" % (network, "%.2f ms" % np.mean(prof_res), "%.2f ms" % np.std(prof_res)))
"""Benchmark script for ImageNet models on mobile GPU.
see README.md for the usage and results of this script.
"""
import argparse
import numpy as np
import tvm
from tvm.contrib.util import tempdir
import tvm.contrib.graph_runtime as runtime
import nnvm.compiler
import nnvm.testing
from util import get_network, print_progress
def evaluate_network(network, target, target_host, number):
# connect to remote device
tracker = tvm.rpc.connect_tracker(args.host, args.port)
remote = tracker.request(args.rpc_key)
print_progress(network)
net, params, input_shape, output_shape = get_network(network, batch_size=1)
print_progress("%-20s building..." % network)
with nnvm.compiler.build_config(opt_level=3):
graph, lib, params = nnvm.compiler.build(
net, target=target, target_host=target_host,
shape={'data': input_shape}, params=params, dtype=dtype)
tmp = tempdir()
if 'android' in str(target) or 'android' in str(target_host):
from tvm.contrib import ndk
filename = "%s.so" % network
lib.export_library(tmp.relpath(filename), ndk.create_shared)
else:
filename = "%s.tar" % network
lib.export_library(tmp.relpath(filename))
# upload library and params
print_progress("%-20s uploading..." % network)
ctx = remote.context(str(target), 0)
remote.upload(tmp.relpath(filename))
rparams = {k: tvm.nd.array(v, ctx) for k, v in params.items()}
rlib = remote.load_module(filename)
module = runtime.create(graph, rlib, ctx)
data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype))
module.set_input('data', data_tvm)
module.set_input(**rparams)
del rparams
# evaluate
print_progress("%-20s evaluating..." % network)
ftimer = module.module.time_evaluator("run", ctx, number=number, repeat=3)
prof_res = np.array(ftimer().results) * 1000 # multiply 1000 for converting to millisecond
print("%-20s %-19s (%s)" % (network, "%.2f ms" % np.mean(prof_res), "%.2f ms" % np.std(prof_res)))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--network", type=str, choices=
['resnet-18', 'resnet-34', 'vgg-16',
'mobilenet', 'mobilenet_v2', 'squeezenet v1.1'])
parser.add_argument("--model", type=str, choices=
['rk3399'], default='rk3399',
help="The model of the test device. If your device is not listed in "
"the choices list, pick the most similar one as argument.")
parser.add_argument("--host", type=str, default='localhost')
parser.add_argument("--port", type=int, default=9190)
parser.add_argument("--rpc-key", type=str, required=True)
parser.add_argument("--number", type=int, default=10)
args = parser.parse_args()
dtype = 'float32'
if args.network is None:
networks = ['squeezenet_v1.1', 'mobilenet', 'resnet-18', 'vgg-16']
else:
networks = [args.network]
target = tvm.target.mali(model=args.model)
target_host = tvm.target.arm_cpu(model=args.model)
print("--------------------------------------------------")
print("%-20s %-20s" % ("Network Name", "Mean Inference Time (std dev)"))
print("--------------------------------------------------")
for network in networks:
evaluate_network(network, target, target_host, args.number)
\ No newline at end of file
...@@ -27,20 +27,25 @@ def get_network(name, batch_size): ...@@ -27,20 +27,25 @@ def get_network(name, batch_size):
input_shape = (batch_size, 3, 224, 224) input_shape = (batch_size, 3, 224, 224)
output_shape = (batch_size, 1000) output_shape = (batch_size, 1000)
if "resnet" in name: if name == 'mobilenet':
net, params = nnvm.testing.mobilenet.get_workload(batch_size=batch_size)
elif name == 'mobilenet_v2':
net, params = nnvm.testing.mobilenet_v2.get_workload(batch_size=batch_size)
elif name == 'inception_v3':
input_shape = (1, 3, 299, 299)
net, params = nnvm.testing.inception_v3.get_workload(batch_size=batch_size)
elif "resnet" in name:
n_layer = int(name.split('-')[1]) n_layer = int(name.split('-')[1])
net, params = nnvm.testing.resnet.get_workload(num_layers=n_layer, batch_size=batch_size) net, params = nnvm.testing.resnet.get_workload(num_layers=n_layer, batch_size=batch_size)
elif "vgg" in name: elif "vgg" in name:
n_layer = int(name.split('-')[1]) n_layer = int(name.split('-')[1])
net, params = nnvm.testing.vgg.get_workload(num_layers=n_layer, batch_size=batch_size) net, params = nnvm.testing.vgg.get_workload(num_layers=n_layer, batch_size=batch_size)
elif name == 'mobilenet': elif "densenet" in name:
net, params = nnvm.testing.mobilenet.get_workload(batch_size=batch_size) n_layer = int(name.split('-')[1])
net, params = nnvm.testing.densenet.get_workload(num_layers=n_layer, batch_size=batch_size)
elif "squeezenet" in name: elif "squeezenet" in name:
version = name.split("_v")[1] version = name.split("_v")[1]
net, params = nnvm.testing.squeezenet.get_workload(batch_size=batch_size, version=version) net, params = nnvm.testing.squeezenet.get_workload(batch_size=batch_size, version=version)
elif name == 'inception_v3':
input_shape = (1, 3, 299, 299)
net, params = nnvm.testing.inception_v3.get_workload(batch_size=batch_size)
elif name == 'custom': elif name == 'custom':
# an example for custom network # an example for custom network
from nnvm.testing import utils from nnvm.testing import utils
......
...@@ -8,6 +8,7 @@ from . import mobilenet_v2 ...@@ -8,6 +8,7 @@ from . import mobilenet_v2
from . import mlp from . import mlp
from . import resnet from . import resnet
from . import vgg from . import vgg
from . import densenet
from . import squeezenet from . import squeezenet
from . import inception_v3 from . import inception_v3
from . import dcgan from . import dcgan
......
"""
DenseNet, load model from gluon model zoo
Reference:
Huang, Gao, et al. "Densely Connected Convolutional Networks." CVPR 2017
"""
from .utils import create_workload
from ..frontend.mxnet import _from_mxnet_impl
def get_workload(batch_size, num_classes=1000, num_layers=121, dtype="float32"):
"""Get benchmark workload for mobilenet
Parameters
----------
batch_size : int
The batch size used in the model
num_classes : int, optional
Number of classes
num_layers : int, optional
The number of layers
dtype : str, optional
The data type
Returns
-------
net : nnvm.Symbol
The computational graph
params : dict of str to NDArray
The parameters.
"""
import mxnet as mx
from mxnet.gluon.model_zoo.vision import get_model
image_shape = (1, 3, 224, 224)
block = get_model('densenet%d' % num_layers, classes=num_classes, pretrained=False)
data = mx.sym.Variable('data')
sym = block(data)
sym = mx.sym.SoftmaxOutput(sym)
net = _from_mxnet_impl(sym, {})
return create_workload(net, batch_size, image_shape[1:], dtype)
...@@ -46,18 +46,16 @@ def residual_unit(data, num_filter, stride, dim_match, name, bottle_neck=True): ...@@ -46,18 +46,16 @@ def residual_unit(data, num_filter, stride, dim_match, name, bottle_neck=True):
Base name of the operators Base name of the operators
""" """
if bottle_neck: if bottle_neck:
# the same as https://github.com/facebook/fb.resnet.torch#notes,
# a bit difference with origin paper
bn1 = sym.batch_norm(data=data, epsilon=2e-5, name=name + '_bn1') bn1 = sym.batch_norm(data=data, epsilon=2e-5, name=name + '_bn1')
act1 = sym.relu(data=bn1, name=name + '_relu1') act1 = sym.relu(data=bn1, name=name + '_relu1')
conv1 = sym.conv2d( conv1 = sym.conv2d(
data=act1, channels=int(num_filter*0.25), kernel_size=(1, 1), data=act1, channels=int(num_filter*0.25), kernel_size=(1, 1),
strides=(1, 1), padding=(0, 0), use_bias=False, name=name + '_conv1') strides=stride, padding=(0, 0), use_bias=False, name=name + '_conv1')
bn2 = sym.batch_norm(data=conv1, epsilon=2e-5, name=name + '_bn2') bn2 = sym.batch_norm(data=conv1, epsilon=2e-5, name=name + '_bn2')
act2 = sym.relu(data=bn2, name=name + '_relu2') act2 = sym.relu(data=bn2, name=name + '_relu2')
conv2 = sym.conv2d( conv2 = sym.conv2d(
data=act2, channels=int(num_filter*0.25), kernel_size=(3, 3), data=act2, channels=int(num_filter*0.25), kernel_size=(3, 3),
strides=stride, padding=(1, 1), use_bias=False, name=name + '_conv2') strides=(1, 1), padding=(1, 1), use_bias=False, name=name + '_conv2')
bn3 = sym.batch_norm(data=conv2, epsilon=2e-5, name=name + '_bn3') bn3 = sym.batch_norm(data=conv2, epsilon=2e-5, name=name + '_bn3')
act3 = sym.relu(data=bn3, name=name + '_relu3') act3 = sym.relu(data=bn3, name=name + '_relu3')
conv3 = sym.conv2d( conv3 = sym.conv2d(
......
...@@ -46,14 +46,13 @@ def residual_unit(data, num_filter, stride, dim_match, name, bottle_neck=True, b ...@@ -46,14 +46,13 @@ def residual_unit(data, num_filter, stride, dim_match, name, bottle_neck=True, b
Workspace used in convolution operator Workspace used in convolution operator
""" """
if bottle_neck: if bottle_neck:
# the same as https://github.com/facebook/fb.resnet.torch#notes, a bit difference with origin paper
bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn1') bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn1')
act1 = mx.sym.Activation(data=bn1, act_type='relu', name=name + '_relu1') act1 = mx.sym.Activation(data=bn1, act_type='relu', name=name + '_relu1')
conv1 = mx.sym.Convolution(data=act1, num_filter=int(num_filter*0.25), kernel=(1,1), stride=(1,1), pad=(0,0), conv1 = mx.sym.Convolution(data=act1, num_filter=int(num_filter*0.25), kernel=(1,1), stride=stride, pad=(0,0),
no_bias=True, workspace=workspace, name=name + '_conv1') no_bias=True, workspace=workspace, name=name + '_conv1')
bn2 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn2') bn2 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn2')
act2 = mx.sym.Activation(data=bn2, act_type='relu', name=name + '_relu2') act2 = mx.sym.Activation(data=bn2, act_type='relu', name=name + '_relu2')
conv2 = mx.sym.Convolution(data=act2, num_filter=int(num_filter*0.25), kernel=(3,3), stride=stride, pad=(1,1), conv2 = mx.sym.Convolution(data=act2, num_filter=int(num_filter*0.25), kernel=(3,3), stride=(1,1), pad=(1,1),
no_bias=True, workspace=workspace, name=name + '_conv2') no_bias=True, workspace=workspace, name=name + '_conv2')
bn3 = mx.sym.BatchNorm(data=conv2, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn3') bn3 = mx.sym.BatchNorm(data=conv2, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn3')
act3 = mx.sym.Activation(data=bn3, act_type='relu', name=name + '_relu3') act3 = mx.sym.Activation(data=bn3, act_type='relu', name=name + '_relu3')
......
...@@ -164,6 +164,31 @@ def measure_option(builder, runner): ...@@ -164,6 +164,31 @@ def measure_option(builder, runner):
Specify how to build programs Specify how to build programs
runner: Runner runner: Runner
Specify how to run programs Specify how to run programs
Examples
--------
# example setting for using local devices
>>> measure_option = autotvm.measure_option(
>>> builder=autotvm.LocalBuilder(), # use all local cpu cores for compilation
>>> runner=autotvm.LocalRunner( # measure them sequentially
>>> number=10,
>>> timeout=5)
>>> )
# example setting for using remote devices
>>> measure_option = autotvm.measure_option(
>>> builder=autotvm.LocalBuilder(), # use all local cpu cores for compilation
>>> runner=autotvm.RPCRunner(
>>> 'rasp3b', 'locahost', 9190, # device key, host and port of the rpc tracker
>>> number=4,
>>> timeout=4) # timeout of a run on the device. RPC request waiting time is excluded.
>>>)
Note
----
To make measurement results accurate, you should pick the correct value for the argument
`number` and `repeat` in Runner(). Using `min_repeat_ms` can dynamically adjusts `number`,
so it is recommended. The typical value for NVIDIA GPU is 100 ms.
""" """
from .measure_methods import LocalBuilder, LocalRunner from .measure_methods import LocalBuilder, LocalRunner
......
...@@ -72,12 +72,15 @@ class LocalBuilder(Builder): ...@@ -72,12 +72,15 @@ class LocalBuilder(Builder):
raise ValueError("Invalid build_func" + build_func) raise ValueError("Invalid build_func" + build_func)
self.build_func = build_func self.build_func = build_func
self.tmp_dir = tempfile.mkdtemp()
self.executor = LocalExecutor(timeout=timeout) self.executor = LocalExecutor(timeout=timeout)
self.tmp_dir = tempfile.mkdtemp()
def build(self, measure_inputs): def build(self, measure_inputs):
results = [] results = []
shutil.rmtree(self.tmp_dir)
self.tmp_dir = tempfile.mkdtemp()
for i in range(0, len(measure_inputs), self.n_parallel): for i in range(0, len(measure_inputs), self.n_parallel):
futures = [] futures = []
for inp in measure_inputs[i:i + self.n_parallel]: for inp in measure_inputs[i:i + self.n_parallel]:
...@@ -95,7 +98,7 @@ class LocalBuilder(Builder): ...@@ -95,7 +98,7 @@ class LocalBuilder(Builder):
results.append(MeasureResult((res,), MeasureErrorNo.BUILD_TIMEOUT, results.append(MeasureResult((res,), MeasureErrorNo.BUILD_TIMEOUT,
self.timeout, time.time())) self.timeout, time.time()))
elif res.error is not None: elif res.error is not None:
# instantiation errorD # instantiation error
if isinstance(res.error, InstantiationError): if isinstance(res.error, InstantiationError):
results.append(MeasureResult((res.error,), results.append(MeasureResult((res.error,),
MeasureErrorNo.INSTANTIATION_ERROR, MeasureErrorNo.INSTANTIATION_ERROR,
...@@ -120,9 +123,6 @@ class LocalBuilder(Builder): ...@@ -120,9 +123,6 @@ class LocalBuilder(Builder):
return results return results
def __del__(self):
shutil.rmtree(self.tmp_dir)
class RPCRunner(Runner): class RPCRunner(Runner):
"""Run generated code on remove devices. """Run generated code on remove devices.
...@@ -519,7 +519,7 @@ def request_remote(device_key, host=None, port=None, priority=1, timeout=60): ...@@ -519,7 +519,7 @@ def request_remote(device_key, host=None, port=None, priority=1, timeout=60):
return remote return remote
def check_remote(target, device_key, host=None, port=None, priority=2, timeout=10): def check_remote(target, device_key, host=None, port=None, priority=100, timeout=10):
""" """
Check the availability of a remote device Check the availability of a remote device
......
...@@ -271,7 +271,7 @@ if __name__ == '__main__': ...@@ -271,7 +271,7 @@ if __name__ == '__main__':
parser.add_argument("--code", action='store_true') parser.add_argument("--code", action='store_true')
args = parser.parse_args() args = parser.parse_args()
logger.basicConfig(level=logger.INFO) logging.basicConfig(level=logger.INFO)
if args.mode == 'pick': if args.mode == 'pick':
args.o = args.o or args.i + ".best.log" args.o = args.o or args.i + ".best.log"
......
...@@ -9,7 +9,8 @@ of typical tasks of interest. ...@@ -9,7 +9,8 @@ of typical tasks of interest.
from .task import Task, create, register, template, get_config, args_to_workload from .task import Task, create, register, template, get_config, args_to_workload
from .space import ConfigSpace, ConfigEntity from .space import ConfigSpace, ConfigEntity
from .code_hash import attach_code_hash, attach_code_hash_to_arg from .code_hash import attach_code_hash, attach_code_hash_to_arg
from .dispatcher import DispatchContext, ApplyConfig, ApplyHistoryBest, FallbackContext, dispatcher from .dispatcher import dispatcher, DispatchContext, ApplyConfig, ApplyHistoryBest, \
FallbackContext, clear_fallback_cache
from .topi_integration import register_topi_compute, register_topi_schedule from .topi_integration import register_topi_compute, register_topi_schedule
from .nnvm_integration import extract_from_graph from .nnvm_integration import extract_from_graph, extract_from_multiple_graph
...@@ -289,15 +289,20 @@ class FallbackContext(DispatchContext): ...@@ -289,15 +289,20 @@ class FallbackContext(DispatchContext):
self.memory = {} self.memory = {}
self.silent = False self.silent = False
# a set to prevent print duplicated message
self.messages = set()
def _query_inside(self, target, workload): def _query_inside(self, target, workload):
key = (str(target), workload) key = (str(target), workload)
if key in self.memory: if key in self.memory:
return self.memory[key] return self.memory[key]
if not self.silent: if not self.silent:
logger.warning( msg = "Cannot find config for target=%s, workload=%s. A fallback configuration "\
"Cannot find config for target=%s, workload=%s. A fallback configuration " "is used, which may bring great performance regression." % (target, workload)
"is used, which may bring great performance regression.", target, workload) if msg not in self.messages:
self.messages.add(msg)
logger.warning(msg)
cfg = FallbackConfigEntity() cfg = FallbackConfigEntity()
# cache this config # cache this config
...@@ -320,3 +325,23 @@ class FallbackContext(DispatchContext): ...@@ -320,3 +325,23 @@ class FallbackContext(DispatchContext):
del self.memory[key] del self.memory[key]
DispatchContext.current = FallbackContext() DispatchContext.current = FallbackContext()
def clear_fallback_cache(target, workload):
"""Clear fallback cache. Pass the same argument as _query_inside to this function
to clean the cache.
Parameters
----------
target: Target
The current target
workload : Workload
The current workload.
Note
----
This is used in alter_op_layout to clear the bad cache created before call topi compute function
"""
context = DispatchContext.current
while not isinstance(context, FallbackContext):
context = context._old_ctx
context.clear_cache(target, workload)
...@@ -208,7 +208,7 @@ def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None): ...@@ -208,7 +208,7 @@ def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):
---------- ----------
graph : Graph graph : Graph
The graph to tune The graph to tune
shape : dict of str to tuple, optional shape : dict of str to tuple
The input shape to the graph The input shape to the graph
dtype : str or dict of str to str dtype : str or dict of str to str
The input types to the graph The input types to the graph
...@@ -249,6 +249,69 @@ def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None): ...@@ -249,6 +249,69 @@ def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):
logger.disabled = old_state logger.disabled = old_state
# create tasks for target
tasks = []
for task_name, args in env.get_tasks():
tasks.append(create(task_name, args,
target=target, target_host=target_host,
template_key='direct'))
return tasks
def extract_from_multiple_graph(graphs, shapes, dtypes, target, symbols, target_host=None):
""" Extract tuning tasks from multiple nnvm graphs.
This function is the multiple graph version of extract_from_graph
Parameters
----------
graphs : List of Graph
The list of graphs to tune
shapes : List of dict of str to tuple
The input shape to the graph
dtypes : List of str or dict of str to str
The input types to the graph
target: tvm.target.Target
The compilation target
symbols : Array of nnvm.symbol
Array of nnvm symbols want to be tuned
target_host: tvm.target.Target
The host compilation target
Returns
-------
task: Array of autotvm.task.Task
collected tasks
"""
import nnvm.compiler
env = TaskExtractEnv.get()
topi_funcs = []
for sym_name in symbols:
if sym_name in env.symbol2topi:
topi_funcs.extend(env.symbol2topi[sym_name])
else:
warnings.warn("Symbol %s is not tunable, ignored" % sym_name)
# run compiler to collect all TOPI calls during compilation
env.reset(topi_funcs)
# disable logger temporarily
old_state = logger.disabled
logger.disabled = True
# use a "tracing" target to do a fake compile for collecting topi calls
tracing_target = _target.create("llvm -device=tracing")
nnvm.compiler.engine.clear_cache()
for graph, shape, dtype in zip(graphs, shapes, dtypes):
nnvm.compiler.build(graph, target=tracing_target, shape=shape, dtype=dtype)
logger.disabled = old_state
# create tasks for target
tasks = [] tasks = []
for task_name, args in env.get_tasks(): for task_name, args in env.get_tasks():
tasks.append(create(task_name, args, tasks.append(create(task_name, args,
......
...@@ -900,6 +900,7 @@ class ConfigEntity(ConfigSpace): ...@@ -900,6 +900,7 @@ class ConfigEntity(ConfigSpace):
return "%s,%s,%s,%d" % (str(self._entity_map)[12:-1], self.template_key, return "%s,%s,%s,%d" % (str(self._entity_map)[12:-1], self.template_key,
self.code_hash, self.index) self.code_hash, self.index)
class FallbackConfigEntity(ConfigSpace): class FallbackConfigEntity(ConfigSpace):
"""The config entity created to support fallback""" """The config entity created to support fallback"""
...@@ -926,18 +927,74 @@ class FallbackConfigEntity(ConfigSpace): ...@@ -926,18 +927,74 @@ class FallbackConfigEntity(ConfigSpace):
Then cfg.fallback_split('tile_0', [-1, 8, 4]) will give you cfg['tile_0'].size = [7, 7, 1] Then cfg.fallback_split('tile_0', [-1, 8, 4]) will give you cfg['tile_0'].size = [7, 7, 1]
""" """
space = self.space_map[name] space = self.space_map[name]
assert isinstance(space, SplitSpace)
assert len(constraints) == space.num_outputs assert len(constraints) == space.num_outputs
indices = np.arange(space.num_outputs)
# '-1' means no constraint # '-1' means no constraint
constraints = [x if x != -1 else 1e10 for x in constraints] constraints = [x if x != -1 else 1e10 for x in constraints]
for entity in reversed(space.entities): entity = self._entity_map[name]
if all([entity.size[i] <= constraints[i] for i in indices]): now = space.product
self._entity_map[name] = entity
return for i in reversed(range(space.num_outputs)):
factors = get_factors(now)
find = len(factors) - 1
for j, f in enumerate(factors):
if f > constraints[i]:
find = j - 1
break
if find >= 0:
entity.size[i] = factors[find]
now //= factors[find]
else:
raise RuntimeError("Cannot find feasible fallback split entity for node: " + name)
def fallback_with_reference_log(self, ref_log):
"""A data driven fallback mechanism.
We use tuned parameters from TopHub as reference data.
For an unseen shape, we find the most similar tuned one from TopHub and
mimic its parameters.
Parameters
----------
ref_log: List of (MeasureInput, MeasureResult)
The reference log
"""
knob_names = [x for x in self.space_map.keys() if
isinstance(self.space_map[x], SplitSpace)]
# find best match config in reference data by matching tiling factors
factor_list = []
for knob_name in knob_names:
factor_list.append(get_factors(self.space_map[knob_name].product))
best_match_cfg = None
best_match_score = 0
for inp, _ in ref_log:
match_score = 0
for i, knob_name in enumerate(knob_names):
factors = get_factors(int(np.prod(inp.config[knob_name].size)))
match_score += (float(len(set(factor_list[i]).intersection(factors))) /
len(factor_list[i]))
if match_score > best_match_score:
best_match_score, best_match_cfg = match_score, inp.config
if best_match_cfg is None:
return
# mimic its tiling strategy
for knob_name in knob_names:
constraint = list(best_match_cfg[knob_name].size)
constraint[0] = -1
self.fallback_split(knob_name, constraint)
raise RuntimeError("Cannot find feasible fallback split entity for node: " + name) # copy other knobs
for knob_name in self.space_map.keys():
if not isinstance(self.space_map[knob_name], SplitSpace):
self._entity_map[knob_name] = best_match_cfg[knob_name]
def __repr__(self): def __repr__(self):
return "%s,%s,%s" % (str(self._entity_map)[12:-1], self.template_key, self.code_hash) return "%s,%s,%s" % (str(self._entity_map)[12:-1], self.template_key, self.code_hash)
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
TopHub: Tensor Operator Hub TopHub: Tensor Operator Hub
To get the best performance, we typically need auto-tuning for the specific devices. To get the best performance, we typically need auto-tuning for the specific devices.
TVM releases pre-tuned parameters in TopHub for some common networks and hardware targets. TVM releases pre-tuned parameters in TopHub for some common networks and hardware targets.
TVM will download these parameters for you when you create the target for the first time. TVM will download these parameters for you when you call nnvm.compiler.build_module .
""" """
# pylint: disable=invalid-name # pylint: disable=invalid-name
...@@ -13,15 +13,21 @@ import sys ...@@ -13,15 +13,21 @@ import sys
from .task import ApplyHistoryBest from .task import ApplyHistoryBest
from .. import target as _target from .. import target as _target
from ..contrib.download import download from ..contrib.download import download
from .record import load_from_file
# root path to store TopHub files # root path to store TopHub files
AUTOTVM_TOPHUB_ROOT_PATH = os.path.join(os.path.expanduser('~'), ".tvm", "tophub") AUTOTVM_TOPHUB_ROOT_PATH = os.path.join(os.path.expanduser('~'), ".tvm", "tophub")
# the version of each package # the version of each package
PACKAGE_VERSION = { PACKAGE_VERSION = {
'vta': "v0.01",
'arm_cpu': "v0.01", 'arm_cpu': "v0.01",
'cuda': "v0.01",
'cuda': "v0.02",
'rocm': "v0.01",
'opencl': "v0.01",
'mali': "v0.01",
'vta': "v0.01",
} }
logger = logging.getLogger('autotvm') logger = logging.getLogger('autotvm')
...@@ -30,6 +36,9 @@ def _alias(name): ...@@ -30,6 +36,9 @@ def _alias(name):
"""convert alias for some packages""" """convert alias for some packages"""
table = { table = {
'vtacpu': 'vta', 'vtacpu': 'vta',
'metal': 'opencl',
'nvptx': 'cuda'
} }
return table.get(name, name) return table.get(name, name)
...@@ -60,6 +69,7 @@ def context(target, extra_files=None): ...@@ -60,6 +69,7 @@ def context(target, extra_files=None):
all_packages = list(PACKAGE_VERSION.keys()) all_packages = list(PACKAGE_VERSION.keys())
for name in possible_names: for name in possible_names:
name = _alias(name)
if name in all_packages: if name in all_packages:
check_backend(name) check_backend(name)
...@@ -121,3 +131,51 @@ def download_package(package_name): ...@@ -121,3 +131,51 @@ def download_package(package_name):
logger.info("Download pre-tuned parameters package %s", package_name) logger.info("Download pre-tuned parameters package %s", package_name)
download("https://raw.githubusercontent.com/uwsaml/tvm-distro/master/tophub/%s" download("https://raw.githubusercontent.com/uwsaml/tvm-distro/master/tophub/%s"
% package_name, os.path.join(rootpath, package_name), True, verbose=0) % package_name, os.path.join(rootpath, package_name), True, verbose=0)
# global cache for load_reference_log
REFERENCE_LOG_CACHE = {}
def load_reference_log(backend, model, workload_name, template_key):
""" Load reference log from TopHub to support fallback in template.
Template will use these reference logs to choose fallback config.
Parameters
----------
backend: str
The backend name
model: str
The name of the model
workload_name: str
The name of the workload. (The first item in the workload tuple)
template_key: str
The template key
"""
backend = _alias(backend)
version = PACKAGE_VERSION[backend]
package_name = "%s_%s.log" % (backend, version)
filename = os.path.join(AUTOTVM_TOPHUB_ROOT_PATH, package_name)
global REFERENCE_LOG_CACHE
key = (backend, model, workload_name, template_key)
if key not in REFERENCE_LOG_CACHE:
tmp = []
if os.path.isfile(os.path.join(AUTOTVM_TOPHUB_ROOT_PATH, package_name)):
find = False
inp = None
for inp, res in load_from_file(filename):
if model == inp.target.model:
find = True
break
if not find and inp:
model = inp.target.model
for inp, res in load_from_file(filename):
if (model == inp.target.model and inp.task.workload[0] == workload_name and
inp.config.template_key == template_key):
tmp.append((inp, res))
REFERENCE_LOG_CACHE[key] = tmp
return REFERENCE_LOG_CACHE[key]
...@@ -34,6 +34,7 @@ class Tuner(object): ...@@ -34,6 +34,7 @@ class Tuner(object):
# time to leave # time to leave
self.ttl = None self.ttl = None
self.n_trial = None self.n_trial = None
self.early_stopping = None
def has_next(self): def has_next(self):
"""Whether has next untried config in the space """Whether has next untried config in the space
...@@ -92,6 +93,7 @@ class Tuner(object): ...@@ -92,6 +93,7 @@ class Tuner(object):
n_parallel = getattr(measure_batch, 'n_parallel', 1) n_parallel = getattr(measure_batch, 'n_parallel', 1)
early_stopping = early_stopping or 1e9 early_stopping = early_stopping or 1e9
self.n_trial = n_trial self.n_trial = n_trial
self.early_stopping = early_stopping
old_level = logger.level old_level = logger.level
...@@ -127,18 +129,18 @@ class Tuner(object): ...@@ -127,18 +129,18 @@ class Tuner(object):
res, config) res, config)
i += len(results) i += len(results)
self.ttl = min(early_stopping + self.best_iter, n_trial) - i
self.update(inputs, results) self.update(inputs, results)
for callback in callbacks: for callback in callbacks:
callback(self, inputs, results) callback(self, inputs, results)
self.ttl = min(early_stopping + self.best_iter, n_trial) - i
if i >= self.best_iter + early_stopping: if i >= self.best_iter + early_stopping:
logger.debug("Early stopped. Best iter: %d.", self.best_iter) logger.debug("Early stopped. Best iter: %d.", self.best_iter)
break break
if error_ct > 150: if error_ct > 150:
logging.basicConfig()
logger.warning("Too many errors happen in the tuning. Now is in debug mode") logger.warning("Too many errors happen in the tuning. Now is in debug mode")
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
else: else:
......
...@@ -53,7 +53,7 @@ class XGBoostCostModel(CostModel): ...@@ -53,7 +53,7 @@ class XGBoostCostModel(CostModel):
upper_model: XGBoostCostModel, optional upper_model: XGBoostCostModel, optional
The upper model used in transfer learning The upper model used in transfer learning
""" """
def __init__(self, task, feature_type, loss_type, num_threads=4, log_interval=25, def __init__(self, task, feature_type, loss_type, num_threads=None, log_interval=25,
upper_model=None): upper_model=None):
super(XGBoostCostModel, self).__init__() super(XGBoostCostModel, self).__init__()
......
...@@ -64,8 +64,8 @@ def download(url, path, overwrite=False, size_compare=False, verbose=1): ...@@ -64,8 +64,8 @@ def download(url, path, overwrite=False, size_compare=False, verbose=1):
progress_size = int(count * block_size) progress_size = int(count * block_size)
speed = int(progress_size / (1024 * duration)) speed = int(progress_size / (1024 * duration))
percent = min(int(count * block_size * 100 / total_size), 100) percent = min(int(count * block_size * 100 / total_size), 100)
sys.stdout.write("\r...%d%%, %d MB, %d KB/s, %d seconds passed" % sys.stdout.write("\r...%d%%, %.2f MB, %d KB/s, %d seconds passed" %
(percent, progress_size / (1024 * 1024), speed, duration)) (percent, progress_size / (1024.0 * 1024), speed, duration))
sys.stdout.flush() sys.stdout.flush()
if sys.version_info >= (3,): if sys.version_info >= (3,):
......
...@@ -105,6 +105,13 @@ class Target(NodeBase): ...@@ -105,6 +105,13 @@ class Target(NodeBase):
self._libs = [l.value for l in self.libs_array] self._libs = [l.value for l in self.libs_array]
return self._libs return self._libs
@property
def model(self):
for opt in self.options_array:
if opt.value.startswith('-model='):
return opt.value[7:]
return 'unknown'
def __enter__(self): def __enter__(self):
_api_internal._EnterTargetScope(self) _api_internal._EnterTargetScope(self)
return self return self
...@@ -354,52 +361,60 @@ def generic_func(fdefault): ...@@ -354,52 +361,60 @@ def generic_func(fdefault):
return fdecorate return fdecorate
def cuda(options=None): def cuda(model='unknown', options=None):
"""Returns a cuda target. """Returns a cuda target.
Parameters Parameters
---------- ----------
model: str
The model of cuda device (e.g. 1080ti)
options : str or list of str options : str or list of str
Additional options Additional options
""" """
options = _merge_opts([], options) opts = _merge_opts(['-model=%s' % model], options)
return _api_internal._TargetCreate("cuda", *options) return _api_internal._TargetCreate("cuda", *opts)
def rocm(options=None): def rocm(model='unknown', options=None):
"""Returns a ROCM target. """Returns a ROCM target.
Parameters Parameters
---------- ----------
model: str
The model of this device
options : str or list of str options : str or list of str
Additional options Additional options
""" """
options = _merge_opts([], options) opts = _merge_opts(["-model=%s" % model], options)
return _api_internal._TargetCreate("rocm", *options) return _api_internal._TargetCreate("rocm", *opts)
def mali(options=None): def mali(model='unknown', options=None):
"""Returns a ARM Mali GPU target. """Returns a ARM Mali GPU target.
Parameters Parameters
---------- ----------
model: str
The model of this device
options : str or list of str options : str or list of str
Additional options Additional options
""" """
opts = ["-device=mali"] opts = ["-device=mali", '-model=%s' % model]
opts = _merge_opts(opts, options) opts = _merge_opts(opts, options)
return _api_internal._TargetCreate("opencl", *opts) return _api_internal._TargetCreate("opencl", *opts)
def intel_graphics(options=None): def intel_graphics(model='unknown', options=None):
"""Returns an Intel Graphics target. """Returns an Intel Graphics target.
Parameters Parameters
---------- ----------
model: str
The model of this device
options : str or list of str options : str or list of str
Additional options Additional options
""" """
opts = ["-device=intel_graphics"] opts = ["-device=intel_graphics", '-model=%s' % model]
opts = _merge_opts(opts, options) opts = _merge_opts(opts, options)
return _api_internal._TargetCreate("opencl", *opts) return _api_internal._TargetCreate("opencl", *opts)
...@@ -436,6 +451,7 @@ def arm_cpu(model='unknown', options=None): ...@@ -436,6 +451,7 @@ def arm_cpu(model='unknown', options=None):
"rasp3b": ["-model=bcm2837", "-target=armv7l-linux-gnueabihf -mattr=+neon"], "rasp3b": ["-model=bcm2837", "-target=armv7l-linux-gnueabihf -mattr=+neon"],
"rk3399": ["-model=rk3399", "-target=aarch64-linux-gnu -mattr=+neon"], "rk3399": ["-model=rk3399", "-target=aarch64-linux-gnu -mattr=+neon"],
"pynq": ["-model=pynq", "-target=armv7a-linux-eabi -mattr=+neon"], "pynq": ["-model=pynq", "-target=armv7a-linux-eabi -mattr=+neon"],
"ultra96": ["-model=ultra96", "-target=aarch64-linux-gnu -mattr=+neon"],
} }
pre_defined_opt = trans_table.get(model, ["-model=%s" % model]) pre_defined_opt = trans_table.get(model, ["-model=%s" % model])
...@@ -494,5 +510,4 @@ def current_target(allow_none=True): ...@@ -494,5 +510,4 @@ def current_target(allow_none=True):
------ ------
ValueError if current target is not set. ValueError if current target is not set.
""" """
target_str = _api_internal._GetCurrentTarget(allow_none) return _api_internal._GetCurrentTarget(allow_none)
return create(target_str) if target_str is not None else None
...@@ -583,8 +583,7 @@ class Canonical::Internal : public IRMutator { ...@@ -583,8 +583,7 @@ class Canonical::Internal : public IRMutator {
while (i < suma->elem.size() && j < sumb->elem.size()) { while (i < suma->elem.size() && j < sumb->elem.size()) {
const auto& a = suma->elem[i]; const auto& a = suma->elem[i];
const auto& b = sumb->elem[j]; const auto& b = sumb->elem[j];
if (a.value.same_as(b.value)) { if (a.value.same_as(b.value) && a.level == b.level) {
CHECK_EQ(a.level, b.level);
ComExprEntry e = a; ComExprEntry e = a;
e.scale = a.scale + b.scale * bscale; e.scale = a.scale + b.scale * bscale;
if (e.scale != 0) { if (e.scale != 0) {
......
...@@ -252,11 +252,8 @@ void OpenCLWorkspace::Init(const std::string& type_key, const std::string& devic ...@@ -252,11 +252,8 @@ void OpenCLWorkspace::Init(const std::string& type_key, const std::string& devic
this->platform_name = cl::GetPlatformInfo(platform_id, CL_PLATFORM_NAME); this->platform_name = cl::GetPlatformInfo(platform_id, CL_PLATFORM_NAME);
this->device_type = device_type; this->device_type = device_type;
this->devices = devices_matched; this->devices = devices_matched;
LOG(INFO) << "Initialize OpenCL platform \'" << this->platform_name << '\'';
break; break;
} }
LOG(INFO) << "\'" << cl::GetPlatformInfo(platform_id, CL_PLATFORM_NAME)
<< "\' platform has no OpenCL device: " << device_type << " mode";
} }
if (this->platform_id == nullptr) { if (this->platform_id == nullptr) {
LOG(WARNING) << "No OpenCL device"; LOG(WARNING) << "No OpenCL device";
...@@ -273,9 +270,6 @@ void OpenCLWorkspace::Init(const std::string& type_key, const std::string& devic ...@@ -273,9 +270,6 @@ void OpenCLWorkspace::Init(const std::string& type_key, const std::string& devic
this->queues.push_back( this->queues.push_back(
clCreateCommandQueue(this->context, did, 0, &err_code)); clCreateCommandQueue(this->context, did, 0, &err_code));
OPENCL_CHECK_ERROR(err_code); OPENCL_CHECK_ERROR(err_code);
LOG(INFO) << type_key << "(" << i
<< ")=\'" << cl::GetDeviceInfo(did, CL_DEVICE_NAME)
<< "\' cl_device_id=" << did;
} }
} }
......
...@@ -30,14 +30,21 @@ def test_split(): ...@@ -30,14 +30,21 @@ def test_split():
cfg = FallbackConfigEntity() cfg = FallbackConfigEntity()
cfg.define_split('tile_n', cfg.axis(128), num_outputs=3) cfg.define_split('tile_n', cfg.axis(128), num_outputs=3)
cfg.fallback_split('tile_n', [-1, 8, 4]) cfg.fallback_split('tile_n', [-1, 8, 4])
assert cfg['tile_n'].size == [4, 8, 4] assert cfg['tile_n'].size == [4, 8, 4]
cfg = FallbackConfigEntity() cfg = FallbackConfigEntity()
cfg.define_split('tile_n', cfg.axis(49), num_outputs=3) cfg.define_split('tile_n', cfg.axis(49), num_outputs=3)
cfg.fallback_split('tile_n', [-1, 8, 4]) cfg.fallback_split('tile_n', [-1, 8, 4])
assert cfg['tile_n'].size == [7, 7, 1] assert cfg['tile_n'].size == [7, 7, 1]
cfg = FallbackConfigEntity()
cfg.define_split('tile_n', cfg.axis(49), num_outputs=3)
try:
cfg.fallback_split('tile_n', [-1, 1, 0])
assert False
except RuntimeError:
pass
if __name__ == '__main__': if __name__ == '__main__':
test_split() test_split()
...@@ -61,7 +61,6 @@ def test_make_attrs(): ...@@ -61,7 +61,6 @@ def test_make_attrs():
datrr = tvm.load_json(tvm.save_json(dattr)) datrr = tvm.load_json(tvm.save_json(dattr))
assert dattr.name.value == "xyz" assert dattr.name.value == "xyz"
def test_make_sum(): def test_make_sum():
A = tvm.placeholder((2, 10), name='A') A = tvm.placeholder((2, 10), name='A')
k = tvm.reduce_axis((0,10), "k") k = tvm.reduce_axis((0,10), "k")
......
...@@ -34,20 +34,21 @@ def test_target_dispatch(): ...@@ -34,20 +34,21 @@ def test_target_dispatch():
with tvm.target.create("metal"): with tvm.target.create("metal"):
assert mygeneric(1) == 3 assert mygeneric(1) == 3
assert tvm.target.current_target() == None assert tvm.target.current_target() is None
def test_target_string_parse(): def test_target_string_parse():
target = tvm.target.create("cuda -libs=cublas,cudnn") target = tvm.target.create("cuda -model=unknown -libs=cublas,cudnn")
assert target.target_name == "cuda" assert target.target_name == "cuda"
assert target.options == ['-libs=cublas,cudnn'] assert target.options == ['-model=unknown', '-libs=cublas,cudnn']
assert target.keys == ['cuda', 'gpu'] assert target.keys == ['cuda', 'gpu']
assert target.libs == ['cublas', 'cudnn'] assert target.libs == ['cublas', 'cudnn']
assert str(target) == str(tvm.target.cuda("-libs=cublas,cudnn")) assert str(target) == str(tvm.target.cuda(options="-libs=cublas,cudnn"))
assert tvm.target.intel_graphics().device_name == "intel_graphics" assert tvm.target.intel_graphics().device_name == "intel_graphics"
assert tvm.target.mali().device_name == "mali"
assert tvm.target.arm_cpu().device_name == "arm_cpu"
if __name__ == "__main__": if __name__ == "__main__":
test_target_dispatch() test_target_dispatch()
......
...@@ -42,9 +42,24 @@ def decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype): ...@@ -42,9 +42,24 @@ def decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype):
"""spatial packing template""" """spatial packing template"""
return _decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype, num_tile=2) return _decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype, num_tile=2)
@autotvm.task.register_topi_schedule(schedule_conv2d_nchw, 'arm_cpu', ['direct', 'winograd']) @autotvm.register_topi_schedule(schedule_conv2d_nchw, 'arm_cpu', ['direct', 'winograd'])
def schedule_conv2d_nchw_arm_cpu(cfg, outs): def schedule_conv2d_nchw_arm_cpu(cfg, outs):
"""TOPI schedule callback""" """TOPI schedule callback for conv2d
Parameters
----------
cfg: ConfigEntity
The config for this template
outs: Array of Tensor
The computation graph description of conv2d
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for conv2d.
"""
s = tvm.create_schedule([x.op for x in outs]) s = tvm.create_schedule([x.op for x in outs])
def _callback(op): def _callback(op):
...@@ -120,19 +135,16 @@ def _decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype, n ...@@ -120,19 +135,16 @@ def _decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype, n
cfg.define_annotate("ann_reduce", [kh, kw], policy='try_unroll') cfg.define_annotate("ann_reduce", [kh, kw], policy='try_unroll')
cfg.define_annotate("ann_spatial", [vh, vw, vc], policy='try_unroll_vec') cfg.define_annotate("ann_spatial", [vh, vw, vc], policy='try_unroll_vec')
# ====================================================================
# fallback support
if cfg.is_fallback: if cfg.is_fallback:
if num_tile == 2: if num_tile == 2: # arm cpu
cfg.fallback_split('tile_co', [-1, 8]) ref_log = autotvm.tophub.load_reference_log('arm_cpu', 'rk3399', 'conv2d', 'direct')
cfg.fallback_split('tile_oh', [-1, 2]) cfg.fallback_with_reference_log(ref_log)
cfg.fallback_split('tile_ow', [-1, 8]) elif num_tile == 3: # mali gpu
else: ref_log = autotvm.tophub.load_reference_log('mali', 'rk3399', 'conv2d', 'direct')
cfg.fallback_split('tile_co', [-1, 16, 4]) cfg.fallback_with_reference_log(ref_log)
cfg.fallback_split('tile_oh', [-1, 1, 1]) # ====================================================================
cfg.fallback_split('tile_ow', [-1, 1, 4])
cfg['ann_reduce'].anns = ['unroll', 'unroll']
cfg['ann_spatial'].anns = ['none', 'unroll', 'vec']
VC = cfg["tile_co"].size[-1] VC = cfg["tile_co"].size[-1]
VH = cfg["tile_oh"].size[-1] VH = cfg["tile_oh"].size[-1]
...@@ -478,8 +490,8 @@ def decl_winograd_ww(cfg, data, kernel, strides, padding, layout, out_dtype, til ...@@ -478,8 +490,8 @@ def decl_winograd_ww(cfg, data, kernel, strides, padding, layout, out_dtype, til
tile_size) tile_size)
@autotvm.task.register_topi_schedule(schedule_conv2d_winograd_without_weight_transform, @autotvm.register_topi_schedule(schedule_conv2d_winograd_without_weight_transform,
'arm_cpu', ['winograd']) 'arm_cpu', ['winograd'])
def schedule_conv2d_winograd_without_weight_transform_(cfg, outs): def schedule_conv2d_winograd_without_weight_transform_(cfg, outs):
"""TOPI schedule callback""" """TOPI schedule callback"""
s = tvm.create_schedule([x.op for x in outs]) s = tvm.create_schedule([x.op for x in outs])
...@@ -517,11 +529,8 @@ def _alter_conv2d_layout(attrs, inputs, tinfos): ...@@ -517,11 +529,8 @@ def _alter_conv2d_layout(attrs, inputs, tinfos):
layout, out_dtype) layout, out_dtype)
cfg = autotvm.DispatchContext.current.query(tvm.target.current_target(), workload) cfg = autotvm.DispatchContext.current.query(tvm.target.current_target(), workload)
if cfg.is_fallback: # if is fallback, clear query cache and return None if cfg.is_fallback: # if is fallback, clear query cache and return None
context = autotvm.DispatchContext.current autotvm.task.clear_fallback_cache(tvm.target.current_target(), workload)
while not isinstance(context, autotvm.FallbackContext):
context = context._old_ctx
context.clear_cache(tvm.target.current_target(), workload)
return None return None
if cfg.template_key == 'direct': # packing weight tensor if cfg.template_key == 'direct': # packing weight tensor
......
...@@ -9,11 +9,11 @@ from ..nn import depthwise_conv2d_nchw ...@@ -9,11 +9,11 @@ from ..nn import depthwise_conv2d_nchw
from ..util import traverse_inline from ..util import traverse_inline
# register original implementation of depthwise_conv2d_nchw since we don't need to change this part # register original implementation of depthwise_conv2d_nchw since we don't need to change this part
autotvm.task.register_topi_compute(depthwise_conv2d_nchw, ['arm_cpu', 'cpu'], 'direct', autotvm.register_topi_compute(depthwise_conv2d_nchw, ['arm_cpu', 'cpu'], 'direct',
depthwise_conv2d_nchw.fdefault) depthwise_conv2d_nchw.fdefault)
# register customized schedule for arm cpu. # register customized schedule for arm cpu.
@autotvm.task.register_topi_schedule(schedule_depthwise_conv2d_nchw, ['arm_cpu', 'cpu'], 'direct') @autotvm.register_topi_schedule(schedule_depthwise_conv2d_nchw, ['arm_cpu', 'cpu'], 'direct')
def schedule_depthwise_conv2d_nchw_arm(cfg, outs): def schedule_depthwise_conv2d_nchw_arm(cfg, outs):
"""Schedule depthwise conv2d """Schedule depthwise conv2d
...@@ -37,16 +37,19 @@ def schedule_depthwise_conv2d_nchw_arm(cfg, outs): ...@@ -37,16 +37,19 @@ def schedule_depthwise_conv2d_nchw_arm(cfg, outs):
A, B, C = data, kernel, output A, B, C = data, kernel, output
s[data_pad].compute_inline() s[data_pad].compute_inline()
# define tile ##### space definition begin #####
n, c, h, w = s[output].op.axis n, c, h, w = s[output].op.axis
cfg.define_split('tile_c', c, num_outputs=2) _, vc = cfg.define_split('tile_c', c, num_outputs=2)
cfg.define_split('tile_h', h, num_outputs=2) _, vh = cfg.define_split('tile_h', h, num_outputs=2)
cfg.define_split('tile_w', w, num_outputs=2) _, vw = cfg.define_split('tile_w', w, num_outputs=2)
cfg.define_annotate('ann', [vh, vw, vc], policy='try_unroll_vec')
# fallback support
if cfg.is_fallback: if cfg.is_fallback:
cfg.fallback_split('tile_c', [-1, 4]) ref_log = autotvm.tophub.load_reference_log(
cfg.fallback_split('tile_h', [-1, 2]) 'arm_cpu', 'rk3399', 'depthwise_conv2d_nchw', 'direct')
cfg.fallback_split('tile_w', [-1, 4]) cfg.fallback_with_reference_log(ref_log)
##### space definition end #####
# park data to vector form [n, c, h, w] -> [n, C, h, w, VC] # park data to vector form [n, c, h, w] -> [n, C, h, w, VC]
A0 = s.cache_read(data_pad, "global", C) A0 = s.cache_read(data_pad, "global", C)
...@@ -78,7 +81,6 @@ def schedule_depthwise_conv2d_nchw_arm(cfg, outs): ...@@ -78,7 +81,6 @@ def schedule_depthwise_conv2d_nchw_arm(cfg, outs):
s[A1].compute_at(s[C0], oh) s[A1].compute_at(s[C0], oh)
# try unroll and vectorization # try unroll and vectorization
cfg.define_annotate('ann', [ih, iw, vc], policy='try_unroll_vec')
cfg['ann'].apply(s, C0, [ih, iw, vc], cfg['ann'].apply(s, C0, [ih, iw, vc],
axis_lens=[cfg['tile_h'].size[-1], axis_lens=[cfg['tile_h'].size[-1],
cfg['tile_w'].size[-1], cfg['tile_w'].size[-1],
......
...@@ -2,10 +2,8 @@ ...@@ -2,10 +2,8 @@
"""CUDA specific declaration and schedules.""" """CUDA specific declaration and schedules."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from .conv2d import conv2d_cuda from . import conv2d, depthwise_conv2d, conv2d_transpose_nchw
from .conv2d_nchw import schedule_conv2d_nchw
from .conv2d_hwcn import schedule_conv2d_hwcn from .conv2d_hwcn import schedule_conv2d_hwcn
from .depthwise_conv2d import schedule_depthwise_conv2d_nchw, schedule_depthwise_conv2d_nhwc
from .depthwise_conv2d import schedule_depthwise_conv2d_backward_input_nhwc from .depthwise_conv2d import schedule_depthwise_conv2d_backward_input_nhwc
from .depthwise_conv2d import schedule_depthwise_conv2d_backward_weight_nhwc from .depthwise_conv2d import schedule_depthwise_conv2d_backward_weight_nhwc
from .reduction import schedule_reduce from .reduction import schedule_reduce
...@@ -13,7 +11,6 @@ from .softmax import schedule_softmax ...@@ -13,7 +11,6 @@ from .softmax import schedule_softmax
from .injective import schedule_injective, schedule_elemwise, schedule_broadcast from .injective import schedule_injective, schedule_elemwise, schedule_broadcast
from .dense import dense_cuda, schedule_dense from .dense import dense_cuda, schedule_dense
from .pooling import schedule_pool, schedule_global_pool from .pooling import schedule_pool, schedule_global_pool
from .conv2d_transpose_nchw import schedule_conv2d_transpose_nchw
from .extern import schedule_extern from .extern import schedule_extern
from .nn import schedule_lrn, schedule_l2_normalize from .nn import schedule_lrn, schedule_l2_normalize
from .vision import * from .vision import *
......
# pylint: disable=invalid-name, no-member, too-many-locals, too-many-statements, too-many-arguments, too-many-branches, line-too-long # pylint: disable=invalid-name
"""Compute definition for conv2d with cuda backend""" """Compute definition for conv2d with cuda backend"""
import tvm import tvm
from tvm import autotvm
from tvm.contrib import cudnn from tvm.contrib import cudnn
import topi
from ..nn.conv2d import conv2d
from ..util import get_const_int
@conv2d.register("cuda") from .. import nn, generic
def conv2d_cuda(data, kernel, stride, padding, layout='NCHW', out_dtype='float32'): from ..util import get_const_int, get_const_tuple, traverse_inline
from .conv2d_direct import schedule_direct_cuda
from .conv2d_winograd import winograd_cuda, schedule_winograd_cuda
@autotvm.register_topi_compute(nn.conv2d, ['cuda', 'gpu'], ['direct', 'winograd'])
def conv2d_cuda(cfg, data, kernel, strides, padding, layout='NCHW', out_dtype='float32'):
"""Conv2D operator for cuda backend. """Conv2D operator for cuda backend.
Parameters Parameters
---------- ----------
input : tvm.Tensor cfg: ConfigEntity
The config for this template
data : tvm.Tensor
4-D with shape [batch, in_channel, in_height, in_width] 4-D with shape [batch, in_channel, in_height, in_width]
filter : tvm.Tensor kernel : tvm.Tensor
4-D with shape [num_filter, in_channel, filter_height, filter_width] 4-D with shape [num_filter, in_channel, filter_height, filter_width]
stride : int or a list/tuple of two ints strides : int or a list/tuple of two ints
stride size, or [stride_height, stride_width] stride size, or [stride_height, stride_width]
padding : int or a list/tuple of two ints padding : int or a list/tuple of two ints
...@@ -27,45 +35,56 @@ def conv2d_cuda(data, kernel, stride, padding, layout='NCHW', out_dtype='float32 ...@@ -27,45 +35,56 @@ def conv2d_cuda(data, kernel, stride, padding, layout='NCHW', out_dtype='float32
layout : str layout : str
layout of data layout of data
out_dtype: str
The output type. This is used for mixed precision.
Returns Returns
------- -------
output : tvm.Tensor output : tvm.Tensor
4-D with shape [batch, out_channel, out_height, out_width] 4-D with shape [batch, out_channel, out_height, out_width]
""" """
assert isinstance(stride, int) or len(stride) == 2
if isinstance(stride, int):
stride_h = stride_w = stride
else:
stride_h, stride_w = stride
if isinstance(padding, int):
pad_h = pad_w = padding
else:
pad_h, pad_w = padding
# handle dilation
dilation_h = dilation_w = 1
kernel_tvm = kernel
kernel_cudnn = kernel
if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
kernel_before_dilation = kernel.op.input_tensors[0]
kernel_cudnn = kernel_before_dilation
if layout == 'NCHW':
dilation_h = (get_const_int(kernel.shape[2]) + get_const_int(kernel_before_dilation.shape[2]) - 1) \
// get_const_int(kernel_before_dilation.shape[2])
dilation_w = (get_const_int(kernel.shape[3]) + get_const_int(kernel_before_dilation.shape[3]) - 1) \
// get_const_int(kernel_before_dilation.shape[2])
elif layout == 'NHWC':
dilation_h = (get_const_int(kernel.shape[1]) + get_const_int(kernel_before_dilation.shape[1]) - 1) \
// get_const_int(kernel_before_dilation.shape[1])
dilation_w = (get_const_int(kernel.shape[2]) + get_const_int(kernel_before_dilation.shape[2]) - 1) \
// get_const_int(kernel_before_dilation.shape[2])
target = tvm.target.current_target() target = tvm.target.current_target()
if "cudnn" in target.libs: if "cudnn" in target.libs:
assert layout != 'HWCN', "HWCN layout not supported with CUDNN." if layout == 'NCHW':
tensor_format = 0 # CUDNN_TENSOR_NCHW tensor_format = 0 # CUDNN_TENSOR_NCHW
if layout == 'NHWC': N, _, H, W = get_const_tuple(data.shape)
elif layout == 'NHWC':
tensor_format = 1 # CUDNN_TENSOR_NHWC tensor_format = 1 # CUDNN_TENSOR_NHWC
N, H, W, _ = get_const_tuple(data.shape)
else:
raise ValueError("Unsupported layout %s in cudnn" % layout)
CO, CI, KH, KW = get_const_tuple(kernel.shape)
# handle dilation
stride_h, stride_w = (strides, strides) if isinstance(strides, int) else strides
pad_h, pad_w = (padding, padding) if isinstance(padding, int) else padding
OH = (H + 2 * pad_h - KH) // stride_h + 1
OW = (W + 2 * pad_w - KW) // stride_w + 1
cfg.add_flop(2 * N * OH * OW * CO * CI * KH * KW)
dilation_h = dilation_w = 1
kernel_before_dilation = kernel
if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
kernel_before_dilation = kernel.op.input_tensors[0]
if layout == 'NCHW':
dilation_h = (get_const_int(kernel.shape[2]) +
get_const_int(kernel_before_dilation.shape[2]) - 1) \
// get_const_int(kernel_before_dilation.shape[2])
dilation_w = (get_const_int(kernel.shape[3]) +
get_const_int(kernel_before_dilation.shape[3]) - 1) \
// get_const_int(kernel_before_dilation.shape[2])
elif layout == 'NHWC':
dilation_h = (get_const_int(kernel.shape[1]) +
get_const_int(kernel_before_dilation.shape[1]) - 1) \
// get_const_int(kernel_before_dilation.shape[1])
dilation_w = (get_const_int(kernel.shape[2]) +
get_const_int(kernel_before_dilation.shape[2]) - 1) \
// get_const_int(kernel_before_dilation.shape[2])
return cudnn.conv2d_forward(data, return cudnn.conv2d_forward(data,
kernel_cudnn, kernel_before_dilation,
stride_h, stride_h,
stride_w, stride_w,
pad_h, pad_h,
...@@ -74,10 +93,51 @@ def conv2d_cuda(data, kernel, stride, padding, layout='NCHW', out_dtype='float32 ...@@ -74,10 +93,51 @@ def conv2d_cuda(data, kernel, stride, padding, layout='NCHW', out_dtype='float32
dilation_w, dilation_w,
conv_mode=1, conv_mode=1,
tensor_format=tensor_format, tensor_format=tensor_format,
algo=-1) # let CUDNN choose the best algo algo=-1) # let CUDNN choose the best algo
elif layout == 'NCHW':
return topi.nn.conv2d_nchw(data, kernel_tvm, stride, padding, out_dtype) if cfg.template_key == 'winograd':
return winograd_cuda(cfg, data, kernel, strides, padding, layout, out_dtype,
pre_computed=False)
if layout == 'NCHW':
return nn.conv2d_nchw(data, kernel, strides, padding, out_dtype)
elif layout == 'HWCN': elif layout == 'HWCN':
return topi.nn.conv2d_hwcn(data, kernel_tvm, stride, padding, out_dtype) return nn.conv2d_hwcn(data, kernel, strides, padding, out_dtype)
else: else:
raise ValueError("not support this layout {} yet".format(layout)) raise ValueError("not support this layout {} yet".format(layout))
@autotvm.register_topi_schedule(generic.schedule_conv2d_nchw, ["cuda", "gpu"],
["direct", 'winograd'])
def schedule_conv2d_nchw_cuda(cfg, outs):
"""TOPI schedule callback of conv2d for cuda gpu
Parameters
----------
cfg: ConfigEntity
The config for this template
outs: Array of Tensor
The computation graph description of conv2d
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for conv2d.
"""
target = tvm.target.current_target()
if 'cudnn' in target.libs:
return generic.schedule_extern(outs)
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
def _callback(op):
if op.tag == 'conv2d_nchw':
schedule_direct_cuda(cfg, s, op.output(0))
if op.tag == 'conv2d_nchw_winograd':
schedule_winograd_cuda(cfg, s, op.output(0), pre_computed=False)
traverse_inline(s, outs[0].op, _callback)
return s
# pylint: disable=invalid-name
"""The templates for cuda conv2d operators"""
import tvm
from tvm import autotvm
def schedule_direct_cuda(cfg, s, conv):
"""schedule optimized for batch size = 1"""
##### space definition begin #####
n, f, y, x = s[conv].op.axis
rc, ry, rx = s[conv].op.reduce_axis
cfg.define_split("tile_f", f, num_outputs=4)
cfg.define_split("tile_y", y, num_outputs=4)
cfg.define_split("tile_x", x, num_outputs=4)
cfg.define_split("tile_rc", rc, num_outputs=2)
cfg.define_split("tile_ry", ry, num_outputs=2)
cfg.define_split("tile_rx", rx, num_outputs=2)
cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
target = tvm.target.current_target()
if target.target_name in ['nvptx', 'rocm']:
cfg.define_knob("unroll_explicit", [1])
else:
cfg.define_knob("unroll_explicit", [0, 1])
# fallback support
if cfg.is_fallback:
ref_log = autotvm.tophub.load_reference_log(
target.target_name, target.model, 'conv2d', 'direct')
cfg.fallback_with_reference_log(ref_log)
##### space definition end #####
pad_data, kernel = s[conv].op.input_tensors
s[pad_data].compute_inline()
if isinstance(kernel.op, tvm.tensor.ComputeOp) and 'dilate' in kernel.op.tag:
s[kernel].compute_inline()
if conv.op in s.outputs:
output = conv
OL = s.cache_write(conv, 'local')
else:
output = s.outputs[0].output(0)
s[conv].set_scope('local')
OL = conv
# create cache stage
AA = s.cache_read(pad_data, 'shared', [OL])
WW = s.cache_read(kernel, 'shared', [OL])
# tile and bind spatial axes
n, f, y, x = s[output].op.axis
kernel_scope, n = s[output].split(n, nparts=1)
bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f)
by, vy, ty, yi = cfg["tile_y"].apply(s, output, y)
bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)
bf = s[output].fuse(n, bf)
s[output].bind(bf, tvm.thread_axis("blockIdx.z"))
s[output].bind(by, tvm.thread_axis("blockIdx.y"))
s[output].bind(bx, tvm.thread_axis("blockIdx.x"))
s[output].bind(vf, tvm.thread_axis("vthread"))
s[output].bind(vy, tvm.thread_axis("vthread"))
s[output].bind(vx, tvm.thread_axis("vthread"))
s[output].bind(tf, tvm.thread_axis("threadIdx.z"))
s[output].bind(ty, tvm.thread_axis("threadIdx.y"))
s[output].bind(tx, tvm.thread_axis("threadIdx.x"))
s[output].reorder(bf, by, bx, vf, vy, vx, tf, ty, tx, fi, yi, xi)
s[OL].compute_at(s[output], tx)
# tile reduction axes
n, f, y, x = s[OL].op.axis
rc, ry, rx = s[OL].op.reduce_axis
rco, rci = cfg['tile_rc'].apply(s, OL, rc)
ryo, ryi = cfg['tile_rx'].apply(s, OL, ry)
rxo, rxi = cfg['tile_ry'].apply(s, OL, rx)
s[OL].reorder(rco, ryo, rxo, rci, ryi, rxi, n, f, y, x)
s[AA].compute_at(s[OL], rxo)
s[WW].compute_at(s[OL], rxo)
# cooperative fetching
for load in [AA, WW]:
n, f, y, x = s[load].op.axis
fused = s[load].fuse(n, f, y, x)
tz, fused = s[load].split(fused, nparts=cfg["tile_f"].size[2])
ty, fused = s[load].split(fused, nparts=cfg["tile_y"].size[2])
tx, fused = s[load].split(fused, nparts=cfg["tile_x"].size[2])
s[load].bind(tz, tvm.thread_axis("threadIdx.z"))
s[load].bind(ty, tvm.thread_axis("threadIdx.y"))
s[load].bind(tx, tvm.thread_axis("threadIdx.x"))
# unroll
s[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val)
s[output].pragma(kernel_scope, 'unroll_explicit', cfg['unroll_explicit'].val)
# pylint: disable=invalid-name # pylint: disable=invalid-name
"""Schedule for depthwise_conv2d with auto fusion""" """Schedule for depthwise_conv2d with auto fusion"""
import tvm import tvm
from ..util import get_const_tuple from tvm import autotvm
from ..util import traverse_inline
from .. import tag from .. import tag
from .. import generic from .. import generic, nn
@generic.schedule_depthwise_conv2d_nchw.register(["cuda", "gpu"]) # register original implementation of depthwise_conv2d_nchw since we don't need to change this part
def schedule_depthwise_conv2d_nchw(outs): autotvm.register_topi_compute(nn.depthwise_conv2d_nchw, ['cuda', 'gpu'], 'direct',
nn.depthwise_conv2d_nchw.fdefault)
@autotvm.register_topi_schedule(generic.schedule_depthwise_conv2d_nchw, ['cuda', 'gpu'], 'direct')
def schedule_depthwise_conv2d_nchw_cuda(cfg, outs):
"""Schedule for depthwise_conv2d nchw forward. """Schedule for depthwise_conv2d nchw forward.
Parameters Parameters
...@@ -22,108 +27,92 @@ def schedule_depthwise_conv2d_nchw(outs): ...@@ -22,108 +27,92 @@ def schedule_depthwise_conv2d_nchw(outs):
""" """
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs]) s = tvm.create_schedule([x.op for x in outs])
def _schedule(PaddedInput, Filter, DepthwiseConv2d):
in_shape = get_const_tuple(PaddedInput.shape)
out_shape = get_const_tuple(DepthwiseConv2d.shape)
in_height = in_shape[2]
in_width = in_shape[3]
out_height = out_shape[2]
out_width = out_shape[3]
channel_multiplier = get_const_tuple(Filter.shape)[1]
s[PaddedInput].compute_inline()
IS = s.cache_read(PaddedInput, "shared", [DepthwiseConv2d])
FS = s.cache_read(Filter, "shared", [DepthwiseConv2d])
IL = s.cache_read(IS, "local", [DepthwiseConv2d])
FL = s.cache_read(FS, "local", [DepthwiseConv2d])
if DepthwiseConv2d.op in s.outputs:
Output = DepthwiseConv2d
CL = s.cache_write(DepthwiseConv2d, "local")
else:
Output = outs[0].op.output(0)
s[DepthwiseConv2d].set_scope("local")
# schedule parameters
num_thread_y = 8
num_thread_x = 8
num_vthread_y = 1
num_vthread_x = 1
blocking_h = out_height
blocking_w = out_width
if out_height % 32 == 0 or in_height >= 108:
blocking_h = 32
if out_width % 32 == 0:
blocking_w = 32
num_thread_x = 16
num_vthread_x = 2
elif in_width >= 108:
blocking_w = 32
block_y = tvm.thread_axis("blockIdx.y")
block_x = tvm.thread_axis("blockIdx.x")
thread_y = tvm.thread_axis((0, num_thread_y), "threadIdx.y")
thread_x = tvm.thread_axis((0, num_thread_x), "threadIdx.x")
thread_vy = tvm.thread_axis((0, num_vthread_y), "vthread", name="vy")
thread_vx = tvm.thread_axis((0, num_vthread_x), "vthread", name="vx")
# split and bind
by, byi = s[Output].split(Output.op.axis[1], factor=channel_multiplier)
s[Output].reorder(Output.op.axis[2], Output.op.axis[3], byi)
by = s[Output].fuse(Output.op.axis[0], by)
s[Output].bind(by, block_y)
bx1, x1i = s[Output].split(Output.op.axis[2], factor=blocking_h)
tvy, vyi = s[Output].split(x1i, nparts=num_vthread_y)
ty, yi = s[Output].split(vyi, nparts=num_thread_y)
bx2, x2i = s[Output].split(Output.op.axis[3], factor=blocking_w)
tvx, vxi = s[Output].split(x2i, nparts=num_vthread_x)
tx, xi = s[Output].split(vxi, nparts=num_thread_x)
s[Output].reorder(bx1, bx2, tvy, tvx, ty, tx, yi, xi)
bx = s[Output].fuse(bx1, bx2)
s[Output].bind(bx, block_x)
s[Output].bind(tvy, thread_vy)
s[Output].bind(tvx, thread_vx)
s[Output].bind(ty, thread_y)
s[Output].bind(tx, thread_x)
# local memory load
s[IL].compute_at(s[Output], tx)
s[FL].compute_at(s[Output], tx)
if DepthwiseConv2d.op in s.outputs:
s[CL].compute_at(s[Output], tx)
else:
s[DepthwiseConv2d].compute_at(s[Output], tx)
# input's shared memory load
s[IS].compute_at(s[Output], bx)
ty, yi = s[IS].split(IS.op.axis[2], nparts=num_thread_y)
tx, xi = s[IS].split(IS.op.axis[3], nparts=num_thread_x)
s[IS].bind(ty, thread_y)
s[IS].bind(tx, thread_x)
# filter's shared memory load
s[FS].compute_at(s[Output], bx)
s[FS].reorder(FS.op.axis[2], FS.op.axis[3], FS.op.axis[1])
ty, yi = s[FS].split(FS.op.axis[2], nparts=num_thread_y)
tx, xi = s[FS].split(FS.op.axis[3], nparts=num_thread_x)
s[FS].bind(ty, thread_y)
s[FS].bind(tx, thread_x)
scheduled_ops = [] def _callback(op):
if op.tag == 'depthwise_conv2d_nchw':
def traverse(OP): pad_data = op.input_tensors[0]
"""Internal travserse function""" kernel = op.input_tensors[1]
# inline all one-to-one-mapping operators except the last stage (output) conv = op.output(0)
if tag.is_broadcast(OP.tag):
if OP not in s.outputs: ##### space definition begin #####
s[OP].compute_inline() n, f, y, x = s[conv].op.axis
for tensor in OP.input_tensors: cfg.define_split("tile_f", f, num_outputs=4)
if tensor.op.input_tensors and tensor.op not in scheduled_ops: cfg.define_split("tile_y", y, num_outputs=4)
traverse(tensor.op) cfg.define_split("tile_x", x, num_outputs=4)
# schedule depthwise_conv2d cfg.define_knob("auto_unroll_max_step", [0, 256, 1500])
if OP.tag == 'depthwise_conv2d_nchw':
PaddedInput = OP.input_tensors[0] target = tvm.target.current_target()
Filter = OP.input_tensors[1] if target.target_name in ['nvptx', 'rocm']:
if isinstance(Filter.op, tvm.tensor.ComputeOp) and 'dilate' in Filter.op.tag: cfg.define_knob("unroll_explicit", [1])
s[Filter].compute_inline() else:
DepthwiseConv2d = OP.output(0) cfg.define_knob("unroll_explicit", [0, 1])
_schedule(PaddedInput, Filter, DepthwiseConv2d)
# fallback support
scheduled_ops.append(OP) if cfg.is_fallback:
ref_log = autotvm.tophub.load_reference_log(
traverse(outs[0].op) target.target_name, target.model, 'depthwise_conv2d_nchw', 'direct')
cfg.fallback_with_reference_log(ref_log)
# TODO(lmzheng): A bug here, set unroll_explicit to False as workaround
cfg['unroll_explicit'].val = 0
##### space definition end #####
s[pad_data].compute_inline()
if isinstance(kernel.op, tvm.tensor.ComputeOp) and 'dilate' in kernel.op.tag:
s[kernel].compute_inline()
if conv.op in s.outputs:
output = conv
OL = s.cache_write(conv, 'local')
else:
output = s.outputs[0].output(0)
s[conv].set_scope('local')
OL = conv
# create cache stage
AA = s.cache_read(pad_data, 'shared', [OL])
WW = s.cache_read(kernel, 'shared', [OL])
AL = s.cache_read(AA, 'local', [OL])
WL = s.cache_read(WW, 'local', [OL])
# tile and bind spatial axes
n, f, y, x = s[output].op.axis
bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f)
by, vy, ty, yi = cfg["tile_y"].apply(s, output, y)
bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)
kernel_scope, n = s[output].split(n, nparts=1)
bf = s[output].fuse(n, bf)
s[output].bind(bf, tvm.thread_axis("blockIdx.z"))
s[output].bind(by, tvm.thread_axis("blockIdx.y"))
s[output].bind(bx, tvm.thread_axis("blockIdx.x"))
s[output].bind(vf, tvm.thread_axis("vthread"))
s[output].bind(vy, tvm.thread_axis("vthread"))
s[output].bind(vx, tvm.thread_axis("vthread"))
s[output].bind(tf, tvm.thread_axis("threadIdx.z"))
s[output].bind(ty, tvm.thread_axis("threadIdx.y"))
s[output].bind(tx, tvm.thread_axis("threadIdx.x"))
s[output].reorder(bf, by, bx, vf, vy, vx, tf, ty, tx, fi, yi, xi)
s[OL].compute_at(s[output], tx)
# cooperative fetching
s[AA].compute_at(s[output], bx)
s[WW].compute_at(s[output], bx)
s[AL].compute_at(s[output], tx)
s[WL].compute_at(s[output], tx)
for load in [AA, WW]:
fused = s[load].fuse(*list(s[load].op.axis))
fused, tx = s[load].split(fused, cfg["tile_x"].size[2])
fused, ty = s[load].split(fused, cfg["tile_y"].size[2])
fused, tz = s[load].split(fused, cfg["tile_f"].size[2])
s[load].bind(tz, tvm.thread_axis("threadIdx.z"))
s[load].bind(ty, tvm.thread_axis("threadIdx.y"))
s[load].bind(tx, tvm.thread_axis("threadIdx.x"))
s[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val)
s[output].pragma(kernel_scope, 'unroll_explicit', cfg['unroll_explicit'].val)
traverse_inline(s, outs[0].op, _callback)
return s return s
@generic.schedule_depthwise_conv2d_nhwc.register(["cuda", "gpu"]) @generic.schedule_depthwise_conv2d_nhwc.register(["cuda", "gpu"])
...@@ -143,8 +132,8 @@ def schedule_depthwise_conv2d_nhwc(outs): ...@@ -143,8 +132,8 @@ def schedule_depthwise_conv2d_nhwc(outs):
""" """
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs]) s = tvm.create_schedule([x.op for x in outs])
def _schedule(temp, Filter, DepthwiseConv2d):
def _schedule(temp, Filter, DepthwiseConv2d):
s[temp].compute_inline() s[temp].compute_inline()
FS = s.cache_read(Filter, "shared", [DepthwiseConv2d]) FS = s.cache_read(Filter, "shared", [DepthwiseConv2d])
if DepthwiseConv2d.op in s.outputs: if DepthwiseConv2d.op in s.outputs:
......
...@@ -4,17 +4,21 @@ ...@@ -4,17 +4,21 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import tvm import tvm
from tvm import autotvm
from .. import generic from .. import generic, nn
from .. import util from ..util import traverse_inline
from .. import tag
@generic.schedule_dense.register(["mali"]) autotvm.register_topi_compute(nn.dense, 'mali', 'direct', nn.dense.fdefault)
def schedule_dense(outs):
@autotvm.register_topi_schedule(generic.schedule_dense, 'mali', 'direct')
def schedule_dense(cfg, outs):
"""Schedule for dense operator. """Schedule for dense operator.
Parameters Parameters
---------- ----------
cfg: ConfigEntity
The config entity for this template
outs: Array of Tensor outs: Array of Tensor
The computation graph description of dense The computation graph description of dense
in the format of an array of tensors. in the format of an array of tensors.
...@@ -26,80 +30,65 @@ def schedule_dense(outs): ...@@ -26,80 +30,65 @@ def schedule_dense(outs):
""" """
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs]) s = tvm.create_schedule([x.op for x in outs])
def _schedule(dense):
data = s[dense].op.input_tensors[0] def _callback(op):
weight = s[dense].op.input_tensors[1] if op.tag == 'dense':
vec_size = [1, 2, 4, 8, 16]
hidden = util.get_const_int(weight.shape[1]) max_unroll = 32
out = util.get_const_int(weight.shape[0])
dense = op.output(0)
# set tunable parameter output = outs[0]
tune_config = getattr(tvm.target.current_target(), "tune_config", None)
if tune_config is None: y, x = s[output].op.axis
if hidden > 8192: c = s[dense].op.reduce_axis[0]
num_thread = 32
unroll_step = 32 ##### space definition begin #####
else: cfg.define_split('tile_y', y, num_outputs=3)
if out <= 1024: cfg.define_split('tile_x', x, num_outputs=3)
num_thread = 32 cfg.define_split('c_unroll', c, num_outputs=2, max_factor=64)
unroll_step = 16
else: # fallback support
num_thread = 256 if cfg.is_fallback:
unroll_step = 32 ref_log = autotvm.tophub.load_reference_log(
'mali', 'rk3399', 'dense', 'direct')
if data.dtype == 'float16': cfg.fallback_with_reference_log(ref_log)
if hidden > 8192: ##### space definition end #####
num_thread = 2
unroll_step = 32 if dense.op in s.outputs:
else: dense = s.cache_write(output, 'local')
num_thread = 8
unroll_step = 256 by, ty, yi = cfg['tile_y'].apply(s, output, y)
else: bx, tx, xi = cfg['tile_x'].apply(s, output, x)
num_thread = tune_config['num_thread']
unroll_step = tune_config['unroll_step'] s[output].bind(by, tvm.thread_axis('blockIdx.y'))
s[output].bind(bx, tvm.thread_axis('blockIdx.x'))
def fuse_and_bind(s, tensor, axis=None, num_thread=None): s[output].bind(ty, tvm.thread_axis('threadIdx.y'))
""" fuse all the axis and bind to GPU threads """ s[output].bind(tx, tvm.thread_axis('threadIdx.x'))
axis = axis or s[tensor].op.axis
fused = s[tensor].fuse(*axis) if cfg['tile_y'].size[-1] < max_unroll:
max_threads = tvm.target.current_target(allow_none=False).max_num_threads s[output].unroll(yi)
bx, tx = s[tensor].split(fused, num_thread or max_threads) if cfg['tile_x'].size[-1] in vec_size:
s[tensor].bind(bx, tvm.thread_axis("blockIdx.x")) s[output].vectorize(xi)
s[tensor].bind(tx, tvm.thread_axis("threadIdx.x"))
return bx, tx
output = outs[0]
bx, tx = fuse_and_bind(s, output, num_thread=num_thread)
k = s[dense].op.reduce_axis[0]
k, k_unroll = s[dense].split(k, unroll_step)
s[dense].unroll(k_unroll)
if dense.op not in s.outputs:
s[dense].compute_at(s[output], tx) s[dense].compute_at(s[output], tx)
# bias = s[outs[0]].op.input_tensors[1] k = s[dense].op.reduce_axis[0]
# print(tvm.lower(s, [data, weight, bias, outs[0]], simple_mode=True)) y, x = s[dense].op.axis
k, k_unroll = cfg['c_unroll'].apply(s, dense, k)
scheduled_ops = [] s[dense].reorder(k, k_unroll, y, x)
s[dense].unroll(k_unroll)
def traverse(OP): if cfg['tile_y'].size[-1] < max_unroll:
"""Internal travserse function""" s[dense].unroll(y)
# inline all one-to-one-mapping operators except the last stage (output) if cfg['tile_x'].size[-1] in vec_size:
if tag.is_broadcast(OP.tag): s[dense].vectorize(x)
if OP not in s.outputs:
s[OP].compute_inline() traverse_inline(s, outs[0].op, _callback)
for tensor in OP.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
# schedule dense
elif OP.tag == 'dense':
dense = OP.output(0)
_schedule(dense)
else:
raise RuntimeError("Unsupported operator: %s" % OP.tag)
scheduled_ops.append(OP)
traverse(outs[0].op)
return s return s
def fuse_and_bind(s, tensor, axis=None, num_thread=None):
""" fuse all the axis and bind to GPU threads """
axis = axis or s[tensor].op.axis
fused = s[tensor].fuse(*axis)
bx, tx = s[tensor].split(fused, num_thread)
s[tensor].bind(bx, tvm.thread_axis("blockIdx.x"))
s[tensor].bind(tx, tvm.thread_axis("threadIdx.x"))
return bx, tx
# pylint: disable=invalid-name,unused-variable,unused-argument # pylint: disable=invalid-name,unused-variable,unused-argument
"""depthwise_conv2d schedule on ARM Mali GPU""" """depthwise_conv2d schedule on ARM Mali GPU"""
from __future__ import absolute_import as _abs
import tvm import tvm
from tvm import autotvm
from .. import generic from ..generic import schedule_depthwise_conv2d_nchw
from .. import util from ..nn import depthwise_conv2d_nchw
from .. import tag from ..util import traverse_inline
@generic.schedule_depthwise_conv2d_nchw.register(["mali"]) # register original implementation of depthwise_conv2d_nchw since we don't need to change this part
def schedule_depthwise_conv2d_nchw(outs): autotvm.register_topi_compute(depthwise_conv2d_nchw, 'mali', 'direct',
"""Schedule for depthwise_conv2d nchw forward. depthwise_conv2d_nchw.fdefault)
# register customized schedule for arm cpu.
@autotvm.register_topi_schedule(schedule_depthwise_conv2d_nchw, 'mali', 'direct')
def schedule_depthwise_conv2d_nchw_mali(cfg, outs):
"""Schedule depthwise conv2d
Parameters Parameters
---------- ----------
cfg: ConfigEntity
The configuration of this template
outs: Array of Tensor outs: Array of Tensor
The computation graph description of depthwise_conv2d The computation graph description of depthwise convolution2d
in the format of an array of tensors. in the format of an array of tensors.
Returns Returns
...@@ -25,89 +32,95 @@ def schedule_depthwise_conv2d_nchw(outs): ...@@ -25,89 +32,95 @@ def schedule_depthwise_conv2d_nchw(outs):
""" """
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs]) s = tvm.create_schedule([x.op for x in outs])
def _schedule(pad_data, kernel, conv):
raw_data = s[pad_data].op.input_tensors[0]
if conv.op not in s.outputs: # has bias or relu def _schedule(pad_data, kernel, conv):
output = outs[0] """schedule depthwise_conv2d"""
else: # no bias or relu max_unroll = 16
output = conv vec_size = [1, 2, 4, 8, 16]
def tile_and_bind3d(tensor, z, y, x, z_factor=2, y_factor=None, x_factor=None): ##### space definition begin #####
""" tile and bind 3d """ n, c, y, x = s[conv].op.axis
y_factor = y_factor or z_factor bc, tc, ci = cfg.define_split("tile_c", c, num_outputs=3)
x_factor = x_factor or y_factor by, ty, yi = cfg.define_split('tile_y', y, num_outputs=3)
zo, zi = s[tensor].split(z, z_factor) bx, tx, xi = cfg.define_split("tile_x", x, num_outputs=3)
yo, yi = s[tensor].split(y, y_factor) cfg.define_annotate('ann_spatial', [ci, yi, xi], policy='try_unroll_vec')
xo, xi = s[tensor].split(x, x_factor)
s[tensor].bind(zo, tvm.thread_axis("blockIdx.z"))
s[tensor].bind(zi, tvm.thread_axis("threadIdx.z"))
s[tensor].bind(yo, tvm.thread_axis("blockIdx.y"))
s[tensor].bind(yi, tvm.thread_axis("threadIdx.y"))
s[tensor].bind(xo, tvm.thread_axis("blockIdx.x"))
s[tensor].bind(xi, tvm.thread_axis("threadIdx.x"))
return zo, zi, yo, yi, xo, xi
# set tunable parameters
VH = 1
VW = 1
num_thread = 4
while util.get_const_int(conv.shape[3]) % (VW * 2) == 0 and VW * 2 <= 4:
VW = VW * 2
while util.get_const_int(conv.shape[2]) % (VH * 2) == 0 and VH * 2 <= 2:
VH = VH * 2
if raw_data.dtype == 'float16':
if util.get_const_int(conv.shape[3]) % (VW * 2) == 0:
VW *= 2
num_thread *= 2
else:
num_thread *= 2
# schedule padding # fallback support
_, c, y, x = s[pad_data].op.axis if cfg.is_fallback:
tile_and_bind3d(pad_data, c, y, x, num_thread, 1, 1) ref_log = autotvm.tophub.load_reference_log(
'mali', 'rk3399', 'depthwise_conv2d_nchw', 'direct')
cfg.fallback_with_reference_log(ref_log)
###### space definition end ######
# schedule conv
di, dj = s[conv].op.reduce_axis
s[conv].unroll(di)
s[conv].unroll(dj)
_, c, y, x = s[output].op.axis # schedule padding
y, x, yi, xi = s[output].tile(y, x, VH, VW) n, c, y, x = s[pad_data].op.axis
s[output].unroll(yi) tile_and_bind3d(s, pad_data, c, y, x, cfg["tile_c"].size[1], 1, 1)
s[output].vectorize(xi)
_, _, _, _, _, ji = tile_and_bind3d(output, c, y, x, num_thread, 1, 1) # schedule dilation
if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
s[kernel].compute_inline()
# schedule conv
if conv.op not in s.outputs: if conv.op not in s.outputs:
_, c, y, x = s[conv].op.axis s[conv].set_scope('local')
y, x, yi, xi = s[conv].tile(y, x, VH, VW) OL = conv
s[conv].unroll(yi) output = s.outputs[0].output(0)
s[conv].vectorize(xi) else:
s[conv].compute_at(s[output], ji) OL = s.cache_write(conv, 'local')
output = conv
scheduled_ops = []
def traverse(op):
"""Internal travserse function"""
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(op.tag):
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
n, c, y, x = s[output].op.axis
bc, tc, ci = cfg['tile_c'].apply(s, output, c)
by, ty, yi = cfg['tile_y'].apply(s, output, y)
bx, tx, xi = cfg['tile_x'].apply(s, output, x)
bc = s[output].fuse(n, bc)
s[output].bind(bc, tvm.thread_axis("blockIdx.z"))
s[output].bind(tc, tvm.thread_axis("threadIdx.z"))
s[output].bind(by, tvm.thread_axis("blockIdx.y"))
s[output].bind(ty, tvm.thread_axis("threadIdx.y"))
s[output].bind(bx, tvm.thread_axis("blockIdx.x"))
s[output].bind(tx, tvm.thread_axis("threadIdx.x"))
di, dj = s[OL].op.reduce_axis
s[OL].unroll(di)
s[OL].unroll(dj)
s[OL].compute_at(s[output], tx)
n, ci, yi, xi = s[OL].op.axis
cfg["ann_spatial"].apply(s, OL, [ci, yi, xi],
axis_lens=[cfg['tile_c'].size[2], cfg['tile_y'].size[2],
cfg['tile_x'].size[2]],
max_unroll=max_unroll,
vec_size=vec_size,
cfg=cfg)
def _callback(op):
"""traverse to find op to schedule"""
# schedule depthwise_conv2d # schedule depthwise_conv2d
if op.tag == 'depthwise_conv2d_nchw': if op.tag == 'depthwise_conv2d_nchw':
pad_data = op.input_tensors[0] pad_data = op.input_tensors[0]
kernel = op.input_tensors[1] kernel = op.input_tensors[1]
if isinstance(kernel.op, tvm.tensor.ComputeOp) and 'dilate' in kernel.op.tag:
s[kernel].compute_inline()
conv = op.output(0) conv = op.output(0)
_schedule(pad_data, kernel, conv) _schedule(pad_data, kernel, conv)
scheduled_ops.append(op) traverse_inline(s, outs[0].op, _callback)
traverse(outs[0].op)
return s return s
def tile_and_bind3d(s, tensor, z, y, x, z_factor=2, y_factor=None, x_factor=None):
""" tile and bind 3d """
y_factor = y_factor or z_factor
x_factor = x_factor or y_factor
zo, zi = s[tensor].split(z, z_factor)
yo, yi = s[tensor].split(y, y_factor)
xo, xi = s[tensor].split(x, x_factor)
s[tensor].bind(zo, tvm.thread_axis("blockIdx.z"))
s[tensor].bind(zi, tvm.thread_axis("threadIdx.z"))
s[tensor].bind(yo, tvm.thread_axis("blockIdx.y"))
s[tensor].bind(yi, tvm.thread_axis("threadIdx.y"))
s[tensor].bind(xo, tvm.thread_axis("blockIdx.x"))
s[tensor].bind(xi, tvm.thread_axis("threadIdx.x"))
return zo, zi, yo, yi, xo, xi
# pylint: disable=invalid-name, no-member, too-many-locals, too-many-statements, too-many-arguments, too-many-branches, line-too-long # pylint: disable=invalid-name
"""Compute and schedule for rocm conv2d_nchw with auto fusion""" """Compute definition for conv2d with rocm backend"""
import tvm import tvm
from tvm import autotvm
from tvm.contrib import miopen from tvm.contrib import miopen
import topi
from .. import generic
from ..nn.conv2d import conv2d
from ..util import get_const_int
from .. import nn, generic
from ..util import get_const_int, get_const_tuple
from ..cuda.conv2d import conv2d_cuda, schedule_conv2d_nchw_cuda
@conv2d.register("rocm") @autotvm.register_topi_compute(nn.conv2d, 'rocm', ['direct', 'winograd'])
def conv2d_rocm(data, kernel, stride, padding, layout='NCHW', out_dtype='float32'): def conv2d_rocm(cfg, data, kernel, strides, padding, layout='NCHW', out_dtype='float32'):
"""Conv2D operator for rocm backend. """Conv2D operator for rocm backend.
Parameters Parameters
---------- ----------
cfg: ConfigEntity
The config for this template
input : tvm.Tensor input : tvm.Tensor
4-D with shape [batch, in_channel, in_height, in_width] 4-D with shape [batch, in_channel, in_height, in_width]
filter : tvm.Tensor filter : tvm.Tensor
4-D with shape [num_filter, in_channel, filter_height, filter_width] 4-D with shape [num_filter, in_channel, filter_height, filter_width]
stride : int or a list/tuple of two ints strides : int or a list/tuple of two ints
stride size, or [stride_height, stride_width] stride size, or [stride_height, stride_width]
padding : int or a list/tuple of two ints padding : int or a list/tuple of two ints
...@@ -34,31 +37,42 @@ def conv2d_rocm(data, kernel, stride, padding, layout='NCHW', out_dtype='float32 ...@@ -34,31 +37,42 @@ def conv2d_rocm(data, kernel, stride, padding, layout='NCHW', out_dtype='float32
output : tvm.Tensor output : tvm.Tensor
4-D with shape [batch, out_channel, out_height, out_width] 4-D with shape [batch, out_channel, out_height, out_width]
""" """
assert layout == 'NCHW', "Only NCHW layout is supported."
assert isinstance(stride, int) or len(stride) == 2
if isinstance(stride, int):
stride_h = stride_w = stride
else:
stride_h, stride_w = stride
if isinstance(padding, int):
pad_h = pad_w = padding
else:
pad_h, pad_w = padding
# handle dilation
dilation_h = dilation_w = 1
kernel_tvm = kernel
kernel_cudnn = kernel
if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
kernel_before_dilation = kernel.op.input_tensors[0]
kernel_cudnn = kernel_before_dilation
dilation_h = (get_const_int(kernel.shape[2]) + get_const_int(kernel_before_dilation.shape[2]) - 1) \
// get_const_int(kernel_before_dilation.shape[2])
dilation_w = (get_const_int(kernel.shape[3]) + get_const_int(kernel_before_dilation.shape[3]) - 1) \
// get_const_int(kernel_before_dilation.shape[2])
target = tvm.target.current_target() target = tvm.target.current_target()
if "miopen" in target.libs: if "miopen" in target.libs:
assert layout == 'NCHW', "Only NCHW layout is supported."
CO, CI, KH, KW = get_const_tuple(kernel.shape)
N, _, H, W = get_const_tuple(data.shape)
# handle dilation
stride_h, stride_w = (strides, strides) if isinstance(strides, int) else strides
pad_h, pad_w = (padding, padding) if isinstance(padding, int) else padding
OH = (H + 2 * pad_h - KH) // stride_h + 1
OW = (W + 2 * pad_w - KW) // stride_w + 1
cfg.add_flop(2 * N * OH * OW * CO * CI * KH * KW)
dilation_h = dilation_w = 1
kernel_before_dilation = kernel
if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
kernel_before_dilation = kernel.op.input_tensors[0]
if layout == 'NCHW':
dilation_h = (get_const_int(kernel.shape[2]) +
get_const_int(kernel_before_dilation.shape[2]) - 1) \
// get_const_int(kernel_before_dilation.shape[2])
dilation_w = (get_const_int(kernel.shape[3]) +
get_const_int(kernel_before_dilation.shape[3]) - 1) \
// get_const_int(kernel_before_dilation.shape[2])
elif layout == 'NHWC':
dilation_h = (get_const_int(kernel.shape[1]) +
get_const_int(kernel_before_dilation.shape[1]) - 1) \
// get_const_int(kernel_before_dilation.shape[1])
dilation_w = (get_const_int(kernel.shape[2]) +
get_const_int(kernel_before_dilation.shape[2]) - 1) \
// get_const_int(kernel_before_dilation.shape[2])
return miopen.conv2d_forward(data, return miopen.conv2d_forward(data,
kernel_cudnn, kernel_before_dilation,
stride_h, stride_h,
stride_w, stride_w,
pad_h, pad_h,
...@@ -66,25 +80,30 @@ def conv2d_rocm(data, kernel, stride, padding, layout='NCHW', out_dtype='float32 ...@@ -66,25 +80,30 @@ def conv2d_rocm(data, kernel, stride, padding, layout='NCHW', out_dtype='float32
dilation_h, dilation_h,
dilation_w, dilation_w,
conv_mode=0) conv_mode=0)
return topi.nn.conv2d_nchw(data, kernel_tvm, stride, padding, out_dtype)
return conv2d_cuda(cfg, data, kernel, strides, padding, layout, out_dtype)
@generic.schedule_conv2d_nchw.register(["rocm"]) @autotvm.register_topi_schedule(generic.schedule_conv2d_nchw, 'rocm', ["direct", 'winograd'])
def schedule_conv2d_nchw(outs): def schedule_conv2d_nchw_rocm(cfg, outs):
"""Schedule for conv2d_nchw with rocm backend. """TOPI schedule callback of conv2d for rocm
Parameters Parameters
---------- ----------
cfg: ConfigEntity
The config for this template
outs: Array of Tensor outs: Array of Tensor
The computation graph description of conv2d_nchw The computation graph description of conv2d
in the format of an array of tensors. in the format of an array of tensors.
Returns Returns
------- -------
s: Schedule s: Schedule
The computation schedule for conv2d_nchw. The computation schedule for conv2d.
""" """
target = tvm.target.current_target() target = tvm.target.current_target()
if target and "miopen" in target.libs: if target and "miopen" in target.libs:
return topi.generic.schedule_extern(outs) return generic.schedule_extern(outs)
return topi.cuda.schedule_conv2d_nchw(outs)
return schedule_conv2d_nchw_cuda(cfg, outs)
...@@ -9,4 +9,4 @@ def get_all_backend(): ...@@ -9,4 +9,4 @@ def get_all_backend():
A list of all supported targets A list of all supported targets
""" """
return ['llvm', 'cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx', return ['llvm', 'cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx',
'llvm -device=arm_cpu', 'aocl_sw_emu'] 'llvm -device=arm_cpu', 'opencl -device=mali', 'aocl_sw_emu']
...@@ -48,7 +48,8 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p ...@@ -48,7 +48,8 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
print("Running on target: %s" % device) print("Running on target: %s" % device)
with tvm.target.create(device): with tvm.target.create(device):
dW = topi.nn.dilate(W, (1, 1, dilation, dilation)) dW = topi.nn.dilate(W, (1, 1, dilation, dilation))
C = topi.nn.conv2d(A, dW, stride, padding, layout='NCHW', out_dtype=dtype) C = topi.nn.conv2d(A, dW, (stride, stride), (padding, padding),
layout='NCHW', out_dtype=dtype)
if add_bias: if add_bias:
C = topi.add(C, bias) C = topi.add(C, bias)
if add_relu: if add_relu:
...@@ -72,7 +73,11 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p ...@@ -72,7 +73,11 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
def test_conv2d_nchw(): def test_conv2d_nchw():
autotvm.DispatchContext.current.silent = True # load tophub
ctx = autotvm.apply_history_best([])
for device in get_all_backend():
context = autotvm.tophub.context(device)
context.__enter__()
# ResNet18 workloads # ResNet18 workloads
verify_conv2d_nchw(1, 3, 224, 64, 7, 2, 3) verify_conv2d_nchw(1, 3, 224, 64, 7, 2, 3)
...@@ -96,9 +101,21 @@ def test_conv2d_nchw(): ...@@ -96,9 +101,21 @@ def test_conv2d_nchw():
# dilation = 2 # dilation = 2
verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1, dilation=2) verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1, dilation=2)
# batch size
verify_conv2d_nchw(4, 64, 56, 64, 3, 1, 1)
verify_conv2d_nchw(9, 64, 56, 64, 3, 1, 1)
# weird workloads # weird workloads
verify_conv2d_nchw(1, 1, 1, 1, 1, 1, 1, dilation=1) verify_conv2d_nchw(2, 2, 2, 2, 2, 2, 2)
verify_conv2d_nchw(1, 1, 1, 1, 1, 1, 1, dilation=2) verify_conv2d_nchw(3, 3, 3, 3, 3, 3, 3)
verify_conv2d_nchw(4, 4, 4, 4, 4, 4, 4)
verify_conv2d_nchw(5, 5, 5, 5, 5, 5, 5)
verify_conv2d_nchw(6, 6, 6, 6, 6, 6, 6)
# disable these tests due to some bugs of llvm with nvptx
# verify_conv2d_nchw(1, 1, 1, 1, 1, 1, 1, dilation=1)
# verify_conv2d_nchw(1, 1, 1, 1, 1, 1, 1, dilation=2)
# verify_conv2d_nchw(2, 13, 71, 59, 3, 1, 1)
# inception v3 workloads # inception v3 workloads
verify_conv2d_nchw(1, 3, 299, 32, 3, 2, 0) verify_conv2d_nchw(1, 3, 299, 32, 3, 2, 0)
...@@ -117,22 +134,22 @@ def test_conv2d_nchw(): ...@@ -117,22 +134,22 @@ def test_conv2d_nchw():
verify_conv2d_nchw(1, 288, 35, 64, 1, 1, 0) verify_conv2d_nchw(1, 288, 35, 64, 1, 1, 0)
verify_conv2d_nchw(1, 288, 35, 48, 1, 1, 0) verify_conv2d_nchw(1, 288, 35, 48, 1, 1, 0)
verify_conv2d_nchw(1, 288, 35, 384, 3, 2, 0) verify_conv2d_nchw(1, 288, 35, 384, 3, 2, 0)
# verify_conv2d_nchw(1, 96, 35, 96, 3, 2, 0) verify_conv2d_nchw(1, 96, 35, 96, 3, 2, 0)
# verify_conv2d_nchw(1, 768, 17, 192, 1, 1, 0) verify_conv2d_nchw(1, 768, 17, 192, 1, 1, 0)
# verify_conv2d_nchw(1, 768, 17, 128, 1, 1, 0) verify_conv2d_nchw(1, 768, 17, 128, 1, 1, 0)
# verify_conv2d_nchw(1, 128, 17, 128, 1, 1, 0) verify_conv2d_nchw(1, 128, 17, 128, 1, 1, 0)
# verify_conv2d_nchw(1, 128, 17, 192, 7, 1, 3) verify_conv2d_nchw(1, 128, 17, 192, 7, 1, 3)
# verify_conv2d_nchw(1, 128, 17, 128, 7, 1, 3) verify_conv2d_nchw(1, 128, 17, 128, 7, 1, 3)
# verify_conv2d_nchw(1, 128, 17, 192, 1, 1, 0) verify_conv2d_nchw(1, 128, 17, 192, 1, 1, 0)
# verify_conv2d_nchw(1, 768, 17, 160, 1, 1, 0) verify_conv2d_nchw(1, 768, 17, 160, 1, 1, 0)
# verify_conv2d_nchw(1, 160, 17, 160, 1, 1, 0) verify_conv2d_nchw(1, 160, 17, 160, 1, 1, 0)
# verify_conv2d_nchw(1, 160, 17, 192, 7, 1, 3) verify_conv2d_nchw(1, 160, 17, 192, 7, 1, 3)
# verify_conv2d_nchw(1, 160, 17, 160, 7, 1, 3) verify_conv2d_nchw(1, 160, 17, 160, 7, 1, 3)
# verify_conv2d_nchw(1, 160, 17, 192, 1, 1, 0) verify_conv2d_nchw(1, 160, 17, 192, 1, 1, 0)
# verify_conv2d_nchw(1, 192, 17, 192, 1, 1, 0) verify_conv2d_nchw(1, 192, 17, 192, 1, 1, 0)
# verify_conv2d_nchw(1, 192, 17, 192, 7, 1, 3) verify_conv2d_nchw(1, 192, 17, 192, 7, 1, 3)
# verify_conv2d_nchw(1, 192, 17, 320, 3, 2, 0) verify_conv2d_nchw(1, 192, 17, 320, 3, 2, 0)
# verify_conv2d_nchw(1, 192, 17, 192, 3, 2, 0) verify_conv2d_nchw(1, 192, 17, 192, 3, 2, 0)
verify_conv2d_nchw(1, 1280, 8, 320, 1, 1, 0) verify_conv2d_nchw(1, 1280, 8, 320, 1, 1, 0)
verify_conv2d_nchw(1, 1280, 8, 384, 1, 1, 0) verify_conv2d_nchw(1, 1280, 8, 384, 1, 1, 0)
verify_conv2d_nchw(1, 384, 8, 384, 1, 1, 0) verify_conv2d_nchw(1, 384, 8, 384, 1, 1, 0)
......
"""Example code to do convolution."""
import numpy as np
import tvm
from tvm import autotvm
from tvm.autotvm.task.space import FallbackConfigEntity
import topi
import topi.testing
from tvm.contrib.pickle_memoize import memoize
from topi.util import get_const_tuple
def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1, add_bias=False, add_relu=False):
print("Workload: (%d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter, kernel, stride, padding))
in_height = in_width = in_size
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W')
bias = tvm.placeholder((num_filter, 1, 1), name='bias')
a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape)
bias_shape = get_const_tuple(bias.shape)
dtype = A.dtype
@memoize("topi.tests.test_topi_conv2d_nchw.verify_conv2d_nchw")
def get_ref_data():
a_np = np.random.uniform(size=a_shape).astype(dtype)
w_np = np.random.uniform(size=w_shape).astype(dtype)
b_np = np.random.uniform(size=bias_shape).astype(dtype)
dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
c_np = topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding)
if add_bias:
b_np = np.random.uniform(size=bias_shape).astype(dtype)
c_np += b_np
if add_relu:
c_np = np.maximum(c_np, 0)
return a_np, w_np, b_np, c_np
a_np, w_np, b_np, c_np = get_ref_data()
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
dW = topi.nn.dilate(W, (1, 1, dilation, dilation))
C = topi.nn.conv2d(A, dW, stride, padding, layout='NCHW', out_dtype=dtype)
if add_bias:
C = topi.add(C, bias)
if add_relu:
C = topi.nn.relu(C)
s = topi.generic.schedule_conv2d_nchw([C])
a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
if add_bias:
func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
func(a, w, b, c)
else:
func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
func(a, w, c)
np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
for device in ['cuda', 'llvm -device=arm_cpu', 'opencl -device=mali']:
check_device(device)
class WinogradFallback(autotvm.FallbackContext):
def _query_inside(self, target, workload):
key = (target, workload)
if key in self.memory:
return self.memory[key]
cfg = FallbackConfigEntity()
cfg.template_key = 'winograd'
self.memory[key] = cfg
return cfg
def test_conv2d_nchw():
autotvm.DispatchContext.current.silent = True
with WinogradFallback():
# resnet 18 workloads
verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1)
verify_conv2d_nchw(1, 128, 28, 128, 3, 1, 1)
verify_conv2d_nchw(1, 256, 14, 256, 3, 1, 1)
verify_conv2d_nchw(1, 512, 7, 512, 3, 1, 1)
# batch size = 2
verify_conv2d_nchw(2, 64, 56, 64, 3, 1, 1)
# relu, bias
verify_conv2d_nchw(2, 64, 56, 64, 3, 1, 1, add_bias=True)
verify_conv2d_nchw(2, 64, 56, 64, 3, 1, 1, add_relu=True)
verify_conv2d_nchw(2, 64, 56, 64, 3, 1, 1, add_relu=True, add_bias=True)
# werid workloads
verify_conv2d_nchw(1, 1, 1, 1, 3, 1, 1)
verify_conv2d_nchw(3, 3, 3, 3, 3, 1, 1)
verify_conv2d_nchw(2, 13, 71, 59, 3, 1, 1)
if __name__ == "__main__":
test_conv2d_nchw()
import tvm import tvm
from tvm import autotvm
import topi import topi
import topi.testing import topi.testing
import numpy as np import numpy as np
from topi.util import get_const_tuple from topi.util import get_const_tuple
from topi.nn.util import get_pad_tuple
from tvm.contrib.pickle_memoize import memoize from tvm.contrib.pickle_memoize import memoize
from common import get_all_backend from common import get_all_backend
...@@ -11,6 +13,16 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu ...@@ -11,6 +13,16 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu
in_width = in_height in_width = in_height
filter_channel = in_channel filter_channel = in_channel
filter_width = filter_height filter_width = filter_height
stride_h = stride_w = stride
if dilation == 1:
# here we transform the padding argument from 'str' to 'tuple' ,
# because we need this to match the "workload" tuple to the records in TopHub
pad_h, pad_w, _, _ = get_pad_tuple(padding, (filter_height, filter_width))
padding_args = (pad_h, pad_w)
else:
padding_args = padding
# placeholder # placeholder
Input = tvm.placeholder((batch, in_channel, in_height, in_width), name='Input') Input = tvm.placeholder((batch, in_channel, in_height, in_width), name='Input')
Filter = tvm.placeholder((filter_channel, channel_multiplier, filter_height, filter_width), name='Filter') Filter = tvm.placeholder((filter_channel, channel_multiplier, filter_height, filter_width), name='Filter')
...@@ -18,6 +30,8 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu ...@@ -18,6 +30,8 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu
Scale = tvm.placeholder((in_channel * channel_multiplier,), name='Scale') Scale = tvm.placeholder((in_channel * channel_multiplier,), name='Scale')
Shift = tvm.placeholder((in_channel * channel_multiplier,), name='Shift') Shift = tvm.placeholder((in_channel * channel_multiplier,), name='Shift')
dtype = 'float32'
def check_device(device): def check_device(device):
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
if not ctx.exist: if not ctx.exist:
...@@ -26,7 +40,8 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu ...@@ -26,7 +40,8 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu
print("Running on target: %s" % device) print("Running on target: %s" % device)
with tvm.target.create(device): with tvm.target.create(device):
# declare # declare
DepthwiseConv2d = topi.nn.depthwise_conv2d_nchw(Input, DilatedFilter, stride=stride, padding=padding) DepthwiseConv2d = topi.nn.depthwise_conv2d_nchw(Input, DilatedFilter,
(stride_h, stride_w), padding_args, dtype)
ScaleShift = topi.nn.scale_shift_nchw(DepthwiseConv2d, Scale, Shift) ScaleShift = topi.nn.scale_shift_nchw(DepthwiseConv2d, Scale, Shift)
Relu = topi.nn.relu(ScaleShift) Relu = topi.nn.relu(ScaleShift)
# schedule # schedule
...@@ -39,7 +54,6 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu ...@@ -39,7 +54,6 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu
f3 = tvm.build(s3, [Input, Filter, Scale, Shift, Relu], device) f3 = tvm.build(s3, [Input, Filter, Scale, Shift, Relu], device)
# Prepare pod type for test data closure # Prepare pod type for test data closure
dtype = Input.dtype
input_shape = get_const_tuple(Input.shape) input_shape = get_const_tuple(Input.shape)
filter_shape = get_const_tuple(Filter.shape) filter_shape = get_const_tuple(Filter.shape)
scale_shape = get_const_tuple(Scale.shape) scale_shape = get_const_tuple(Scale.shape)
...@@ -56,7 +70,7 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu ...@@ -56,7 +70,7 @@ def depthwise_conv2d_with_workload_nchw(batch, in_channel, in_height, channel_mu
shift_np = np.random.uniform(size=shift_shape).astype(dtype) shift_np = np.random.uniform(size=shift_shape).astype(dtype)
# correctness with scipy # correctness with scipy
depthwise_conv2d_scipy = topi.testing.depthwise_conv2d_python_nchw( depthwise_conv2d_scipy = topi.testing.depthwise_conv2d_python_nchw(
input_np, dilated_filter_np, stride=stride, padding=padding) input_np, dilated_filter_np, stride, padding)
scale_shift_scipy = np.zeros(shape=scale_shift_shape) scale_shift_scipy = np.zeros(shape=scale_shift_shape)
for c in range(in_channel * channel_multiplier): for c in range(in_channel * channel_multiplier):
scale_shift_scipy[:,c,:,:] = depthwise_conv2d_scipy[:,c,:,:] * scale_np[c] + shift_np[c] scale_shift_scipy[:,c,:,:] = depthwise_conv2d_scipy[:,c,:,:] * scale_np[c] + shift_np[c]
...@@ -96,6 +110,15 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu ...@@ -96,6 +110,15 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu
filter_channel = in_channel filter_channel = in_channel
filter_width = filter_height filter_width = filter_height
stride_w = stride_h stride_w = stride_h
if dilation == 1:
# here we transform the padding argument from 'str' to 'tuple' ,
# because we need this to match the "workload" tuple to the records in TopHub
pad_h, pad_w, _, _ = get_pad_tuple(padding, (filter_height, filter_width))
padding_args = (pad_h, pad_w)
else:
padding_args = padding
# placeholder # placeholder
Input = tvm.placeholder((batch, in_height, in_width, in_channel), name='Input') Input = tvm.placeholder((batch, in_height, in_width, in_channel), name='Input')
Filter = tvm.placeholder((filter_height, filter_width,filter_channel, channel_multiplier), name='Filter') Filter = tvm.placeholder((filter_height, filter_width,filter_channel, channel_multiplier), name='Filter')
...@@ -103,6 +126,8 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu ...@@ -103,6 +126,8 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu
Scale = tvm.placeholder((in_channel * channel_multiplier,), name='Scale') Scale = tvm.placeholder((in_channel * channel_multiplier,), name='Scale')
Shift = tvm.placeholder((in_channel * channel_multiplier,), name='Shift') Shift = tvm.placeholder((in_channel * channel_multiplier,), name='Shift')
dtype = 'float32'
def check_device(device): def check_device(device):
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
if not ctx.exist: if not ctx.exist:
...@@ -112,7 +137,8 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu ...@@ -112,7 +137,8 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu
with tvm.target.create(device): with tvm.target.create(device):
# declare # declare
DepthwiseConv2d = topi.nn.depthwise_conv2d_nhwc(Input, DilatedFilter, stride=[stride_h, stride_w], padding=padding) DepthwiseConv2d = topi.nn.depthwise_conv2d_nhwc(Input, DilatedFilter,
(stride_h, stride_w), padding_args, dtype)
ScaleShift = topi.nn.scale_shift_nhwc(DepthwiseConv2d, Scale, Shift) ScaleShift = topi.nn.scale_shift_nhwc(DepthwiseConv2d, Scale, Shift)
Relu = topi.nn.relu(ScaleShift) Relu = topi.nn.relu(ScaleShift)
# schedule # schedule
...@@ -125,7 +151,6 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu ...@@ -125,7 +151,6 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu
f3 = tvm.build(s3, [Input, Filter, Scale, Shift, Relu], device) f3 = tvm.build(s3, [Input, Filter, Scale, Shift, Relu], device)
# Prepare pod type for test data closure # Prepare pod type for test data closure
dtype = Input.dtype
input_shape = get_const_tuple(Input.shape) input_shape = get_const_tuple(Input.shape)
filter_shape = get_const_tuple(Filter.shape) filter_shape = get_const_tuple(Filter.shape)
scale_shape = get_const_tuple(Scale.shape) scale_shape = get_const_tuple(Scale.shape)
...@@ -180,26 +205,36 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu ...@@ -180,26 +205,36 @@ def depthwise_conv2d_with_workload_nhwc(batch, in_channel, in_height, channel_mu
def test_depthwise_conv2d(): def test_depthwise_conv2d():
print("testing nchw") # load tophub
depthwise_conv2d_with_workload_nchw(1, 728, 64, 1, 3, 1, "SAME") ctx = autotvm.apply_history_best([])
for device in get_all_backend():
context = autotvm.tophub.context(device)
context.__enter__()
# mobilenet workloads
depthwise_conv2d_with_workload_nchw(1, 32, 112, 1, 3, 1, "SAME")
depthwise_conv2d_with_workload_nchw(1, 64, 112, 1, 3, 2, "SAME")
depthwise_conv2d_with_workload_nchw(1, 128, 56, 1, 3, 1, "SAME")
depthwise_conv2d_with_workload_nchw(1, 128, 56, 1, 3, 2, "SAME")
depthwise_conv2d_with_workload_nchw(1, 256, 28, 1, 3, 1, "SAME")
depthwise_conv2d_with_workload_nchw(1, 256, 28, 1, 3, 2, "SAME")
depthwise_conv2d_with_workload_nchw(1, 512, 14, 1, 3, 1, "SAME")
depthwise_conv2d_with_workload_nchw(1, 512, 14, 1, 3, 2, "SAME")
depthwise_conv2d_with_workload_nchw(1, 1024, 7, 1, 3, 1, "SAME")
# NCHW
depthwise_conv2d_with_workload_nchw(1, 728, 32, 1, 3, 1, "SAME") depthwise_conv2d_with_workload_nchw(1, 728, 32, 1, 3, 1, "SAME")
depthwise_conv2d_with_workload_nchw(4, 256, 64, 2, 5, 2, "SAME") depthwise_conv2d_with_workload_nchw(4, 256, 64, 2, 5, 2, "SAME")
depthwise_conv2d_with_workload_nchw(4, 256, 32, 2, 5, 2, "SAME")
depthwise_conv2d_with_workload_nchw(1, 728, 64, 1, 3, 1, "VALID")
depthwise_conv2d_with_workload_nchw(1, 728, 32, 1, 3, 1, "VALID") depthwise_conv2d_with_workload_nchw(1, 728, 32, 1, 3, 1, "VALID")
depthwise_conv2d_with_workload_nchw(4, 256, 64, 2, 5, 2, "VALID") depthwise_conv2d_with_workload_nchw(4, 256, 64, 2, 5, 2, "VALID")
depthwise_conv2d_with_workload_nchw(4, 256, 32, 2, 5, 2, "VALID")
# dilation = 2 # dilation = 2
depthwise_conv2d_with_workload_nchw(1, 728, 64, 1, 3, 1, "SAME", dilation=2) depthwise_conv2d_with_workload_nchw(1, 728, 64, 1, 3, 1, "SAME", dilation=2)
print("testing nhwc")
depthwise_conv2d_with_workload_nhwc(1, 728, 64, 1, 3, 1, "SAME") # NHWC
depthwise_conv2d_with_workload_nhwc(1, 728, 32, 1, 3, 1, "SAME") depthwise_conv2d_with_workload_nhwc(1, 728, 32, 1, 3, 1, "SAME")
depthwise_conv2d_with_workload_nhwc(4, 256, 64, 2, 5, 2, "SAME") depthwise_conv2d_with_workload_nhwc(4, 256, 64, 2, 5, 2, "SAME")
depthwise_conv2d_with_workload_nhwc(4, 256, 32, 2, 5, 2, "SAME")
depthwise_conv2d_with_workload_nhwc(1, 728, 64, 1, 3, 1, "VALID")
depthwise_conv2d_with_workload_nhwc(1, 728, 32, 1, 3, 1, "VALID") depthwise_conv2d_with_workload_nhwc(1, 728, 32, 1, 3, 1, "VALID")
depthwise_conv2d_with_workload_nhwc(4, 256, 64, 2, 5, 2, "VALID") depthwise_conv2d_with_workload_nhwc(4, 256, 64, 2, 5, 2, "VALID")
depthwise_conv2d_with_workload_nhwc(4, 256, 32, 2, 5, 2, "VALID")
# dilation = 2 # dilation = 2
depthwise_conv2d_with_workload_nhwc(1, 728, 64, 1, 3, 1, "SAME", dilation=2) depthwise_conv2d_with_workload_nhwc(1, 728, 64, 1, 3, 1, "SAME", dilation=2)
......
...@@ -10,7 +10,7 @@ vendor provided library CuDNN in many cases. ...@@ -10,7 +10,7 @@ vendor provided library CuDNN in many cases.
###################################################################### ######################################################################
# Install dependencies # Install dependencies
# ---------------------------------------- # --------------------
# To use autotvm package in tvm, we need to install some extra dependencies. # To use autotvm package in tvm, we need to install some extra dependencies.
# (change "3" to "2" if you use python2): # (change "3" to "2" if you use python2):
# #
...@@ -20,7 +20,6 @@ vendor provided library CuDNN in many cases. ...@@ -20,7 +20,6 @@ vendor provided library CuDNN in many cases.
# #
# To make tvm run faster in tuning, it is recommended to use cython # To make tvm run faster in tuning, it is recommended to use cython
# as FFI of tvm. In the root directory of tvm, execute # as FFI of tvm. In the root directory of tvm, execute
# (change "3" to "2" if you use python2):
# #
# .. code-block:: bash # .. code-block:: bash
# #
...@@ -41,7 +40,7 @@ from tvm import autotvm ...@@ -41,7 +40,7 @@ from tvm import autotvm
###################################################################### ######################################################################
# Step 1: Define the search space # Step 1: Define the search space
# --------------------------------- # --------------------------------
# There are plenty of useful schedule primitives in tvm. You can also find # There are plenty of useful schedule primitives in tvm. You can also find
# some tutorials that describe them in more details, such as # some tutorials that describe them in more details, such as
# (1). :ref:`opt-conv-gpu` # (1). :ref:`opt-conv-gpu`
...@@ -72,6 +71,21 @@ def conv2d_no_batching(N, H, W, CO, CI, KH, KW, stride, padding): ...@@ -72,6 +71,21 @@ def conv2d_no_batching(N, H, W, CO, CI, KH, KW, stride, padding):
conv = topi.nn.conv2d_nchw(data, kernel, stride, padding, 'float32') conv = topi.nn.conv2d_nchw(data, kernel, stride, padding, 'float32')
s = tvm.create_schedule([conv.op]) s = tvm.create_schedule([conv.op])
##### space definition begin #####
n, f, y, x = s[conv].op.axis
rc, ry, rx = s[conv].op.reduce_axis
cfg = autotvm.get_config()
cfg.define_split("tile_f", f, num_outputs=4)
cfg.define_split("tile_y", y, num_outputs=4)
cfg.define_split("tile_x", x, num_outputs=4)
cfg.define_split("tile_rc", rc, num_outputs=3)
cfg.define_split("tile_ry", ry, num_outputs=3)
cfg.define_split("tile_rx", rx, num_outputs=3)
cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
cfg.define_knob("unroll_explicit", [0, 1])
##### space definition end #####
# inline padding # inline padding
pad_data = s[conv].op.input_tensors[0] pad_data = s[conv].op.input_tensors[0]
s[pad_data].compute_inline() s[pad_data].compute_inline()
...@@ -88,10 +102,6 @@ def conv2d_no_batching(N, H, W, CO, CI, KH, KW, stride, padding): ...@@ -88,10 +102,6 @@ def conv2d_no_batching(N, H, W, CO, CI, KH, KW, stride, padding):
# tile and bind spatial axes # tile and bind spatial axes
n, f, y, x = s[output].op.axis n, f, y, x = s[output].op.axis
cfg = autotvm.get_config()
cfg.define_split("tile_f", cfg.axis(f), num_outputs=4)
cfg.define_split("tile_y", cfg.axis(y), num_outputs=4)
cfg.define_split("tile_x", cfg.axis(x), num_outputs=4)
bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f) bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f)
by, vy, ty, yi = cfg["tile_y"].apply(s, output, y) by, vy, ty, yi = cfg["tile_y"].apply(s, output, y)
bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x) bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)
...@@ -109,12 +119,9 @@ def conv2d_no_batching(N, H, W, CO, CI, KH, KW, stride, padding): ...@@ -109,12 +119,9 @@ def conv2d_no_batching(N, H, W, CO, CI, KH, KW, stride, padding):
s[output].reorder(n, bf, by, bx, vf, vy, vx, tf, ty, tx, fi, yi, xi) s[output].reorder(n, bf, by, bx, vf, vy, vx, tf, ty, tx, fi, yi, xi)
s[OL].compute_at(s[output], tx) s[OL].compute_at(s[output], tx)
# tile and bind reduction axes # tile reduction axes
n, f, y, x = s[OL].op.axis n, f, y, x = s[OL].op.axis
rc, ry, rx = s[OL].op.reduce_axis rc, ry, rx = s[OL].op.reduce_axis
cfg.define_split("tile_rc", cfg.axis(rc), num_outputs=3)
cfg.define_split("tile_ry", cfg.axis(ry), num_outputs=3)
cfg.define_split("tile_rx", cfg.axis(rx), num_outputs=3)
rco, rcm, rci = cfg['tile_rc'].apply(s, OL, rc) rco, rcm, rci = cfg['tile_rc'].apply(s, OL, rc)
ryo, rym, ryi = cfg['tile_rx'].apply(s, OL, ry) ryo, rym, ryi = cfg['tile_rx'].apply(s, OL, ry)
rxo, rxm, rxi = cfg['tile_ry'].apply(s, OL, rx) rxo, rxm, rxi = cfg['tile_ry'].apply(s, OL, rx)
...@@ -137,8 +144,6 @@ def conv2d_no_batching(N, H, W, CO, CI, KH, KW, stride, padding): ...@@ -137,8 +144,6 @@ def conv2d_no_batching(N, H, W, CO, CI, KH, KW, stride, padding):
s[load].bind(tx, tvm.thread_axis("threadIdx.x")) s[load].bind(tx, tvm.thread_axis("threadIdx.x"))
# tune unroll # tune unroll
cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
cfg.define_knob("unroll_explicit", [0, 1])
s[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val) s[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val)
s[output].pragma(kernel_scope, 'unroll_explicit', cfg['unroll_explicit'].val) s[output].pragma(kernel_scope, 'unroll_explicit', cfg['unroll_explicit'].val)
......
...@@ -8,9 +8,9 @@ performance. This is a tutorial about how to tune a whole convolutional ...@@ -8,9 +8,9 @@ performance. This is a tutorial about how to tune a whole convolutional
network. network.
The operator implementation for ARM CPU in TVM is written in template form. The operator implementation for ARM CPU in TVM is written in template form.
It has many tunable knobs (tile factor, vectorization, unrolling, etc). The template has many tunable knobs (tile factor, vectorization, unrolling, etc).
We will do tuning for all convolution and depthwise convolution operators We will tune all convolution and depthwise convolution operators
in the neural network. After the tuning, we can get a log file which stores in the neural network. After tuning, we produce a log file which stores
the best knob values for all required operators. When the tvm compiler compiles the best knob values for all required operators. When the tvm compiler compiles
these operators, it will query this log file to get the best knob values. these operators, it will query this log file to get the best knob values.
...@@ -21,15 +21,15 @@ to see the results. ...@@ -21,15 +21,15 @@ to see the results.
###################################################################### ######################################################################
# Install dependencies # Install dependencies
# ---------------------------------------- # --------------------
# To use autotvm package in tvm, we need to install some extra dependencies. # To use the autotvm package in tvm, we need to install some extra dependencies.
# (change "3" to "2" if you use python2): # (change "3" to "2" if you use python2):
# #
# .. code-block:: bash # .. code-block:: bash
# #
# pip3 install --user psutil xgboost tornado # pip3 install --user psutil xgboost tornado
# #
# To make tvm run faster in tuning, it is recommended to use cython # To make tvm run faster during tuning, it is recommended to use cython
# as FFI of tvm. In the root directory of tvm, execute # as FFI of tvm. In the root directory of tvm, execute
# (change "3" to "2" if you use python2): # (change "3" to "2" if you use python2):
# #
...@@ -108,10 +108,9 @@ def get_network(name, batch_size): ...@@ -108,10 +108,9 @@ def get_network(name, batch_size):
# To scale up the tuning, TVM uses RPC Tracker to manage distributed devices. # To scale up the tuning, TVM uses RPC Tracker to manage distributed devices.
# The RPC Tracker is a centralized master node. We can register all devices to # The RPC Tracker is a centralized master node. We can register all devices to
# the tracker. For example, if we have 10 phones, we can register all of them # the tracker. For example, if we have 10 phones, we can register all of them
# to the tracker, then we can run 10 measurements in parallel, which accelerates # to the tracker, and run 10 measurements in parallel, accelerating the tuning process.
# the tuning process.
# #
# To start an RPC tracker, run this command in the host machine. The tracker is # To start an RPC tracker, run this command on the host machine. The tracker is
# required during the whole tuning process, so we need to open a new terminal for # required during the whole tuning process, so we need to open a new terminal for
# this command: # this command:
# #
...@@ -144,6 +143,8 @@ def get_network(name, batch_size): ...@@ -144,6 +143,8 @@ def get_network(name, batch_size):
# * For Android: # * For Android:
# Follow this `readme page <https://github.com/dmlc/tvm/tree/master/apps/android_rpc>`_ to # Follow this `readme page <https://github.com/dmlc/tvm/tree/master/apps/android_rpc>`_ to
# install tvm rpc apk on the android device. Make sure you can pass the android rpc test. # install tvm rpc apk on the android device. Make sure you can pass the android rpc test.
# Then you have already registred your device. During tuning, you have to go to developer option
# and enable "Keep screen awake during changing" and charge your phone to make it stable.
# #
# After registering devices, we can confirm it by querying rpc_tracker # After registering devices, we can confirm it by querying rpc_tracker
# #
...@@ -170,7 +171,7 @@ def get_network(name, batch_size): ...@@ -170,7 +171,7 @@ def get_network(name, batch_size):
########################################### ###########################################
# Set Tuning Options # Set Tuning Options
# ------------------ # ------------------
# Before tuning, we should do some configurations. Here I use an RK3399 board # Before tuning, we should apply some configurations. Here I use an RK3399 board
# as example. In your setting, you should modify the target and device_key accordingly. # as example. In your setting, you should modify the target and device_key accordingly.
# set :code:`use_android` to True if you use android phone. # set :code:`use_android` to True if you use android phone.
...@@ -213,18 +214,20 @@ tuning_option = { ...@@ -213,18 +214,20 @@ tuning_option = {
# #
# .. note:: How to set tuning options # .. note:: How to set tuning options
# #
# In general, the default value provided here works well. # In general, the default values provided here work well.
# If you have large time budget, you can set :code:`n_trial`, :code:`early_stopping` larger, # If you have enough time budget, you can set :code:`n_trial`, :code:`early_stopping` larger,
# which makes the tuning run longer. # which makes the tuning run longer.
# If your device runs very slow or your conv2d operators have many GFLOPs, considering to
# set timeout larger.
# #
################################################################### ###################################################################
# Begin Tuning # Begin Tuning
# ------------ # ------------
# Now we can extract tuning tasks from the network and begin tuning. # Now we can extract tuning tasks from the network and begin tuning.
# Here we provide a simple utility function to tune a list of tasks. # Here, we provide a simple utility function to tune a list of tasks.
# This function is just an initial implementation which tunes them in sequential order. # This function is just an initial implementation which tunes them in sequential order.
# Later we will bring more sophisticated tuner scheduler. # We will introduce a more sophisticated tuning scheduler in the future.
# You can skip the implementation of this function for this tutorial. # You can skip the implementation of this function for this tutorial.
def tune_tasks(tasks, def tune_tasks(tasks,
...@@ -284,7 +287,7 @@ def tune_tasks(tasks, ...@@ -284,7 +287,7 @@ def tune_tasks(tasks,
######################################################################## ########################################################################
# Finally we launch tuning jobs and evaluate the end-to-end performance. # Finally, we launch tuning jobs and evaluate the end-to-end performance.
def tune_and_evaluate(tuning_opt): def tune_and_evaluate(tuning_opt):
# extract workloads from nnvm graph # extract workloads from nnvm graph
...@@ -301,7 +304,7 @@ def tune_and_evaluate(tuning_opt): ...@@ -301,7 +304,7 @@ def tune_and_evaluate(tuning_opt):
# compile kernels with history best records # compile kernels with history best records
with autotvm.apply_history_best(log_file): with autotvm.apply_history_best(log_file):
print("Compile...") print("Compile...")
with nnvm.compiler.build_config(opt_level=2, add_pass=['AlterOpLayout']): with nnvm.compiler.build_config(opt_level=3):
graph, lib, params = nnvm.compiler.build( graph, lib, params = nnvm.compiler.build(
net, target=target, shape={'data': input_shape}, params=params, dtype=dtype) net, target=target, shape={'data': input_shape}, params=params, dtype=dtype)
...@@ -338,7 +341,7 @@ def tune_and_evaluate(tuning_opt): ...@@ -338,7 +341,7 @@ def tune_and_evaluate(tuning_opt):
(np.mean(prof_res), np.std(prof_res))) (np.mean(prof_res), np.std(prof_res)))
# We do not run the tuning in our webpage server since it takes too long. # We do not run the tuning in our webpage server since it takes too long.
# Uncomment the following line to run by yourself. # Uncomment the following line to run it by yourself.
# tune_and_evaluate(tuning_option) # tune_and_evaluate(tuning_option)
...@@ -373,9 +376,9 @@ def tune_and_evaluate(tuning_opt): ...@@ -373,9 +376,9 @@ def tune_and_evaluate(tuning_opt):
###################################################################### ######################################################################
# #
# .. note:: **Meet some problems?** # .. note:: **Experiencing Difficulties?**
# #
# The auto tuning module is error prone. If you always see " 0.00/ 0.00 GFLOPS", # The auto tuning module is error-prone. If you always see " 0.00/ 0.00 GFLOPS",
# then there must be something wrong. # then there must be something wrong.
# #
# First, make sure you set the correct configuration of your device. # First, make sure you set the correct configuration of your device.
......
...@@ -14,7 +14,7 @@ The whole workflow is illustrated by a matrix multiplication example. ...@@ -14,7 +14,7 @@ The whole workflow is illustrated by a matrix multiplication example.
###################################################################### ######################################################################
# Install dependencies # Install dependencies
# ---------------------------------------- # --------------------
# To use autotvm package in tvm, we need to install some extra dependencies. # To use autotvm package in tvm, we need to install some extra dependencies.
# (change "3" to "2" if you use python2): # (change "3" to "2" if you use python2):
# #
...@@ -44,7 +44,7 @@ from tvm import autotvm ...@@ -44,7 +44,7 @@ from tvm import autotvm
###################################################################### ######################################################################
# Step 1: Define the search space # Step 1: Define the search space
# --------------------------------- # --------------------------------
# In this section, we will rewrite a deterministic tvm schedule code to a # In this section, we will rewrite a deterministic tvm schedule code to a
# tunable schedule template. You can regard the process of search space definition # tunable schedule template. You can regard the process of search space definition
# as the parametrization of our exiting schedule code. # as the parametrization of our exiting schedule code.
...@@ -73,7 +73,7 @@ def matmul_v0(N, L, M, dtype): ...@@ -73,7 +73,7 @@ def matmul_v0(N, L, M, dtype):
##################################################################### #####################################################################
# Parametrize the schedule # Parametrize the schedule
# ^^^^^^^^^^^^^^^^^^^^^^^^^ # ^^^^^^^^^^^^^^^^^^^^^^^^
# In the previous schedule code, we use a constant "8" as tiling factor. # In the previous schedule code, we use a constant "8" as tiling factor.
# However, it might not be the best one because the best tiling factor depends # However, it might not be the best one because the best tiling factor depends
# on real hardware environment and input shape. # on real hardware environment and input shape.
......
...@@ -165,7 +165,7 @@ else: ...@@ -165,7 +165,7 @@ else:
# optimization for mali # optimization for mali
target = tvm.target.mali() target = tvm.target.mali()
with nnvm.compiler.build_config(opt_level=2): with nnvm.compiler.build_config(opt_level=3):
graph, lib, params = nnvm.compiler.build(net, target=target, graph, lib, params = nnvm.compiler.build(net, target=target,
shape={"data": data_shape}, params=params, target_host=target_host) shape={"data": data_shape}, params=params, target_host=target_host)
......
...@@ -156,7 +156,7 @@ else: ...@@ -156,7 +156,7 @@ else:
# The above line is a simple form of # The above line is a simple form of
# target = tvm.target.create('llvm -devcie=arm_cpu -model=bcm2837 -target=armv7l-linux-gnueabihf -mattr=+neon') # target = tvm.target.create('llvm -devcie=arm_cpu -model=bcm2837 -target=armv7l-linux-gnueabihf -mattr=+neon')
with nnvm.compiler.build_config(opt_level=2, add_pass=['AlterOpLayout']): with nnvm.compiler.build_config(opt_level=3):
graph, lib, params = nnvm.compiler.build( graph, lib, params = nnvm.compiler.build(
net, target, shape={"data": data_shape}, params=params) net, target, shape={"data": data_shape}, params=params)
......
...@@ -103,13 +103,22 @@ with tvm.target.create("cuda"): ...@@ -103,13 +103,22 @@ with tvm.target.create("cuda"):
###################################################################### ######################################################################
# Fusing convolutions # Fusing convolutions
# ------------------- # -------------------
# We can fuse :code:`topi.nn.conv2d` and :code:`topi.nn.relu` together # We can fuse :code:`topi.nn.conv2d` and :code:`topi.nn.relu` together.
# #
# .. note::
#
# TOPI functions are all generic functions. They have different implementations
# for different backends to optimize for performance.
# For each backend, it is necessary to call them under a target scope for both
# compute declaration and schedule. TVM will choose the right function to call with
# the target information.
data = tvm.placeholder((1, 3, 224, 224)) data = tvm.placeholder((1, 3, 224, 224))
kernel = tvm.placeholder((10, 3, 5, 5)) kernel = tvm.placeholder((10, 3, 5, 5))
conv = topi.nn.conv2d(data, kernel, strides=1, padding=2)
out = topi.nn.relu(conv)
with tvm.target.create("cuda"): with tvm.target.create("cuda"):
conv = topi.nn.conv2d(data, kernel, strides=1, padding=2)
out = topi.nn.relu(conv)
sconv = topi.generic.nn.schedule_conv2d_nchw(out) sconv = topi.generic.nn.schedule_conv2d_nchw(out)
print(tvm.lower(sconv, [data, kernel], simple_mode=True)) print(tvm.lower(sconv, [data, kernel], simple_mode=True))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment