Commit 8f18cc44 by Neo Chien Committed by Tianqi Chen

[AUTOTVM][DOCS] Add a link to the defining network description of auto-tuning tutorial (#4023)

* [AUTOTVM][DOCS] Add a link to autoTVM tutorial to direct the details of building NN with relay

* [AUTOTVM][DOCS] Add a link to autoTVM tutorial to direct the details of building NN with relay
parent f98035b0
...@@ -37,11 +37,13 @@ import tvm.contrib.graph_runtime as runtime ...@@ -37,11 +37,13 @@ import tvm.contrib.graph_runtime as runtime
# Define network # Define network
# -------------- # --------------
# First we need to define the network in relay frontend API. # First we need to define the network in relay frontend API.
# We can load some pre-defined network from :code:`relay.testing`. # We can either load some pre-defined network from :code:`relay.testing`
# or building :any:`relay.testing.resnet` with relay.
# We can also load models from MXNet, ONNX and TensorFlow. # We can also load models from MXNet, ONNX and TensorFlow.
# #
# In this tutorial, we choose resnet-18 as tuning example. # In this tutorial, we choose resnet-18 as tuning example.
def get_network(name, batch_size): def get_network(name, batch_size):
"""Get the symbol definition and random weight of a network""" """Get the symbol definition and random weight of a network"""
input_shape = (batch_size, 3, 224, 224) input_shape = (batch_size, 3, 224, 224)
...@@ -73,6 +75,7 @@ def get_network(name, batch_size): ...@@ -73,6 +75,7 @@ def get_network(name, batch_size):
return mod, params, input_shape, output_shape return mod, params, input_shape, output_shape
# Replace "llvm" with the correct target of your CPU. # Replace "llvm" with the correct target of your CPU.
# For example, for AWS EC2 c5 instance with Intel Xeon # For example, for AWS EC2 c5 instance with Intel Xeon
# Platinum 8000 series, the target should be "llvm -mcpu=skylake-avx512". # Platinum 8000 series, the target should be "llvm -mcpu=skylake-avx512".
...@@ -121,6 +124,7 @@ tuning_option = { ...@@ -121,6 +124,7 @@ tuning_option = {
), ),
} }
# You can skip the implementation of this function for this tutorial. # You can skip the implementation of this function for this tutorial.
def tune_kernels(tasks, def tune_kernels(tasks,
measure_option, measure_option,
...@@ -165,6 +169,7 @@ def tune_kernels(tasks, ...@@ -165,6 +169,7 @@ def tune_kernels(tasks,
autotvm.callback.progress_bar(n_trial, prefix=prefix), autotvm.callback.progress_bar(n_trial, prefix=prefix),
autotvm.callback.log_to_file(log_filename)]) autotvm.callback.log_to_file(log_filename)])
# Use graph tuner to achieve graph level optimal schedules # Use graph tuner to achieve graph level optimal schedules
# Set use_DP=False if it takes too long to finish. # Set use_DP=False if it takes too long to finish.
def tune_graph(graph, dshape, records, opt_sch_file, use_DP=True): def tune_graph(graph, dshape, records, opt_sch_file, use_DP=True):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment