Unverified Commit b0b1e7da by yongfeng-nv Committed by GitHub

Tensor Expression Debug Display (TEDD) (#4651)

* Initial TEDD for publishing.

* 1. Fix lint issues. 2. Print intrin.body instead of intrin.name in Schedule Tree.  3. Add examples to top level APIs' comments.  4. Top level APIs don't print Dot string by default, unless outputdotstring is True.

* Fix more lint issues.

* Update top level API argument names and use raw strings to avoid Python lint warnings in the tests.

* Disable TEDD verification, but keep TE construction.

* Stop importing tedd to avoid failure.

* Separate data extraction and visualization. 1. Add API tedd.dump_json(schedule) to dump a json string for the schedule data for visualization.  2. Update tests.  3. Add a tutorial.  4. Add range information to IterVars.

* Update TEDD about InferBound failure.  1. TEDD doesn't call inferbound for DFG. 2. Update tutorial about the InferBound failure.

* 1. Import IPython only if SVG is requested.  This is required to fix a tutorial publishing faliure.  2. Fix test about IPython availability check.
parent b422f6a9
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=import-outside-toplevel
"""Tensor Expression Debug Display (TEDD), visualizing Tensor Expression"""
import html
import json
import warnings
from graphviz import Digraph
from graphviz import Source
import tvm
TVMDD_TABLE_BODY_WIDTH = 30
# Must match enum IterVarType defined in include/tvm/expr.h
ITERVAR_TYPE_STRING_MAP = {
0: ('kDataPar', '#FFFFFF'),
1: ('kThreadIndex', '#2980B9'),
2: ('kCommReduce', '#FAD7A0'),
3: ('kOrdered', '#D35400'),
4: ('kOpaque', '#ABB2B9'),
5: ('kUnrolled', '#D2B4DE'),
6: ('kVectorized', '#AED6F1'),
7: ('kParallelized', '#F5B7B1'),
8: ('kTensorized', '#A9DFBF'),
}
def dom_path_to_string(dom_path, prefix=""):
path_string = prefix
for index in dom_path:
path_string = path_string + '_' + str(index)
return path_string
def insert_dot_id(sch):
"""Insert unique ID for each node in the DOM tree.
They are used as Dot node ID.
"""
for stage_idx, stage in enumerate(sch["stages"]):
dom_path = [stage_idx]
stage["id"] = dom_path_to_string(dom_path, stage["type"])
for itervar_idx, itervar in enumerate(stage["all_itervars"]):
dom_path = [stage_idx, itervar_idx]
itervar["id"] = dom_path_to_string(dom_path, itervar["type"])
for rel_idx, rel in enumerate(stage["relations"]):
dom_path = [stage_idx, rel_idx]
rel["id"] = dom_path_to_string(dom_path, rel["type"])
for tensor_idx, tensor in enumerate(stage["output_tensors"]):
dom_path = [stage_idx, tensor_idx]
tensor["id"] = dom_path_to_string(dom_path, tensor["type"])
return sch
class ObjectManager:
"""A helper class tracking schedule objects, e.g. stage, IterVar,
relationship, and tensor, to their DOM path."""
def __init__(self, sch):
self.dict = {}
for stage_idx, stage in enumerate(sch.stages):
self.dict[stage] = [stage_idx]
for itervar_idx, itervar in enumerate(stage.all_iter_vars):
self.dict[itervar] = [stage_idx, itervar_idx]
for rel_idx, rel in enumerate(stage.relations):
self.dict[rel] = [stage_idx, rel_idx]
for tensor_idx in range(stage.op.num_outputs):
self.dict[frozenset({stage.op.name,
tensor_idx})] = [stage_idx, tensor_idx]
def get_dom_path(self, obj):
if obj is None:
return None
assert obj in self.dict, 'Node is no found.'
return self.dict[obj]
def get_or_create_dot_id(obj, prefix="", assert_on_missing=False):
"""If obj's ID has been registered, return it.
If not, either assert or create a unique and legal ID, register and
return it, according to assert_on_missing.
ID must be a unique and legal Dotty ID.
Parameters
----------
obj : objet
Serve as the key to the ID.
prefix : string
Prefix to attach to the ID. Usually use obj's non-unique
name as prefix.
assert_on_missing : bool
Assert or not if object doesn't have a registered ID.
"""
prefix = prefix.replace('.', '_')
if not hasattr(get_or_create_dot_id, "obj_id_dict"):
get_or_create_dot_id.obj_id_dict = {}
if obj not in get_or_create_dot_id.obj_id_dict:
if assert_on_missing:
assert False, 'dot_id ' + str(obj) + ' has not been registered.'
else:
get_or_create_dot_id.obj_id_dict[obj] = prefix + hex(id(obj))
return get_or_create_dot_id.obj_id_dict[obj]
def get_port_id(is_input, index):
return 'I_' + str(index) if is_input else 'O_' + str(index)
def get_itervar_type_info(iter_type):
assert iter_type < len(
ITERVAR_TYPE_STRING_MAP), 'Unknown IterVar type: ' + str(iter_type)
return ITERVAR_TYPE_STRING_MAP[iter_type]
def get_itervar_label_color(itervar, iv_type):
type_info = get_itervar_type_info(iv_type)
return linebrk(
str(itervar["name"]) + '(' + type_info[0] + ')',
TVMDD_TABLE_BODY_WIDTH), type_info[1]
def linebrk(s, n):
""" Break input string s with <br/> for every n charactors."""
result = ''
j = 0
for i, c in enumerate(s):
if j == n and i != len(s) - 1:
result = result + '\n'
j = 0
j = j + 1
result = result + c
result = html.escape(str(result), quote=True)
result = result.replace('\n', '<br/>')
return result
def create_graph(name="", rankdir='BT'):
graph = Digraph(name=name)
graph.graph_attr['rankdir'] = rankdir
return graph
def itervar_label(itervar, index, index_color, label):
return '<TR><TD PORT="' + itervar[
"id"] + '" BGCOLOR="' + index_color + '">' + str(
index
) + '</TD><TD BGCOLOR="white" PORT="itervar">' + label + '<br/>' + str(
itervar["properties"]["range"]) + '</TD></TR>'
def stage_label(stage):
return stage['name'] + '<br/>Scope: ' + stage['properties']['scope']
def legend_label():
label = '<<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0" CELLPADDING="4">'
for iter_type in ITERVAR_TYPE_STRING_MAP:
name, color = ITERVAR_TYPE_STRING_MAP[iter_type]
label += '<TR><TD BGCOLOR="' + color + '"></TD>' \
+ '<TD BGCOLOR="white">' + name + '</TD></TR>'
label += '</TABLE>>'
return label
def leaf_itervars(stage):
filtered = filter(lambda x: (x["index"] >= 0), stage["all_itervars"])
return sorted(filtered, key=lambda x: x["index"])
def legend_dot(g):
with g.subgraph(name='cluster_legend') as subgraph:
subgraph.attr(label='Legend')
label = legend_label()
subgraph.node('legend', label, shape='none', margin='0')
def extract_dom_for_viz(sch, need_range=True):
json_str = dump_json(sch, need_range)
s = json.loads(json_str)
s = insert_dot_id(s)
return s
def dump_graph(dot_string,
show_svg=True,
dot_file_path='',
output_dot_string=False):
"""Output dot_string in various formats."""
if dot_file_path:
try:
dot_file = open(dot_file_path, "w+")
dot_file.write(dot_string)
dot_file.close()
except IOError:
print('Cannot open file: ' + dot_file_path)
if show_svg:
from IPython.display import display
from IPython.display import SVG
src = Source(dot_string)
display(SVG(src.pipe(format='svg')))
if output_dot_string:
return dot_string
return None
def dump_json(sch, need_range):
"""Serialize data for visualization from a schedule in JSON format.
Parameters
----------
sch : schedule
The schedule object to serialize
Returns
-------
json : string
Serialized JSON string
"""
def encode_itervar(itervar, stage, index, range_map):
"""Extract and encode IterVar visualization data to a dictionary"""
ivrange = range_map[
itervar] if range_map is not None and itervar in range_map else None
bind_thread = None
tensor_intrin = None
if itervar in stage.iter_var_attrs:
attr = stage.iter_var_attrs[itervar]
iv_type = attr.iter_type
# binding
bind_thread = str(
attr.bind_thread.var) if attr.bind_thread is not None else None
# tensorization
if attr.tensor_intrin is not None:
tensor_intrin = str(attr.tensor_intrin.body)
# remove the final \n
tensor_intrin = tensor_intrin[0:-1] if tensor_intrin[
-1] == "\n" else tensor_intrin
else:
tensor_intrin = None
else:
iv_type = itervar.iter_type
itervar_dict = {
"type": "IterVar",
"index": index,
"name": str(itervar.var),
"itervar_type": iv_type,
"properties": {
"thread": bind_thread,
"intrin": tensor_intrin,
"range": str(ivrange) if ivrange is not None else 'range(N/A)',
}
}
return itervar_dict
def encode_itervars(stage, range_map):
"""Extract and encode IterVars visualization data from a stage to a dictionary"""
def get_leaf_itervar_index(itervar, leaf_iv):
for leaf_index, ivar in enumerate(leaf_iv):
if ivar == itervar:
return leaf_index
return -1
itervars = []
for itervar in stage.all_iter_vars:
leaf_index = get_leaf_itervar_index(itervar, stage.leaf_iter_vars)
itervars.append(
encode_itervar(itervar, stage, leaf_index, range_map))
return itervars
def encode_itervar_relation(obj_manager, rel):
"""Extract and encode IterVar Relationship visualization data to a dictionary"""
rel_type = type(rel)
if rel_type is tvm.schedule.Split:
node_type = 'Split_Relation'
rel_dict = {
"type": node_type,
"parent": obj_manager.get_dom_path(rel.parent),
"outer": obj_manager.get_dom_path(rel.outer),
"inner": obj_manager.get_dom_path(rel.inner),
}
elif rel_type is tvm.schedule.Fuse:
node_type = 'Fuse_Relation'
rel_dict = {
"type": node_type,
"fused": obj_manager.get_dom_path(rel.fused),
"outer": obj_manager.get_dom_path(rel.outer),
"inner": obj_manager.get_dom_path(rel.inner),
}
elif rel_type is tvm.schedule.Singleton:
node_type = 'Singleton_Relation'
rel_dict = {
"type": node_type,
"iter": obj_manager.get_dom_path(rel.iter),
}
else:
return None
return rel_dict
def encode_itervar_relations(obj_manager, stage):
relations = []
for i in range(len(stage.relations)):
rel = encode_itervar_relation(obj_manager, stage.relations[i])
if rel is not None:
relations.append(rel)
return relations
def encode_tensor(obj_manager, tensor, stage):
"""Extract and encode tensor visualization data to a dictionary"""
tensor_dict = {
"type": "Tensor",
"source": obj_manager.get_dom_path(stage),
"value_index": tensor.value_index,
"shape": str(tensor.op.output(tensor.value_index).shape),
"data_type": tensor.op.output(tensor.value_index).dtype,
}
return tensor_dict
def encode_tensors(obj_manager, stage):
tensors = []
for i in range(stage.op.num_outputs):
tensor = stage.op.output(i)
tensors.append(encode_tensor(obj_manager, tensor, stage))
tensors.sort(key=lambda tensor: tensor["value_index"])
return tensors
def encode_stage(obj_manager, stage, range_map):
"""Extract and encode stage visualization data to a dictionary"""
stage_dict = {
"type":
"Stage",
"name":
stage.op.name,
"attaching_to":
obj_manager.get_dom_path(stage.attach_ivar),
"compute":
str(stage.op.body) if hasattr(stage.op, 'body') else None,
"properties": {
"scope": stage.scope,
},
"all_itervars":
encode_itervars(stage, range_map),
"relations":
encode_itervar_relations(obj_manager, stage),
"input_tensors": [
obj_manager.get_dom_path(
frozenset({tensor.op.name, tensor.value_index}))
for tensor in stage.op.input_tensors
],
"output_tensors":
encode_tensors(obj_manager, stage),
}
return stage_dict
def encode_schedule(sch, need_range):
"""Extract and encode data from a schedule for visualization to a nested dictionary.
It is useful for JSON to serialize schedule.
Parameters
----------
sch : schedule
The schedule object to extract
Returns
-------
dict : dictionary
A nested dictionary
"""
assert isinstance(sch, tvm.schedule.Schedule
), 'Input is not a tvm.schedule.Schedule object.'
range_map = None
if need_range:
try:
range_map = tvm.schedule.InferBound(sch)
except tvm._ffi.base.TVMError as expt:
warnings.warn(
'Ranges are not available, because InferBound fails with the following error:\n'
+ str(expt))
obj_manager = ObjectManager(sch)
stages = []
for stage in sch.stages:
stages.append(encode_stage(obj_manager, stage, range_map))
return {
"type": "Schedule",
"stages": stages,
}
return json.dumps(sch, default=lambda s: encode_schedule(s, need_range))
def viz_schedule_tree(sch,
show_svg=False,
dot_file_path='',
output_dot_string=False):
"""Top level API to render schedule tree
Parameters
----------
sch : schedule
The schedule object to visualize
show_svg : bool
Display graph as SVG, useful for Jupyter notebooks.
dot_file_path : string
Dot file to save the graph.
output_dot_string : bool
Return dot file content or an empty string.
Returns
-------
dot_string : string
Dot file content or an empty string according to output_dot_string
Examples
--------
The following code writes a schedule tree to a dot file.
.. code-block:: python
tedd.viz_schedule_tree(s, dot_file_path = '/tmp/example.dot')
Use the following code to render a SVG graph in a Jupyter notebook.
.. code-block:: python
tedd.viz_schedule_tree(s, show_svg = True)
"""
def create_schedule_tree_graph(name=""):
return create_graph(name=name, rankdir='BT')
def root_dot(g):
g.node('ROOT', 'ROOT', shape='oval', margin='0')
def stage_node_dot(g, stage):
node_label = stage_node_label(stage)
g.node(stage['id'], node_label, shape='none', margin='0')
def stage_node_label(stage):
"""Return a html format label for the given stage."""
label = '<<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0" ' \
'CELLPADDING="4"> <TR><TD BGCOLOR="lightgrey" ' \
'COLSPAN="2" PORT="stage">' + stage_label(stage) + '</TD></TR>'
for leafiv in leaf_itervars(stage):
iv_type = leafiv["itervar_type"]
var_attr_label = ''
if "thread" in leafiv["properties"] and \
leafiv["properties"]["thread"] is not None:
var_attr_label = var_attr_label + "<br/>(" + str(
leafiv["properties"]["thread"]) + ")"
if "intrin" in leafiv["properties"] and \
leafiv["properties"]["intrin"] is not None:
var_attr_label = var_attr_label + "<br/>" + \
linebrk("(tensor_intrin:" + str(
leafiv["properties"]["intrin"]) + ")", TVMDD_TABLE_BODY_WIDTH)
var_label, color = get_itervar_label_color(leafiv, iv_type)
label += itervar_label(leafiv, leafiv["index"], color,
var_label + var_attr_label)
if stage["compute"] is not None:
label += '<TR><TD COLSPAN="2">' + linebrk(str(
stage["compute"]), TVMDD_TABLE_BODY_WIDTH) + '</TD></TR>'
label += '</TABLE>>'
return label
def compute_at_dot(g, stage):
"""If the given stage attaches to another stage, create an edge from it
stage to its attach point; otherwise, create an edge to the ROOT.
"""
src = stage["id"]
dst = dom_path_to_string(
[stage["attaching_to"][0]], "Stage") + ":" + dom_path_to_string(
stage["attaching_to"],
"IterVar") if stage["attaching_to"] is not None else "ROOT"
g.edge(src, dst)
graph = create_schedule_tree_graph("Schedule Tree")
s = extract_dom_for_viz(sch)
legend_dot(graph)
for stage in s['stages']:
stage_node_dot(graph, stage)
for stage in s['stages']:
compute_at_dot(graph, stage)
root_dot(graph)
return dump_graph(graph.source, show_svg, dot_file_path, output_dot_string)
def viz_itervar_relationship_graph(sch,
show_svg=False,
dot_file_path='',
output_dot_string=False):
"""Top level API to render IterVar relationship graph
Parameters
----------
sch : schedule
The schedule object to visualize
show_svg : bool
Display graph as SVG, useful for Jupyter notebooks.
dot_file_path : string
Dot file to save the graph.
output_dot_string : bool
Return dot file content or an empty string.
Examples
--------
The following code writes Ian tervar relationship graph to a dot file.
.. code-block:: python
tedd.viz_def viz_itervar_relationship_graph(sch,
(s, dot_file_path = '/tmp/example.dot')
Use the following code to render a SVG graph in a Jupyter notebook.
.. code-block:: python
tedd.viz_def viz_itervar_relationship_graph(sch,
(s, show_svg = True)
"""
def create_itervar_relation_graph(name=""):
return create_graph(name=name, rankdir='TB')
def itervar_node_dot(g, itervar, iv_type, index):
label = itervar_node_label(itervar, iv_type, index)
g.node(itervar["id"], label, shape='none', margin='0')
def itervar_node_label(itervar, iv_type, index):
label = '<<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0" ' \
'CELLPADDING="4">' + itervar_label(
itervar, index,
get_itervar_label_color(itervar, iv_type)[1],
get_itervar_label_color(itervar, iv_type)[0]) + '</TABLE>>'
return label
def itervar_relation_node_dot(g, node_id, node_label, input_ports,
output_ports):
label = itervar_relation_node_label(node_label, input_ports,
output_ports)
g.node(node_id, label, shape='none', margin='0')
def itervar_relation_node_label(node_label, input_ports, output_ports):
"""Return a html format label for an itervar relationship node
including node_label and input/output ports.
"""
label = '<<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0" ' \
'CELLPADDING="4">' + '<TR>'
max_port_num = max(len(input_ports), len(output_ports))
for i in range(max_port_num):
if i < len(input_ports):
input_port = input_ports[i]
label += '<TD BGCOLOR="lightgrey" PORT="' + input_port + '">' \
+ input_port + '</TD>'
else:
label += '<TD BGCOLOR="white"></TD>'
label += '</TR>'
label += '<TR><TD BGCOLOR="white" COLSPAN="' + str(
max_port_num) + '" PORT="relation">' + node_label + '</TD></TR>'
label += '<TR>'
for i in range(max_port_num):
if i < len(output_ports):
output_port = output_ports[i]
label += '<TD BGCOLOR="lightgrey" PORT="' + output_port + '">' \
+ output_port + '</TD>'
else:
label += '<TD BGCOLOR="white"></TD>'
label += '</TR>'
label += '</TABLE>>'
return label
def itervar_relation_dot(g, node, node_id):
"""Create an itervar relationship node."""
node_type = node["type"]
if node_type == "Split_Relation":
node_type = 'Split'
itervar_relation_node_dot(g, node_id, node_type, ['Input'],
['Outer', 'Inner'])
parent = dom_path_to_string(node["parent"], "IterVar")
outer = dom_path_to_string(node["outer"], "IterVar")
inner = dom_path_to_string(node["inner"], "IterVar")
g.edge(parent + ':itervar', node_id + ':Input')
g.edge(node_id + ':Outer', outer + ':itervar')
g.edge(node_id + ':Inner', inner + ':itervar')
elif node_type == "Fuse_Relation":
node_type = 'Fuse'
itervar_relation_node_dot(g, node_id, node_type,
['Outer', 'Inner'], ['Fused'])
fused = dom_path_to_string(node["fused"], "IterVar")
outer = dom_path_to_string(node["outer"], "IterVar")
inner = dom_path_to_string(node["inner"], "IterVar")
g.edge(outer + ':itervar', node_id + ':Outer')
g.edge(inner + ':itervar', node_id + ':Inner')
g.edge(node_id + ':Fused', fused + ':itervar')
elif node_type == "Singleton_Relation":
node_type = 'Singleton'
itervar_relation_node_dot(g, node_id, node_type, [], ['Iter'])
itervar = dom_path_to_string(node["inner"], "IterVar")
g.edge(node_id + ':Iter', itervar + ':itervar')
else:
assert False, 'Unknown IterVarRelationNode: ' + node_type
def stage_node_dot(g, stage):
"""Create a stage node."""
with g.subgraph(name='cluster_' + stage["id"]) as subgraph:
subgraph.attr(label=stage["name"])
if stage["all_itervars"]:
for itervar in stage["all_itervars"]:
iv_type = itervar["itervar_type"]
itervar_node_dot(subgraph, itervar, iv_type,
itervar["index"])
for rel in stage["relations"]:
node_id = rel["id"]
itervar_relation_dot(subgraph, rel, node_id)
else:
subgraph.node(stage["name"] + '_placeholder', style='invis')
graph = create_itervar_relation_graph("IterVar Relationship Graph")
s = extract_dom_for_viz(sch)
legend_dot(graph)
for stage in s['stages']:
stage_node_dot(graph, stage)
return dump_graph(graph.source, show_svg, dot_file_path, output_dot_string)
def viz_dataflow_graph(sch,
show_svg=False,
dot_file_path='',
output_dot_string=False):
"""Top level API to render dataflow graph
Parameters
----------
sch : schedule
The schedule object to visualize
show_svg : bool
Display graph as SVG, useful for Jupyter notebooks.
dot_file_path : string
Dot file to save the graph.
output_dot_string : bool
Return dot file content or an empty string.
Examples
--------
The following code writes a dataflow graph to a dot file.
.. code-block:: python
tedd.viz_dataflow_graph(s, dot_file_path = '/tmp/example.dot')
Use the following code to render a SVG graph in a Jupyter notebook.
.. code-block:: python
tedd.viz_dataflow_graph(s, show_svg = True) """
def create_dataflow_graph(name=""):
return create_graph(name=name, rankdir='LR')
def tensor_node_dot(g, tensor):
"""Create a tensor node."""
label = tensor_node_label(tensor)
g.node(tensor["id"], label, shape='oval', margin='0')
def tensor_node_label(tensor):
"""Return a html format label for the given tensor."""
label = str(tensor["shape"]) + '\n' + str(tensor["data_type"])
return label
def stage_node_dot(g, stage):
"""Create a stage node."""
label = stage_node_label(stage)
g.node(stage["id"], label, shape='none', margin='0')
def stage_node_label(stage):
"""Return a html format label for the given stage."""
rows = max(
1, max(len(stage["output_tensors"]), len(stage["input_tensors"])))
label = '<<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0" ' \
'CELLPADDING="4">'
for i in range(rows):
label += '<TR>'
if i < len(stage["input_tensors"]):
port_id = get_port_id(True, i)
label += '<TD BGCOLOR="lightgrey" COLSPAN="2" PORT="' \
+ port_id + '">' + str(
i) + '</TD>'
else:
label += '<TD BGCOLOR="white" COLSPAN="2"></TD>'
if i == 0:
label += '<TD BGCOLOR="white" COLSPAN="2" ROWSPAN="' + str(
rows) + '">' + stage_label(stage) + '</TD>'
if i < len(stage["output_tensors"]):
port_id = get_port_id(False, i)
label += '<TD BGCOLOR="lightgrey" COLSPAN="2" PORT="' \
+ port_id + '">' + str(
i) + '</TD>'
else:
label += '<TD BGCOLOR="white" COLSPAN="2"></TD>'
label += '</TR>'
label += '</TABLE>>'
return label
def dfg_dot(g, sch):
"""Create edges among stages."""
stages = sch['stages']
for stage in stages:
for i in range(len(stage["input_tensors"])):
src = dom_path_to_string(stage["input_tensors"][i], "Tensor")
dst = stage["id"] + ':' + get_port_id(True, i)
g.edge(src, dst)
for i in range(len(stage["output_tensors"])):
src = stage["id"] + ':' + get_port_id(False, i)
dst = stage["output_tensors"][i]["id"]
g.edge(src, dst)
graph = create_dataflow_graph("Dataflow Graph")
s = extract_dom_for_viz(sch, need_range=False)
for stage in s['stages']:
stage_node_dot(graph, stage)
for tensor in stage["output_tensors"]:
tensor_node_dot(graph, tensor)
dfg_dot(graph, s)
return dump_graph(graph.source, show_svg, dot_file_path, output_dot_string)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
import numpy as np
import re
import topi
def findany(pattern, str):
matches = re.findall(pattern, str)
assert (len(matches) >
0), 'Pattern not found.\nPattern: ' + pattern + '\nString: ' + str
def checkdepdency():
import pkg_resources
return not {'graphviz', 'ipython'} - {pkg.key for pkg in pkg_resources.working_set}
def test_dfg():
A = tvm.placeholder((1024, 4096), dtype='float32', name='A')
B = topi.nn.softmax(A)
# confirm lower works
s = tvm.create_schedule([B.op])
def verify():
from tvm.contrib import tedd
str = tedd.viz_dataflow_graph(s, False, '', True)
# Check all edges are available
findany(r"digraph \"Dataflow Graph\"", str)
findany(r"Stage_0:O_0 -> Tensor_0_0", str)
findany(r"Tensor_0_0 -> Stage_1:I_0", str)
findany(r"Stage_1:O_0 -> Tensor_1_0", str)
findany(r"Tensor_0_0 -> Stage_2:I_0", str)
findany(r"Tensor_1_0 -> Stage_2:I_1", str)
findany(r"Stage_2:O_0 -> Tensor_2_0", str)
findany(r"Tensor_2_0 -> Stage_3:I_0", str)
findany(r"Stage_3:O_0 -> Tensor_3_0", str)
findany(r"Tensor_2_0 -> Stage_4:I_0", str)
findany(r"Tensor_3_0 -> Stage_4:I_1", str)
findany(r"Stage_4:O_0 -> Tensor_4_0", str)
if checkdepdency():
verify()
def test_itervar_relationship_graph():
n = tvm.var("n")
m = tvm.var("m")
A = tvm.placeholder((n, m), name='A')
k = tvm.reduce_axis((0, m), "k")
B = tvm.compute((n, ), lambda i: tvm.sum(A[i, k], axis=k), name="B")
s = tvm.create_schedule(B.op)
s[B].split(B.op.reduce_axis[0], factor=16)
def verify():
from tvm.contrib import tedd
str = tedd.viz_itervar_relationship_graph(s, False, '', True)
findany(r"digraph \"IterVar Relationship Graph\"", str)
findany(r"subgraph cluster_legend", str)
# Check subgraphs for stages
findany(r"subgraph cluster_Stage_0", str)
findany(r"subgraph cluster_Stage_1", str)
# Check itervars and their types
findany(r"i\(kDataPar\)\<br/\>range\(min=0, ext=n\)", str)
findany(r"k\(kCommReduce\)\<br/\>range\(min=0, ext=m\)", str)
# Check the split node
findany(r"Split_Relation_1_0 +.+\>Split", str)
# Check all edges to/from the split node
findany(r"IterVar_1_1:itervar -> Split_Relation_1_0:Input", str)
findany(r"Split_Relation_1_0:Outer -> IterVar_1_2:itervar", str)
findany(r"Split_Relation_1_0:Inner -> IterVar_1_3:itervar", str)
if checkdepdency():
verify()
def test_schedule_tree():
block_x = tvm.thread_axis('blockIdx.x')
thread_x = tvm.thread_axis('threadIdx.x')
n = tvm.var("n")
m = tvm.var("m")
l = tvm.var("l")
A = tvm.placeholder((n, m, l), name='A')
B = tvm.compute((n, m, l), lambda bi, bj, bk: A[bi, bj, bk] + 1, name='B')
r = tvm.reduce_axis((0, m), "r")
C = tvm.compute((n, m,),
lambda ci, cj: tvm.sum(B[ci, cj, r], axis=r),
name="C")
s = tvm.create_schedule(C.op)
s.cache_read(A, 'shared', [B])
s[B].vectorize(B.op.axis[-1])
s[C].reorder(C.op.reduce_axis[0], C.op.axis[0])
_, ki = s[C].split(C.op.reduce_axis[0], factor=16)
Cr = s.rfactor(C, ki)
s[Cr].compute_at(s[C], s[C].op.axis[-1])
s[C].bind(s[C].op.axis[0], block_x)
s[C].bind(s[C].op.axis[1], thread_x)
def verify():
from tvm.contrib import tedd
str = tedd.viz_schedule_tree(s, False, '', True)
findany(r"digraph \"Schedule Tree\"", str)
findany(r"subgraph cluster_legend", str)
# Check the A_shared stage, including memory scope, itervars,
# and compute
findany(r"Stage_1.*A\.shared<br/>Scope: shared.+>0.+>" \
r"ax0\(kDataPar\).+>1.+ax1\(kDataPar\).+>2.+>ax2\(kDataPar\).+>" \
r"\[A\(ax0, ax1, ax2\)\]", str)
# Check itervars of types different from KDataPar
findany(r"bk\(kVectorized\)", str)
findany(r"r.outer\(kCommReduce\)", str)
findany(r"label=ROOT", str)
# Check the compute_at edge
findany(r"Stage_1", str)
if checkdepdency():
verify()
if __name__ == "__main__":
test_dfg()
test_itervar_relationship_graph()
test_schedule_tree()
\ No newline at end of file
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Use Tensor Expression Debug Display (TEDD) for Visualization
============================================================
**Author**: `Yongfeng Gu <https://github.com/yongfeng-nv>`_
This is an introduction about using TEDD to visualize tensor expressions.
Tensor Expressions are scheduled with primitives. Although individual
primitives are usually easy to understand, they become complicated quickly
when you put them together. We have introduced an operational model of
schedule primitives in Tensor Expression in this document
(https://docs.google.com/document/d/1nmz00_n4Ju-SpYN0QFl3abTHTlR_P0dRyo5zsWC0Q1k/edit?usp=sharing)
to make it easier to understand
* the interactions between different schedule primitives,
* the impact of the schedule primitives on the final code generation.
The operational model is based on a Dataflow Graph, a Schedule Tree and an
IterVar Relationship Graph. Schedule primitives perform operations on these
graphs.
TEDD renders these three graphs from a given schedule. This tutorial demonstrates
how to use TEDD and how to interpret the rendered graphs.
"""
from __future__ import absolute_import, print_function
import tvm
import topi
from tvm.contrib import tedd
######################################################################
# Define and Schedule Convolution with Bias and ReLU
# --------------------------------------------------
# Let's build an example Tensor Expression for a convolution followed by Bias and ReLU.
# We first connect conv2d, add, and relu TOPIs. Then, we create a TOPI generic schedule.
#
batch = 1
in_channel = 256
in_size = 32
num_filter = 256
kernel = 3
stride = 1
padding = "SAME"
dilation=1
A = tvm.placeholder((in_size, in_size, in_channel, batch), name='A')
W = tvm.placeholder((kernel, kernel, in_channel, num_filter), name='W')
B = tvm.placeholder((1, num_filter, 1), name='bias')
with tvm.target.create("cuda"):
t_conv = topi.nn.conv2d(A, W, stride, padding, dilation, layout='HWCN')
t_bias = topi.add(t_conv, B)
t_relu = topi.nn.relu(t_bias)
s = topi.generic.schedule_conv2d_hwcn([t_relu])
######################################################################
# Render Graphs with TEDD
# -----------------------
# We render graphs to see the computation
# and how it is scheduled.
# If you run the tutorial in a Jupyter notebook, you can use the following commented lines
# to render SVG figures showing in notebook directly.
#
tedd.viz_dataflow_graph(s, dot_file_path = '/tmp/dfg.dot')
#tedd.viz_dataflow_graph(s, show_svg = True)
######################################################################
# .. image:: https://github.com/dmlc/web-data/raw/master/tvm/tutorial/tedd_dfg.png
# :align: center
# :scale: 100%
#
# The first one is a dataflow graph. Every node represents a stage with name and memory
# scope shown in the middle and inputs/outputs information on the sides.
# Edges show nodes' dependency.
#
tedd.viz_schedule_tree(s, dot_file_path = '/tmp/scheduletree.dot')
#tedd.viz_schedule_tree(s, show_svg = True)
######################################################################
# We just rendered the schedule tree graph. You may notice an warning about ranges not
# available.
# The message also suggests to call normalize() to infer range information. We will
# skip inspecting the first schedule tree and encourage you to compare the graphs before
# and after normalize() for its impact.
#
s = s.normalize()
tedd.viz_schedule_tree(s, dot_file_path = '/tmp/scheduletree2.dot')
#tedd.viz_schedule_tree(s, show_svg = True)
######################################################################
# .. image:: https://github.com/dmlc/web-data/raw/master/tvm/tutorial/tedd_st.png
# :align: center
# :scale: 100%
#
# Now, let us take a close look at the second schedule tree. Every block under ROOT
# represents a
# stage. Stage name shows in the top row and compute shows in the bottom row.
# The middle rows are for IterVars, the higher the outer, the lower the inner.
# An IterVar row contains its index, name, type, and other optional information.
# Let's use the W.shared stage as an example. The top row tells
# its name, "W.shared", and memory scope, "Shared". Its compute is
# :code:`W(ax0, ax1, ax2, ax3)`.
# Its outer most loop IterVar is ax0.ax1.fused.ax2.fused.ax3.fused.outer,
# indexed with 0, of kDataPar, bound to threadIdx.y, and with range(min=0, ext=8).
# You can also tell
# IterVar type with the index box color, shown in the legend.
#
# If a stage doesn't compute_at any other stage, it has an edge directly to the
# ROOT node. Otherwise, it has an edge pointing to the IterVar it attaches to,
# such as W.shared attaches to rx.outer in the middle compute stage.
#
######################################################################
# .. note::
#
# By definition, IterVars are internal nodes and computes are leaf nodes in
# a schedule tree. The edges among IterVars and compute within one stage are
# omitted, making every stage a block, for better readability.
#
tedd.viz_itervar_relationship_graph(s, dot_file_path = '/tmp/itervar.dot')
#tedd.viz_itervar_relationship_graph(s, show_svg = True)
######################################################################
# .. image:: https://github.com/dmlc/web-data/raw/master/tvm/tutorial/tedd_itervar_rel.png
# :align: center
# :scale: 100%
#
# The last one is an IterVar Relationship Graph. Every subgraph represents a
# stage and contains IterVar nodes and transformation nodes. For example,
# W.shared has three split nodes and three fuse nodes. The rest are IterVar
# nodes of the same format as the IterVar rows in Schedule Trees. Root
# IterVars are those not driven by any transformation node, such as ax0; leaf
# IterVars don't drive any transformation node and have non-negative indices,
# such as ax0.ax1.fused.ax2.fused.ax3.fused.outer with index of 0.
#
######################################################################
# Summary
# -------
# This tutorial demonstrates the usage of TEDD. We use an example built
# with TOPI to show the schedules under the hood. You can also use
# it before and after any schedule primitive to inspect its effect.
#
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment