Unverified Commit 4aff8dc1 by Zhi Committed by GitHub

[Refactor][Relay] Refactor Relay Python to use new FFI (#5077)

* refactor relay python

* revert relay/ir/*.py to relay

* Address comments

* remove direct access to analysis and transform namespace
parent 646cfc63
...@@ -19,10 +19,6 @@ tvm.relay.base ...@@ -19,10 +19,6 @@ tvm.relay.base
-------------- --------------
.. automodule:: tvm.relay.base .. automodule:: tvm.relay.base
.. autofunction:: tvm.relay.base.register_relay_node
.. autofunction:: tvm.relay.base.register_relay_attr_node
.. autoclass:: tvm.relay.base.RelayNode .. autoclass:: tvm.relay.base.RelayNode
:members: :members:
......
...@@ -19,35 +19,37 @@ ...@@ -19,35 +19,37 @@
import os import os
from sys import setrecursionlimit from sys import setrecursionlimit
from . import call_graph
from . import base from . import base
from . import ty from . import ty
from . import expr from . import expr
from . import type_functor from . import type_functor
from . import expr_functor from . import expr_functor
from . import adt from . import adt
from . import analysis from . import prelude
from . import loops
from . import scope_builder
from . import parser
from . import transform from . import transform
from . import analysis
from .analysis import alpha_equal
from .build_module import build, create_executor, optimize from .build_module import build, create_executor, optimize
from .transform import build_config from .transform import build_config
from . import prelude
from . import parser
from . import debug from . import debug
from . import param_dict from . import param_dict
from . import feature
from .backend import vm from .backend import vm
# Root operators # Root operators
from .op import Op from .op import Op
from .op import nn
from .op import image
from .op import annotation
from .op import vision
from .op import contrib
from .op.reduce import * from .op.reduce import *
from .op.tensor import * from .op.tensor import *
from .op.transform import * from .op.transform import *
from .op.algorithm import * from .op.algorithm import *
from . import nn
from . import annotation
from . import vision
from . import contrib
from . import image
from . import frontend from . import frontend
from . import backend from . import backend
from . import quantize from . import quantize
...@@ -55,15 +57,12 @@ from . import quantize ...@@ -55,15 +57,12 @@ from . import quantize
# Dialects # Dialects
from . import qnn from . import qnn
from .scope_builder import ScopeBuilder
# Load Memory pass
from . import memory_alloc
# Required to traverse large programs # Required to traverse large programs
setrecursionlimit(10000) setrecursionlimit(10000)
# Span # Span
Span = base.Span Span = base.Span
SourceName = base.SourceName
# Type # Type
Type = ty.Type Type = ty.Type
...@@ -98,6 +97,7 @@ RefRead = expr.RefRead ...@@ -98,6 +97,7 @@ RefRead = expr.RefRead
RefWrite = expr.RefWrite RefWrite = expr.RefWrite
# ADT # ADT
Pattern = adt.Pattern
PatternWildcard = adt.PatternWildcard PatternWildcard = adt.PatternWildcard
PatternVar = adt.PatternVar PatternVar = adt.PatternVar
PatternConstructor = adt.PatternConstructor PatternConstructor = adt.PatternConstructor
...@@ -111,9 +111,6 @@ Match = adt.Match ...@@ -111,9 +111,6 @@ Match = adt.Match
var = expr.var var = expr.var
const = expr.const const = expr.const
bind = expr.bind bind = expr.bind
module_pass = transform.module_pass
function_pass = transform.function_pass
alpha_equal = analysis.alpha_equal
# TypeFunctor # TypeFunctor
TypeFunctor = type_functor.TypeFunctor TypeFunctor = type_functor.TypeFunctor
...@@ -125,6 +122,15 @@ ExprFunctor = expr_functor.ExprFunctor ...@@ -125,6 +122,15 @@ ExprFunctor = expr_functor.ExprFunctor
ExprVisitor = expr_functor.ExprVisitor ExprVisitor = expr_functor.ExprVisitor
ExprMutator = expr_functor.ExprMutator ExprMutator = expr_functor.ExprMutator
# Prelude
Prelude = prelude.Prelude
# Scope builder
ScopeBuilder = scope_builder.ScopeBuilder
module_pass = transform.module_pass
function_pass = transform.function_pass
# Parser # Parser
fromtext = parser.fromtext fromtext = parser.fromtext
...@@ -139,9 +145,3 @@ Pass = transform.Pass ...@@ -139,9 +145,3 @@ Pass = transform.Pass
ModulePass = transform.ModulePass ModulePass = transform.ModulePass
FunctionPass = transform.FunctionPass FunctionPass = transform.FunctionPass
Sequential = transform.Sequential Sequential = transform.Sequential
# Feature
Feature = feature.Feature
# CallGraph
CallGraph = call_graph.CallGraph
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""FFI exposing the passes for Relay program analysis.""" """FFI APIs for Relay program IR."""
import tvm._ffi import tvm._ffi
tvm._ffi._init_api("relay._analysis", __name__) tvm._ffi._init_api("relay.ir", __name__)
...@@ -17,9 +17,11 @@ ...@@ -17,9 +17,11 @@
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, unused-import # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, unused-import
"""Algebraic data types in Relay.""" """Algebraic data types in Relay."""
from tvm.ir import Constructor, TypeData from tvm.ir import Constructor, TypeData
from tvm.runtime import Object
import tvm._ffi
from .base import RelayNode, register_relay_node, Object from .base import RelayNode
from . import _make from . import _ffi_api
from .ty import Type from .ty import Type
from .expr import ExprWithOp, RelayExpr, Call from .expr import ExprWithOp, RelayExpr, Call
...@@ -28,7 +30,7 @@ class Pattern(RelayNode): ...@@ -28,7 +30,7 @@ class Pattern(RelayNode):
"""Base type for pattern matching constructs.""" """Base type for pattern matching constructs."""
@register_relay_node @tvm._ffi.register_object("relay.PatternWildcard")
class PatternWildcard(Pattern): class PatternWildcard(Pattern):
"""Wildcard pattern in Relay: Matches any ADT and binds nothing.""" """Wildcard pattern in Relay: Matches any ADT and binds nothing."""
...@@ -44,10 +46,10 @@ class PatternWildcard(Pattern): ...@@ -44,10 +46,10 @@ class PatternWildcard(Pattern):
wildcard: PatternWildcard wildcard: PatternWildcard
a wildcard pattern. a wildcard pattern.
""" """
self.__init_handle_by_constructor__(_make.PatternWildcard) self.__init_handle_by_constructor__(_ffi_api.PatternWildcard)
@register_relay_node @tvm._ffi.register_object("relay.PatternVar")
class PatternVar(Pattern): class PatternVar(Pattern):
"""Variable pattern in Relay: Matches anything and binds it to the variable.""" """Variable pattern in Relay: Matches anything and binds it to the variable."""
...@@ -63,10 +65,10 @@ class PatternVar(Pattern): ...@@ -63,10 +65,10 @@ class PatternVar(Pattern):
pv: PatternVar pv: PatternVar
A variable pattern. A variable pattern.
""" """
self.__init_handle_by_constructor__(_make.PatternVar, var) self.__init_handle_by_constructor__(_ffi_api.PatternVar, var)
@register_relay_node @tvm._ffi.register_object("relay.PatternConstructor")
class PatternConstructor(Pattern): class PatternConstructor(Pattern):
"""Constructor pattern in Relay: Matches an ADT of the given constructor, binds recursively.""" """Constructor pattern in Relay: Matches an ADT of the given constructor, binds recursively."""
...@@ -88,10 +90,10 @@ class PatternConstructor(Pattern): ...@@ -88,10 +90,10 @@ class PatternConstructor(Pattern):
""" """
if patterns is None: if patterns is None:
patterns = [] patterns = []
self.__init_handle_by_constructor__(_make.PatternConstructor, constructor, patterns) self.__init_handle_by_constructor__(_ffi_api.PatternConstructor, constructor, patterns)
@register_relay_node @tvm._ffi.register_object("relay.PatternTuple")
class PatternTuple(Pattern): class PatternTuple(Pattern):
"""Constructor pattern in Relay: Matches a tuple, binds recursively.""" """Constructor pattern in Relay: Matches a tuple, binds recursively."""
...@@ -111,10 +113,10 @@ class PatternTuple(Pattern): ...@@ -111,10 +113,10 @@ class PatternTuple(Pattern):
""" """
if patterns is None: if patterns is None:
patterns = [] patterns = []
self.__init_handle_by_constructor__(_make.PatternTuple, patterns) self.__init_handle_by_constructor__(_ffi_api.PatternTuple, patterns)
@register_relay_node @tvm._ffi.register_object("relay.Clause")
class Clause(Object): class Clause(Object):
"""Clause for pattern matching in Relay.""" """Clause for pattern matching in Relay."""
...@@ -133,10 +135,10 @@ class Clause(Object): ...@@ -133,10 +135,10 @@ class Clause(Object):
clause: Clause clause: Clause
The Clause. The Clause.
""" """
self.__init_handle_by_constructor__(_make.Clause, lhs, rhs) self.__init_handle_by_constructor__(_ffi_api.Clause, lhs, rhs)
@register_relay_node @tvm._ffi.register_object("relay.Match")
class Match(ExprWithOp): class Match(ExprWithOp):
"""Pattern matching expression in Relay.""" """Pattern matching expression in Relay."""
...@@ -160,4 +162,4 @@ class Match(ExprWithOp): ...@@ -160,4 +162,4 @@ class Match(ExprWithOp):
match: tvm.relay.Expr match: tvm.relay.Expr
The match expression. The match expression.
""" """
self.__init_handle_by_constructor__(_make.Match, data, clauses, complete) self.__init_handle_by_constructor__(_ffi_api.Match, data, clauses, complete)
...@@ -14,8 +14,15 @@ ...@@ -14,8 +14,15 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=wildcard-import, redefined-builtin, invalid-name
"""The Relay IR namespace containing the analysis passes."""
# Analysis passes
from .analysis import *
# pylint: disable=wildcard-import, unused-import, unused-wildcard-import # Call graph
"""Annotation related operators.""" from . import call_graph
# Re-export in a specific file name so that autodoc can pick it up
from .op.annotation import * # Feature
from . import feature
CallGraph = call_graph.CallGraph
...@@ -14,8 +14,7 @@ ...@@ -14,8 +14,7 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable """FFI APIs for Relay program analysis."""
"""The interface of expr function exposed from C++."""
import tvm._ffi import tvm._ffi
tvm._ffi._init_api("relay._base", __name__) tvm._ffi._init_api("relay.analysis", __name__)
...@@ -22,9 +22,9 @@ configuring the passes and scripting them in Python. ...@@ -22,9 +22,9 @@ configuring the passes and scripting them in Python.
""" """
from tvm.ir import RelayExpr, IRModule from tvm.ir import RelayExpr, IRModule
from . import _analysis from . import _ffi_api
from .ty import Type
from .feature import Feature from .feature import Feature
from ..ty import Type
def post_order_visit(expr, fvisit): def post_order_visit(expr, fvisit):
...@@ -40,7 +40,7 @@ def post_order_visit(expr, fvisit): ...@@ -40,7 +40,7 @@ def post_order_visit(expr, fvisit):
fvisit : function fvisit : function
The visitor function to be applied. The visitor function to be applied.
""" """
return _analysis.post_order_visit(expr, fvisit) return _ffi_api.post_order_visit(expr, fvisit)
def well_formed(expr): def well_formed(expr):
...@@ -56,7 +56,7 @@ def well_formed(expr): ...@@ -56,7 +56,7 @@ def well_formed(expr):
well_form : bool well_form : bool
Whether the input expression is well formed Whether the input expression is well formed
""" """
return _analysis.well_formed(expr) return _ffi_api.well_formed(expr)
def check_kind(t, mod=None): def check_kind(t, mod=None):
...@@ -85,9 +85,9 @@ def check_kind(t, mod=None): ...@@ -85,9 +85,9 @@ def check_kind(t, mod=None):
assert check_kind(relay.TupleType([relay.TypeParam('tp1', relay.Kind.Type)])) == Type assert check_kind(relay.TupleType([relay.TypeParam('tp1', relay.Kind.Type)])) == Type
""" """
if mod is not None: if mod is not None:
return _analysis.check_kind(t, mod) return _ffi_api.check_kind(t, mod)
else: else:
return _analysis.check_kind(t) return _ffi_api.check_kind(t)
def check_constant(expr): def check_constant(expr):
...@@ -103,7 +103,7 @@ def check_constant(expr): ...@@ -103,7 +103,7 @@ def check_constant(expr):
result : bool result : bool
Whether the expression is constant. Whether the expression is constant.
""" """
return _analysis.check_constant(expr) return _ffi_api.check_constant(expr)
def free_vars(expr): def free_vars(expr):
...@@ -125,7 +125,7 @@ def free_vars(expr): ...@@ -125,7 +125,7 @@ def free_vars(expr):
neural networks: usually this means weights of previous neural networks: usually this means weights of previous
are ordered first. are ordered first.
""" """
return _analysis.free_vars(expr) return _ffi_api.free_vars(expr)
def bound_vars(expr): def bound_vars(expr):
...@@ -141,7 +141,7 @@ def bound_vars(expr): ...@@ -141,7 +141,7 @@ def bound_vars(expr):
free : List[tvm.relay.Var] free : List[tvm.relay.Var]
The list of bound variables in post-DFS order. The list of bound variables in post-DFS order.
""" """
return _analysis.bound_vars(expr) return _ffi_api.bound_vars(expr)
def all_vars(expr): def all_vars(expr):
...@@ -157,7 +157,7 @@ def all_vars(expr): ...@@ -157,7 +157,7 @@ def all_vars(expr):
free : List[tvm.relay.Var] free : List[tvm.relay.Var]
The list of all variables in post-DFS order. The list of all variables in post-DFS order.
""" """
return _analysis.all_vars(expr) return _ffi_api.all_vars(expr)
def free_type_vars(expr, mod=None): def free_type_vars(expr, mod=None):
...@@ -177,7 +177,7 @@ def free_type_vars(expr, mod=None): ...@@ -177,7 +177,7 @@ def free_type_vars(expr, mod=None):
The list of free type variables in post-DFS order The list of free type variables in post-DFS order
""" """
use_mod = mod if mod is not None else IRModule() use_mod = mod if mod is not None else IRModule()
return _analysis.free_type_vars(expr, use_mod) return _ffi_api.free_type_vars(expr, use_mod)
def bound_type_vars(expr, mod=None): def bound_type_vars(expr, mod=None):
...@@ -197,7 +197,7 @@ def bound_type_vars(expr, mod=None): ...@@ -197,7 +197,7 @@ def bound_type_vars(expr, mod=None):
The list of bound type variables in post-DFS order The list of bound type variables in post-DFS order
""" """
use_mod = mod if mod is not None else IRModule() use_mod = mod if mod is not None else IRModule()
return _analysis.bound_type_vars(expr, use_mod) return _ffi_api.bound_type_vars(expr, use_mod)
def all_type_vars(expr, mod=None): def all_type_vars(expr, mod=None):
...@@ -217,7 +217,7 @@ def all_type_vars(expr, mod=None): ...@@ -217,7 +217,7 @@ def all_type_vars(expr, mod=None):
The list of all type variables in post-DFS order The list of all type variables in post-DFS order
""" """
use_mod = mod if mod is not None else IRModule() use_mod = mod if mod is not None else IRModule()
return _analysis.all_type_vars(expr, use_mod) return _ffi_api.all_type_vars(expr, use_mod)
def alpha_equal(lhs, rhs): def alpha_equal(lhs, rhs):
...@@ -236,7 +236,7 @@ def alpha_equal(lhs, rhs): ...@@ -236,7 +236,7 @@ def alpha_equal(lhs, rhs):
result : bool result : bool
True iff lhs is alpha equal to rhs. True iff lhs is alpha equal to rhs.
""" """
return bool(_analysis._alpha_equal(lhs, rhs)) return bool(_ffi_api._alpha_equal(lhs, rhs))
def assert_alpha_equal(lhs, rhs): def assert_alpha_equal(lhs, rhs):
...@@ -250,7 +250,7 @@ def assert_alpha_equal(lhs, rhs): ...@@ -250,7 +250,7 @@ def assert_alpha_equal(lhs, rhs):
rhs : tvm.relay.Expr rhs : tvm.relay.Expr
One of the input Expression. One of the input Expression.
""" """
_analysis._assert_alpha_equal(lhs, rhs) _ffi_api._assert_alpha_equal(lhs, rhs)
def graph_equal(lhs, rhs): def graph_equal(lhs, rhs):
...@@ -272,7 +272,7 @@ def graph_equal(lhs, rhs): ...@@ -272,7 +272,7 @@ def graph_equal(lhs, rhs):
result : bool result : bool
True iff lhs is data-flow equivalent to rhs. True iff lhs is data-flow equivalent to rhs.
""" """
return bool(_analysis._graph_equal(lhs, rhs)) return bool(_ffi_api._graph_equal(lhs, rhs))
def assert_graph_equal(lhs, rhs): def assert_graph_equal(lhs, rhs):
...@@ -289,7 +289,7 @@ def assert_graph_equal(lhs, rhs): ...@@ -289,7 +289,7 @@ def assert_graph_equal(lhs, rhs):
rhs : tvm.relay.Expr rhs : tvm.relay.Expr
One of the input Expression. One of the input Expression.
""" """
_analysis._assert_graph_equal(lhs, rhs) _ffi_api._assert_graph_equal(lhs, rhs)
def collect_device_info(expr): def collect_device_info(expr):
...@@ -303,10 +303,10 @@ def collect_device_info(expr): ...@@ -303,10 +303,10 @@ def collect_device_info(expr):
Returns Returns
------- -------
ret : Dict[tvm.relay.expr, int] ret : Dict[tvm.relay.ir.expr, int]
A dictionary mapping tvm.relay.Expr to device type. A dictionary mapping tvm.relay.Expr to device type.
""" """
return _analysis.CollectDeviceInfo(expr) return _ffi_api.CollectDeviceInfo(expr)
def collect_device_annotation_ops(expr): def collect_device_annotation_ops(expr):
...@@ -319,11 +319,11 @@ def collect_device_annotation_ops(expr): ...@@ -319,11 +319,11 @@ def collect_device_annotation_ops(expr):
Returns Returns
------- -------
ret : Dict[tvm.relay.expr, int] ret : Dict[tvm.relay.Expr, int]
A dictionary mapping tvm.relay.Expr to device type where the keys are A dictionary mapping tvm.relay.Expr to device type where the keys are
annotation expressions. annotation expressions.
""" """
return _analysis.CollectDeviceAnnotationOps(expr) return _ffi_api.CollectDeviceAnnotationOps(expr)
def get_total_mac_number(expr): def get_total_mac_number(expr):
...@@ -340,7 +340,7 @@ def get_total_mac_number(expr): ...@@ -340,7 +340,7 @@ def get_total_mac_number(expr):
result : int64 result : int64
The number of MACs (multiply-accumulate) of a model The number of MACs (multiply-accumulate) of a model
""" """
return _analysis.GetTotalMacNumber(expr) return _ffi_api.GetTotalMacNumber(expr)
def unmatched_cases(match, mod=None): def unmatched_cases(match, mod=None):
...@@ -360,7 +360,7 @@ def unmatched_cases(match, mod=None): ...@@ -360,7 +360,7 @@ def unmatched_cases(match, mod=None):
missing_patterns : [tvm.relay.Pattern] missing_patterns : [tvm.relay.Pattern]
Patterns that the match expression does not catch. Patterns that the match expression does not catch.
""" """
return _analysis.unmatched_cases(match, mod) return _ffi_api.unmatched_cases(match, mod)
def detect_feature(a, b=None): def detect_feature(a, b=None):
...@@ -383,7 +383,7 @@ def detect_feature(a, b=None): ...@@ -383,7 +383,7 @@ def detect_feature(a, b=None):
""" """
if isinstance(a, IRModule): if isinstance(a, IRModule):
a, b = b, a a, b = b, a
return {Feature(int(x)) for x in _analysis.detect_feature(a, b)} return {Feature(int(x)) for x in _ffi_api.detect_feature(a, b)}
def structural_hash(value): def structural_hash(value):
...@@ -400,9 +400,9 @@ def structural_hash(value): ...@@ -400,9 +400,9 @@ def structural_hash(value):
The hash value The hash value
""" """
if isinstance(value, RelayExpr): if isinstance(value, RelayExpr):
return int(_analysis._expr_hash(value)) return int(_ffi_api._expr_hash(value))
elif isinstance(value, Type): elif isinstance(value, Type):
return int(_analysis._type_hash(value)) return int(_ffi_api._type_hash(value))
else: else:
msg = ("found value of type {0} expected" + msg = ("found value of type {0} expected" +
"relay.Expr or relay.Type").format(type(value)) "relay.Expr or relay.Type").format(type(value))
...@@ -421,10 +421,10 @@ def extract_fused_functions(mod): ...@@ -421,10 +421,10 @@ def extract_fused_functions(mod):
Returns Returns
------- -------
ret : Dict[int, tvm.relay.expr.Function] ret : Dict[int, tvm.relay.ir.expr.Function]
A module containing only fused primitive functions A module containing only fused primitive functions
""" """
ret_mod = _analysis.ExtractFusedFunctions()(mod) ret_mod = _ffi_api.ExtractFusedFunctions()(mod)
ret = {} ret = {}
for hash_, func in ret_mod.functions.items(): for hash_, func in ret_mod.functions.items():
ret[hash_] = func ret[hash_] = func
......
...@@ -18,9 +18,9 @@ ...@@ -18,9 +18,9 @@
"""Call graph used in Relay.""" """Call graph used in Relay."""
from tvm.ir import IRModule from tvm.ir import IRModule
from .base import Object from tvm.runtime import Object
from .expr import GlobalVar from ..expr import GlobalVar
from . import _analysis from . import _ffi_api
class CallGraph(Object): class CallGraph(Object):
...@@ -39,7 +39,7 @@ class CallGraph(Object): ...@@ -39,7 +39,7 @@ class CallGraph(Object):
call_graph: CallGraph call_graph: CallGraph
A constructed call graph. A constructed call graph.
""" """
self.__init_handle_by_constructor__(_analysis.CallGraph, module) self.__init_handle_by_constructor__(_ffi_api.CallGraph, module)
@property @property
def module(self): def module(self):
...@@ -54,7 +54,7 @@ class CallGraph(Object): ...@@ -54,7 +54,7 @@ class CallGraph(Object):
ret : tvm.ir.IRModule ret : tvm.ir.IRModule
The contained IRModule The contained IRModule
""" """
return _analysis.GetModule(self) return _ffi_api.GetModule(self)
def ref_count(self, var): def ref_count(self, var):
"""Return the number of references to the global var """Return the number of references to the global var
...@@ -69,7 +69,7 @@ class CallGraph(Object): ...@@ -69,7 +69,7 @@ class CallGraph(Object):
The number reference to the global var The number reference to the global var
""" """
var = self._get_global_var(var) var = self._get_global_var(var)
return _analysis.GetRefCountGlobalVar(self, var) return _ffi_api.GetRefCountGlobalVar(self, var)
def global_call_count(self, var): def global_call_count(self, var):
"""Return the number of global function calls from a given global var. """Return the number of global function calls from a given global var.
...@@ -84,7 +84,7 @@ class CallGraph(Object): ...@@ -84,7 +84,7 @@ class CallGraph(Object):
The number of global function calls from the given var. The number of global function calls from the given var.
""" """
var = self._get_global_var(var) var = self._get_global_var(var)
return _analysis.GetGlobalVarCallCount(self, var) return _ffi_api.GetGlobalVarCallCount(self, var)
def is_recursive(self, var): def is_recursive(self, var):
"""Return if the function corresponding to a var is a recursive """Return if the function corresponding to a var is a recursive
...@@ -100,7 +100,7 @@ class CallGraph(Object): ...@@ -100,7 +100,7 @@ class CallGraph(Object):
If the function corresponding to var is recurisve. If the function corresponding to var is recurisve.
""" """
var = self._get_global_var(var) var = self._get_global_var(var)
return _analysis.IsRecursive(self, var) return _ffi_api.IsRecursive(self, var)
def _get_global_var(self, var): def _get_global_var(self, var):
"""Return the global var using a given name or GlobalVar. """Return the global var using a given name or GlobalVar.
...@@ -137,8 +137,8 @@ class CallGraph(Object): ...@@ -137,8 +137,8 @@ class CallGraph(Object):
The call graph represented in string. The call graph represented in string.
""" """
var = self._get_global_var(var) var = self._get_global_var(var)
return _analysis.PrintCallGraphGlobalVar(self, var) return _ffi_api.PrintCallGraphGlobalVar(self, var)
def __str__(self): def __str__(self):
"""Print the call graph in the topological order.""" """Print the call graph in the topological order."""
return _analysis.PrintCallGraph(self) return _ffi_api.PrintCallGraph(self)
...@@ -22,7 +22,7 @@ import logging ...@@ -22,7 +22,7 @@ import logging
import numpy as np import numpy as np
import tvm import tvm
from tvm import te from tvm import te
from ..base import register_relay_node, Object from tvm.runtime import Object
from ... import target as _target from ... import target as _target
from ... import autotvm from ... import autotvm
from .. import expr as _expr from .. import expr as _expr
...@@ -33,7 +33,7 @@ from . import _backend ...@@ -33,7 +33,7 @@ from . import _backend
logger = logging.getLogger('compile_engine') logger = logging.getLogger('compile_engine')
@register_relay_node @tvm._ffi.register_object("relay.LoweredOutput")
class LoweredOutput(Object): class LoweredOutput(Object):
"""Lowered output""" """Lowered output"""
def __init__(self, outputs, implement): def __init__(self, outputs, implement):
...@@ -41,7 +41,7 @@ class LoweredOutput(Object): ...@@ -41,7 +41,7 @@ class LoweredOutput(Object):
_backend._make_LoweredOutput, outputs, implement) _backend._make_LoweredOutput, outputs, implement)
@register_relay_node @tvm._ffi.register_object("relay.CCacheKey")
class CCacheKey(Object): class CCacheKey(Object):
"""Key in the CompileEngine. """Key in the CompileEngine.
...@@ -58,7 +58,7 @@ class CCacheKey(Object): ...@@ -58,7 +58,7 @@ class CCacheKey(Object):
_backend._make_CCacheKey, source_func, target) _backend._make_CCacheKey, source_func, target)
@register_relay_node @tvm._ffi.register_object("relay.CCacheValue")
class CCacheValue(Object): class CCacheValue(Object):
"""Value in the CompileEngine, including usage statistics. """Value in the CompileEngine, including usage statistics.
""" """
...@@ -261,7 +261,7 @@ def lower_call(call, inputs, target): ...@@ -261,7 +261,7 @@ def lower_call(call, inputs, target):
return LoweredOutput(outputs, best_impl) return LoweredOutput(outputs, best_impl)
@register_relay_node @tvm._ffi.register_object("relay.CompileEngine")
class CompileEngine(Object): class CompileEngine(Object):
"""CompileEngine to get lowered code. """CompileEngine to get lowered code.
""" """
......
...@@ -20,25 +20,25 @@ from __future__ import absolute_import ...@@ -20,25 +20,25 @@ from __future__ import absolute_import
import numpy as np import numpy as np
from tvm.runtime import container import tvm._ffi
from tvm.runtime import container, Object
from tvm.ir import IRModule from tvm.ir import IRModule
from . import _backend from . import _backend
from .. import _make, analysis, transform from .. import _make, analysis, transform
from ... import nd from ... import nd
from ..base import Object, register_relay_node
from ..expr import Tuple, RefCreate, Call, Constant, GlobalVar, Function, const from ..expr import Tuple, RefCreate, Call, Constant, GlobalVar, Function, const
from ..scope_builder import ScopeBuilder from ..scope_builder import ScopeBuilder
@register_relay_node @tvm._ffi.register_object("relay.ConstructorValue")
class ConstructorValue(Object): class ConstructorValue(Object):
def __init__(self, tag, fields, constructor): def __init__(self, tag, fields, constructor):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_make.ConstructorValue, tag, fields, constructor) _make.ConstructorValue, tag, fields, constructor)
@register_relay_node @tvm._ffi.register_object("relay.RefValue")
class RefValue(Object): class RefValue(Object):
def __init__(self, value): def __init__(self, value):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
......
...@@ -21,47 +21,17 @@ import tvm._ffi ...@@ -21,47 +21,17 @@ import tvm._ffi
from tvm.runtime import Object from tvm.runtime import Object
from tvm.ir import SourceName, Span, Node as RelayNode from tvm.ir import SourceName, Span, Node as RelayNode
from . import _make
from . import _expr
from . import _base
__STD_PATH__ = os.path.join(os.path.dirname(os.path.realpath(__file__)), "std") __STD_PATH__ = os.path.join(os.path.dirname(os.path.realpath(__file__)), "std")
@tvm._ffi.register_func("tvm.relay.std_path") @tvm._ffi.register_func("tvm.relay.std_path")
def _std_path(): def _std_path():
return __STD_PATH__ return __STD_PATH__
def register_relay_node(type_key=None): @tvm._ffi.register_object("relay.Id")
"""Register a Relay node type.
Parameters
----------
type_key : str or cls
The type key of the node.
"""
if not isinstance(type_key, str):
return tvm._ffi.register_object(
"relay." + type_key.__name__)(type_key)
return tvm._ffi.register_object(type_key)
def register_relay_attr_node(type_key=None):
"""Register a Relay attribute node.
Parameters
----------
type_key : str or cls
The type key of the node.
"""
if not isinstance(type_key, str):
return tvm._ffi.register_object(
"relay.attrs." + type_key.__name__)(type_key)
return tvm._ffi.register_object(type_key)
@register_relay_node
class Id(Object): class Id(Object):
"""Unique identifier(name) used in Var. """Unique identifier(name) used in Var.
Guaranteed to be stable across all passes. Guaranteed to be stable across all passes.
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=wildcard-import, unused-import, unused-wildcard-import
"""Contrib operators."""
# Re-export in a specific file name so that autodoc can pick it up
from .op.contrib import *
...@@ -20,13 +20,13 @@ from __future__ import absolute_import ...@@ -20,13 +20,13 @@ from __future__ import absolute_import
from numbers import Number as _Number from numbers import Number as _Number
import numpy as _np import numpy as _np
import tvm._ffi
from tvm._ffi import base as _base from tvm._ffi import base as _base
from tvm.runtime import NDArray, convert, ndarray as _nd from tvm.runtime import NDArray, convert, ndarray as _nd
from tvm.ir import RelayExpr, GlobalVar, BaseFunc from tvm.ir import RelayExpr, GlobalVar, BaseFunc
from .base import RelayNode, register_relay_node from .base import RelayNode
from . import _make from . import _ffi_api
from . import _expr
from . import ty as _ty from . import ty as _ty
# alias relay expr as Expr. # alias relay expr as Expr.
...@@ -54,7 +54,7 @@ class ExprWithOp(RelayExpr): ...@@ -54,7 +54,7 @@ class ExprWithOp(RelayExpr):
result : tvm.relay.Expr result : tvm.relay.Expr
The result expression. The result expression.
""" """
return _make.cast(self, dtype) return _ffi_api.cast(self, dtype)
def __neg__(self): def __neg__(self):
return _op_make.negative(self) return _op_make.negative(self)
...@@ -160,7 +160,7 @@ class ExprWithOp(RelayExpr): ...@@ -160,7 +160,7 @@ class ExprWithOp(RelayExpr):
""" """
return Call(self, args) return Call(self, args)
@register_relay_node @tvm._ffi.register_object("relay.Constant")
class Constant(ExprWithOp): class Constant(ExprWithOp):
"""A constant expression in Relay. """A constant expression in Relay.
...@@ -170,10 +170,10 @@ class Constant(ExprWithOp): ...@@ -170,10 +170,10 @@ class Constant(ExprWithOp):
The data content of the constant expression. The data content of the constant expression.
""" """
def __init__(self, data): def __init__(self, data):
self.__init_handle_by_constructor__(_make.Constant, data) self.__init_handle_by_constructor__(_ffi_api.Constant, data)
@register_relay_node @tvm._ffi.register_object("relay.Tuple")
class Tuple(ExprWithOp): class Tuple(ExprWithOp):
"""Tuple expression that groups several fields together. """Tuple expression that groups several fields together.
...@@ -183,7 +183,7 @@ class Tuple(ExprWithOp): ...@@ -183,7 +183,7 @@ class Tuple(ExprWithOp):
The fields in the tuple. The fields in the tuple.
""" """
def __init__(self, fields): def __init__(self, fields):
self.__init_handle_by_constructor__(_make.Tuple, fields) self.__init_handle_by_constructor__(_ffi_api.Tuple, fields)
def __getitem__(self, index): def __getitem__(self, index):
if index >= len(self): if index >= len(self):
...@@ -197,7 +197,7 @@ class Tuple(ExprWithOp): ...@@ -197,7 +197,7 @@ class Tuple(ExprWithOp):
raise TypeError("astype cannot be used on tuple") raise TypeError("astype cannot be used on tuple")
@register_relay_node @tvm._ffi.register_object("relay.Var")
class Var(ExprWithOp): class Var(ExprWithOp):
"""A local variable in Relay. """A local variable in Relay.
...@@ -216,7 +216,7 @@ class Var(ExprWithOp): ...@@ -216,7 +216,7 @@ class Var(ExprWithOp):
""" """
def __init__(self, name_hint, type_annotation=None): def __init__(self, name_hint, type_annotation=None):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_make.Var, name_hint, type_annotation) _ffi_api.Var, name_hint, type_annotation)
@property @property
def name_hint(self): def name_hint(self):
...@@ -225,7 +225,7 @@ class Var(ExprWithOp): ...@@ -225,7 +225,7 @@ class Var(ExprWithOp):
return name return name
@register_relay_node @tvm._ffi.register_object("relay.Function")
class Function(BaseFunc): class Function(BaseFunc):
"""A function declaration expression. """A function declaration expression.
...@@ -254,7 +254,7 @@ class Function(BaseFunc): ...@@ -254,7 +254,7 @@ class Function(BaseFunc):
type_params = convert([]) type_params = convert([])
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_make.Function, params, body, ret_type, type_params, attrs) _ffi_api.Function, params, body, ret_type, type_params, attrs)
def __call__(self, *args): def __call__(self, *args):
"""Invoke the global function. """Invoke the global function.
...@@ -282,12 +282,12 @@ class Function(BaseFunc): ...@@ -282,12 +282,12 @@ class Function(BaseFunc):
func : Function func : Function
A new copy of the function A new copy of the function
""" """
return _expr.FunctionWithAttr( return _ffi_api.FunctionWithAttr(
self, attr_key, convert(attr_value)) self, attr_key, convert(attr_value))
@register_relay_node @tvm._ffi.register_object("relay.Call")
class Call(ExprWithOp): class Call(ExprWithOp):
"""Function call node in Relay. """Function call node in Relay.
...@@ -313,10 +313,10 @@ class Call(ExprWithOp): ...@@ -313,10 +313,10 @@ class Call(ExprWithOp):
if not type_args: if not type_args:
type_args = [] type_args = []
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_make.Call, op, args, attrs, type_args) _ffi_api.Call, op, args, attrs, type_args)
@register_relay_node @tvm._ffi.register_object("relay.Let")
class Let(ExprWithOp): class Let(ExprWithOp):
"""Let variable binding expression. """Let variable binding expression.
...@@ -333,10 +333,10 @@ class Let(ExprWithOp): ...@@ -333,10 +333,10 @@ class Let(ExprWithOp):
""" """
def __init__(self, variable, value, body): def __init__(self, variable, value, body):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_make.Let, variable, value, body) _ffi_api.Let, variable, value, body)
@register_relay_node @tvm._ffi.register_object("relay.If")
class If(ExprWithOp): class If(ExprWithOp):
"""A conditional expression in Relay. """A conditional expression in Relay.
...@@ -353,10 +353,10 @@ class If(ExprWithOp): ...@@ -353,10 +353,10 @@ class If(ExprWithOp):
""" """
def __init__(self, cond, true_branch, false_branch): def __init__(self, cond, true_branch, false_branch):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_make.If, cond, true_branch, false_branch) _ffi_api.If, cond, true_branch, false_branch)
@register_relay_node @tvm._ffi.register_object("relay.TupleGetItem")
class TupleGetItem(ExprWithOp): class TupleGetItem(ExprWithOp):
"""Get index-th item from a tuple. """Get index-th item from a tuple.
...@@ -370,10 +370,10 @@ class TupleGetItem(ExprWithOp): ...@@ -370,10 +370,10 @@ class TupleGetItem(ExprWithOp):
""" """
def __init__(self, tuple_value, index): def __init__(self, tuple_value, index):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_make.TupleGetItem, tuple_value, index) _ffi_api.TupleGetItem, tuple_value, index)
@register_relay_node @tvm._ffi.register_object("relay.RefCreate")
class RefCreate(ExprWithOp): class RefCreate(ExprWithOp):
"""Create a new reference from initial value. """Create a new reference from initial value.
Parameters Parameters
...@@ -382,10 +382,10 @@ class RefCreate(ExprWithOp): ...@@ -382,10 +382,10 @@ class RefCreate(ExprWithOp):
The initial value. The initial value.
""" """
def __init__(self, value): def __init__(self, value):
self.__init_handle_by_constructor__(_make.RefCreate, value) self.__init_handle_by_constructor__(_ffi_api.RefCreate, value)
@register_relay_node @tvm._ffi.register_object("relay.RefRead")
class RefRead(ExprWithOp): class RefRead(ExprWithOp):
"""Get the value inside the reference. """Get the value inside the reference.
Parameters Parameters
...@@ -394,10 +394,10 @@ class RefRead(ExprWithOp): ...@@ -394,10 +394,10 @@ class RefRead(ExprWithOp):
The reference. The reference.
""" """
def __init__(self, ref): def __init__(self, ref):
self.__init_handle_by_constructor__(_make.RefRead, ref) self.__init_handle_by_constructor__(_ffi_api.RefRead, ref)
@register_relay_node @tvm._ffi.register_object("relay.RefWrite")
class RefWrite(ExprWithOp): class RefWrite(ExprWithOp):
""" """
Update the value inside the reference. Update the value inside the reference.
...@@ -410,7 +410,7 @@ class RefWrite(ExprWithOp): ...@@ -410,7 +410,7 @@ class RefWrite(ExprWithOp):
The new value. The new value.
""" """
def __init__(self, ref, value): def __init__(self, ref, value):
self.__init_handle_by_constructor__(_make.RefWrite, ref, value) self.__init_handle_by_constructor__(_ffi_api.RefWrite, ref, value)
class TempExpr(ExprWithOp): class TempExpr(ExprWithOp):
...@@ -427,7 +427,7 @@ class TempExpr(ExprWithOp): ...@@ -427,7 +427,7 @@ class TempExpr(ExprWithOp):
------- -------
The corresponding normal expression. The corresponding normal expression.
""" """
return _expr.TempExprRealize(self) return _ffi_api.TempExprRealize(self)
class TupleWrapper(object): class TupleWrapper(object):
...@@ -587,4 +587,4 @@ def bind(expr, binds): ...@@ -587,4 +587,4 @@ def bind(expr, binds):
result : tvm.relay.Expr result : tvm.relay.Expr
The expression or function after binding. The expression or function after binding.
""" """
return _expr.Bind(expr, binds) return _ffi_api.Bind(expr, binds)
...@@ -26,8 +26,8 @@ from .. import analysis ...@@ -26,8 +26,8 @@ from .. import analysis
from .. import expr as _expr from .. import expr as _expr
from .. import op as _op from .. import op as _op
from .. import qnn as _qnn from .. import qnn as _qnn
from ..util import get_scalar_from_constant
from ... import nd as _nd from ... import nd as _nd
from .util import get_scalar_from_constant
from .common import ExprTable from .common import ExprTable
from .common import infer_shape as _infer_shape from .common import infer_shape as _infer_shape
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
""" Utility functions that are used across many directories. """ """ Utility functions that are used across many directories. """
from __future__ import absolute_import from __future__ import absolute_import
import numpy as np import numpy as np
from . import expr as _expr from .. import expr as _expr
def get_scalar_from_constant(expr): def get_scalar_from_constant(expr):
""" Returns scalar value from Relay constant scalar. """ """ Returns scalar value from Relay constant scalar. """
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=wildcard-import, unused-import, unused-wildcard-import
"""Image network related operators."""
# Re-export in a specific file name so that autodoc can pick it up
from .op.image import *
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=wildcard-import, unused-import, unused-wildcard-import
"""Neural network related operators."""
# Re-export in a specific file name so that autodoc can pick it up
from .op.nn import *
...@@ -41,7 +41,6 @@ from . import _tensor_grad ...@@ -41,7 +41,6 @@ from . import _tensor_grad
from . import _transform from . import _transform
from . import _reduce from . import _reduce
from . import _algorithm from . import _algorithm
from ..base import register_relay_node
def _register_op_make(): def _register_op_make():
......
...@@ -19,13 +19,12 @@ ...@@ -19,13 +19,12 @@
import tvm._ffi import tvm._ffi
from tvm.driver import lower, build from tvm.driver import lower, build
from ..base import register_relay_node
from ..expr import RelayExpr from ..expr import RelayExpr
from ...target import get_native_generic_func, GenericFunc from ...target import get_native_generic_func, GenericFunc
from ...runtime import Object from ...runtime import Object
from . import _make from . import _make
@register_relay_node @tvm._ffi.register_object("relay.Op")
class Op(RelayExpr): class Op(RelayExpr):
"""A Relay operator definition.""" """A Relay operator definition."""
......
...@@ -38,7 +38,7 @@ def cast(data, dtype): ...@@ -38,7 +38,7 @@ def cast(data, dtype):
result : relay.Expr result : relay.Expr
The casted result. The casted result.
""" """
from .. import _make as _relay_make from .. import _ffi_api as _relay_make
return _relay_make.cast(data, dtype) return _relay_make.cast(data, dtype)
...@@ -55,7 +55,7 @@ def cast_like(data, dtype_like): ...@@ -55,7 +55,7 @@ def cast_like(data, dtype_like):
result : relay.Expr result : relay.Expr
The casted result. The casted result.
""" """
from .. import _make as _relay_make from .. import _ffi_api as _relay_make
return _relay_make.cast_like(data, dtype_like) return _relay_make.cast_like(data, dtype_like)
......
...@@ -21,7 +21,7 @@ from __future__ import absolute_import ...@@ -21,7 +21,7 @@ from __future__ import absolute_import
import tvm import tvm
from tvm import relay from tvm import relay
from .. import op as reg from .. import op as reg
from ...util import get_scalar_from_constant from ...frontend.util import get_scalar_from_constant
################################################# #################################################
# Register the functions for different operators. # Register the functions for different operators.
......
...@@ -24,7 +24,6 @@ from .. import expr as _expr ...@@ -24,7 +24,6 @@ from .. import expr as _expr
from .. import analysis as _analysis from .. import analysis as _analysis
from .. import op as _op from .. import op as _op
from ..op import op as _reg from ..op import op as _reg
from ..base import register_relay_node
from . import _quantize from . import _quantize
from .quantize import QAnnotateKind, current_qconfig, quantize_context from .quantize import QAnnotateKind, current_qconfig, quantize_context
from .quantize import _forward_op from .quantize import _forward_op
...@@ -58,7 +57,7 @@ _reg.register_pattern("relay.op.annotation.simulated_quantize", ...@@ -58,7 +57,7 @@ _reg.register_pattern("relay.op.annotation.simulated_quantize",
_reg.register_injective_schedule("annotation.cast_hint") _reg.register_injective_schedule("annotation.cast_hint")
@register_relay_node @tvm._ffi.register_object("relay.QAnnotateExpr")
class QAnnotateExpr(_expr.TempExpr): class QAnnotateExpr(_expr.TempExpr):
"""A special kind of Expr for Annotating. """A special kind of Expr for Annotating.
......
...@@ -19,7 +19,6 @@ ...@@ -19,7 +19,6 @@
import tvm import tvm
from .. import expr as _expr from .. import expr as _expr
from .. import analysis as _analysis from .. import analysis as _analysis
from ..base import register_relay_node
from ..op import op as _reg from ..op import op as _reg
from . import _quantize from . import _quantize
from .quantize import _forward_op from .quantize import _forward_op
...@@ -30,7 +29,7 @@ def register_partition_function(op_name, frewrite=None, level=10): ...@@ -30,7 +29,7 @@ def register_partition_function(op_name, frewrite=None, level=10):
return _register(frewrite) if frewrite is not None else _register return _register(frewrite) if frewrite is not None else _register
@register_relay_node @tvm._ffi.register_object("relay.QPartitionExpr")
class QPartitionExpr(_expr.TempExpr): class QPartitionExpr(_expr.TempExpr):
def __init__(self, expr): def __init__(self, expr):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
......
...@@ -17,12 +17,12 @@ ...@@ -17,12 +17,12 @@
#pylint: disable=unused-argument, not-context-manager #pylint: disable=unused-argument, not-context-manager
"""Automatic quantization toolkit.""" """Automatic quantization toolkit."""
import tvm.ir import tvm.ir
from tvm.runtime import Object
from . import _quantize from . import _quantize
from ._calibrate import calibrate from ._calibrate import calibrate
from .. import expr as _expr from .. import expr as _expr
from .. import transform as _transform from .. import transform as _transform
from ..base import Object, register_relay_node
class QAnnotateKind(object): class QAnnotateKind(object):
...@@ -52,7 +52,7 @@ def _forward_op(ref_call, args): ...@@ -52,7 +52,7 @@ def _forward_op(ref_call, args):
ref_call.op, args, ref_call.attrs, ref_call.type_args) ref_call.op, args, ref_call.attrs, ref_call.type_args)
@register_relay_node("relay.quantize.QConfig") @tvm._ffi.register_object("relay.quantize.QConfig")
class QConfig(Object): class QConfig(Object):
"""Configure the quantization behavior by setting config variables. """Configure the quantization behavior by setting config variables.
......
...@@ -14,8 +14,9 @@ ...@@ -14,8 +14,9 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable # pylint: disable=wildcard-import, redefined-builtin, invalid-name
"""The interface of expr function exposed from C++.""" """The Relay IR namespace containing transformations."""
import tvm._ffi # transformation passes
from .transform import *
tvm._ffi._init_api("relay._expr", __name__) from . import memory_alloc
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""FFI exposing the Relay type inference and checking.""" """FFI APIs for Relay transformation passes."""
import tvm._ffi import tvm._ffi
tvm._ffi._init_api("relay._transform", __name__) tvm._ffi._init_api("relay._transform", __name__)
...@@ -19,12 +19,13 @@ ...@@ -19,12 +19,13 @@
A pass for manifesting explicit memory allocations. A pass for manifesting explicit memory allocations.
""" """
import numpy as np import numpy as np
from .expr_functor import ExprMutator from ..expr_functor import ExprMutator
from .scope_builder import ScopeBuilder from ..scope_builder import ScopeBuilder
from . import transform from . import transform
from . import op, ty, expr from .. import op
from .. import DataType, register_func from ... import DataType, register_func
from .backend import compile_engine from .. import ty, expr
from ..backend import compile_engine
def is_primitive(call): def is_primitive(call):
......
...@@ -20,10 +20,10 @@ from tvm.ir import Type, TypeKind, TypeVar, GlobalTypeVar ...@@ -20,10 +20,10 @@ from tvm.ir import Type, TypeKind, TypeVar, GlobalTypeVar
from tvm.ir import TypeConstraint, FuncType, TupleType, IncompleteType from tvm.ir import TypeConstraint, FuncType, TupleType, IncompleteType
from tvm.ir import TypeCall, TypeRelation, TensorType, RelayRefType as RefType from tvm.ir import TypeCall, TypeRelation, TensorType, RelayRefType as RefType
from .base import RelayNode, register_relay_node from .base import RelayNode
from . import _make from . import _ffi_api
Any = _make.Any Any = _ffi_api.Any
def type_has_any(tensor_type): def type_has_any(tensor_type):
"""Check whether type has any as a shape. """Check whether type has any as a shape.
...@@ -36,7 +36,7 @@ def type_has_any(tensor_type): ...@@ -36,7 +36,7 @@ def type_has_any(tensor_type):
has_any : bool has_any : bool
The check result. The check result.
""" """
return _make.IsDynamic(tensor_type) return _ffi_api.IsDynamic(tensor_type)
def ShapeVar(name): def ShapeVar(name):
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=wildcard-import, unused-import, unused-wildcard-import
"""Vision network related operators."""
# Re-export in a specific file name so that autodoc can pick it up
from .op.vision import *
...@@ -581,7 +581,7 @@ bool AlphaEqual(const Expr& lhs, const Expr& rhs) { ...@@ -581,7 +581,7 @@ bool AlphaEqual(const Expr& lhs, const Expr& rhs) {
return AlphaEqualHandler(false, false).ExprEqual(lhs, rhs); return AlphaEqualHandler(false, false).ExprEqual(lhs, rhs);
} }
TVM_REGISTER_GLOBAL("relay._analysis._alpha_equal") TVM_REGISTER_GLOBAL("relay.analysis._alpha_equal")
.set_body_typed([](ObjectRef a, ObjectRef b) { .set_body_typed([](ObjectRef a, ObjectRef b) {
return AlphaEqualHandler(false, false).Equal(a, b); return AlphaEqualHandler(false, false).Equal(a, b);
}); });
...@@ -591,18 +591,18 @@ TVM_REGISTER_GLOBAL("ir.type_alpha_equal") ...@@ -591,18 +591,18 @@ TVM_REGISTER_GLOBAL("ir.type_alpha_equal")
return AlphaEqual(a, b); return AlphaEqual(a, b);
}); });
TVM_REGISTER_GLOBAL("relay._analysis._assert_alpha_equal") TVM_REGISTER_GLOBAL("relay.analysis._assert_alpha_equal")
.set_body_typed([](ObjectRef a, ObjectRef b) { .set_body_typed([](ObjectRef a, ObjectRef b) {
bool alpha_equal = AlphaEqualHandler(false, true).Equal(a, b); bool alpha_equal = AlphaEqualHandler(false, true).Equal(a, b);
CHECK(alpha_equal) << AsText(a, true) << " and " << AsText(b, true) << " are not alpha equal"; CHECK(alpha_equal) << AsText(a, true) << " and " << AsText(b, true) << " are not alpha equal";
}); });
TVM_REGISTER_GLOBAL("relay._analysis._graph_equal") TVM_REGISTER_GLOBAL("relay.analysis._graph_equal")
.set_body_typed([](ObjectRef a, ObjectRef b) { .set_body_typed([](ObjectRef a, ObjectRef b) {
return AlphaEqualHandler(true, false).Equal(a, b); return AlphaEqualHandler(true, false).Equal(a, b);
}); });
TVM_REGISTER_GLOBAL("relay._analysis._assert_graph_equal") TVM_REGISTER_GLOBAL("relay.analysis._assert_graph_equal")
.set_body_typed([](ObjectRef a, ObjectRef b) { .set_body_typed([](ObjectRef a, ObjectRef b) {
bool graph_equal = AlphaEqualHandler(true, true).Equal(a, b); bool graph_equal = AlphaEqualHandler(true, true).Equal(a, b);
CHECK(graph_equal) << AsText(a, true) << " and " << AsText(b, true) << " are not graph equal"; CHECK(graph_equal) << AsText(a, true) << " and " << AsText(b, true) << " are not graph equal";
......
...@@ -299,24 +299,24 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) ...@@ -299,24 +299,24 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "CallGraph: \n" << GetRef<CallGraph>(node); p->stream << "CallGraph: \n" << GetRef<CallGraph>(node);
}); });
TVM_REGISTER_GLOBAL("relay._analysis.CallGraph") TVM_REGISTER_GLOBAL("relay.analysis.CallGraph")
.set_body_typed([](IRModule module) { .set_body_typed([](IRModule module) {
return CallGraph(module); return CallGraph(module);
}); });
TVM_REGISTER_GLOBAL("relay._analysis.PrintCallGraph") TVM_REGISTER_GLOBAL("relay.analysis.PrintCallGraph")
.set_body_typed([](CallGraph call_graph) { .set_body_typed([](CallGraph call_graph) {
std::stringstream ss; std::stringstream ss;
ss << call_graph; ss << call_graph;
return ss.str(); return ss.str();
}); });
TVM_REGISTER_GLOBAL("relay._analysis.GetModule") TVM_REGISTER_GLOBAL("relay.analysis.GetModule")
.set_body_typed([](CallGraph call_graph) { .set_body_typed([](CallGraph call_graph) {
return call_graph->module; return call_graph->module;
}); });
TVM_REGISTER_GLOBAL("relay._analysis.PrintCallGraphGlobalVar") TVM_REGISTER_GLOBAL("relay.analysis.PrintCallGraphGlobalVar")
.set_body_typed([](CallGraph call_graph, GlobalVar var) { .set_body_typed([](CallGraph call_graph, GlobalVar var) {
const auto* entry_node = call_graph[var]; const auto* entry_node = call_graph[var];
std::stringstream ss; std::stringstream ss;
...@@ -324,19 +324,19 @@ TVM_REGISTER_GLOBAL("relay._analysis.PrintCallGraphGlobalVar") ...@@ -324,19 +324,19 @@ TVM_REGISTER_GLOBAL("relay._analysis.PrintCallGraphGlobalVar")
return ss.str(); return ss.str();
}); });
TVM_REGISTER_GLOBAL("relay._analysis.GetRefCountGlobalVar") TVM_REGISTER_GLOBAL("relay.analysis.GetRefCountGlobalVar")
.set_body_typed([](CallGraph call_graph, GlobalVar var) { .set_body_typed([](CallGraph call_graph, GlobalVar var) {
const auto* entry_node = call_graph[var]; const auto* entry_node = call_graph[var];
return static_cast<int>(entry_node->GetRefCount()); return static_cast<int>(entry_node->GetRefCount());
}); });
TVM_REGISTER_GLOBAL("relay._analysis.GetGlobalVarCallCount") TVM_REGISTER_GLOBAL("relay.analysis.GetGlobalVarCallCount")
.set_body_typed([](CallGraph call_graph, GlobalVar var) { .set_body_typed([](CallGraph call_graph, GlobalVar var) {
const auto* entry_node = call_graph[var]; const auto* entry_node = call_graph[var];
return static_cast<int>(entry_node->size()); return static_cast<int>(entry_node->size());
}); });
TVM_REGISTER_GLOBAL("relay._analysis.IsRecursive") TVM_REGISTER_GLOBAL("relay.analysis.IsRecursive")
.set_body_typed([](CallGraph call_graph, GlobalVar var) { .set_body_typed([](CallGraph call_graph, GlobalVar var) {
const auto* entry_node = call_graph[var]; const auto* entry_node = call_graph[var];
return entry_node->IsRecursive(); return entry_node->IsRecursive();
......
...@@ -74,7 +74,7 @@ Pass ExtractFusedFunctions() { ...@@ -74,7 +74,7 @@ Pass ExtractFusedFunctions() {
"ExtractFusedFunctions"); "ExtractFusedFunctions");
} }
TVM_REGISTER_GLOBAL("relay._analysis.ExtractFusedFunctions").set_body_typed(ExtractFusedFunctions); TVM_REGISTER_GLOBAL("relay.analysis.ExtractFusedFunctions").set_body_typed(ExtractFusedFunctions);
} // namespace transform } // namespace transform
......
...@@ -104,7 +104,7 @@ Array<Integer> PyDetectFeature(const Expr& expr, const IRModule& mod) { ...@@ -104,7 +104,7 @@ Array<Integer> PyDetectFeature(const Expr& expr, const IRModule& mod) {
return static_cast<Array<Integer>>(fs); return static_cast<Array<Integer>>(fs);
} }
TVM_REGISTER_GLOBAL("relay._analysis.detect_feature") TVM_REGISTER_GLOBAL("relay.analysis.detect_feature")
.set_body_typed(PyDetectFeature); .set_body_typed(PyDetectFeature);
} // namespace relay } // namespace relay
......
...@@ -186,7 +186,7 @@ Kind KindCheck(const Type& t, const IRModule& mod) { ...@@ -186,7 +186,7 @@ Kind KindCheck(const Type& t, const IRModule& mod) {
return kc.Check(t); return kc.Check(t);
} }
TVM_REGISTER_GLOBAL("relay._analysis.check_kind") TVM_REGISTER_GLOBAL("relay.analysis.check_kind")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
if (args.size() == 1) { if (args.size() == 1) {
*ret = KindCheck(args[0], IRModule({}, {})); *ret = KindCheck(args[0], IRModule({}, {}));
......
...@@ -206,7 +206,7 @@ int64_t GetTotalMacNumber(const Expr& expr) { ...@@ -206,7 +206,7 @@ int64_t GetTotalMacNumber(const Expr& expr) {
return MacCounter::GetTotalMacNumber(expr); return MacCounter::GetTotalMacNumber(expr);
} }
TVM_REGISTER_GLOBAL("relay._analysis.GetTotalMacNumber") TVM_REGISTER_GLOBAL("relay.analysis.GetTotalMacNumber")
.set_body_typed(GetTotalMacNumber); .set_body_typed(GetTotalMacNumber);
} // namespace mac_count } // namespace mac_count
......
...@@ -310,7 +310,7 @@ Array<Pattern> UnmatchedCases(const Match& match, const IRModule& mod) { ...@@ -310,7 +310,7 @@ Array<Pattern> UnmatchedCases(const Match& match, const IRModule& mod) {
} }
// expose for testing only // expose for testing only
TVM_REGISTER_GLOBAL("relay._analysis.unmatched_cases") TVM_REGISTER_GLOBAL("relay.analysis.unmatched_cases")
.set_body_typed( .set_body_typed(
[](const Match& match, const IRModule& mod_ref) { [](const Match& match, const IRModule& mod_ref) {
IRModule call_mod = mod_ref; IRModule call_mod = mod_ref;
......
...@@ -659,7 +659,7 @@ bool TypeSolver::Solve() { ...@@ -659,7 +659,7 @@ bool TypeSolver::Solve() {
} }
// Expose type solver only for debugging purposes. // Expose type solver only for debugging purposes.
TVM_REGISTER_GLOBAL("relay._analysis._test_type_solver") TVM_REGISTER_GLOBAL("relay.analysis._test_type_solver")
.set_body([](runtime::TVMArgs args, runtime::TVMRetValue* ret) { .set_body([](runtime::TVMArgs args, runtime::TVMRetValue* ret) {
using runtime::PackedFunc; using runtime::PackedFunc;
using runtime::TypedPackedFunc; using runtime::TypedPackedFunc;
......
...@@ -274,10 +274,10 @@ tvm::Array<Var> AllVars(const Expr& expr) { ...@@ -274,10 +274,10 @@ tvm::Array<Var> AllVars(const Expr& expr) {
return VarVisitor().All(expr); return VarVisitor().All(expr);
} }
TVM_REGISTER_GLOBAL("relay._analysis.free_vars") TVM_REGISTER_GLOBAL("relay.analysis.free_vars")
.set_body_typed(FreeVars); .set_body_typed(FreeVars);
TVM_REGISTER_GLOBAL("relay._analysis.bound_vars") TVM_REGISTER_GLOBAL("relay.analysis.bound_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
ObjectRef x = args[0]; ObjectRef x = args[0];
if (x.as<ExprNode>()) { if (x.as<ExprNode>()) {
...@@ -287,10 +287,10 @@ TVM_REGISTER_GLOBAL("relay._analysis.bound_vars") ...@@ -287,10 +287,10 @@ TVM_REGISTER_GLOBAL("relay._analysis.bound_vars")
} }
}); });
TVM_REGISTER_GLOBAL("relay._analysis.all_vars") TVM_REGISTER_GLOBAL("relay.analysis.all_vars")
.set_body_typed(AllVars); .set_body_typed(AllVars);
TVM_REGISTER_GLOBAL("relay._analysis.free_type_vars") TVM_REGISTER_GLOBAL("relay.analysis.free_type_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
ObjectRef x = args[0]; ObjectRef x = args[0];
IRModule mod = args[1]; IRModule mod = args[1];
...@@ -301,7 +301,7 @@ TVM_REGISTER_GLOBAL("relay._analysis.free_type_vars") ...@@ -301,7 +301,7 @@ TVM_REGISTER_GLOBAL("relay._analysis.free_type_vars")
} }
}); });
TVM_REGISTER_GLOBAL("relay._analysis.bound_type_vars") TVM_REGISTER_GLOBAL("relay.analysis.bound_type_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
ObjectRef x = args[0]; ObjectRef x = args[0];
IRModule mod = args[1]; IRModule mod = args[1];
...@@ -312,7 +312,7 @@ TVM_REGISTER_GLOBAL("relay._analysis.bound_type_vars") ...@@ -312,7 +312,7 @@ TVM_REGISTER_GLOBAL("relay._analysis.bound_type_vars")
} }
}); });
TVM_REGISTER_GLOBAL("relay._analysis.all_type_vars") TVM_REGISTER_GLOBAL("relay.analysis.all_type_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
ObjectRef x = args[0]; ObjectRef x = args[0];
IRModule mod = args[1]; IRModule mod = args[1];
......
...@@ -125,7 +125,7 @@ bool WellFormed(const Expr& e) { ...@@ -125,7 +125,7 @@ bool WellFormed(const Expr& e) {
return WellFormedChecker().CheckWellFormed(e); return WellFormedChecker().CheckWellFormed(e);
} }
TVM_REGISTER_GLOBAL("relay._analysis.well_formed") TVM_REGISTER_GLOBAL("relay.analysis.well_formed")
.set_body_typed(WellFormed); .set_body_typed(WellFormed);
} // namespace relay } // namespace relay
......
...@@ -86,7 +86,7 @@ bool IsDynamic(const Type& ty) { ...@@ -86,7 +86,7 @@ bool IsDynamic(const Type& ty) {
} }
// TODO(@jroesch): MOVE ME // TODO(@jroesch): MOVE ME
TVM_REGISTER_GLOBAL("relay._make.IsDynamic") TVM_REGISTER_GLOBAL("relay.ir.IsDynamic")
.set_body_typed(IsDynamic); .set_body_typed(IsDynamic);
Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) { Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
......
...@@ -34,7 +34,7 @@ PatternWildcard PatternWildcardNode::make() { ...@@ -34,7 +34,7 @@ PatternWildcard PatternWildcardNode::make() {
TVM_REGISTER_NODE_TYPE(PatternWildcardNode); TVM_REGISTER_NODE_TYPE(PatternWildcardNode);
TVM_REGISTER_GLOBAL("relay._make.PatternWildcard") TVM_REGISTER_GLOBAL("relay.ir.PatternWildcard")
.set_body_typed(PatternWildcardNode::make); .set_body_typed(PatternWildcardNode::make);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
...@@ -50,7 +50,7 @@ PatternVar PatternVarNode::make(tvm::relay::Var var) { ...@@ -50,7 +50,7 @@ PatternVar PatternVarNode::make(tvm::relay::Var var) {
TVM_REGISTER_NODE_TYPE(PatternVarNode); TVM_REGISTER_NODE_TYPE(PatternVarNode);
TVM_REGISTER_GLOBAL("relay._make.PatternVar") TVM_REGISTER_GLOBAL("relay.ir.PatternVar")
.set_body_typed(PatternVarNode::make); .set_body_typed(PatternVarNode::make);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
...@@ -69,7 +69,7 @@ PatternConstructor PatternConstructorNode::make(Constructor constructor, ...@@ -69,7 +69,7 @@ PatternConstructor PatternConstructorNode::make(Constructor constructor,
TVM_REGISTER_NODE_TYPE(PatternConstructorNode); TVM_REGISTER_NODE_TYPE(PatternConstructorNode);
TVM_REGISTER_GLOBAL("relay._make.PatternConstructor") TVM_REGISTER_GLOBAL("relay.ir.PatternConstructor")
.set_body_typed(PatternConstructorNode::make); .set_body_typed(PatternConstructorNode::make);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
...@@ -87,7 +87,7 @@ PatternTuple PatternTupleNode::make(tvm::Array<Pattern> patterns) { ...@@ -87,7 +87,7 @@ PatternTuple PatternTupleNode::make(tvm::Array<Pattern> patterns) {
TVM_REGISTER_NODE_TYPE(PatternTupleNode); TVM_REGISTER_NODE_TYPE(PatternTupleNode);
TVM_REGISTER_GLOBAL("relay._make.PatternTuple") TVM_REGISTER_GLOBAL("relay.ir.PatternTuple")
.set_body_typed(PatternTupleNode::make); .set_body_typed(PatternTupleNode::make);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
...@@ -105,7 +105,7 @@ Clause ClauseNode::make(Pattern lhs, Expr rhs) { ...@@ -105,7 +105,7 @@ Clause ClauseNode::make(Pattern lhs, Expr rhs) {
TVM_REGISTER_NODE_TYPE(ClauseNode); TVM_REGISTER_NODE_TYPE(ClauseNode);
TVM_REGISTER_GLOBAL("relay._make.Clause") TVM_REGISTER_GLOBAL("relay.ir.Clause")
.set_body_typed(ClauseNode::make); .set_body_typed(ClauseNode::make);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
...@@ -125,7 +125,7 @@ Match MatchNode::make(Expr data, tvm::Array<Clause> clauses, bool complete) { ...@@ -125,7 +125,7 @@ Match MatchNode::make(Expr data, tvm::Array<Clause> clauses, bool complete) {
TVM_REGISTER_NODE_TYPE(MatchNode); TVM_REGISTER_NODE_TYPE(MatchNode);
TVM_REGISTER_GLOBAL("relay._make.Match") TVM_REGISTER_GLOBAL("relay.ir.Match")
.set_body_typed(MatchNode::make); .set_body_typed(MatchNode::make);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
......
...@@ -38,7 +38,7 @@ Constant ConstantNode::make(runtime::NDArray data) { ...@@ -38,7 +38,7 @@ Constant ConstantNode::make(runtime::NDArray data) {
TVM_REGISTER_NODE_TYPE(ConstantNode); TVM_REGISTER_NODE_TYPE(ConstantNode);
TVM_REGISTER_GLOBAL("relay._make.Constant") TVM_REGISTER_GLOBAL("relay.ir.Constant")
.set_body_typed(ConstantNode::make); .set_body_typed(ConstantNode::make);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
...@@ -71,7 +71,7 @@ Tuple TupleNode::make(tvm::Array<relay::Expr> fields) { ...@@ -71,7 +71,7 @@ Tuple TupleNode::make(tvm::Array<relay::Expr> fields) {
TVM_REGISTER_NODE_TYPE(TupleNode); TVM_REGISTER_NODE_TYPE(TupleNode);
TVM_REGISTER_GLOBAL("relay._make.Tuple") TVM_REGISTER_GLOBAL("relay.ir.Tuple")
.set_body_typed(TupleNode::make); .set_body_typed(TupleNode::make);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
...@@ -96,7 +96,7 @@ Var VarNode::make(std::string name_hint, Type type_annotation) { ...@@ -96,7 +96,7 @@ Var VarNode::make(std::string name_hint, Type type_annotation) {
TVM_REGISTER_NODE_TYPE(VarNode); TVM_REGISTER_NODE_TYPE(VarNode);
TVM_REGISTER_GLOBAL("relay._make.Var") TVM_REGISTER_GLOBAL("relay.ir.Var")
.set_body_typed(static_cast<Var (*)(std::string, Type)>(VarNode::make)); .set_body_typed(static_cast<Var (*)(std::string, Type)>(VarNode::make));
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
...@@ -123,7 +123,7 @@ Call CallNode::make(Expr op, Array<Expr> args, Attrs attrs, ...@@ -123,7 +123,7 @@ Call CallNode::make(Expr op, Array<Expr> args, Attrs attrs,
TVM_REGISTER_NODE_TYPE(CallNode); TVM_REGISTER_NODE_TYPE(CallNode);
TVM_REGISTER_GLOBAL("relay._make.Call") TVM_REGISTER_GLOBAL("relay.ir.Call")
.set_body_typed(CallNode::make); .set_body_typed(CallNode::make);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
...@@ -143,7 +143,7 @@ Let LetNode::make(Var var, Expr value, Expr body) { ...@@ -143,7 +143,7 @@ Let LetNode::make(Var var, Expr value, Expr body) {
TVM_REGISTER_NODE_TYPE(LetNode); TVM_REGISTER_NODE_TYPE(LetNode);
TVM_REGISTER_GLOBAL("relay._make.Let") TVM_REGISTER_GLOBAL("relay.ir.Let")
.set_body_typed(LetNode::make); .set_body_typed(LetNode::make);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
...@@ -163,7 +163,7 @@ If IfNode::make(Expr cond, Expr true_branch, Expr false_branch) { ...@@ -163,7 +163,7 @@ If IfNode::make(Expr cond, Expr true_branch, Expr false_branch) {
TVM_REGISTER_NODE_TYPE(IfNode); TVM_REGISTER_NODE_TYPE(IfNode);
TVM_REGISTER_GLOBAL("relay._make.If") TVM_REGISTER_GLOBAL("relay.ir.If")
.set_body_typed(IfNode::make); .set_body_typed(IfNode::make);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
...@@ -182,7 +182,7 @@ TupleGetItem TupleGetItemNode::make(Expr tuple, int index) { ...@@ -182,7 +182,7 @@ TupleGetItem TupleGetItemNode::make(Expr tuple, int index) {
TVM_REGISTER_NODE_TYPE(TupleGetItemNode); TVM_REGISTER_NODE_TYPE(TupleGetItemNode);
TVM_REGISTER_GLOBAL("relay._make.TupleGetItem") TVM_REGISTER_GLOBAL("relay.ir.TupleGetItem")
.set_body_typed(TupleGetItemNode::make); .set_body_typed(TupleGetItemNode::make);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
...@@ -199,7 +199,7 @@ RefCreate RefCreateNode::make(Expr value) { ...@@ -199,7 +199,7 @@ RefCreate RefCreateNode::make(Expr value) {
TVM_REGISTER_NODE_TYPE(RefCreateNode); TVM_REGISTER_NODE_TYPE(RefCreateNode);
TVM_REGISTER_GLOBAL("relay._make.RefCreate") TVM_REGISTER_GLOBAL("relay.ir.RefCreate")
.set_body_typed(RefCreateNode::make); .set_body_typed(RefCreateNode::make);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
...@@ -216,7 +216,7 @@ RefRead RefReadNode::make(Expr ref) { ...@@ -216,7 +216,7 @@ RefRead RefReadNode::make(Expr ref) {
TVM_REGISTER_NODE_TYPE(RefReadNode); TVM_REGISTER_NODE_TYPE(RefReadNode);
TVM_REGISTER_GLOBAL("relay._make.RefRead") TVM_REGISTER_GLOBAL("relay.ir.RefRead")
.set_body_typed(RefReadNode::make); .set_body_typed(RefReadNode::make);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
...@@ -234,7 +234,7 @@ RefWrite RefWriteNode::make(Expr ref, Expr value) { ...@@ -234,7 +234,7 @@ RefWrite RefWriteNode::make(Expr ref, Expr value) {
TVM_REGISTER_NODE_TYPE(RefWriteNode); TVM_REGISTER_NODE_TYPE(RefWriteNode);
TVM_REGISTER_GLOBAL("relay._make.RefWrite") TVM_REGISTER_GLOBAL("relay.ir.RefWrite")
.set_body_typed(RefWriteNode::make); .set_body_typed(RefWriteNode::make);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
...@@ -243,12 +243,12 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) ...@@ -243,12 +243,12 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "RefWriteNode(" << node->ref << ", " << node->value << ")"; p->stream << "RefWriteNode(" << node->ref << ", " << node->value << ")";
}); });
TVM_REGISTER_GLOBAL("relay._expr.TempExprRealize") TVM_REGISTER_GLOBAL("relay.ir.TempExprRealize")
.set_body_typed([](TempExpr temp) { .set_body_typed([](TempExpr temp) {
return temp->Realize(); return temp->Realize();
}); });
TVM_REGISTER_GLOBAL("relay._make.Any") TVM_REGISTER_GLOBAL("relay.ir.Any")
.set_body_typed([]() { return Any::make(); }); .set_body_typed([]() { return Any::make(); });
} // namespace relay } // namespace relay
......
...@@ -347,7 +347,7 @@ void PostOrderVisit(const Expr& e, std::function<void(const Expr&)> fvisit) { ...@@ -347,7 +347,7 @@ void PostOrderVisit(const Expr& e, std::function<void(const Expr&)> fvisit) {
ExprApplyVisit(fvisit).VisitExpr(e); ExprApplyVisit(fvisit).VisitExpr(e);
} }
TVM_REGISTER_GLOBAL("relay._analysis.post_order_visit") TVM_REGISTER_GLOBAL("relay.analysis.post_order_visit")
.set_body_typed([](Expr expr, PackedFunc f) { .set_body_typed([](Expr expr, PackedFunc f) {
PostOrderVisit(expr, [f](const Expr& n) { PostOrderVisit(expr, [f](const Expr& n) {
f(n); f(n);
...@@ -443,7 +443,7 @@ Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) { ...@@ -443,7 +443,7 @@ Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {
} }
} }
TVM_REGISTER_GLOBAL("relay._expr.Bind") TVM_REGISTER_GLOBAL("relay.ir.Bind")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
ObjectRef input = args[0]; ObjectRef input = args[0];
if (input->IsInstance<ExprNode>()) { if (input->IsInstance<ExprNode>()) {
......
...@@ -62,7 +62,7 @@ bool FunctionNode::UseDefaultCompiler() const { ...@@ -62,7 +62,7 @@ bool FunctionNode::UseDefaultCompiler() const {
TVM_REGISTER_NODE_TYPE(FunctionNode); TVM_REGISTER_NODE_TYPE(FunctionNode);
TVM_REGISTER_GLOBAL("relay._make.Function") TVM_REGISTER_GLOBAL("relay.ir.Function")
.set_body_typed([](tvm::Array<Var> params, .set_body_typed([](tvm::Array<Var> params,
Expr body, Expr body,
Type ret_type, Type ret_type,
...@@ -80,7 +80,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) ...@@ -80,7 +80,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
<< node->attrs << ")"; << node->attrs << ")";
}); });
TVM_REGISTER_GLOBAL("relay._expr.FunctionWithAttr") TVM_REGISTER_GLOBAL("relay.ir.FunctionWithAttr")
.set_body_typed( .set_body_typed(
[](Function func, std::string name, ObjectRef ref) { [](Function func, std::string name, ObjectRef ref) {
return WithAttr(std::move(func), name, ref); return WithAttr(std::move(func), name, ref);
......
...@@ -423,12 +423,12 @@ size_t StructuralHash::operator()(const Expr& expr) const { ...@@ -423,12 +423,12 @@ size_t StructuralHash::operator()(const Expr& expr) const {
return RelayHashHandler().ExprHash(expr); return RelayHashHandler().ExprHash(expr);
} }
TVM_REGISTER_GLOBAL("relay._analysis._expr_hash") TVM_REGISTER_GLOBAL("relay.analysis._expr_hash")
.set_body_typed([](ObjectRef ref) { .set_body_typed([](ObjectRef ref) {
return static_cast<int64_t>(RelayHashHandler().Hash(ref)); return static_cast<int64_t>(RelayHashHandler().Hash(ref));
}); });
TVM_REGISTER_GLOBAL("relay._analysis._type_hash") TVM_REGISTER_GLOBAL("relay.analysis._type_hash")
.set_body_typed([](Type type) { .set_body_typed([](Type type) {
return static_cast<int64_t>(RelayHashHandler().TypeHash(type)); return static_cast<int64_t>(RelayHashHandler().TypeHash(type));
}); });
......
...@@ -82,7 +82,7 @@ Expr MakeCast(Expr data, ...@@ -82,7 +82,7 @@ Expr MakeCast(Expr data,
return CallNode::make(op, {data}, Attrs(attrs), {}); return CallNode::make(op, {data}, Attrs(attrs), {});
} }
TVM_REGISTER_GLOBAL("relay._make.cast") TVM_REGISTER_GLOBAL("relay.ir.cast")
.set_body_typed(MakeCast); .set_body_typed(MakeCast);
RELAY_REGISTER_OP("cast") RELAY_REGISTER_OP("cast")
...@@ -138,7 +138,7 @@ Expr MakeCastLike(Expr data, ...@@ -138,7 +138,7 @@ Expr MakeCastLike(Expr data,
} }
TVM_REGISTER_GLOBAL("relay._make.cast_like") TVM_REGISTER_GLOBAL("relay.ir.cast_like")
.set_body_typed(MakeCastLike); .set_body_typed(MakeCastLike);
RELAY_REGISTER_OP("cast_like") RELAY_REGISTER_OP("cast_like")
......
...@@ -560,10 +560,10 @@ Map<Expr, Integer> CollectDeviceAnnotationOps(const Expr& expr) { ...@@ -560,10 +560,10 @@ Map<Expr, Integer> CollectDeviceAnnotationOps(const Expr& expr) {
return AnnotatationVisitor::GetAnnotations(expr); return AnnotatationVisitor::GetAnnotations(expr);
} }
TVM_REGISTER_GLOBAL("relay._analysis.CollectDeviceInfo") TVM_REGISTER_GLOBAL("relay.analysis.CollectDeviceInfo")
.set_body_typed(CollectDeviceInfo); .set_body_typed(CollectDeviceInfo);
TVM_REGISTER_GLOBAL("relay._analysis.CollectDeviceAnnotationOps") TVM_REGISTER_GLOBAL("relay.analysis.CollectDeviceAnnotationOps")
.set_body_typed(CollectDeviceAnnotationOps); .set_body_typed(CollectDeviceAnnotationOps);
namespace transform { namespace transform {
......
...@@ -73,7 +73,7 @@ bool ConstantCheck(const Expr& e) { ...@@ -73,7 +73,7 @@ bool ConstantCheck(const Expr& e) {
return ConstantChecker().Check(e); return ConstantChecker().Check(e);
} }
TVM_REGISTER_GLOBAL("relay._analysis.check_constant") TVM_REGISTER_GLOBAL("relay.analysis.check_constant")
.set_body_typed(ConstantCheck); .set_body_typed(ConstantCheck);
// TODO(tvm-team) consider combine dead-code with constant folder. // TODO(tvm-team) consider combine dead-code with constant folder.
......
...@@ -18,9 +18,8 @@ ...@@ -18,9 +18,8 @@
import tvm import tvm
from tvm import te from tvm import te
from tvm import relay from tvm import relay
from tvm.relay.analysis import detect_feature from tvm.relay.analysis import detect_feature, Feature
from tvm.relay.transform import gradient from tvm.relay.transform import gradient
from tvm.relay.feature import Feature
from tvm.relay.prelude import Prelude from tvm.relay.prelude import Prelude
from tvm.relay.testing import run_infer_type from tvm.relay.testing import run_infer_type
......
...@@ -17,10 +17,8 @@ ...@@ -17,10 +17,8 @@
import numpy as np import numpy as np
import tvm import tvm
from tvm import te
from tvm import relay from tvm import relay
from tvm.contrib import graph_runtime from tvm.contrib import graph_runtime
from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay.op import add from tvm.relay.op import add
from tvm.relay.testing.config import ctx_list from tvm.relay.testing.config import ctx_list
......
...@@ -25,7 +25,7 @@ def test_callgraph_construct(): ...@@ -25,7 +25,7 @@ def test_callgraph_construct():
x = relay.var("x", shape=(2, 3)) x = relay.var("x", shape=(2, 3))
y = relay.var("y", shape=(2, 3)) y = relay.var("y", shape=(2, 3))
mod["g1"] = relay.Function([x, y], x + y) mod["g1"] = relay.Function([x, y], x + y)
call_graph = relay.CallGraph(mod) call_graph = relay.analysis.CallGraph(mod)
assert "g1" in str(call_graph) assert "g1" in str(call_graph)
assert relay.alpha_equal(mod, call_graph.module) assert relay.alpha_equal(mod, call_graph.module)
...@@ -38,7 +38,7 @@ def test_print_element(): ...@@ -38,7 +38,7 @@ def test_print_element():
x1 = relay.var("x1", shape=(2, 3)) x1 = relay.var("x1", shape=(2, 3))
y1 = relay.var("y1", shape=(2, 3)) y1 = relay.var("y1", shape=(2, 3))
mod["g1"] = relay.Function([x1, y1], x1 - y1) mod["g1"] = relay.Function([x1, y1], x1 - y1)
call_graph = relay.CallGraph(mod) call_graph = relay.analysis.CallGraph(mod)
assert "#refs = 0" in str(call_graph.print_var("g0")) assert "#refs = 0" in str(call_graph.print_var("g0"))
assert "#refs = 0" in str(call_graph.print_var("g1")) assert "#refs = 0" in str(call_graph.print_var("g1"))
...@@ -54,13 +54,13 @@ def test_global_call_count(): ...@@ -54,13 +54,13 @@ def test_global_call_count():
y1 = relay.var("y1", shape=(2, 3)) y1 = relay.var("y1", shape=(2, 3))
g1 = relay.GlobalVar("g1") g1 = relay.GlobalVar("g1")
mod[g1] = relay.Function([x1, y1], g0(x1, y1)) mod[g1] = relay.Function([x1, y1], g0(x1, y1))
call_graph = relay.CallGraph(mod) call_graph = relay.analysis.CallGraph(mod)
p0 = relay.var("p0", shape=(2, 3)) p0 = relay.var("p0", shape=(2, 3))
p1 = relay.var("p1", shape=(2, 3)) p1 = relay.var("p1", shape=(2, 3))
func = relay.Function([p0, p1], g0(p0, p1) * g1(p0, p1)) func = relay.Function([p0, p1], g0(p0, p1) * g1(p0, p1))
mod["main"] = func mod["main"] = func
call_graph = relay.CallGraph(mod) call_graph = relay.analysis.CallGraph(mod)
assert call_graph.global_call_count(g0) == 0 assert call_graph.global_call_count(g0) == 0
assert call_graph.global_call_count(g1) == 1 assert call_graph.global_call_count(g1) == 1
...@@ -77,13 +77,13 @@ def test_ref_count(): ...@@ -77,13 +77,13 @@ def test_ref_count():
y1 = relay.var("y1", shape=(2, 3)) y1 = relay.var("y1", shape=(2, 3))
g1 = relay.GlobalVar("g1") g1 = relay.GlobalVar("g1")
mod[g1] = relay.Function([x1, y1], x1 - y1) mod[g1] = relay.Function([x1, y1], x1 - y1)
call_graph = relay.CallGraph(mod) call_graph = relay.analysis.CallGraph(mod)
p0 = relay.var("p0", shape=(2, 3)) p0 = relay.var("p0", shape=(2, 3))
p1 = relay.var("p1", shape=(2, 3)) p1 = relay.var("p1", shape=(2, 3))
func = relay.Function([p0, p1], g0(p0, p1) * g1(p0, p1)) func = relay.Function([p0, p1], g0(p0, p1) * g1(p0, p1))
mod["main"] = func mod["main"] = func
call_graph = relay.CallGraph(mod) call_graph = relay.analysis.CallGraph(mod)
assert call_graph.ref_count(g0) == 1 assert call_graph.ref_count(g0) == 1
assert call_graph.ref_count(g1) == 1 assert call_graph.ref_count(g1) == 1
...@@ -100,13 +100,13 @@ def test_nested_ref(): ...@@ -100,13 +100,13 @@ def test_nested_ref():
y1 = relay.var("y1", shape=(2, 3)) y1 = relay.var("y1", shape=(2, 3))
g1 = relay.GlobalVar("g1") g1 = relay.GlobalVar("g1")
mod[g1] = relay.Function([x1, y1], g0(x1, y1)) mod[g1] = relay.Function([x1, y1], g0(x1, y1))
call_graph = relay.CallGraph(mod) call_graph = relay.analysis.CallGraph(mod)
p0 = relay.var("p0", shape=(2, 3)) p0 = relay.var("p0", shape=(2, 3))
p1 = relay.var("p1", shape=(2, 3)) p1 = relay.var("p1", shape=(2, 3))
func = relay.Function([p0, p1], g0(p0, p1) * g1(p0, p1)) func = relay.Function([p0, p1], g0(p0, p1) * g1(p0, p1))
mod["main"] = func mod["main"] = func
call_graph = relay.CallGraph(mod) call_graph = relay.analysis.CallGraph(mod)
assert call_graph.ref_count(g0) == 2 assert call_graph.ref_count(g0) == 2
assert call_graph.ref_count(g1) == 1 assert call_graph.ref_count(g1) == 1
...@@ -138,7 +138,7 @@ def test_recursive_func(): ...@@ -138,7 +138,7 @@ def test_recursive_func():
mod[sum_up] = func mod[sum_up] = func
iarg = relay.var('i', shape=[], dtype='int32') iarg = relay.var('i', shape=[], dtype='int32')
mod["main"] = relay.Function([iarg], sum_up(iarg)) mod["main"] = relay.Function([iarg], sum_up(iarg))
call_graph = relay.CallGraph(mod) call_graph = relay.analysis.CallGraph(mod)
assert call_graph.is_recursive(sum_up) assert call_graph.is_recursive(sum_up)
assert call_graph.ref_count(sum_up) == 2 assert call_graph.ref_count(sum_up) == 2
......
...@@ -19,7 +19,6 @@ import json ...@@ -19,7 +19,6 @@ import json
import numpy as np import numpy as np
import tvm import tvm
from tvm import te
from tvm import relay from tvm import relay
from tvm.contrib import graph_runtime from tvm.contrib import graph_runtime
from tvm.relay.expr_functor import ExprMutator from tvm.relay.expr_functor import ExprMutator
......
...@@ -18,7 +18,7 @@ import tvm ...@@ -18,7 +18,7 @@ import tvm
from tvm import te from tvm import te
import numpy as np import numpy as np
from tvm import relay from tvm import relay
from tvm.relay import memory_alloc from tvm.relay.transform import memory_alloc
def check_vm_alloc(func, check_fn): def check_vm_alloc(func, check_fn):
mod = tvm.IRModule() mod = tvm.IRModule()
......
...@@ -25,7 +25,7 @@ from tvm import relay ...@@ -25,7 +25,7 @@ from tvm import relay
from tvm import runtime from tvm import runtime
from tvm.relay import transform from tvm.relay import transform
from tvm.contrib import util from tvm.contrib import util
from tvm.relay.annotation import compiler_begin, compiler_end from tvm.relay.op.annotation import compiler_begin, compiler_end
from tvm.relay.expr_functor import ExprMutator from tvm.relay.expr_functor import ExprMutator
# Leverage the pass manager to write a simple white list based annotator # Leverage the pass manager to write a simple white list based annotator
......
...@@ -22,7 +22,7 @@ from tvm.relay.analysis import alpha_equal, detect_feature ...@@ -22,7 +22,7 @@ from tvm.relay.analysis import alpha_equal, detect_feature
from tvm.relay import op, create_executor, transform from tvm.relay import op, create_executor, transform
from tvm.relay.prelude import Prelude from tvm.relay.prelude import Prelude
from tvm.relay.testing import add_nat_definitions, count from tvm.relay.testing import add_nat_definitions, count
from tvm.relay.feature import Feature from tvm.relay.analysis import Feature
def run_opt_pass(expr, passes): def run_opt_pass(expr, passes):
......
...@@ -16,15 +16,14 @@ ...@@ -16,15 +16,14 @@
# under the License. # under the License.
import numpy as np import numpy as np
import tvm import tvm
from tvm import te
from tvm import relay from tvm import relay
from tvm.relay.analysis import alpha_equal, detect_feature from tvm.relay.analysis import alpha_equal, detect_feature
from tvm.relay.transform import to_cps, un_cps from tvm.relay.transform import to_cps, un_cps
from tvm.relay.feature import Feature from tvm.relay.analysis import Feature
from tvm.relay.prelude import Prelude from tvm.relay.prelude import Prelude
from tvm.relay.testing import add_nat_definitions, make_nat_expr, rand, run_infer_type, run_opt_pass from tvm.relay.testing import add_nat_definitions, make_nat_expr, rand, run_infer_type, run_opt_pass
from tvm.relay import create_executor from tvm.relay import create_executor
from tvm.relay import Function, transform from tvm.relay import transform
def test_id(): def test_id():
......
...@@ -16,9 +16,9 @@ ...@@ -16,9 +16,9 @@
# under the License. # under the License.
import numpy as np import numpy as np
import tvm import tvm
from tvm import te
from tvm import relay from tvm import relay
from tvm.relay import op, create_executor, transform, Feature from tvm.relay import op, create_executor, transform
from tvm.relay.analysis import Feature
from tvm.relay.analysis import detect_feature from tvm.relay.analysis import detect_feature
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import tvm import tvm
from tvm import te
from tvm import relay from tvm import relay
import pytest import pytest
...@@ -27,7 +26,7 @@ def make_rel(name, args, num_inputs=None, attrs=None): ...@@ -27,7 +26,7 @@ def make_rel(name, args, num_inputs=None, attrs=None):
return relay.ty.TypeRelation(func, args, num_inputs, attrs) return relay.ty.TypeRelation(func, args, num_inputs, attrs)
def make_solver(): def make_solver():
solver = relay._analysis._test_type_solver() solver = relay.analysis._ffi_api._test_type_solver()
solver.Solve = solver("Solve") solver.Solve = solver("Solve")
solver.Unify = solver("Unify") solver.Unify = solver("Unify")
solver.Resolve = solver("Resolve") solver.Resolve = solver("Resolve")
......
...@@ -18,7 +18,6 @@ import numpy as np ...@@ -18,7 +18,6 @@ import numpy as np
import pytest import pytest
import tvm import tvm
from tvm import te
from tvm import runtime from tvm import runtime
from tvm import relay from tvm import relay
from tvm.relay.scope_builder import ScopeBuilder from tvm.relay.scope_builder import ScopeBuilder
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment