Commit 129eb645 by eqy Committed by Tianqi Chen

[AUTOTVM][RELAY][DOCS] relay ports of `tune_nnvm_*` autotvm tutorials (#2594)

parent ebcad896
"""
Auto-tuning a convolutional network for ARM CPU
====================================================
**Author**: `Lianmin Zheng <https://github.com/merrymercy>`_, `Zhao Wu <https://github.com/FrozenGene>`_, `Eddie Yan <https://github.com/eqy>`_
Auto-tuning for a specific ARM device is critical for getting the best
performance. This is a tutorial about how to tune a whole convolutional
network.
The operator implementation for ARM CPU in TVM is written in template form.
The template has many tunable knobs (tile factor, vectorization, unrolling, etc).
We will tune all convolution and depthwise convolution operators
in the neural network. After tuning, we produce a log file which stores
the best knob values for all required operators. When the tvm compiler compiles
these operators, it will query this log file to get the best knob values.
We also released pre-tuned parameters for some arm devices. You can go to
`ARM CPU Benchmark <https://github.com/dmlc/tvm/wiki/Benchmark#arm-cpu>`_
to see the results.
"""
######################################################################
# Install dependencies
# --------------------
# To use the autotvm package in tvm, we need to install some extra dependencies.
# (change "3" to "2" if you use python2):
#
# .. code-block:: bash
#
# pip3 install --user psutil xgboost tornado
#
# To make tvm run faster during tuning, it is recommended to use cython
# as FFI of tvm. In the root directory of tvm, execute
# (change "3" to "2" if you use python2):
#
# .. code-block:: bash
#
# pip3 install --user cython
# sudo make cython3
#
# Now return to python code. Import packages.
import os
import numpy as np
import tvm
from tvm import autotvm
from tvm import relay
import tvm.relay.testing
from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner
from tvm.contrib.util import tempdir
import tvm.contrib.graph_runtime as runtime
#################################################################
# Define network
# --------------
# First we need to define the network in relay frontend API.
# We can load some pre-defined network from :code:`relay.testing`.
# We can also load models from MXNet, ONNX and TensorFlow.
def get_network(name, batch_size):
"""Get the symbol definition and random weight of a network"""
input_shape = (batch_size, 3, 224, 224)
output_shape = (batch_size, 1000)
if "resnet" in name:
n_layer = int(name.split('-')[1])
net, params = relay.testing.resnet.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype)
elif "vgg" in name:
n_layer = int(name.split('-')[1])
net, params = relay.testing.vgg.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype)
elif name == 'mobilenet':
net, params = relay.testing.mobilenet.get_workload(batch_size=batch_size)
elif name == 'squeezenet_v1.1':
net, params = relay.testing.squeezenet.get_workload(batch_size=batch_size, version='1.1', dtype=dtype)
elif name == 'inception_v3':
input_shape = (1, 3, 299, 299)
net, params = relay.testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype)
elif name == 'mxnet':
# an example for mxnet model
from mxnet.gluon.model_zoo.vision import get_model
block = get_model('resnet18_v1', pretrained=True)
net, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype)
net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs)
else:
raise ValueError("Unsupported network: " + name)
return net, params, input_shape, output_shape
#################################################################
# Start RPC Tracker
# -----------------
# TVM uses RPC session to communicate with ARM boards.
# During tuning, the tuner will send the generated code to the board and
# measure the speed of code on the board.
#
# To scale up the tuning, TVM uses RPC Tracker to manage distributed devices.
# The RPC Tracker is a centralized master node. We can register all devices to
# the tracker. For example, if we have 10 phones, we can register all of them
# to the tracker, and run 10 measurements in parallel, accelerating the tuning process.
#
# To start an RPC tracker, run this command on the host machine. The tracker is
# required during the whole tuning process, so we need to open a new terminal for
# this command:
#
# .. code-block:: bash
#
# python -m tvm.exec.rpc_tracker --host=0.0.0.0 --port=9190
#
# The expected output is
#
# .. code-block:: bash
#
# INFO:RPCTracker:bind to 0.0.0.0:9190
#################################################################
# Register devices to RPC Tracker
# -----------------------------------
# Now we can register our devices to the tracker. The first step is to
# build tvm runtime for the ARM devices.
#
# * For Linux:
# Follow this section :ref:`build-tvm-runtime-on-device` to build
# tvm runtime on the device. Then register the device to tracker by
#
# .. code-block:: bash
#
# python -m tvm.exec.rpc_server --tracker=[HOST_IP]:9190 --key=rk3399
#
# (replace :code:`[HOST_IP]` with the IP address of your host machine)
#
# * For Android:
# Follow this `readme page <https://github.com/dmlc/tvm/tree/master/apps/android_rpc>`_ to
# install tvm rpc apk on the android device. Make sure you can pass the android rpc test.
# Then you have already registred your device. During tuning, you have to go to developer option
# and enable "Keep screen awake during changing" and charge your phone to make it stable.
#
# After registering devices, we can confirm it by querying rpc_tracker
#
# .. code-block:: bash
#
# python -m tvm.exec.query_rpc_tracker --host=0.0.0.0 --port=9190
#
# For example, if we have 2 Huawei mate10 pro, 11 Raspberry Pi 3B and 2 rk3399,
# the output can be
#
# .. code-block:: bash
#
# Queue Status
# ----------------------------------
# key total free pending
# ----------------------------------
# mate10pro 2 2 0
# rk3399 2 2 0
# rpi3b 11 11 0
# ----------------------------------
#
# You can register multiple devices to the tracker to accelerate the measurement in tuning.
###########################################
# Set Tuning Options
# ------------------
# Before tuning, we should apply some configurations. Here I use an RK3399 board
# as example. In your setting, you should modify the target and device_key accordingly.
# set :code:`use_android` to True if you use android phone.
#### DEVICE CONFIG ####
# Replace "aarch64-linux-gnu" with the correct target of your board.
# This target is used for cross compilation. You can query it by :code:`gcc -v` on your device.
target = tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu')
# Also replace this with the device key in your tracker
device_key = 'rk3399'
# Set this to True if you use android phone
use_android = False
#### TUNING OPTION ####
network = 'resnet-18'
log_file = "%s.%s.log" % (device_key, network)
dtype = 'float32'
tuning_option = {
'log_filename': log_file,
'tuner': 'xgb',
'n_trial': 1000,
'early_stopping': 800,
'measure_option': autotvm.measure_option(
builder=autotvm.LocalBuilder(
build_func='ndk' if use_android else 'default'),
runner=autotvm.RPCRunner(
device_key, host='fleet', port=9190,
number=5,
timeout=10,
),
),
}
####################################################################
#
# .. note:: How to set tuning options
#
# In general, the default values provided here work well.
# If you have enough time budget, you can set :code:`n_trial`, :code:`early_stopping` larger,
# which makes the tuning run longer.
# If your device runs very slow or your conv2d operators have many GFLOPs, considering to
# set timeout larger.
#
# If your model has depthwise convolution, you could consider setting
# :code:`try_spatial_pack_depthwise` be :code:`True`, which perform better than default
# optimization in general. For example, on ARM CPU A53 2.0GHz, we find it could boost 1.6x
# performance of depthwise convolution on Mobilenet V1 model.
###################################################################
# Begin Tuning
# ------------
# Now we can extract tuning tasks from the network and begin tuning.
# Here, we provide a simple utility function to tune a list of tasks.
# This function is just an initial implementation which tunes them in sequential order.
# We will introduce a more sophisticated tuning scheduler in the future.
# You can skip the implementation of this function for this tutorial.
def tune_tasks(tasks,
measure_option,
tuner='xgb',
n_trial=1000,
early_stopping=None,
log_filename='tuning.log',
use_transfer_learning=True,
try_winograd=True,
try_spatial_pack_depthwise=False):
if try_winograd:
for i in range(len(tasks)):
try: # try winograd template
tsk = autotvm.task.create(tasks[i].name, tasks[i].args,
tasks[i].target, tasks[i].target_host, 'winograd')
input_channel = tsk.workload[1][1]
if input_channel >= 64:
tasks[i] = tsk
except Exception:
pass
# if we want to use spatial pack for depthwise convolution
if try_spatial_pack_depthwise:
tuner = 'xgb_knob'
for i in range(len(tasks)):
if tasks[i].name == 'topi_nn_depthwise_conv2d_nchw':
tsk = autotvm.task.create(tasks[i].name, tasks[i].args,
tasks[i].target, tasks[i].target_host,
'contrib_spatial_pack')
tasks[i] = tsk
# create tmp log file
tmp_log_file = log_filename + ".tmp"
if os.path.exists(tmp_log_file):
os.remove(tmp_log_file)
for i, tsk in enumerate(reversed(tasks)):
prefix = "[Task %2d/%2d] " % (i+1, len(tasks))
# create tuner
if tuner == 'xgb' or tuner == 'xgb-rank':
tuner_obj = XGBTuner(tsk, loss_type='rank')
elif tuner == 'xgb_knob':
tuner_obj = XGBTuner(tsk, loss_type='rank', feature_type='knob')
elif tuner == 'ga':
tuner_obj = GATuner(tsk, pop_size=50)
elif tuner == 'random':
tuner_obj = RandomTuner(tsk)
elif tuner == 'gridsearch':
tuner_obj = GridSearchTuner(tsk)
else:
raise ValueError("Invalid tuner: " + tuner)
if use_transfer_learning:
if os.path.isfile(tmp_log_file):
tuner_obj.load_history(autotvm.record.load_from_file(tmp_log_file))
# do tuning
tuner_obj.tune(n_trial=min(n_trial, len(tsk.config_space)),
early_stopping=early_stopping,
measure_option=measure_option,
callbacks=[
autotvm.callback.progress_bar(n_trial, prefix=prefix),
autotvm.callback.log_to_file(tmp_log_file)])
# pick best records to a cache file
autotvm.record.pick_best(tmp_log_file, log_filename)
os.remove(tmp_log_file)
########################################################################
# Finally, we launch tuning jobs and evaluate the end-to-end performance.
def tune_and_evaluate(tuning_opt):
# extract workloads from relay program
print("Extract tasks...")
net, params, input_shape, _ = get_network(network, batch_size=1)
tasks = autotvm.task.extract_from_program(net, target=target,
params=params,
ops=(relay.op.nn.conv2d,))
# run tuning tasks
print("Tuning...")
tune_tasks(tasks, **tuning_opt)
# compile kernels with history best records
with autotvm.apply_history_best(log_file):
print("Compile...")
with relay.build_config(opt_level=3):
graph, lib, params = relay.build_module.build(
net, target=target, params=params)
# export library
tmp = tempdir()
if use_android:
from tvm.contrib import ndk
filename = "net.so"
lib.export_library(tmp.relpath(filename), ndk.create_shared)
else:
filename = "net.tar"
lib.export_library(tmp.relpath(filename))
# upload module to device
print("Upload...")
remote = autotvm.measure.request_remote(device_key, 'localhost', 9190,
timeout=10000)
remote.upload(tmp.relpath(filename))
rlib = remote.load_module(filename)
# upload parameters to device
ctx = remote.context(str(target), 0)
module = runtime.create(graph, rlib, ctx)
data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype))
module.set_input('data', data_tvm)
module.set_input(**params)
# evaluate
print("Evaluate inference time cost...")
ftimer = module.module.time_evaluator("run", ctx, number=1, repeat=10)
prof_res = np.array(ftimer().results) * 1000 # convert to millisecond
print("Mean inference time (std dev): %.2f ms (%.2f ms)" %
(np.mean(prof_res), np.std(prof_res)))
# We do not run the tuning in our webpage server since it takes too long.
# Uncomment the following line to run it by yourself.
# tune_and_evaluate(tuning_option)
######################################################################
# Sample Output
# -------------
# The tuning needs to compile many programs and extract feature from them.
# So a high performance CPU is recommended.
# One sample output is listed below.
# It takes about 2 hours on a 32T AMD Ryzen Threadripper.
#
# .. code-block:: bash
#
# Extract tasks...
# Tuning...
# [Task 1/12] Current/Best: 22.37/ 52.19 GFLOPS | Progress: (544/1000) | 406.59 s Done.
# [Task 2/12] Current/Best: 6.51/ 18.77 GFLOPS | Progress: (608/1000) | 325.05 s Done.
# [Task 3/12] Current/Best: 4.67/ 24.87 GFLOPS | Progress: (480/1000) | 372.31 s Done.
# [Task 4/12] Current/Best: 11.35/ 46.83 GFLOPS | Progress: (736/1000) | 602.39 s Done.
# [Task 5/12] Current/Best: 1.01/ 19.80 GFLOPS | Progress: (448/1000) | 262.16 s Done.
# [Task 6/12] Current/Best: 2.47/ 23.76 GFLOPS | Progress: (672/1000) | 563.85 s Done.
# [Task 7/12] Current/Best: 14.57/ 33.97 GFLOPS | Progress: (544/1000) | 465.15 s Done.
# [Task 8/12] Current/Best: 1.13/ 17.65 GFLOPS | Progress: (576/1000) | 365.08 s Done.
# [Task 9/12] Current/Best: 14.45/ 22.66 GFLOPS | Progress: (928/1000) | 724.25 s Done.
# [Task 10/12] Current/Best: 3.22/ 15.36 GFLOPS | Progress: (864/1000) | 564.27 s Done.
# [Task 11/12] Current/Best: 11.03/ 32.23 GFLOPS | Progress: (736/1000) | 635.15 s Done.
# [Task 12/12] Current/Best: 8.00/ 21.65 GFLOPS | Progress: (1000/1000) | 1111.81 s Done.
# Compile...
# Upload...
# Evaluate inference time cost...
# Mean inference time (std dev): 162.59 ms (0.06 ms)
######################################################################
#
# .. note:: **Experiencing Difficulties?**
#
# The auto tuning module is error-prone. If you always see " 0.00/ 0.00 GFLOPS",
# then there must be something wrong.
#
# First, make sure you set the correct configuration of your device.
# Then, you can print debug information by adding these lines in the beginning
# of the script. It will print every measurement result, where you can find useful
# error messages.
#
# .. code-block:: python
#
# import logging
# logging.getLogger('autotvm').setLevel(logging.DEBUG)
#
# Finally, always feel free to ask our community for help on https://discuss.tvm.ai
"""
Auto-tuning a convolutional network for NVIDIA GPU
====================================================
**Author**: `Lianmin Zheng <https://github.com/merrymercy>`_, `Eddie Yan <https://github.com/eqy/>`_
Auto-tuning for specific devices and workloads is critical for getting the
best performance. This is a tutorial on how to tune a whole convolutional
network for NVIDIA GPU.
The operator implementation for NVIDIA GPU in TVM is written in template form.
The template has many tunable knobs (tile factor, unrolling, etc).
We will tune all convolution and depthwise convolution operators
in the neural network. After tuning, we produce a log file which stores
the best knob values for all required operators. When the tvm compiler compiles
these operators, it will query this log file to get the best knob values.
We also released pre-tuned parameters for some NVIDIA GPUs. You can go to
`NVIDIA GPU Benchmark <https://github.com/dmlc/tvm/wiki/Benchmark#nvidia-gpu>`_
to see the results.
"""
######################################################################
# Install dependencies
# --------------------
# To use the autotvm package in tvm, we need to install some extra dependencies.
# (change "3" to "2" if you use python2):
#
# .. code-block:: bash
#
# pip3 install --user psutil xgboost tornado
#
# To make tvm run faster during tuning, it is recommended to use cython
# as FFI of tvm. In the root directory of tvm, execute:
#
# .. code-block:: bash
#
# pip3 install --user cython
# sudo make cython3
#
# Now return to python code. Import packages.
import os
import numpy as np
import tvm
from tvm import autotvm
from tvm import relay
import tvm.relay.testing
from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner
from tvm.contrib.util import tempdir
import tvm.contrib.graph_runtime as runtime
#################################################################
# Define Network
# --------------
# First we need to define the network in relay frontend API.
# We can load some pre-defined network from :code:`nnvm.testing`.
# We can also load models from MXNet, ONNX and TensorFlow.
def get_network(name, batch_size):
"""Get the symbol definition and random weight of a network"""
input_shape = (batch_size, 3, 224, 224)
output_shape = (batch_size, 1000)
if "resnet" in name:
n_layer = int(name.split('-')[1])
net, params = relay.testing.resnet.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype)
elif "vgg" in name:
n_layer = int(name.split('-')[1])
net, params = relay.testing.vgg.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype)
elif name == 'mobilenet':
net, params = relay.testing.mobilenet.get_workload(batch_size=batch_size, dtype=dtype)
elif name == 'squeezenet_v1.1':
net, params = relay.testing.squeezenet.get_workload(batch_size=batch_size, version='1.1', dtype=dtype)
elif name == 'inception_v3':
input_shape = (1, 3, 299, 299)
net, params = relay.testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype)
elif name == 'mxnet':
# an example for mxnet model
from mxnet.gluon.model_zoo.vision import get_model
block = get_model('resnet18_v1', pretrained=True)
net, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype)
net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs)
else:
raise ValueError("Unsupported network: " + name)
return net, params, input_shape, output_shape
###########################################
# Set Tuning Options
# ------------------
# Before tuning, we apply some configurations.
#### DEVICE CONFIG ####
target = tvm.target.cuda()
#### TUNING OPTION ####
network = 'resnet-18'
log_file = "%s.log" % network
dtype = 'float32'
tuning_option = {
'log_filename': log_file,
'tuner': 'xgb',
'n_trial': 2000,
'early_stopping': 600,
'measure_option': autotvm.measure_option(
builder=autotvm.LocalBuilder(timeout=10),
#runner=autotvm.LocalRunner(number=20, repeat=3, timeout=4, min_repeat_ms=150),
runner=autotvm.RPCRunner(
'1080ti', # change the device key to your key
'localhost', 9190,
number=20, repeat=3, timeout=4, min_repeat_ms=150)
),
}
####################################################################
#
# .. note:: How to set tuning options
#
# In general, the default value provided here works well.
#
# If you have large time budget, you can set :code:`n_trial`, :code:`early_stopping` larger,
# which makes the tuning runs longer.
#
# If you have multiple devices, you can use all of them for measurement to
# accelerate the tuning process. (see the 'Scale up measurement` section below).
#
###################################################################
# Begin Tuning
# ------------
# Now we can extract tuning tasks from the network and begin tuning.
# Here, we provide a simple utility function to tune a list of tasks.
# This function is just an initial implementation which tunes them in sequential order.
# We will introduce a more sophisticated tuning scheduler in the future.
# You can skip the implementation of this function for this tutorial.
def tune_tasks(tasks,
measure_option,
tuner='xgb',
n_trial=1000,
early_stopping=None,
log_filename='tuning.log',
use_transfer_learning=True,
try_winograd=True):
if try_winograd:
for i in range(len(tasks)):
try: # try winograd template
tsk = autotvm.task.create(tasks[i].name, tasks[i].args,
tasks[i].target, tasks[i].target_host, 'winograd')
input_channel = tsk.workload[1][1]
if input_channel >= 64:
tasks[i] = tsk
except Exception:
pass
# create tmp log file
tmp_log_file = log_filename + ".tmp"
if os.path.exists(tmp_log_file):
os.remove(tmp_log_file)
for i, tsk in enumerate(reversed(tasks)):
prefix = "[Task %2d/%2d] " %(i+1, len(tasks))
# create tuner
if tuner == 'xgb' or tuner == 'xgb-rank':
tuner_obj = XGBTuner(tsk, loss_type='rank')
elif tuner == 'ga':
tuner_obj = GATuner(tsk, pop_size=100)
elif tuner == 'random':
tuner_obj = RandomTuner(tsk)
elif tuner == 'gridsearch':
tuner_obj = GridSearchTuner(tsk)
else:
raise ValueError("Invalid tuner: " + tuner)
if use_transfer_learning:
if os.path.isfile(tmp_log_file):
tuner_obj.load_history(autotvm.record.load_from_file(tmp_log_file))
# do tuning
tuner_obj.tune(n_trial=min(n_trial, len(tsk.config_space)),
early_stopping=early_stopping,
measure_option=measure_option,
callbacks=[
autotvm.callback.progress_bar(n_trial, prefix=prefix),
autotvm.callback.log_to_file(tmp_log_file)])
# pick best records to a cache file
autotvm.record.pick_best(tmp_log_file, log_filename)
os.remove(tmp_log_file)
########################################################################
# Finally, we launch tuning jobs and evaluate the end-to-end performance.
def tune_and_evaluate(tuning_opt):
# extract workloads from relay program
print("Extract tasks...")
net, params, input_shape, out_shape = get_network(network, batch_size=1)
tasks = autotvm.task.extract_from_program(net, target=target,
params=params, ops=(relay.op.nn.conv2d,))
# run tuning tasks
print("Tuning...")
tune_tasks(tasks, **tuning_opt)
# compile kernels with history best records
with autotvm.apply_history_best(log_file):
print("Compile...")
with relay.build_config(opt_level=3):
graph, lib, params = relay.build_module.build(
net, target=target, params=params)
# export library
tmp = tempdir()
filename = "net.tar"
lib.export_library(tmp.relpath(filename))
# load parameters
ctx = tvm.context(str(target), 0)
module = runtime.create(graph, lib, ctx)
data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype))
module.set_input('data', data_tvm)
module.set_input(**params)
# evaluate
print("Evaluate inference time cost...")
ftimer = module.module.time_evaluator("run", ctx, number=1, repeat=600)
prof_res = np.array(ftimer().results) * 1000 # convert to millisecond
print("Mean inference time (std dev): %.2f ms (%.2f ms)" %
(np.mean(prof_res), np.std(prof_res)))
# We do not run the tuning in our webpage server since it takes too long.
# Uncomment the following line to run it by yourself.
# tune_and_evaluate(tuning_option)
######################################################################
# Sample Output
# -------------
# The tuning needs to compile many programs and extract feature from them.
# So a high performance CPU is recommended. One sample output is listed below.
# It takes about 4 hours to get the following output on a 32T AMD Ryzen Threadripper.
# The tuning target is NVIDIA 1080 Ti.
# (You can see some errors during compilation. If the tuning is not stuck, it is okay.)
#
# .. code-block:: bash
#
# Extract tasks...
# Tuning...
# [Task 1/12] Current/Best: 541.83/3570.66 GFLOPS | Progress: (960/2000) | 1001.31 s Done.
# [Task 2/12] Current/Best: 0.56/ 803.33 GFLOPS | Progress: (704/2000) | 608.08 s Done.
# [Task 3/12] Current/Best: 103.69/1141.25 GFLOPS | Progress: (768/2000) | 702.13 s Done.
# [Task 4/12] Current/Best: 2905.03/3925.15 GFLOPS | Progress: (864/2000) | 745.94 sterminate called without an active exception
# [Task 4/12] Current/Best: 2789.36/3925.15 GFLOPS | Progress: (1056/2000) | 929.40 s Done.
# [Task 5/12] Current/Best: 89.06/1076.24 GFLOPS | Progress: (704/2000) | 601.73 s Done.
# [Task 6/12] Current/Best: 40.39/2129.02 GFLOPS | Progress: (1088/2000) | 1125.76 s Done.
# [Task 7/12] Current/Best: 4090.53/5007.02 GFLOPS | Progress: (800/2000) | 903.90 s Done.
# [Task 8/12] Current/Best: 4.78/1272.28 GFLOPS | Progress: (768/2000) | 749.14 s Done.
# [Task 9/12] Current/Best: 1391.45/2325.08 GFLOPS | Progress: (992/2000) | 1084.87 s Done.
# [Task 10/12] Current/Best: 1995.44/2383.59 GFLOPS | Progress: (864/2000) | 862.60 s Done.
# [Task 11/12] Current/Best: 4093.94/4899.80 GFLOPS | Progress: (224/2000) | 240.92 sterminate called without an active exception
# [Task 11/12] Current/Best: 3487.98/4909.91 GFLOPS | Progress: (480/2000) | 534.96 sterminate called without an active exception
# [Task 11/12] Current/Best: 4636.84/4912.17 GFLOPS | Progress: (1184/2000) | 1381.16 sterminate called without an active exception
# [Task 11/12] Current/Best: 50.12/4912.17 GFLOPS | Progress: (1344/2000) | 1602.81 s Done.
# [Task 12/12] Current/Best: 3581.31/4286.30 GFLOPS | Progress: (736/2000) | 943.52 s Done.
# Compile...
# Evaluate inference time cost...
# Mean inference time (std dev): 1.07 ms (0.05 ms)
#
# As a reference baseline, the time cost of MXNet + TensorRT on resnet-18 is 1.30ms. So we are a little faster.
######################################################################
#
# .. note:: **Experiencing Difficulties?**
#
# The auto tuning module is error-prone. If you always see " 0.00/ 0.00 GFLOPS",
# then there must be something wrong.
#
# First, make sure you set the correct configuration of your device.
# Then, you can print debug information by adding these lines in the beginning
# of the script. It will print every measurement result, where you can find useful
# error messages.
#
# .. code-block:: python
#
# import logging
# logging.getLogger('autotvm').setLevel(logging.DEBUG)
#
# Finally, always feel free to ask our community for help on https://discuss.tvm.ai
#################################################################
# Scale up measurement by using multiple devices
# ----------------------------------------------
#
# If you have multiple devices, you can use all of them for measurement.
# TVM uses the RPC Tracker to manage distributed devices.
# The RPC Tracker is a centralized master node. We can register all devices to
# the tracker. For example, if we have 10 GPU cards, we can register all of them
# to the tracker, and run 10 measurements in parallel, accelerating the tuning process.
#
# To start an RPC tracker, run this command on the host machine. The tracker is
# required during the whole tuning process, so we need to open a new terminal for
# this command:
#
# .. code-block:: bash
#
# python -m tvm.exec.rpc_tracker --host=0.0.0.0 --port=9190
#
# The expected output is
#
# .. code-block:: bash
#
# INFO:RPCTracker:bind to 0.0.0.0:9190
#
# Then open another new terminal for the RPC server. We need to start one server
# for each dedicated device. We use a string key to distinguish the types of devices.
# You can pick a name you like.
# (Note: For rocm backend, there are some internal errors with the compiler,
# we need to add `--no-fork` to the argument list.)
#
# .. code-block:: bash
#
# python -m tvm.exec.rpc_server --tracker=localhost:9190 --key=1080ti
#
# After registering devices, we can confirm it by querying rpc_tracker
#
# .. code-block:: bash
#
# python -m tvm.exec.query_rpc_tracker --host=localhost --port=9190
#
# For example, if we have four 1080ti, two titanx and one gfx900, the output can be
#
# .. code-block:: bash
#
# Queue Status
# ----------------------------------
# key total free pending
# ----------------------------------
# 1080ti 4 4 0
# titanx 2 2 0
# gfx900 1 1 0
# ----------------------------------
#
# Finally, we need to change the tuning option to use RPCRunner. Use the code below
# to replace the corresponding part above.
tuning_option = {
'log_filename': log_file,
'tuner': 'xgb',
'n_trial': 2000,
'early_stopping': 600,
'measure_option': autotvm.measure_option(
builder=autotvm.LocalBuilder(timeout=10),
runner=autotvm.RPCRunner(
'1080ti', # change the device key to your key
'localhost', 9190,
number=20, repeat=3, timeout=4, min_repeat_ms=150),
),
}
"""
Auto-tuning a convolutional network for Mobile GPU
====================================================
**Author**: `Lianmin Zheng <https://https://github.com/merrymercy>`_, `Eddie Yan <https://github.com/eqy>`_
Auto-tuning for a specific device is critical for getting the best
performance. This is a tutorial about how to tune a whole convolutional
network.
The operator implementation for Mobile GPU in TVM is written in template form.
The template has many tunable knobs (tile factor, vectorization, unrolling, etc).
We will tune all convolution, depthwise convolution and dense operators
in the neural network. After tuning, we produce a log file which stores
the best knob values for all required operators. When the tvm compiler compiles
these operators, it will query this log file to get the best knob values.
We also released pre-tuned parameters for some arm devices. You can go to
`Mobile GPU Benchmark <https://github.com/dmlc/tvm/wiki/Benchmark#mobile-gpu>`_
to see the results.
"""
######################################################################
# Install dependencies
# --------------------
# To use the autotvm package in tvm, we need to install some extra dependencies.
# (change "3" to "2" if you use python2):
#
# .. code-block:: bash
#
# pip3 install --user psutil xgboost tornado
#
# To make tvm run faster during tuning, it is recommended to use cython
# as FFI of tvm. In the root directory of tvm, execute
# (change "3" to "2" if you use python2):
#
# .. code-block:: bash
#
# pip3 install --user cython
# sudo make cython3
#
# Now return to python code. Import packages.
import os
import numpy as np
import tvm
from tvm import autotvm
from tvm import relay
import tvm.relay.testing
from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner
from tvm.contrib.util import tempdir
import tvm.contrib.graph_runtime as runtime
#################################################################
# Define network
# --------------
# First we need to define the network in relay frontend API.
# We can load some pre-defined network from :code:`relay.testing`.
# We can also load models from MXNet, ONNX and TensorFlow.
def get_network(name, batch_size):
"""Get the symbol definition and random weight of a network"""
input_shape = (batch_size, 3, 224, 224)
output_shape = (batch_size, 1000)
if "resnet" in name:
n_layer = int(name.split('-')[1])
net, params = relay.testing.resnet.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype)
elif "vgg" in name:
n_layer = int(name.split('-')[1])
net, params = relay.testing.vgg.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype)
elif name == 'mobilenet':
net, params = relay.testing.mobilenet.get_workload(batch_size=batch_size, dtype=dtype)
elif name == 'squeezenet_v1.1':
net, params = relay.testing.squeezenet.get_workload(batch_size=batch_size, version='1.1', dtype=dtype)
elif name == 'inception_v3':
input_shape = (1, 3, 299, 299)
net, params = relay.testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype)
elif name == 'mxnet':
# an example for mxnet model
from mxnet.gluon.model_zoo.vision import get_model
block = get_model('resnet18_v1', pretrained=True)
net, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype)
net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs)
else:
raise ValueError("Unsupported network: " + name)
return net, params, input_shape, output_shape
#################################################################
# Start RPC Tracker
# -----------------
# TVM uses RPC session to communicate with ARM boards.
# During tuning, the tuner will send the generated code to the board and
# measure the speed of code on the board.
#
# To scale up the tuning, TVM uses RPC Tracker to manage distributed devices.
# The RPC Tracker is a centralized master node. We can register all devices to
# the tracker. For example, if we have 10 phones, we can register all of them
# to the tracker, and run 10 measurements in parallel, accelerating the tuning process.
#
# To start an RPC tracker, run this command on the host machine. The tracker is
# required during the whole tuning process, so we need to open a new terminal for
# this command:
#
# .. code-block:: bash
#
# python -m tvm.exec.rpc_tracker --host=0.0.0.0 --port=9190
#
# The expected output is
#
# .. code-block:: bash
#
# INFO:RPCTracker:bind to 0.0.0.0:9190
#################################################################
# Register devices to RPC Tracker
# -----------------------------------
# Now we can register our devices to the tracker. The first step is to
# build tvm runtime for the ARM devices.
#
# * For Linux:
# Follow this section :ref:`build-tvm-runtime-on-device` to build
# tvm runtime on the device. Then register the device to tracker by
#
# .. code-block:: bash
#
# python -m tvm.exec.rpc_server --tracker=[HOST_IP]:9190 --key=rk3399
#
# (replace :code:`[HOST_IP]` with the IP address of your host machine)
#
# * For Android:
# Follow this `readme page <https://github.com/dmlc/tvm/tree/master/apps/android_rpc>`_ to
# install tvm rpc apk on the android device. Make sure you can pass the android rpc test.
# Then you have already registred your device. During tuning, you have to go to developer option
# and enable "Keep screen awake during changing" and charge your phone to make it stable.
#
# After registering devices, we can confirm it by querying rpc_tracker
#
# .. code-block:: bash
#
# python -m tvm.exec.query_rpc_tracker --host=0.0.0.0 --port=9190
#
# For example, if we have 2 Huawei mate10 pro, 11 Raspberry Pi 3B and 2 rk3399,
# the output can be
#
# .. code-block:: bash
#
# Queue Status
# ----------------------------------
# key total free pending
# ----------------------------------
# mate10pro 2 2 0
# rk3399 2 2 0
# rpi3b 11 11 0
# ----------------------------------
#
# You can register multiple devices to the tracker to accelerate the measurement in tuning.
###########################################
# Set Tuning Options
# ------------------
# Before tuning, we should apply some configurations. Here I use an RK3399 board
# as example. In your setting, you should modify the target and device_key accordingly.
# set :code:`use_android` to True if you use android phone.
#### DEVICE CONFIG ####
target = tvm.target.create('opencl -device=mali')
# Replace "aarch64-linux-gnu" with the correct target of your board.
# This target host is used for cross compilation. You can query it by :code:`gcc -v` on your device.
target_host = 'llvm -target=aarch64-linux-gnu'
# Also replace this with the device key in your tracker
device_key = 'rk3399'
# Set this to True if you use android phone
use_android = False
#### TUNING OPTION ####
network = 'resnet-18'
log_file = "%s.%s.log" % (device_key, network)
dtype = 'float32'
tuning_option = {
'log_filename': log_file,
'tuner': 'xgb',
'n_trial': 1000,
'early_stopping': 450,
'measure_option': autotvm.measure_option(
builder=autotvm.LocalBuilder(
build_func='ndk' if use_android else 'default'),
runner=autotvm.RPCRunner(
device_key, host='localhost', port=9190,
number=10,
timeout=5,
),
),
}
####################################################################
#
# .. note:: How to set tuning options
#
# In general, the default values provided here work well.
# If you have enough time budget, you can set :code:`n_trial`, :code:`early_stopping` larger,
# which makes the tuning run longer.
# If your device runs very slow or your conv2d operators have many GFLOPs, considering to
# set timeout larger.
#
###################################################################
# Begin Tuning
# ------------
# Now we can extract tuning tasks from the network and begin tuning.
# Here, we provide a simple utility function to tune a list of tasks.
# This function is just an initial implementation which tunes them in sequential order.
# We will introduce a more sophisticated tuning scheduler in the future.
# You can skip the implementation of this function for this tutorial.
def tune_tasks(tasks,
measure_option,
tuner='xgb',
n_trial=1000,
early_stopping=None,
log_filename='tuning.log',
use_transfer_learning=True,
try_winograd=True):
if try_winograd:
for i in range(len(tasks)):
try: # try winograd template
tsk = autotvm.task.create(tasks[i].name, tasks[i].args,
tasks[i].target, tasks[i].target_host, 'winograd')
tasks.append(tsk)
except Exception:
pass
# create tmp log file
tmp_log_file = log_filename + ".tmp"
if os.path.exists(tmp_log_file):
os.remove(tmp_log_file)
for i, tsk in enumerate(reversed(tasks)):
prefix = "[Task %2d/%2d] " % (i+1, len(tasks))
# create tuner
if tuner == 'xgb' or tuner == 'xgb-rank':
tuner_obj = XGBTuner(tsk, loss_type='rank')
elif tuner == 'ga':
tuner_obj = GATuner(tsk, pop_size=50)
elif tuner == 'random':
tuner_obj = RandomTuner(tsk)
elif tuner == 'gridsearch':
tuner_obj = GridSearchTuner(tsk)
else:
raise ValueError("Invalid tuner: " + tuner)
if use_transfer_learning:
if os.path.isfile(tmp_log_file):
tuner_obj.load_history(autotvm.record.load_from_file(tmp_log_file))
# do tuning
tuner_obj.tune(n_trial=min(n_trial, len(tsk.config_space)),
early_stopping=early_stopping,
measure_option=measure_option,
callbacks=[
autotvm.callback.progress_bar(n_trial, prefix=prefix),
autotvm.callback.log_to_file(tmp_log_file)])
# pick best records to a cache file
autotvm.record.pick_best(tmp_log_file, log_filename)
os.remove(tmp_log_file)
########################################################################
# Finally, we launch tuning jobs and evaluate the end-to-end performance.
def tune_and_evaluate(tuning_opt):
# extract workloads from relay program
print("Extract tasks...")
net, params, input_shape, _ = get_network(network, batch_size=1)
tasks = autotvm.task.extract_from_program(net, target=target, target_host=target_host,
params=params, ops=(relay.op.nn.conv2d,))
# run tuning tasks
print("Tuning...")
tune_tasks(tasks, **tuning_opt)
# compile kernels with history best records
with autotvm.apply_history_best(log_file):
print("Compile...")
with relay.build_config(opt_level=3):
graph, lib, params = relay.build_module.build(
net, target=target, params=params, target_host=target_host)
# export library
tmp = tempdir()
if use_android:
from tvm.contrib import ndk
filename = "net.so"
lib.export_library(tmp.relpath(filename), ndk.create_shared)
else:
filename = "net.tar"
lib.export_library(tmp.relpath(filename))
# upload module to device
print("Upload...")
remote = autotvm.measure.request_remote(device_key, 'localhost', 9190,
timeout=10000)
remote.upload(tmp.relpath(filename))
rlib = remote.load_module(filename)
# upload parameters to device
ctx = remote.context(str(target), 0)
module = runtime.create(graph, rlib, ctx)
data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype))
module.set_input('data', data_tvm)
module.set_input(**params)
# evaluate
print("Evaluate inference time cost...")
ftimer = module.module.time_evaluator("run", ctx, number=1, repeat=30)
prof_res = np.array(ftimer().results) * 1000 # convert to millisecond
print("Mean inference time (std dev): %.2f ms (%.2f ms)" %
(np.mean(prof_res), np.std(prof_res)))
# We do not run the tuning in our webpage server since it takes too long.
# Uncomment the following line to run it by yourself.
# tune_and_evaluate(tuning_option)
######################################################################
# Sample Output
# -------------
# The tuning needs to compile many programs and extract feature from them.
# So a high performance CPU is recommended.
# One sample output is listed below. It takes about 3 hours on a 32T AMD Ryzen Threadripper.
#
# .. code-block:: bash
#
# Extract tasks...
# Tuning...
# [Task 1/17] Current/Best: 25.30/ 39.12 GFLOPS | Progress: (992/1000) | 751.22 s Done.
# [Task 2/17] Current/Best: 40.70/ 45.50 GFLOPS | Progress: (736/1000) | 545.46 s Done.
# [Task 3/17] Current/Best: 38.83/ 42.35 GFLOPS | Progress: (992/1000) | 1549.85 s Done.
# [Task 4/17] Current/Best: 23.31/ 31.02 GFLOPS | Progress: (640/1000) | 1059.31 s Done.
# [Task 5/17] Current/Best: 0.06/ 2.34 GFLOPS | Progress: (544/1000) | 305.45 s Done.
# [Task 6/17] Current/Best: 10.97/ 17.20 GFLOPS | Progress: (992/1000) | 1050.00 s Done.
# [Task 7/17] Current/Best: 8.98/ 10.94 GFLOPS | Progress: (928/1000) | 421.36 s Done.
# [Task 8/17] Current/Best: 4.48/ 14.86 GFLOPS | Progress: (704/1000) | 582.60 s Done.
# [Task 9/17] Current/Best: 10.30/ 25.99 GFLOPS | Progress: (864/1000) | 899.85 s Done.
# [Task 10/17] Current/Best: 11.73/ 12.52 GFLOPS | Progress: (608/1000) | 304.85 s Done.
# [Task 11/17] Current/Best: 15.26/ 18.68 GFLOPS | Progress: (800/1000) | 747.52 s Done.
# [Task 12/17] Current/Best: 17.48/ 26.71 GFLOPS | Progress: (1000/1000) | 1166.40 s Done.
# [Task 13/17] Current/Best: 0.96/ 11.43 GFLOPS | Progress: (960/1000) | 611.65 s Done.
# [Task 14/17] Current/Best: 17.88/ 20.22 GFLOPS | Progress: (672/1000) | 670.29 s Done.
# [Task 15/17] Current/Best: 11.62/ 13.98 GFLOPS | Progress: (736/1000) | 449.25 s Done.
# [Task 16/17] Current/Best: 19.90/ 23.83 GFLOPS | Progress: (608/1000) | 708.64 s Done.
# [Task 17/17] Current/Best: 17.98/ 22.75 GFLOPS | Progress: (736/1000) | 1122.60 s Done.
# Compile...
# Upload...
# Evaluate inference time cost...
# Mean inference time (std dev): 128.05 ms (7.74 ms)
#
######################################################################
#
# .. note:: **Experiencing Difficulties?**
#
# The auto tuning module is error-prone. If you always see " 0.00/ 0.00 GFLOPS",
# then there must be something wrong.
#
# First, make sure you set the correct configuration of your device.
# Then, you can print debug information by adding these lines in the beginning
# of the script. It will print every measurement result, where you can find useful
# error messages.
#
# .. code-block:: python
#
# import logging
# logging.getLogger('autotvm').setLevel(logging.DEBUG)
#
# Finally, always feel free to ask our community for help on https://discuss.tvm.ai
"""
kuto-tuning a convolutional network for x86 CPU
====================================================
**Author**: `Yao Wang <https://github.com/kevinthesun>`_, `Eddie Yan <https://github.com/eqy>`_
This is a tutorial about how to tune convolution neural network
for x86 cpu.
"""
import os
import numpy as np
import tvm
from tvm import autotvm
from tvm import relay
from tvm.relay import testing
from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner
import tvm.contrib.graph_runtime as runtime
#################################################################
# Define network
# --------------
# First we need to define the network in relay frontend API.
# We can load some pre-defined network from :code:`relay.testing`.
# We can also load models from MXNet, ONNX and TensorFlow.
#
# In this tutorial, we choose resnet-18 as tuning example.
def get_network(name, batch_size):
"""Get the symbol definition and random weight of a network"""
input_shape = (batch_size, 3, 224, 224)
output_shape = (batch_size, 1000)
if "resnet" in name:
n_layer = int(name.split('-')[1])
net, params = relay.testing.resnet.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype)
elif "vgg" in name:
n_layer = int(name.split('-')[1])
net, params = relay.testing.vgg.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype)
elif name == 'mobilenet':
net, params = relay.testing.mobilenet.get_workload(batch_size=batch_size, dtype=dtype)
elif name == 'squeezenet_v1.1':
net, params = relay.testing.squeezenet.get_workload(batch_size=batch_size, version='1.1', dtype=dtype)
elif name == 'inception_v3':
input_shape = (1, 3, 299, 299)
net, params = relay.testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype)
elif name == 'mxnet':
# an example for mxnet model
from mxnet.gluon.model_zoo.vision import get_model
block = get_model('resnet18_v1', pretrained=True)
net, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype)
net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs)
else:
raise ValueError("Unsupported network: " + name)
return net, params, input_shape, output_shape
# Replace "llvm" with the correct target of your cpu.
# For example, for AWS EC2 c5 instance with Intel Xeon
# Platinum 8000 series, the target should be "llvm -mcpu=skylake-avx512".
# For AWS EC2 c4 instance with Intel Xeon E5-2666 v3, it should be
# "llvm -mcpu=core-avx2".
target = "llvm"
batch_size = 1
dtype = "float32"
model_name = "resnet-18"
log_file = "%s.log" % model_name
# Set number of threads used for tuning based on the number of
# physical cpu cores on your machine.
num_threads = 1
os.environ["TVM_NUM_THREADS"] = str(num_threads)
#################################################################
# Configure tensor tuning settings and create tasks
# -------------------------------------------------
# To get better kernel execution performance on x86 cpu,
# we need to change data layout of convolution kernel from
# "NCHW" to "NCHWc". To deal with this situation, we define
# conv2d_NCHWc operator in topi. We will tune this operator
# instead of plain conv2d.
#
# We will use local mode for tuning configuration. RPC tracker
# mode can be setup similarly to the approach in
# :ref:`tune_relay_arm` tutorial.
tuning_option = {
'log_filename': log_file,
'tuner': 'random',
'early_stopping': None,
'measure_option': autotvm.measure_option(
builder=autotvm.LocalBuilder(),
runner=autotvm.LocalRunner(number=10, repeat=1,
min_repeat_ms=1000),
),
}
# You can skip the implementation of this function for this tutorial.
def tune_kernels(tasks,
measure_option,
tuner='gridsearch',
early_stopping=None,
log_filename='tuning.log'):
for i, tsk in enumerate(tasks):
prefix = "[Task %2d/%2d] " % (i+1, len(tasks))
# converting conv2d tasks to conv2d_NCHWc tasks
op_name = tsk.workload[0]
if op_name == 'conv2d':
func_create = 'topi_x86_conv2d_NCHWc'
elif op_name == 'depthwise_conv2d_nchw':
func_create = 'topi_x86_depthwise_conv2d_NCHWc_from_nchw'
else:
raise ValueError("Tuning {} is not supported on x86".format(op_name))
task = autotvm.task.create(func_create, args=tsk.args,
target=target, template_key='direct')
task.workload = tsk.workload
# create tuner
if tuner == 'xgb' or tuner == 'xgb-rank':
tuner_obj = XGBTuner(task, loss_type='rank')
elif tuner == 'ga':
tuner_obj = GATuner(task, pop_size=50)
elif tuner == 'random':
tuner_obj = RandomTuner(task)
elif tuner == 'gridsearch':
tuner_obj = GridSearchTuner(task)
else:
raise ValueError("Invalid tuner: " + tuner)
# do tuning
n_trial=len(task.config_space)
tuner_obj.tune(n_trial=n_trial,
early_stopping=early_stopping,
measure_option=measure_option,
callbacks=[
autotvm.callback.progress_bar(n_trial, prefix=prefix),
autotvm.callback.log_to_file(log_filename)])
########################################################################
# Finally, we launch tuning jobs and evaluate the end-to-end performance.
def tune_and_evaluate(tuning_opt):
# extract workloads from relay program
print("Extract tasks...")
net, params, data_shape, out_shape = get_network(model_name, batch_size)
tasks = autotvm.task.extract_from_program(net, target=target,
params=params, ops=(relay.op.nn.conv2d,))
# run tuning tasks
print("Tuning...")
tune_kernels(tasks, **tuning_opt)
# compile kernels with history best records
with autotvm.apply_history_best(log_file):
print("Compile...")
with relay.build_config(opt_level=3):
graph, lib, params = relay.build_module.build(
net, target=target, params=params)
# upload parameters to device
ctx = tvm.cpu()
data_tvm = tvm.nd.array((np.random.uniform(size=data_shape)).astype(dtype))
module = runtime.create(graph, lib, ctx)
module.set_input('data', data_tvm)
module.set_input(**params)
# evaluate
print("Evaluate inference time cost...")
ftimer = module.module.time_evaluator("run", ctx, number=100, repeat=3)
prof_res = np.array(ftimer().results) * 1000 # convert to millisecond
print("Mean inference time (std dev): %.2f ms (%.2f ms)" %
(np.mean(prof_res), np.std(prof_res)))
# We do not run the tuning in our webpage server since it takes too long.
# Uncomment the following line to run it by yourself.
# tune_and_evaluate(tuning_option)
######################################################################
# Sample Output
# -------------
# The tuning needs to compile many programs and extract feature from them.
# So a high performance CPU is recommended.
# One sample output is listed below.
#
# .. code-block:: bash
#
# Extract tasks...
# Tuning...
# [Task 1/12] Current/Best: 598.05/2497.63 GFLOPS | Progress: (252/252) | 1357.95 s Done.
# [Task 2/12] Current/Best: 522.63/2279.24 GFLOPS | Progress: (784/784) | 3989.60 s Done.
# [Task 3/12] Current/Best: 447.33/1927.69 GFLOPS | Progress: (784/784) | 3869.14 s Done.
# [Task 4/12] Current/Best: 481.11/1912.34 GFLOPS | Progress: (672/672) | 3274.25 s Done.
# [Task 5/12] Current/Best: 414.09/1598.45 GFLOPS | Progress: (672/672) | 2720.78 s Done.
# [Task 6/12] Current/Best: 508.96/2273.20 GFLOPS | Progress: (768/768) | 3718.75 s Done.
# [Task 7/12] Current/Best: 469.14/1955.79 GFLOPS | Progress: (576/576) | 2665.67 s Done.
# [Task 8/12] Current/Best: 230.91/1658.97 GFLOPS | Progress: (576/576) | 2435.01 s Done.
# [Task 9/12] Current/Best: 487.75/2295.19 GFLOPS | Progress: (648/648) | 3009.95 s Done.
# [Task 10/12] Current/Best: 182.33/1734.45 GFLOPS | Progress: (360/360) | 1755.06 s Done.
# [Task 11/12] Current/Best: 372.18/1745.15 GFLOPS | Progress: (360/360) | 1684.50 s Done.
# [Task 12/12] Current/Best: 215.34/2271.11 GFLOPS | Progress: (400/400) | 2128.74 s Done.
# Compile...
# Evaluate inference time cost...
# Mean inference time (std dev): 3.16 ms (0.03 ms)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment