Unverified Commit b0b1e7da by yongfeng-nv Committed by GitHub

Tensor Expression Debug Display (TEDD) (#4651)

* Initial TEDD for publishing.

* 1. Fix lint issues. 2. Print intrin.body instead of intrin.name in Schedule Tree.  3. Add examples to top level APIs' comments.  4. Top level APIs don't print Dot string by default, unless outputdotstring is True.

* Fix more lint issues.

* Update top level API argument names and use raw strings to avoid Python lint warnings in the tests.

* Disable TEDD verification, but keep TE construction.

* Stop importing tedd to avoid failure.

* Separate data extraction and visualization. 1. Add API tedd.dump_json(schedule) to dump a json string for the schedule data for visualization.  2. Update tests.  3. Add a tutorial.  4. Add range information to IterVars.

* Update TEDD about InferBound failure.  1. TEDD doesn't call inferbound for DFG. 2. Update tutorial about the InferBound failure.

* 1. Import IPython only if SVG is requested.  This is required to fix a tutorial publishing faliure.  2. Fix test about IPython availability check.
parent b422f6a9
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
import numpy as np
import re
import topi
def findany(pattern, str):
matches = re.findall(pattern, str)
assert (len(matches) >
0), 'Pattern not found.\nPattern: ' + pattern + '\nString: ' + str
def checkdepdency():
import pkg_resources
return not {'graphviz', 'ipython'} - {pkg.key for pkg in pkg_resources.working_set}
def test_dfg():
A = tvm.placeholder((1024, 4096), dtype='float32', name='A')
B = topi.nn.softmax(A)
# confirm lower works
s = tvm.create_schedule([B.op])
def verify():
from tvm.contrib import tedd
str = tedd.viz_dataflow_graph(s, False, '', True)
# Check all edges are available
findany(r"digraph \"Dataflow Graph\"", str)
findany(r"Stage_0:O_0 -> Tensor_0_0", str)
findany(r"Tensor_0_0 -> Stage_1:I_0", str)
findany(r"Stage_1:O_0 -> Tensor_1_0", str)
findany(r"Tensor_0_0 -> Stage_2:I_0", str)
findany(r"Tensor_1_0 -> Stage_2:I_1", str)
findany(r"Stage_2:O_0 -> Tensor_2_0", str)
findany(r"Tensor_2_0 -> Stage_3:I_0", str)
findany(r"Stage_3:O_0 -> Tensor_3_0", str)
findany(r"Tensor_2_0 -> Stage_4:I_0", str)
findany(r"Tensor_3_0 -> Stage_4:I_1", str)
findany(r"Stage_4:O_0 -> Tensor_4_0", str)
if checkdepdency():
verify()
def test_itervar_relationship_graph():
n = tvm.var("n")
m = tvm.var("m")
A = tvm.placeholder((n, m), name='A')
k = tvm.reduce_axis((0, m), "k")
B = tvm.compute((n, ), lambda i: tvm.sum(A[i, k], axis=k), name="B")
s = tvm.create_schedule(B.op)
s[B].split(B.op.reduce_axis[0], factor=16)
def verify():
from tvm.contrib import tedd
str = tedd.viz_itervar_relationship_graph(s, False, '', True)
findany(r"digraph \"IterVar Relationship Graph\"", str)
findany(r"subgraph cluster_legend", str)
# Check subgraphs for stages
findany(r"subgraph cluster_Stage_0", str)
findany(r"subgraph cluster_Stage_1", str)
# Check itervars and their types
findany(r"i\(kDataPar\)\<br/\>range\(min=0, ext=n\)", str)
findany(r"k\(kCommReduce\)\<br/\>range\(min=0, ext=m\)", str)
# Check the split node
findany(r"Split_Relation_1_0 +.+\>Split", str)
# Check all edges to/from the split node
findany(r"IterVar_1_1:itervar -> Split_Relation_1_0:Input", str)
findany(r"Split_Relation_1_0:Outer -> IterVar_1_2:itervar", str)
findany(r"Split_Relation_1_0:Inner -> IterVar_1_3:itervar", str)
if checkdepdency():
verify()
def test_schedule_tree():
block_x = tvm.thread_axis('blockIdx.x')
thread_x = tvm.thread_axis('threadIdx.x')
n = tvm.var("n")
m = tvm.var("m")
l = tvm.var("l")
A = tvm.placeholder((n, m, l), name='A')
B = tvm.compute((n, m, l), lambda bi, bj, bk: A[bi, bj, bk] + 1, name='B')
r = tvm.reduce_axis((0, m), "r")
C = tvm.compute((n, m,),
lambda ci, cj: tvm.sum(B[ci, cj, r], axis=r),
name="C")
s = tvm.create_schedule(C.op)
s.cache_read(A, 'shared', [B])
s[B].vectorize(B.op.axis[-1])
s[C].reorder(C.op.reduce_axis[0], C.op.axis[0])
_, ki = s[C].split(C.op.reduce_axis[0], factor=16)
Cr = s.rfactor(C, ki)
s[Cr].compute_at(s[C], s[C].op.axis[-1])
s[C].bind(s[C].op.axis[0], block_x)
s[C].bind(s[C].op.axis[1], thread_x)
def verify():
from tvm.contrib import tedd
str = tedd.viz_schedule_tree(s, False, '', True)
findany(r"digraph \"Schedule Tree\"", str)
findany(r"subgraph cluster_legend", str)
# Check the A_shared stage, including memory scope, itervars,
# and compute
findany(r"Stage_1.*A\.shared<br/>Scope: shared.+>0.+>" \
r"ax0\(kDataPar\).+>1.+ax1\(kDataPar\).+>2.+>ax2\(kDataPar\).+>" \
r"\[A\(ax0, ax1, ax2\)\]", str)
# Check itervars of types different from KDataPar
findany(r"bk\(kVectorized\)", str)
findany(r"r.outer\(kCommReduce\)", str)
findany(r"label=ROOT", str)
# Check the compute_at edge
findany(r"Stage_1", str)
if checkdepdency():
verify()
if __name__ == "__main__":
test_dfg()
test_itervar_relationship_graph()
test_schedule_tree()
\ No newline at end of file
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Use Tensor Expression Debug Display (TEDD) for Visualization
============================================================
**Author**: `Yongfeng Gu <https://github.com/yongfeng-nv>`_
This is an introduction about using TEDD to visualize tensor expressions.
Tensor Expressions are scheduled with primitives. Although individual
primitives are usually easy to understand, they become complicated quickly
when you put them together. We have introduced an operational model of
schedule primitives in Tensor Expression in this document
(https://docs.google.com/document/d/1nmz00_n4Ju-SpYN0QFl3abTHTlR_P0dRyo5zsWC0Q1k/edit?usp=sharing)
to make it easier to understand
* the interactions between different schedule primitives,
* the impact of the schedule primitives on the final code generation.
The operational model is based on a Dataflow Graph, a Schedule Tree and an
IterVar Relationship Graph. Schedule primitives perform operations on these
graphs.
TEDD renders these three graphs from a given schedule. This tutorial demonstrates
how to use TEDD and how to interpret the rendered graphs.
"""
from __future__ import absolute_import, print_function
import tvm
import topi
from tvm.contrib import tedd
######################################################################
# Define and Schedule Convolution with Bias and ReLU
# --------------------------------------------------
# Let's build an example Tensor Expression for a convolution followed by Bias and ReLU.
# We first connect conv2d, add, and relu TOPIs. Then, we create a TOPI generic schedule.
#
batch = 1
in_channel = 256
in_size = 32
num_filter = 256
kernel = 3
stride = 1
padding = "SAME"
dilation=1
A = tvm.placeholder((in_size, in_size, in_channel, batch), name='A')
W = tvm.placeholder((kernel, kernel, in_channel, num_filter), name='W')
B = tvm.placeholder((1, num_filter, 1), name='bias')
with tvm.target.create("cuda"):
t_conv = topi.nn.conv2d(A, W, stride, padding, dilation, layout='HWCN')
t_bias = topi.add(t_conv, B)
t_relu = topi.nn.relu(t_bias)
s = topi.generic.schedule_conv2d_hwcn([t_relu])
######################################################################
# Render Graphs with TEDD
# -----------------------
# We render graphs to see the computation
# and how it is scheduled.
# If you run the tutorial in a Jupyter notebook, you can use the following commented lines
# to render SVG figures showing in notebook directly.
#
tedd.viz_dataflow_graph(s, dot_file_path = '/tmp/dfg.dot')
#tedd.viz_dataflow_graph(s, show_svg = True)
######################################################################
# .. image:: https://github.com/dmlc/web-data/raw/master/tvm/tutorial/tedd_dfg.png
# :align: center
# :scale: 100%
#
# The first one is a dataflow graph. Every node represents a stage with name and memory
# scope shown in the middle and inputs/outputs information on the sides.
# Edges show nodes' dependency.
#
tedd.viz_schedule_tree(s, dot_file_path = '/tmp/scheduletree.dot')
#tedd.viz_schedule_tree(s, show_svg = True)
######################################################################
# We just rendered the schedule tree graph. You may notice an warning about ranges not
# available.
# The message also suggests to call normalize() to infer range information. We will
# skip inspecting the first schedule tree and encourage you to compare the graphs before
# and after normalize() for its impact.
#
s = s.normalize()
tedd.viz_schedule_tree(s, dot_file_path = '/tmp/scheduletree2.dot')
#tedd.viz_schedule_tree(s, show_svg = True)
######################################################################
# .. image:: https://github.com/dmlc/web-data/raw/master/tvm/tutorial/tedd_st.png
# :align: center
# :scale: 100%
#
# Now, let us take a close look at the second schedule tree. Every block under ROOT
# represents a
# stage. Stage name shows in the top row and compute shows in the bottom row.
# The middle rows are for IterVars, the higher the outer, the lower the inner.
# An IterVar row contains its index, name, type, and other optional information.
# Let's use the W.shared stage as an example. The top row tells
# its name, "W.shared", and memory scope, "Shared". Its compute is
# :code:`W(ax0, ax1, ax2, ax3)`.
# Its outer most loop IterVar is ax0.ax1.fused.ax2.fused.ax3.fused.outer,
# indexed with 0, of kDataPar, bound to threadIdx.y, and with range(min=0, ext=8).
# You can also tell
# IterVar type with the index box color, shown in the legend.
#
# If a stage doesn't compute_at any other stage, it has an edge directly to the
# ROOT node. Otherwise, it has an edge pointing to the IterVar it attaches to,
# such as W.shared attaches to rx.outer in the middle compute stage.
#
######################################################################
# .. note::
#
# By definition, IterVars are internal nodes and computes are leaf nodes in
# a schedule tree. The edges among IterVars and compute within one stage are
# omitted, making every stage a block, for better readability.
#
tedd.viz_itervar_relationship_graph(s, dot_file_path = '/tmp/itervar.dot')
#tedd.viz_itervar_relationship_graph(s, show_svg = True)
######################################################################
# .. image:: https://github.com/dmlc/web-data/raw/master/tvm/tutorial/tedd_itervar_rel.png
# :align: center
# :scale: 100%
#
# The last one is an IterVar Relationship Graph. Every subgraph represents a
# stage and contains IterVar nodes and transformation nodes. For example,
# W.shared has three split nodes and three fuse nodes. The rest are IterVar
# nodes of the same format as the IterVar rows in Schedule Trees. Root
# IterVars are those not driven by any transformation node, such as ax0; leaf
# IterVars don't drive any transformation node and have non-negative indices,
# such as ax0.ax1.fused.ax2.fused.ax3.fused.outer with index of 0.
#
######################################################################
# Summary
# -------
# This tutorial demonstrates the usage of TEDD. We use an example built
# with TOPI to show the schedules under the hood. You can also use
# it before and after any schedule primitive to inspect its effect.
#
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment