tune_nnvm_arm.py 16.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17
"""
18 19
Auto-tuning a convolutional network for ARM CPU (NNVM)
======================================================
20
**Author**: `Lianmin Zheng <https://github.com/merrymercy>`_, `Zhao Wu <https://github.com/FrozenGene>`_
21 22 23 24 25 26

Auto-tuning for a specific ARM device is critical for getting the best
performance. This is a tutorial about how to tune a whole convolutional
network.

The operator implementation for ARM CPU in TVM is written in template form.
27 28 29
The template has many tunable knobs (tile factor, vectorization, unrolling, etc).
We will tune all convolution and depthwise convolution operators
in the neural network. After tuning, we produce a log file which stores
30 31 32 33 34 35 36 37 38 39
the best knob values for all required operators. When the tvm compiler compiles
these operators, it will query this log file to get the best knob values.

We also released pre-tuned parameters for some arm devices. You can go to
`ARM CPU Benchmark <https://github.com/dmlc/tvm/wiki/Benchmark#arm-cpu>`_
to see the results.
"""

######################################################################
# Install dependencies
40 41
# --------------------
# To use the autotvm package in tvm, we need to install some extra dependencies.
42 43 44 45
# (change "3" to "2" if you use python2):
#
# .. code-block:: bash
#
46
#   pip3 install --user psutil xgboost tornado
47
#
48
# To make tvm run faster during tuning, it is recommended to use cython
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
# as FFI of tvm. In the root directory of tvm, execute
# (change "3" to "2" if you use python2):
#
# .. code-block:: bash
#
#   pip3 install --user cython
#   sudo make cython3
#
# Now return to python code. Import packages.

import os

import numpy as np

import nnvm.testing
import nnvm.compiler
import tvm
from tvm import autotvm
from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner
from tvm.contrib.util import tempdir
import tvm.contrib.graph_runtime as runtime

#################################################################
# Define network
# --------------
# First we need to define the network in nnvm symbol API.
# We can load some pre-defined network from :code:`nnvm.testing`.
# We can also load models from MXNet, ONNX and TensorFlow (see NNVM
# tutorials :ref:`tutorial-nnvm` for more details).

def get_network(name, batch_size):
    """Get the symbol definition and random weight of a network"""
81
    input_shape = (batch_size, 3, 224, 224)
82 83
    output_shape = (batch_size, 1000)

84 85 86 87 88 89 90
    if "resnet" in name:
        n_layer = int(name.split('-')[1])
        net, params = nnvm.testing.resnet.get_workload(num_layers=n_layer, batch_size=batch_size)
    elif "vgg" in name:
        n_layer = int(name.split('-')[1])
        net, params = nnvm.testing.vgg.get_workload(num_layers=n_layer, batch_size=batch_size)
    elif name == 'mobilenet':
91
        net, params = nnvm.testing.mobilenet.get_workload(batch_size=batch_size)
92
    elif name == 'squeezenet_v1.1':
93
        net, params = nnvm.testing.squeezenet.get_workload(batch_size=batch_size, version='1.1')
94 95 96 97
    elif name == 'inception_v3':
        input_shape = (1, 3, 299, 299)
        net, params = nnvm.testing.inception_v3.get_workload(batch_size=batch_size)
    elif name == 'custom':
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
        # an example for custom network
        from nnvm.testing import utils
        net = nnvm.sym.Variable('data')
        net = nnvm.sym.conv2d(net, channels=4, kernel_size=(3,3), padding=(1,1))
        net = nnvm.sym.flatten(net)
        net = nnvm.sym.dense(net, units=1000)
        net, params = utils.create_workload(net, batch_size, (3, 224, 224))
    elif name == 'mxnet':
        # an example for mxnet model
        from mxnet.gluon.model_zoo.vision import get_model
        block = get_model('resnet18_v1', pretrained=True)
        net, params = nnvm.frontend.from_mxnet(block)
        net = nnvm.sym.softmax(net)
    else:
        raise ValueError("Unsupported network: " + name)

114
    return net, params, input_shape, output_shape
115

116

117 118 119 120 121 122 123 124 125 126
#################################################################
# Start RPC Tracker
# -----------------
# TVM uses RPC session to communicate with ARM boards.
# During tuning, the tuner will send the generated code to the board and
# measure the speed of code on the board.
#
# To scale up the tuning, TVM uses RPC Tracker to manage distributed devices.
# The RPC Tracker is a centralized master node. We can register all devices to
# the tracker. For example, if we have 10 phones, we can register all of them
127
# to the tracker, and run 10 measurements in parallel, accelerating the tuning process.
128
#
129
# To start an RPC tracker, run this command on the host machine. The tracker is
130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
# required during the whole tuning process, so we need to open a new terminal for
# this command:
#
# .. code-block:: bash
#
#   python -m tvm.exec.rpc_tracker --host=0.0.0.0 --port=9190
#
# The expected output is
#
# .. code-block:: bash
#
#   INFO:RPCTracker:bind to 0.0.0.0:9190

#################################################################
# Register devices to RPC Tracker
# -----------------------------------
# Now we can register our devices to the tracker. The first step is to
# build tvm runtime for the ARM devices.
#
# * For Linux:
#   Follow this section :ref:`build-tvm-runtime-on-device` to build
#   tvm runtime on the device. Then register the device to tracker by
#
#   .. code-block:: bash
#
#     python -m tvm.exec.rpc_server --tracker=[HOST_IP]:9190 --key=rk3399
#
#   (replace :code:`[HOST_IP]` with the IP address of your host machine)
#
# * For Android:
#   Follow this `readme page <https://github.com/dmlc/tvm/tree/master/apps/android_rpc>`_ to
#   install tvm rpc apk on the android device. Make sure you can pass the android rpc test.
162 163
#   Then you have already registred your device. During tuning, you have to go to developer option
#   and enable "Keep screen awake during changing" and charge your phone to make it stable.
164 165 166 167 168 169 170 171 172 173 174 175 176
#
# After registering devices, we can confirm it by querying rpc_tracker
#
# .. code-block:: bash
#
#   python -m tvm.exec.query_rpc_tracker --host=0.0.0.0 --port=9190
#
# For example, if we have 2 Huawei mate10 pro, 11 Raspberry Pi 3B and 2 rk3399,
# the output can be
#
# .. code-block:: bash
#
#    Queue Status
177 178 179 180 181 182 183
#    ----------------------------------
#    key          total  free  pending
#    ----------------------------------
#    mate10pro    2      2     0
#    rk3399       2      2     0
#    rpi3b        11     11    0
#    ----------------------------------
184 185
#
# You can register multiple devices to the tracker to accelerate the measurement in tuning.
186 187 188 189

###########################################
# Set Tuning Options
# ------------------
190
# Before tuning, we should apply some configurations. Here I use an RK3399 board
191 192 193 194
# as example. In your setting, you should modify the target and device_key accordingly.
# set :code:`use_android` to True if you use android phone.

#### DEVICE CONFIG ####
195 196 197 198 199 200 201 202

# Replace "aarch64-linux-gnu" with the correct target of your board.
# This target is used for cross compilation. You can query it by :code:`gcc -v` on your device.
target = tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu')

# Also replace this with the device key in your tracker
device_key = 'rk3399'

203 204 205 206
# Set this to True if you use android phone
use_android = False

#### TUNING OPTION ####
207 208 209 210 211
network = 'resnet-18'
log_file = "%s.%s.log" % (device_key, network)
dtype = 'float32'

tuning_option = {
212 213 214
    'log_filename': log_file,

    'tuner': 'xgb',
215 216
    'n_trial': 2000,
    'early_stopping': 800,
217 218 219 220 221 222 223 224 225 226

    'measure_option': autotvm.measure_option(
        builder=autotvm.LocalBuilder(
            build_func='ndk' if use_android else 'default'),
        runner=autotvm.RPCRunner(
            device_key, host='localhost', port=9190,
            number=5,
            timeout=4,
        ),
    ),
227 228 229 230 231 232
}

####################################################################
#
# .. note:: How to set tuning options
#
233 234
#   In general, the default values provided here work well.
#   If you have enough time budget, you can set :code:`n_trial`, :code:`early_stopping` larger,
235
#   which makes the tuning run longer.
236 237
#   If your device runs very slow or your conv2d operators have many GFLOPs, considering to
#   set timeout larger.
238
#
239 240 241 242
#   If your model has depthwise convolution, you could consider setting
#   :code:`try_spatial_pack_depthwise` be :code:`True`, which perform better than default
#   optimization in general. For example, on ARM CPU A53 2.0GHz, we find it could boost 1.6x
#   performance of depthwise convolution on Mobilenet V1 model.
243 244 245 246 247

###################################################################
# Begin Tuning
# ------------
# Now we can extract tuning tasks from the network and begin tuning.
248
# Here, we provide a simple utility function to tune a list of tasks.
249
# This function is just an initial implementation which tunes them in sequential order.
250
# We will introduce a more sophisticated tuning scheduler in the future.
251 252 253 254 255

# You can skip the implementation of this function for this tutorial.
def tune_tasks(tasks,
               measure_option,
               tuner='xgb',
256 257
               n_trial=1000,
               early_stopping=None,
258 259
               log_filename='tuning.log',
               use_transfer_learning=True,
260 261
               try_winograd=True,
               try_spatial_pack_depthwise=False):
262 263 264 265 266
    if try_winograd:
        for i in range(len(tasks)):
            try:  # try winograd template
                tsk = autotvm.task.create(tasks[i].name, tasks[i].args,
                                          tasks[i].target, tasks[i].target_host, 'winograd')
267 268 269
                input_channel = tsk.workload[1][1]
                if input_channel >= 64:
                    tasks[i] = tsk
270 271 272
            except Exception:
                pass

273 274 275 276 277 278 279 280 281 282
    # if we want to use spatial pack for depthwise convolution
    if try_spatial_pack_depthwise:
        tuner = 'xgb_knob'
        for i in range(len(tasks)):
            if tasks[i].name == 'topi_nn_depthwise_conv2d_nchw':
                tsk = autotvm.task.create(tasks[i].name, tasks[i].args,
                                          tasks[i].target, tasks[i].target_host,
                                          'contrib_spatial_pack')
                tasks[i] = tsk

283 284 285 286 287
    # create tmp log file
    tmp_log_file = log_filename + ".tmp"
    if os.path.exists(tmp_log_file):
        os.remove(tmp_log_file)

288 289
    for i, tsk in enumerate(reversed(tasks)):
        prefix = "[Task %2d/%2d] " % (i+1, len(tasks))
290 291 292 293

        # create tuner
        if tuner == 'xgb' or tuner == 'xgb-rank':
            tuner_obj = XGBTuner(tsk, loss_type='rank')
294 295
        elif tuner == 'xgb_knob':
            tuner_obj = XGBTuner(tsk, loss_type='rank', feature_type='knob')
296 297 298 299 300 301 302 303 304 305 306 307 308 309
        elif tuner == 'ga':
            tuner_obj = GATuner(tsk, pop_size=50)
        elif tuner == 'random':
            tuner_obj = RandomTuner(tsk)
        elif tuner == 'gridsearch':
            tuner_obj = GridSearchTuner(tsk)
        else:
            raise ValueError("Invalid tuner: " + tuner)

        if use_transfer_learning:
            if os.path.isfile(tmp_log_file):
                tuner_obj.load_history(autotvm.record.load_from_file(tmp_log_file))

        # do tuning
310 311
        n_trial = min(n_trial, len(tsk.config_space))
        tuner_obj.tune(n_trial=n_trial,
312 313 314 315 316 317 318 319 320 321 322 323
                       early_stopping=early_stopping,
                       measure_option=measure_option,
                       callbacks=[
                           autotvm.callback.progress_bar(n_trial, prefix=prefix),
                           autotvm.callback.log_to_file(tmp_log_file)])

    # pick best records to a cache file
    autotvm.record.pick_best(tmp_log_file, log_filename)
    os.remove(tmp_log_file)


########################################################################
324
# Finally, we launch tuning jobs and evaluate the end-to-end performance.
325

326
def tune_and_evaluate(tuning_opt):
327
    # extract workloads from nnvm graph
328
    print("Extract tasks...")
329 330 331 332
    net, params, input_shape, out_shape = get_network(network, batch_size=1)
    tasks = autotvm.task.extract_from_graph(net, target=target,
                                            shape={'data': input_shape}, dtype=dtype,
                                            symbols=(nnvm.sym.conv2d,))
333 334

    # run tuning tasks
335
    print("Tuning...")
336
    tune_tasks(tasks, **tuning_opt)
337 338 339 340

    # compile kernels with history best records
    with autotvm.apply_history_best(log_file):
        print("Compile...")
341
        with nnvm.compiler.build_config(opt_level=3):
342
            graph, lib, params = nnvm.compiler.build(
343
                net, target=target, shape={'data': input_shape}, params=params, dtype=dtype)
344 345 346

        # export library
        tmp = tempdir()
347
        if use_android:
348 349 350 351 352 353 354 355 356
            from tvm.contrib import ndk
            filename = "net.so"
            lib.export_library(tmp.relpath(filename), ndk.create_shared)
        else:
            filename = "net.tar"
            lib.export_library(tmp.relpath(filename))

        # upload module to device
        print("Upload...")
357
        remote = autotvm.measure.request_remote(device_key, 'localhost', 9190,
358
                                                timeout=10000)
359 360 361 362 363 364
        remote.upload(tmp.relpath(filename))
        rlib = remote.load_module(filename)

        # upload parameters to device
        ctx = remote.context(str(target), 0)
        module = runtime.create(graph, rlib, ctx)
365
        data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype))
366
        module.set_input('data', data_tvm)
367
        module.set_input(**params)
368 369 370

        # evaluate
        print("Evaluate inference time cost...")
371
        ftimer = module.module.time_evaluator("run", ctx, number=1, repeat=10)
372
        prof_res = np.array(ftimer().results) * 1000  # convert to millisecond
373
        print("Mean inference time (std dev): %.2f ms (%.2f ms)" %
374
              (np.mean(prof_res), np.std(prof_res)))
375 376

# We do not run the tuning in our webpage server since it takes too long.
377
# Uncomment the following line to run it by yourself.
378 379

# tune_and_evaluate(tuning_option)
380 381 382 383

######################################################################
# Sample Output
# -------------
384
# The tuning needs to compile many programs and extract feature from them.
385
# So a high performance CPU is recommended.
386 387
# One sample output is listed below.
# It takes about 2 hours on a 32T AMD Ryzen Threadripper.
388 389 390
#
# .. code-block:: bash
#
391 392
#    Extract tasks...
#    Tuning...
393 394 395 396 397 398 399 400 401 402 403 404
#    [Task  1/12]  Current/Best:   22.37/  52.19 GFLOPS | Progress: (544/1000) | 406.59 s Done.
#    [Task  2/12]  Current/Best:    6.51/  18.77 GFLOPS | Progress: (608/1000) | 325.05 s Done.
#    [Task  3/12]  Current/Best:    4.67/  24.87 GFLOPS | Progress: (480/1000) | 372.31 s Done.
#    [Task  4/12]  Current/Best:   11.35/  46.83 GFLOPS | Progress: (736/1000) | 602.39 s Done.
#    [Task  5/12]  Current/Best:    1.01/  19.80 GFLOPS | Progress: (448/1000) | 262.16 s Done.
#    [Task  6/12]  Current/Best:    2.47/  23.76 GFLOPS | Progress: (672/1000) | 563.85 s Done.
#    [Task  7/12]  Current/Best:   14.57/  33.97 GFLOPS | Progress: (544/1000) | 465.15 s Done.
#    [Task  8/12]  Current/Best:    1.13/  17.65 GFLOPS | Progress: (576/1000) | 365.08 s Done.
#    [Task  9/12]  Current/Best:   14.45/  22.66 GFLOPS | Progress: (928/1000) | 724.25 s Done.
#    [Task 10/12]  Current/Best:    3.22/  15.36 GFLOPS | Progress: (864/1000) | 564.27 s Done.
#    [Task 11/12]  Current/Best:   11.03/  32.23 GFLOPS | Progress: (736/1000) | 635.15 s Done.
#    [Task 12/12]  Current/Best:    8.00/  21.65 GFLOPS | Progress: (1000/1000) | 1111.81 s Done.
405 406 407
#    Compile...
#    Upload...
#    Evaluate inference time cost...
408
#    Mean inference time (std dev): 162.59 ms (0.06 ms)
409 410 411

######################################################################
#
412
# .. note:: **Experiencing Difficulties?**
413
#
414
#   The auto tuning module is error-prone. If you always see " 0.00/ 0.00 GFLOPS",
415 416 417 418 419 420 421 422 423 424 425 426 427
#   then there must be something wrong.
#
#   First, make sure you set the correct configuration of your device.
#   Then, you can print debug information by adding these lines in the beginning
#   of the script. It will print every measurement result, where you can find useful
#   error messages.
#
#   .. code-block:: python
#
#      import logging
#      logging.getLogger('autotvm').setLevel(logging.DEBUG)
#
#   Finally, always feel free to ask our community for help on https://discuss.tvm.ai