tedd.py 6.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Use Tensor Expression Debug Display (TEDD) for Visualization
============================================================
20 21
**Author**: `Yongfeng Gu <https://github.com/yongfeng-nv>`_

22
This is an introduction about using TEDD to visualize tensor expressions.
23

24 25 26 27
Tensor Expressions are scheduled with primitives.  Although individual
primitives are usually easy to understand, they become complicated quickly
when you put them together. We have introduced an operational model of
schedule primitives in Tensor Expression.
28

29 30
* the interactions between different schedule primitives,
* the impact of the schedule primitives on the final code generation.
31

32 33
The operational model is based on a Dataflow Graph, a Schedule Tree and an
IterVar Relationship Graph. Schedule primitives perform operations on these
34 35
graphs.

36 37
TEDD renders these three graphs from a given schedule.  This tutorial demonstrates
how to use TEDD and how to interpret the rendered graphs.
38 39 40

"""
import tvm
41
from tvm import te
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
import topi
from tvm.contrib import tedd

######################################################################
# Define and Schedule Convolution with Bias and ReLU
# --------------------------------------------------
# Let's build an example Tensor Expression for a convolution followed by Bias and ReLU.
# We first connect conv2d, add, and relu TOPIs.  Then, we create a TOPI generic schedule.
#

batch = 1
in_channel = 256
in_size = 32
num_filter = 256
kernel = 3
stride = 1
padding = "SAME"
dilation=1
60 61 62 63 64

A = te.placeholder((in_size, in_size, in_channel, batch), name='A')
W = te.placeholder((kernel, kernel, in_channel, num_filter), name='W')
B = te.placeholder((1, num_filter, 1), name='bias')

65
with tvm.target.create("llvm"):
66
    t_conv = topi.nn.conv2d_hwcn(A, W, stride, padding, dilation)
67 68
    t_bias = topi.add(t_conv, B)
    t_relu = topi.nn.relu(t_bias)
69 70 71
    s = topi.generic.schedule_conv2d_hwcn([t_relu])

######################################################################
72 73
# Render Graphs with TEDD
# -----------------------
74 75 76
# We render graphs to see the computation
# and how it is scheduled.
# If you run the tutorial in a Jupyter notebook, you can use the following commented lines
77 78 79
# to render SVG figures showing in notebook directly.
#

80 81
tedd.viz_dataflow_graph(s, dot_file_path = '/tmp/dfg.dot')
#tedd.viz_dataflow_graph(s, show_svg = True)
82 83 84 85 86

######################################################################
# .. image:: https://github.com/dmlc/web-data/raw/master/tvm/tutorial/tedd_dfg.png
#      :align: center
#
87 88 89
# The first one is a dataflow graph.  Every node represents a stage with name and memory
# scope shown in the middle and inputs/outputs information on the sides.
# Edges show nodes' dependency.
90 91
#

92 93
tedd.viz_schedule_tree(s, dot_file_path = '/tmp/scheduletree.dot')
#tedd.viz_schedule_tree(s, show_svg = True)
94 95

######################################################################
96
# We just rendered the schedule tree graph.  You may notice an warning about ranges not
97 98 99 100 101 102 103
# available.
# The message also suggests to call normalize() to infer range information.  We will
# skip inspecting the first schedule tree and encourage you to compare the graphs before
# and after normalize() for its impact.
#

s = s.normalize()
104 105
tedd.viz_schedule_tree(s, dot_file_path = '/tmp/scheduletree2.dot')
#tedd.viz_schedule_tree(s, show_svg = True)
106 107 108 109 110

######################################################################
# .. image:: https://github.com/dmlc/web-data/raw/master/tvm/tutorial/tedd_st.png
#      :align: center
#
111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
# Now, let us take a close look at the second schedule tree.  Every block under ROOT
# represents a
# stage.  Stage name shows in the top row and compute shows in the bottom row.
# The middle rows are for IterVars, the higher the outer, the lower the inner.
# An IterVar row contains its index, name, type, and other optional information.
# Let's use the W.shared stage as an example.  The top row tells
# its name, "W.shared", and memory scope, "Shared".  Its compute is
# :code:`W(ax0, ax1, ax2, ax3)`.
# Its outer most loop IterVar is ax0.ax1.fused.ax2.fused.ax3.fused.outer,
# indexed with 0, of kDataPar, bound to threadIdx.y, and with range(min=0, ext=8).
# You can also tell
# IterVar type with the index box color, shown in the legend.
#
# If a stage doesn't compute_at any other stage, it has an edge directly to the
# ROOT node.  Otherwise, it has an edge pointing to the IterVar it attaches to,
# such as W.shared attaches to rx.outer in the middle compute stage.
127 128
#

129 130 131 132 133
######################################################################
# .. note::
#
#   By definition, IterVars are internal nodes and computes are leaf nodes in
#   a schedule tree.   The edges among IterVars and compute within one stage are
134 135 136
#   omitted, making every stage a block, for better readability.
#

137 138
tedd.viz_itervar_relationship_graph(s, dot_file_path = '/tmp/itervar.dot')
#tedd.viz_itervar_relationship_graph(s, show_svg = True)
139 140 141 142 143

######################################################################
# .. image:: https://github.com/dmlc/web-data/raw/master/tvm/tutorial/tedd_itervar_rel.png
#      :align: center
#
144 145 146 147 148 149
# The last one is an IterVar Relationship Graph.  Every subgraph represents a
# stage and contains IterVar nodes and transformation nodes.  For example,
# W.shared has three split nodes and three fuse nodes.  The rest are IterVar
# nodes of the same format as the IterVar rows in Schedule Trees.  Root
# IterVars are those not driven by any transformation node, such as ax0; leaf
# IterVars don't drive any transformation node and have non-negative indices,
150 151 152 153
# such as ax0.ax1.fused.ax2.fused.ax3.fused.outer with index of 0.
#


154 155 156 157 158
######################################################################
# Summary
# -------
# This tutorial demonstrates the usage of TEDD.  We use an example built
# with TOPI to show the schedules under the hood.  You can also use
159
# it before and after any schedule primitive to inspect its effect.
160
#