Unverified Commit eba50ad8 by Zhi Committed by GitHub

[Relay][pass] call graph for relay (#4922)

* call graph for relay

* CallGraphEntryNode->CallGraphEntry, __getitem__->print_var

* fix typos
parent 61bea507
......@@ -19,6 +19,7 @@
import os
from sys import setrecursionlimit
from ..api import register_func
from . import call_graph
from . import base
from . import ty
from . import expr
......@@ -141,3 +142,6 @@ Sequential = transform.Sequential
# Feature
Feature = feature.Feature
# CallGraph
CallGraph = call_graph.CallGraph
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, unused-import
"""Call graph used in Relay."""
from tvm.ir import IRModule
from .base import Object
from .expr import GlobalVar
from . import _analysis
class CallGraph(Object):
"""Class to represent a call graph."""
def __init__(self, module):
"""Construct a call graph.
Parameters
----------
module : tvm.ir.IRModule
The IR module used to create a call graph
Returns
-------
call_graph: CallGraph
A constructed call graph.
"""
self.__init_handle_by_constructor__(_analysis.CallGraph, module)
@property
def module(self):
"""Return the contained Relay IR module.
Parameters
----------
None
Returns
-------
ret : tvm.ir.IRModule
The contained IRModule
"""
return _analysis.GetModule(self)
def ref_count(self, var):
"""Return the number of references to the global var
Parameters
----------
var : Union[String, tvm.relay.GlobalVar]
Returns
-------
ret : int
The number reference to the global var
"""
var = self._get_global_var(var)
return _analysis.GetRefCountGlobalVar(self, var)
def global_call_count(self, var):
"""Return the number of global function calls from a given global var.
Parameters
----------
var : Union[String, tvm.relay.GlobalVar]
Returns
-------
ret : int
The number of global function calls from the given var.
"""
var = self._get_global_var(var)
return _analysis.GetGlobalVarCallCount(self, var)
def is_recursive(self, var):
"""Return if the function corresponding to a var is a recursive
function.
Parameters
----------
var : Union[String, tvm.relay.GlobalVar]
Returns
-------
ret : Boolean
If the function corresponding to var is recurisve.
"""
var = self._get_global_var(var)
return _analysis.IsRecursive(self, var)
def _get_global_var(self, var):
"""Return the global var using a given name or GlobalVar.
Parameters
----------
var : Union[String, tvm.relay.GlobalVar]
Returns
-------
ret : tvm.relay.GlobalVar
The global var.
"""
if isinstance(var, str):
mod = self.module
var = mod.get_global_var(var)
if isinstance(var, GlobalVar):
return var
else:
raise TypeError("var should be either a string or GlobalVar")
def print_var(self, var):
"""Print a call graph of a global function by name or by variable.
Parameters
----------
var: Union[String, tvm.relay.GlobalVar]
The name or global variable.
Returns
-------
ret : String
The call graph represented in string.
"""
var = self._get_global_var(var)
return _analysis.PrintCallGraphGlobalVar(self, var)
def __str__(self):
"""Print the call graph in the topological order."""
return _analysis.PrintCallGraph(self)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
import pytest
import tvm
from tvm import relay
def test_callgraph_construct():
mod = tvm.IRModule({})
x = relay.var("x", shape=(2, 3))
y = relay.var("y", shape=(2, 3))
mod["g1"] = relay.Function([x, y], x + y)
call_graph = relay.CallGraph(mod)
assert "g1" in str(call_graph)
assert relay.alpha_equal(mod, call_graph.module)
def test_print_element():
mod = tvm.IRModule({})
x0 = relay.var("x0", shape=(2, 3))
y0 = relay.var("y0", shape=(2, 3))
mod["g0"] = relay.Function([x0, y0], x0 + y0)
x1 = relay.var("x1", shape=(2, 3))
y1 = relay.var("y1", shape=(2, 3))
mod["g1"] = relay.Function([x1, y1], x1 - y1)
call_graph = relay.CallGraph(mod)
assert "#refs = 0" in str(call_graph.print_var("g0"))
assert "#refs = 0" in str(call_graph.print_var("g1"))
def test_global_call_count():
mod = tvm.IRModule({})
x0 = relay.var("x0", shape=(2, 3))
y0 = relay.var("y0", shape=(2, 3))
g0 = relay.GlobalVar("g0")
mod[g0] = relay.Function([x0, y0], x0 + y0)
x1 = relay.var("x1", shape=(2, 3))
y1 = relay.var("y1", shape=(2, 3))
g1 = relay.GlobalVar("g1")
mod[g1] = relay.Function([x1, y1], g0(x1, y1))
call_graph = relay.CallGraph(mod)
p0 = relay.var("p0", shape=(2, 3))
p1 = relay.var("p1", shape=(2, 3))
func = relay.Function([p0, p1], g0(p0, p1) * g1(p0, p1))
mod["main"] = func
call_graph = relay.CallGraph(mod)
assert call_graph.global_call_count(g0) == 0
assert call_graph.global_call_count(g1) == 1
assert call_graph.global_call_count("main") == 2
def test_ref_count():
mod = tvm.IRModule({})
x0 = relay.var("x0", shape=(2, 3))
y0 = relay.var("y0", shape=(2, 3))
g0 = relay.GlobalVar("g0")
mod[g0] = relay.Function([x0, y0], x0 + y0)
x1 = relay.var("x1", shape=(2, 3))
y1 = relay.var("y1", shape=(2, 3))
g1 = relay.GlobalVar("g1")
mod[g1] = relay.Function([x1, y1], x1 - y1)
call_graph = relay.CallGraph(mod)
p0 = relay.var("p0", shape=(2, 3))
p1 = relay.var("p1", shape=(2, 3))
func = relay.Function([p0, p1], g0(p0, p1) * g1(p0, p1))
mod["main"] = func
call_graph = relay.CallGraph(mod)
assert call_graph.ref_count(g0) == 1
assert call_graph.ref_count(g1) == 1
assert call_graph.ref_count("main") == 0
def test_nested_ref():
mod = tvm.IRModule({})
x0 = relay.var("x0", shape=(2, 3))
y0 = relay.var("y0", shape=(2, 3))
g0 = relay.GlobalVar("g0")
mod[g0] = relay.Function([x0, y0], x0 + y0)
x1 = relay.var("x1", shape=(2, 3))
y1 = relay.var("y1", shape=(2, 3))
g1 = relay.GlobalVar("g1")
mod[g1] = relay.Function([x1, y1], g0(x1, y1))
call_graph = relay.CallGraph(mod)
p0 = relay.var("p0", shape=(2, 3))
p1 = relay.var("p1", shape=(2, 3))
func = relay.Function([p0, p1], g0(p0, p1) * g1(p0, p1))
mod["main"] = func
call_graph = relay.CallGraph(mod)
assert call_graph.ref_count(g0) == 2
assert call_graph.ref_count(g1) == 1
assert call_graph.ref_count("main") == 0
def test_recursive_func():
mod = tvm.IRModule({})
x = relay.var('x', shape=[], dtype='int32')
fn0 = relay.Function([x], x)
gx = relay.GlobalVar("gx")
mod[gx] = fn0
sum_up = relay.GlobalVar('sum_up')
i = relay.var('i', shape=[], dtype='int32')
sb = relay.ScopeBuilder()
with sb.if_scope(relay.equal(i, relay.const(0, dtype='int32'))):
sb.ret(i)
with sb.else_scope():
one_less = relay.subtract(i, relay.const(1, dtype='int32'))
global_call = gx(i)
rec_call = relay.Call(sum_up, [one_less]) + global_call
sb.ret(relay.add(rec_call, i))
func = relay.Function([i],
sb.get(),
ret_type=relay.TensorType([], 'int32'))
func = func.set_attribute("Compiler", tvm.tir.StringImm("a"))
mod[sum_up] = func
iarg = relay.var('i', shape=[], dtype='int32')
mod["main"] = relay.Function([iarg], sum_up(iarg))
call_graph = relay.CallGraph(mod)
assert call_graph.is_recursive(sum_up)
assert call_graph.ref_count(sum_up) == 2
assert call_graph.ref_count(gx) == 1
assert call_graph.ref_count("main") == 0
if __name__ == "__main__":
pytest.main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment