Unverified Commit 4aff8dc1 by Zhi Committed by GitHub

[Refactor][Relay] Refactor Relay Python to use new FFI (#5077)

* refactor relay python

* revert relay/ir/*.py to relay

* Address comments

* remove direct access to analysis and transform namespace
parent 646cfc63
......@@ -19,10 +19,6 @@ tvm.relay.base
--------------
.. automodule:: tvm.relay.base
.. autofunction:: tvm.relay.base.register_relay_node
.. autofunction:: tvm.relay.base.register_relay_attr_node
.. autoclass:: tvm.relay.base.RelayNode
:members:
......
......@@ -19,35 +19,37 @@
import os
from sys import setrecursionlimit
from . import call_graph
from . import base
from . import ty
from . import expr
from . import type_functor
from . import expr_functor
from . import adt
from . import analysis
from . import prelude
from . import loops
from . import scope_builder
from . import parser
from . import transform
from . import analysis
from .analysis import alpha_equal
from .build_module import build, create_executor, optimize
from .transform import build_config
from . import prelude
from . import parser
from . import debug
from . import param_dict
from . import feature
from .backend import vm
# Root operators
from .op import Op
from .op import nn
from .op import image
from .op import annotation
from .op import vision
from .op import contrib
from .op.reduce import *
from .op.tensor import *
from .op.transform import *
from .op.algorithm import *
from . import nn
from . import annotation
from . import vision
from . import contrib
from . import image
from . import frontend
from . import backend
from . import quantize
......@@ -55,15 +57,12 @@ from . import quantize
# Dialects
from . import qnn
from .scope_builder import ScopeBuilder
# Load Memory pass
from . import memory_alloc
# Required to traverse large programs
setrecursionlimit(10000)
# Span
Span = base.Span
SourceName = base.SourceName
# Type
Type = ty.Type
......@@ -98,6 +97,7 @@ RefRead = expr.RefRead
RefWrite = expr.RefWrite
# ADT
Pattern = adt.Pattern
PatternWildcard = adt.PatternWildcard
PatternVar = adt.PatternVar
PatternConstructor = adt.PatternConstructor
......@@ -111,9 +111,6 @@ Match = adt.Match
var = expr.var
const = expr.const
bind = expr.bind
module_pass = transform.module_pass
function_pass = transform.function_pass
alpha_equal = analysis.alpha_equal
# TypeFunctor
TypeFunctor = type_functor.TypeFunctor
......@@ -125,6 +122,15 @@ ExprFunctor = expr_functor.ExprFunctor
ExprVisitor = expr_functor.ExprVisitor
ExprMutator = expr_functor.ExprMutator
# Prelude
Prelude = prelude.Prelude
# Scope builder
ScopeBuilder = scope_builder.ScopeBuilder
module_pass = transform.module_pass
function_pass = transform.function_pass
# Parser
fromtext = parser.fromtext
......@@ -139,9 +145,3 @@ Pass = transform.Pass
ModulePass = transform.ModulePass
FunctionPass = transform.FunctionPass
Sequential = transform.Sequential
# Feature
Feature = feature.Feature
# CallGraph
CallGraph = call_graph.CallGraph
......@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""FFI exposing the passes for Relay program analysis."""
"""FFI APIs for Relay program IR."""
import tvm._ffi
tvm._ffi._init_api("relay._analysis", __name__)
tvm._ffi._init_api("relay.ir", __name__)
......@@ -17,9 +17,11 @@
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, unused-import
"""Algebraic data types in Relay."""
from tvm.ir import Constructor, TypeData
from tvm.runtime import Object
import tvm._ffi
from .base import RelayNode, register_relay_node, Object
from . import _make
from .base import RelayNode
from . import _ffi_api
from .ty import Type
from .expr import ExprWithOp, RelayExpr, Call
......@@ -28,7 +30,7 @@ class Pattern(RelayNode):
"""Base type for pattern matching constructs."""
@register_relay_node
@tvm._ffi.register_object("relay.PatternWildcard")
class PatternWildcard(Pattern):
"""Wildcard pattern in Relay: Matches any ADT and binds nothing."""
......@@ -44,10 +46,10 @@ class PatternWildcard(Pattern):
wildcard: PatternWildcard
a wildcard pattern.
"""
self.__init_handle_by_constructor__(_make.PatternWildcard)
self.__init_handle_by_constructor__(_ffi_api.PatternWildcard)
@register_relay_node
@tvm._ffi.register_object("relay.PatternVar")
class PatternVar(Pattern):
"""Variable pattern in Relay: Matches anything and binds it to the variable."""
......@@ -63,10 +65,10 @@ class PatternVar(Pattern):
pv: PatternVar
A variable pattern.
"""
self.__init_handle_by_constructor__(_make.PatternVar, var)
self.__init_handle_by_constructor__(_ffi_api.PatternVar, var)
@register_relay_node
@tvm._ffi.register_object("relay.PatternConstructor")
class PatternConstructor(Pattern):
"""Constructor pattern in Relay: Matches an ADT of the given constructor, binds recursively."""
......@@ -88,10 +90,10 @@ class PatternConstructor(Pattern):
"""
if patterns is None:
patterns = []
self.__init_handle_by_constructor__(_make.PatternConstructor, constructor, patterns)
self.__init_handle_by_constructor__(_ffi_api.PatternConstructor, constructor, patterns)
@register_relay_node
@tvm._ffi.register_object("relay.PatternTuple")
class PatternTuple(Pattern):
"""Constructor pattern in Relay: Matches a tuple, binds recursively."""
......@@ -111,10 +113,10 @@ class PatternTuple(Pattern):
"""
if patterns is None:
patterns = []
self.__init_handle_by_constructor__(_make.PatternTuple, patterns)
self.__init_handle_by_constructor__(_ffi_api.PatternTuple, patterns)
@register_relay_node
@tvm._ffi.register_object("relay.Clause")
class Clause(Object):
"""Clause for pattern matching in Relay."""
......@@ -133,10 +135,10 @@ class Clause(Object):
clause: Clause
The Clause.
"""
self.__init_handle_by_constructor__(_make.Clause, lhs, rhs)
self.__init_handle_by_constructor__(_ffi_api.Clause, lhs, rhs)
@register_relay_node
@tvm._ffi.register_object("relay.Match")
class Match(ExprWithOp):
"""Pattern matching expression in Relay."""
......@@ -160,4 +162,4 @@ class Match(ExprWithOp):
match: tvm.relay.Expr
The match expression.
"""
self.__init_handle_by_constructor__(_make.Match, data, clauses, complete)
self.__init_handle_by_constructor__(_ffi_api.Match, data, clauses, complete)
......@@ -14,8 +14,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=wildcard-import, redefined-builtin, invalid-name
"""The Relay IR namespace containing the analysis passes."""
# Analysis passes
from .analysis import *
# pylint: disable=wildcard-import, unused-import, unused-wildcard-import
"""Annotation related operators."""
# Re-export in a specific file name so that autodoc can pick it up
from .op.annotation import *
# Call graph
from . import call_graph
# Feature
from . import feature
CallGraph = call_graph.CallGraph
......@@ -14,8 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable
"""The interface of expr function exposed from C++."""
"""FFI APIs for Relay program analysis."""
import tvm._ffi
tvm._ffi._init_api("relay._base", __name__)
tvm._ffi._init_api("relay.analysis", __name__)
......@@ -22,9 +22,9 @@ configuring the passes and scripting them in Python.
"""
from tvm.ir import RelayExpr, IRModule
from . import _analysis
from .ty import Type
from . import _ffi_api
from .feature import Feature
from ..ty import Type
def post_order_visit(expr, fvisit):
......@@ -40,7 +40,7 @@ def post_order_visit(expr, fvisit):
fvisit : function
The visitor function to be applied.
"""
return _analysis.post_order_visit(expr, fvisit)
return _ffi_api.post_order_visit(expr, fvisit)
def well_formed(expr):
......@@ -56,7 +56,7 @@ def well_formed(expr):
well_form : bool
Whether the input expression is well formed
"""
return _analysis.well_formed(expr)
return _ffi_api.well_formed(expr)
def check_kind(t, mod=None):
......@@ -85,9 +85,9 @@ def check_kind(t, mod=None):
assert check_kind(relay.TupleType([relay.TypeParam('tp1', relay.Kind.Type)])) == Type
"""
if mod is not None:
return _analysis.check_kind(t, mod)
return _ffi_api.check_kind(t, mod)
else:
return _analysis.check_kind(t)
return _ffi_api.check_kind(t)
def check_constant(expr):
......@@ -103,7 +103,7 @@ def check_constant(expr):
result : bool
Whether the expression is constant.
"""
return _analysis.check_constant(expr)
return _ffi_api.check_constant(expr)
def free_vars(expr):
......@@ -125,7 +125,7 @@ def free_vars(expr):
neural networks: usually this means weights of previous
are ordered first.
"""
return _analysis.free_vars(expr)
return _ffi_api.free_vars(expr)
def bound_vars(expr):
......@@ -141,7 +141,7 @@ def bound_vars(expr):
free : List[tvm.relay.Var]
The list of bound variables in post-DFS order.
"""
return _analysis.bound_vars(expr)
return _ffi_api.bound_vars(expr)
def all_vars(expr):
......@@ -157,7 +157,7 @@ def all_vars(expr):
free : List[tvm.relay.Var]
The list of all variables in post-DFS order.
"""
return _analysis.all_vars(expr)
return _ffi_api.all_vars(expr)
def free_type_vars(expr, mod=None):
......@@ -177,7 +177,7 @@ def free_type_vars(expr, mod=None):
The list of free type variables in post-DFS order
"""
use_mod = mod if mod is not None else IRModule()
return _analysis.free_type_vars(expr, use_mod)
return _ffi_api.free_type_vars(expr, use_mod)
def bound_type_vars(expr, mod=None):
......@@ -197,7 +197,7 @@ def bound_type_vars(expr, mod=None):
The list of bound type variables in post-DFS order
"""
use_mod = mod if mod is not None else IRModule()
return _analysis.bound_type_vars(expr, use_mod)
return _ffi_api.bound_type_vars(expr, use_mod)
def all_type_vars(expr, mod=None):
......@@ -217,7 +217,7 @@ def all_type_vars(expr, mod=None):
The list of all type variables in post-DFS order
"""
use_mod = mod if mod is not None else IRModule()
return _analysis.all_type_vars(expr, use_mod)
return _ffi_api.all_type_vars(expr, use_mod)
def alpha_equal(lhs, rhs):
......@@ -236,7 +236,7 @@ def alpha_equal(lhs, rhs):
result : bool
True iff lhs is alpha equal to rhs.
"""
return bool(_analysis._alpha_equal(lhs, rhs))
return bool(_ffi_api._alpha_equal(lhs, rhs))
def assert_alpha_equal(lhs, rhs):
......@@ -250,7 +250,7 @@ def assert_alpha_equal(lhs, rhs):
rhs : tvm.relay.Expr
One of the input Expression.
"""
_analysis._assert_alpha_equal(lhs, rhs)
_ffi_api._assert_alpha_equal(lhs, rhs)
def graph_equal(lhs, rhs):
......@@ -272,7 +272,7 @@ def graph_equal(lhs, rhs):
result : bool
True iff lhs is data-flow equivalent to rhs.
"""
return bool(_analysis._graph_equal(lhs, rhs))
return bool(_ffi_api._graph_equal(lhs, rhs))
def assert_graph_equal(lhs, rhs):
......@@ -289,7 +289,7 @@ def assert_graph_equal(lhs, rhs):
rhs : tvm.relay.Expr
One of the input Expression.
"""
_analysis._assert_graph_equal(lhs, rhs)
_ffi_api._assert_graph_equal(lhs, rhs)
def collect_device_info(expr):
......@@ -303,10 +303,10 @@ def collect_device_info(expr):
Returns
-------
ret : Dict[tvm.relay.expr, int]
ret : Dict[tvm.relay.ir.expr, int]
A dictionary mapping tvm.relay.Expr to device type.
"""
return _analysis.CollectDeviceInfo(expr)
return _ffi_api.CollectDeviceInfo(expr)
def collect_device_annotation_ops(expr):
......@@ -319,11 +319,11 @@ def collect_device_annotation_ops(expr):
Returns
-------
ret : Dict[tvm.relay.expr, int]
ret : Dict[tvm.relay.Expr, int]
A dictionary mapping tvm.relay.Expr to device type where the keys are
annotation expressions.
"""
return _analysis.CollectDeviceAnnotationOps(expr)
return _ffi_api.CollectDeviceAnnotationOps(expr)
def get_total_mac_number(expr):
......@@ -340,7 +340,7 @@ def get_total_mac_number(expr):
result : int64
The number of MACs (multiply-accumulate) of a model
"""
return _analysis.GetTotalMacNumber(expr)
return _ffi_api.GetTotalMacNumber(expr)
def unmatched_cases(match, mod=None):
......@@ -360,7 +360,7 @@ def unmatched_cases(match, mod=None):
missing_patterns : [tvm.relay.Pattern]
Patterns that the match expression does not catch.
"""
return _analysis.unmatched_cases(match, mod)
return _ffi_api.unmatched_cases(match, mod)
def detect_feature(a, b=None):
......@@ -383,7 +383,7 @@ def detect_feature(a, b=None):
"""
if isinstance(a, IRModule):
a, b = b, a
return {Feature(int(x)) for x in _analysis.detect_feature(a, b)}
return {Feature(int(x)) for x in _ffi_api.detect_feature(a, b)}
def structural_hash(value):
......@@ -400,9 +400,9 @@ def structural_hash(value):
The hash value
"""
if isinstance(value, RelayExpr):
return int(_analysis._expr_hash(value))
return int(_ffi_api._expr_hash(value))
elif isinstance(value, Type):
return int(_analysis._type_hash(value))
return int(_ffi_api._type_hash(value))
else:
msg = ("found value of type {0} expected" +
"relay.Expr or relay.Type").format(type(value))
......@@ -421,10 +421,10 @@ def extract_fused_functions(mod):
Returns
-------
ret : Dict[int, tvm.relay.expr.Function]
ret : Dict[int, tvm.relay.ir.expr.Function]
A module containing only fused primitive functions
"""
ret_mod = _analysis.ExtractFusedFunctions()(mod)
ret_mod = _ffi_api.ExtractFusedFunctions()(mod)
ret = {}
for hash_, func in ret_mod.functions.items():
ret[hash_] = func
......
......@@ -18,9 +18,9 @@
"""Call graph used in Relay."""
from tvm.ir import IRModule
from .base import Object
from .expr import GlobalVar
from . import _analysis
from tvm.runtime import Object
from ..expr import GlobalVar
from . import _ffi_api
class CallGraph(Object):
......@@ -39,7 +39,7 @@ class CallGraph(Object):
call_graph: CallGraph
A constructed call graph.
"""
self.__init_handle_by_constructor__(_analysis.CallGraph, module)
self.__init_handle_by_constructor__(_ffi_api.CallGraph, module)
@property
def module(self):
......@@ -54,7 +54,7 @@ class CallGraph(Object):
ret : tvm.ir.IRModule
The contained IRModule
"""
return _analysis.GetModule(self)
return _ffi_api.GetModule(self)
def ref_count(self, var):
"""Return the number of references to the global var
......@@ -69,7 +69,7 @@ class CallGraph(Object):
The number reference to the global var
"""
var = self._get_global_var(var)
return _analysis.GetRefCountGlobalVar(self, var)
return _ffi_api.GetRefCountGlobalVar(self, var)
def global_call_count(self, var):
"""Return the number of global function calls from a given global var.
......@@ -84,7 +84,7 @@ class CallGraph(Object):
The number of global function calls from the given var.
"""
var = self._get_global_var(var)
return _analysis.GetGlobalVarCallCount(self, var)
return _ffi_api.GetGlobalVarCallCount(self, var)
def is_recursive(self, var):
"""Return if the function corresponding to a var is a recursive
......@@ -100,7 +100,7 @@ class CallGraph(Object):
If the function corresponding to var is recurisve.
"""
var = self._get_global_var(var)
return _analysis.IsRecursive(self, var)
return _ffi_api.IsRecursive(self, var)
def _get_global_var(self, var):
"""Return the global var using a given name or GlobalVar.
......@@ -137,8 +137,8 @@ class CallGraph(Object):
The call graph represented in string.
"""
var = self._get_global_var(var)
return _analysis.PrintCallGraphGlobalVar(self, var)
return _ffi_api.PrintCallGraphGlobalVar(self, var)
def __str__(self):
"""Print the call graph in the topological order."""
return _analysis.PrintCallGraph(self)
return _ffi_api.PrintCallGraph(self)
......@@ -22,7 +22,7 @@ import logging
import numpy as np
import tvm
from tvm import te
from ..base import register_relay_node, Object
from tvm.runtime import Object
from ... import target as _target
from ... import autotvm
from .. import expr as _expr
......@@ -33,7 +33,7 @@ from . import _backend
logger = logging.getLogger('compile_engine')
@register_relay_node
@tvm._ffi.register_object("relay.LoweredOutput")
class LoweredOutput(Object):
"""Lowered output"""
def __init__(self, outputs, implement):
......@@ -41,7 +41,7 @@ class LoweredOutput(Object):
_backend._make_LoweredOutput, outputs, implement)
@register_relay_node
@tvm._ffi.register_object("relay.CCacheKey")
class CCacheKey(Object):
"""Key in the CompileEngine.
......@@ -58,7 +58,7 @@ class CCacheKey(Object):
_backend._make_CCacheKey, source_func, target)
@register_relay_node
@tvm._ffi.register_object("relay.CCacheValue")
class CCacheValue(Object):
"""Value in the CompileEngine, including usage statistics.
"""
......@@ -261,7 +261,7 @@ def lower_call(call, inputs, target):
return LoweredOutput(outputs, best_impl)
@register_relay_node
@tvm._ffi.register_object("relay.CompileEngine")
class CompileEngine(Object):
"""CompileEngine to get lowered code.
"""
......
......@@ -20,25 +20,25 @@ from __future__ import absolute_import
import numpy as np
from tvm.runtime import container
import tvm._ffi
from tvm.runtime import container, Object
from tvm.ir import IRModule
from . import _backend
from .. import _make, analysis, transform
from ... import nd
from ..base import Object, register_relay_node
from ..expr import Tuple, RefCreate, Call, Constant, GlobalVar, Function, const
from ..scope_builder import ScopeBuilder
@register_relay_node
@tvm._ffi.register_object("relay.ConstructorValue")
class ConstructorValue(Object):
def __init__(self, tag, fields, constructor):
self.__init_handle_by_constructor__(
_make.ConstructorValue, tag, fields, constructor)
@register_relay_node
@tvm._ffi.register_object("relay.RefValue")
class RefValue(Object):
def __init__(self, value):
self.__init_handle_by_constructor__(
......
......@@ -21,47 +21,17 @@ import tvm._ffi
from tvm.runtime import Object
from tvm.ir import SourceName, Span, Node as RelayNode
from . import _make
from . import _expr
from . import _base
__STD_PATH__ = os.path.join(os.path.dirname(os.path.realpath(__file__)), "std")
@tvm._ffi.register_func("tvm.relay.std_path")
def _std_path():
return __STD_PATH__
def register_relay_node(type_key=None):
"""Register a Relay node type.
Parameters
----------
type_key : str or cls
The type key of the node.
"""
if not isinstance(type_key, str):
return tvm._ffi.register_object(
"relay." + type_key.__name__)(type_key)
return tvm._ffi.register_object(type_key)
def register_relay_attr_node(type_key=None):
"""Register a Relay attribute node.
Parameters
----------
type_key : str or cls
The type key of the node.
"""
if not isinstance(type_key, str):
return tvm._ffi.register_object(
"relay.attrs." + type_key.__name__)(type_key)
return tvm._ffi.register_object(type_key)
@register_relay_node
@tvm._ffi.register_object("relay.Id")
class Id(Object):
"""Unique identifier(name) used in Var.
Guaranteed to be stable across all passes.
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=wildcard-import, unused-import, unused-wildcard-import
"""Contrib operators."""
# Re-export in a specific file name so that autodoc can pick it up
from .op.contrib import *
......@@ -20,13 +20,13 @@ from __future__ import absolute_import
from numbers import Number as _Number
import numpy as _np
import tvm._ffi
from tvm._ffi import base as _base
from tvm.runtime import NDArray, convert, ndarray as _nd
from tvm.ir import RelayExpr, GlobalVar, BaseFunc
from .base import RelayNode, register_relay_node
from . import _make
from . import _expr
from .base import RelayNode
from . import _ffi_api
from . import ty as _ty
# alias relay expr as Expr.
......@@ -54,7 +54,7 @@ class ExprWithOp(RelayExpr):
result : tvm.relay.Expr
The result expression.
"""
return _make.cast(self, dtype)
return _ffi_api.cast(self, dtype)
def __neg__(self):
return _op_make.negative(self)
......@@ -160,7 +160,7 @@ class ExprWithOp(RelayExpr):
"""
return Call(self, args)
@register_relay_node
@tvm._ffi.register_object("relay.Constant")
class Constant(ExprWithOp):
"""A constant expression in Relay.
......@@ -170,10 +170,10 @@ class Constant(ExprWithOp):
The data content of the constant expression.
"""
def __init__(self, data):
self.__init_handle_by_constructor__(_make.Constant, data)
self.__init_handle_by_constructor__(_ffi_api.Constant, data)
@register_relay_node
@tvm._ffi.register_object("relay.Tuple")
class Tuple(ExprWithOp):
"""Tuple expression that groups several fields together.
......@@ -183,7 +183,7 @@ class Tuple(ExprWithOp):
The fields in the tuple.
"""
def __init__(self, fields):
self.__init_handle_by_constructor__(_make.Tuple, fields)
self.__init_handle_by_constructor__(_ffi_api.Tuple, fields)
def __getitem__(self, index):
if index >= len(self):
......@@ -197,7 +197,7 @@ class Tuple(ExprWithOp):
raise TypeError("astype cannot be used on tuple")
@register_relay_node
@tvm._ffi.register_object("relay.Var")
class Var(ExprWithOp):
"""A local variable in Relay.
......@@ -216,7 +216,7 @@ class Var(ExprWithOp):
"""
def __init__(self, name_hint, type_annotation=None):
self.__init_handle_by_constructor__(
_make.Var, name_hint, type_annotation)
_ffi_api.Var, name_hint, type_annotation)
@property
def name_hint(self):
......@@ -225,7 +225,7 @@ class Var(ExprWithOp):
return name
@register_relay_node
@tvm._ffi.register_object("relay.Function")
class Function(BaseFunc):
"""A function declaration expression.
......@@ -254,7 +254,7 @@ class Function(BaseFunc):
type_params = convert([])
self.__init_handle_by_constructor__(
_make.Function, params, body, ret_type, type_params, attrs)
_ffi_api.Function, params, body, ret_type, type_params, attrs)
def __call__(self, *args):
"""Invoke the global function.
......@@ -282,12 +282,12 @@ class Function(BaseFunc):
func : Function
A new copy of the function
"""
return _expr.FunctionWithAttr(
return _ffi_api.FunctionWithAttr(
self, attr_key, convert(attr_value))
@register_relay_node
@tvm._ffi.register_object("relay.Call")
class Call(ExprWithOp):
"""Function call node in Relay.
......@@ -313,10 +313,10 @@ class Call(ExprWithOp):
if not type_args:
type_args = []
self.__init_handle_by_constructor__(
_make.Call, op, args, attrs, type_args)
_ffi_api.Call, op, args, attrs, type_args)
@register_relay_node
@tvm._ffi.register_object("relay.Let")
class Let(ExprWithOp):
"""Let variable binding expression.
......@@ -333,10 +333,10 @@ class Let(ExprWithOp):
"""
def __init__(self, variable, value, body):
self.__init_handle_by_constructor__(
_make.Let, variable, value, body)
_ffi_api.Let, variable, value, body)
@register_relay_node
@tvm._ffi.register_object("relay.If")
class If(ExprWithOp):
"""A conditional expression in Relay.
......@@ -353,10 +353,10 @@ class If(ExprWithOp):
"""
def __init__(self, cond, true_branch, false_branch):
self.__init_handle_by_constructor__(
_make.If, cond, true_branch, false_branch)
_ffi_api.If, cond, true_branch, false_branch)
@register_relay_node
@tvm._ffi.register_object("relay.TupleGetItem")
class TupleGetItem(ExprWithOp):
"""Get index-th item from a tuple.
......@@ -370,10 +370,10 @@ class TupleGetItem(ExprWithOp):
"""
def __init__(self, tuple_value, index):
self.__init_handle_by_constructor__(
_make.TupleGetItem, tuple_value, index)
_ffi_api.TupleGetItem, tuple_value, index)
@register_relay_node
@tvm._ffi.register_object("relay.RefCreate")
class RefCreate(ExprWithOp):
"""Create a new reference from initial value.
Parameters
......@@ -382,10 +382,10 @@ class RefCreate(ExprWithOp):
The initial value.
"""
def __init__(self, value):
self.__init_handle_by_constructor__(_make.RefCreate, value)
self.__init_handle_by_constructor__(_ffi_api.RefCreate, value)
@register_relay_node
@tvm._ffi.register_object("relay.RefRead")
class RefRead(ExprWithOp):
"""Get the value inside the reference.
Parameters
......@@ -394,10 +394,10 @@ class RefRead(ExprWithOp):
The reference.
"""
def __init__(self, ref):
self.__init_handle_by_constructor__(_make.RefRead, ref)
self.__init_handle_by_constructor__(_ffi_api.RefRead, ref)
@register_relay_node
@tvm._ffi.register_object("relay.RefWrite")
class RefWrite(ExprWithOp):
"""
Update the value inside the reference.
......@@ -410,7 +410,7 @@ class RefWrite(ExprWithOp):
The new value.
"""
def __init__(self, ref, value):
self.__init_handle_by_constructor__(_make.RefWrite, ref, value)
self.__init_handle_by_constructor__(_ffi_api.RefWrite, ref, value)
class TempExpr(ExprWithOp):
......@@ -427,7 +427,7 @@ class TempExpr(ExprWithOp):
-------
The corresponding normal expression.
"""
return _expr.TempExprRealize(self)
return _ffi_api.TempExprRealize(self)
class TupleWrapper(object):
......@@ -587,4 +587,4 @@ def bind(expr, binds):
result : tvm.relay.Expr
The expression or function after binding.
"""
return _expr.Bind(expr, binds)
return _ffi_api.Bind(expr, binds)
......@@ -26,8 +26,8 @@ from .. import analysis
from .. import expr as _expr
from .. import op as _op
from .. import qnn as _qnn
from ..util import get_scalar_from_constant
from ... import nd as _nd
from .util import get_scalar_from_constant
from .common import ExprTable
from .common import infer_shape as _infer_shape
......
......@@ -18,7 +18,7 @@
""" Utility functions that are used across many directories. """
from __future__ import absolute_import
import numpy as np
from . import expr as _expr
from .. import expr as _expr
def get_scalar_from_constant(expr):
""" Returns scalar value from Relay constant scalar. """
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=wildcard-import, unused-import, unused-wildcard-import
"""Image network related operators."""
# Re-export in a specific file name so that autodoc can pick it up
from .op.image import *
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=wildcard-import, unused-import, unused-wildcard-import
"""Neural network related operators."""
# Re-export in a specific file name so that autodoc can pick it up
from .op.nn import *
......@@ -41,7 +41,6 @@ from . import _tensor_grad
from . import _transform
from . import _reduce
from . import _algorithm
from ..base import register_relay_node
def _register_op_make():
......
......@@ -19,13 +19,12 @@
import tvm._ffi
from tvm.driver import lower, build
from ..base import register_relay_node
from ..expr import RelayExpr
from ...target import get_native_generic_func, GenericFunc
from ...runtime import Object
from . import _make
@register_relay_node
@tvm._ffi.register_object("relay.Op")
class Op(RelayExpr):
"""A Relay operator definition."""
......
......@@ -15,32 +15,31 @@
# specific language governing permissions and limitations
# under the License.
"""The attributes node used for Relay operators"""
from tvm.ir import Attrs
from ..base import register_relay_attr_node
import tvm._ffi
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.Conv1DAttrs")
class Conv1DAttrs(Attrs):
"""Attributes for nn.conv1d"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.Conv2DAttrs")
class Conv2DAttrs(Attrs):
"""Attributes for nn.conv2d"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.Conv2DWinogradAttrs")
class Conv2DWinogradAttrs(Attrs):
"""Attributes for nn.contrib_conv2d_winograd_without_weight_transform"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.Conv2DWinogradWeightTransformAttrs")
class Conv2DWinogradWeightTransformAttrs(Attrs):
"""Attributes for nn.contrib_conv2d_winograd_weight_transform"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.Conv2DWinogradNNPACKWeightTransformAttrs")
class Conv2DWinogradNNPACKWeightTransformAttrs(Attrs):
"""Attributes for nn.contrib_conv2d_winograd_nnpack_weight_transform"""
......@@ -48,285 +47,285 @@ class Conv2DWinogradNNPACKWeightTransformAttrs(Attrs):
class Dilation2DAttrs(Attrs):
"""Attributes for nn.dilation2d"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.GlobalPool2DAttrs")
class GlobalPool2DAttrs(Attrs):
"""Attributes for nn.global_pool"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.BiasAddAttrs")
class BiasAddAttrs(Attrs):
"""Atttribute of nn.bias_add"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.DenseAttrs")
class DenseAttrs(Attrs):
"""Attributes for nn.dense"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.FIFOBufferAttrs")
class FIFOBufferAttrs(Attrs):
"""Attributes for nn.fifo_buffer"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.UpSamplingAttrs")
class UpSamplingAttrs(Attrs):
"""Attributes for nn.upsampling"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.UpSampling3DAttrs")
class UpSampling3DAttrs(Attrs):
"""Attributes for nn.upsampling3d"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.PadAttrs")
class PadAttrs(Attrs):
"""Attributes for nn.pad"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.MirrorPadAttrs")
class MirrorPadAttrs(Attrs):
"""Attributes for nn.mirror_pad"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.LeakyReluAttrs")
class LeakyReluAttrs(Attrs):
"""Attributes for nn.leaky_relu"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.PReluAttrs")
class PReluAttrs(Attrs):
"""Attributes for nn.prelu"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.DropoutAttrs")
class DropoutAttrs(Attrs):
"""Attributes for nn.dropout"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.BatchNormAttrs")
class BatchNormAttrs(Attrs):
"""Attributes for nn.batch_norm"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.LRNAttrs")
class LRNAttrs(Attrs):
"""Attributes for nn.lrn"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.L2NormalizeAttrs")
class L2NormalizeAttrs(Attrs):
"""Attributes for nn.l2_normalize"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.DeformableConv2DAttrs")
class DeformableConv2DAttrs(Attrs):
"""Attributes for nn.deformable_conv2d"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.ResizeAttrs")
class ResizeAttrs(Attrs):
"""Attributes for image.resize"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.CropAndResizeAttrs")
class CropAndResizeAttrs(Attrs):
"""Attributes for image.crop_and_resize"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.ArgsortAttrs")
class ArgsortAttrs(Attrs):
"""Attributes for algorithm.argsort"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.OnDeviceAttrs")
class OnDeviceAttrs(Attrs):
"""Attributes for annotation.on_device"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.DebugAttrs")
class DebugAttrs(Attrs):
"""Attributes for debug"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.OnDeviceAttrs")
class DeviceCopyAttrs(Attrs):
"""Attributes for tensor.device_copy"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.CastAttrs")
class CastAttrs(Attrs):
"""Attributes for transform.cast"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.ConcatenateAttrs")
class ConcatenateAttrs(Attrs):
"""Attributes for tensor.concatenate"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.TransposeAttrs")
class TransposeAttrs(Attrs):
"""Attributes for transform.transpose"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.ReshapeAttrs")
class ReshapeAttrs(Attrs):
"""Attributes for transform.reshape"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.TakeAttrs")
class TakeAttrs(Attrs):
"""Attributes for transform.take"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.InitOpAttrs")
class InitOpAttrs(Attrs):
"""Attributes for ops specifying a tensor"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.ArangeAttrs")
class ArangeAttrs(Attrs):
"""Attributes used in arange operators"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.StackAttrs")
class StackAttrs(Attrs):
"""Attributes used in stack operators"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.RepeatAttrs")
class RepeatAttrs(Attrs):
"""Attributes used in repeat operators"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.TileAttrs")
class TileAttrs(Attrs):
"""Attributes used in tile operators"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.ReverseAttrs")
class ReverseAttrs(Attrs):
"""Attributes used in reverse operators"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.SqueezeAttrs")
class SqueezeAttrs(Attrs):
"""Attributes used in squeeze operators"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.SplitAttrs")
class SplitAttrs(Attrs):
"""Attributes for transform.split"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.StridedSliceAttrs")
class StridedSliceAttrs(Attrs):
"""Attributes for transform.stranded_slice"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.SliceLikeAttrs")
class SliceLikeAttrs(Attrs):
"""Attributes for transform.slice_like"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.ClipAttrs")
class ClipAttrs(Attrs):
"""Attributes for transform.clip"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.LayoutTransformAttrs")
class LayoutTransformAttrs(Attrs):
"""Attributes for transform.layout_transform"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.ShapeOfAttrs")
class ShapeOfAttrs(Attrs):
"""Attributes for tensor.shape_of"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.MultiBoxPriorAttrs")
class MultiBoxPriorAttrs(Attrs):
"""Attributes for vision.multibox_prior"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.MultiBoxTransformLocAttrs")
class MultiBoxTransformLocAttrs(Attrs):
"""Attributes for vision.multibox_transform_loc"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.GetValidCountsAttrs")
class GetValidCountsAttrs(Attrs):
"""Attributes for vision.get_valid_counts"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.NonMaximumSuppressionAttrs")
class NonMaximumSuppressionAttrs(Attrs):
"""Attributes for vision.non_maximum_suppression"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.ROIAlignAttrs")
class ROIAlignAttrs(Attrs):
"""Attributes for vision.roi_align"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.ROIPoolAttrs")
class ROIPoolAttrs(Attrs):
"""Attributes for vision.roi_pool"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.YoloReorgAttrs")
class YoloReorgAttrs(Attrs):
"""Attributes for vision.yolo_reorg"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.ProposalAttrs")
class ProposalAttrs(Attrs):
"""Attributes used in proposal operators"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.MaxPool2DAttrs")
class MaxPool2DAttrs(Attrs):
"""Attributes used in max_pool2d operators"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.AvgPool2DAttrs")
class AvgPool2DAttrs(Attrs):
"""Attributes used in avg_pool2d operators"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.MaxPool1DAttrs")
class MaxPool1DAttrs(Attrs):
"""Attributes used in max_pool1d operators"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.AvgPool1DAttrs")
class AvgPool1DAttrs(Attrs):
"""Attributes used in avg_pool1d operators"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.MaxPool3DAttrs")
class MaxPool3DAttrs(Attrs):
"""Attributes used in max_pool3d operators"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.AvgPool3DAttrs")
class AvgPool3DAttrs(Attrs):
"""Attributes used in avg_pool3d operators"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.BitPackAttrs")
class BitPackAttrs(Attrs):
"""Attributes used in bitpack operator"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.BinaryConv2DAttrs")
class BinaryConv2DAttrs(Attrs):
"""Attributes used in bitserial conv2d operators"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.BinaryDenseAttrs")
class BinaryDenseAttrs(Attrs):
"""Attributes used in bitserial dense operators"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.Conv2DTransposeAttrs")
class Conv2DTransposeAttrs(Attrs):
"""Attributes used in Transposed Conv2D operators"""
@register_relay_attr_node
@tvm._ffi.register_object("relay.attrs.SubPixelAttrs")
class SubPixelAttrs(Attrs):
"""Attributes used in depth to space and space to depth operators"""
......@@ -38,7 +38,7 @@ def cast(data, dtype):
result : relay.Expr
The casted result.
"""
from .. import _make as _relay_make
from .. import _ffi_api as _relay_make
return _relay_make.cast(data, dtype)
......@@ -55,7 +55,7 @@ def cast_like(data, dtype_like):
result : relay.Expr
The casted result.
"""
from .. import _make as _relay_make
from .. import _ffi_api as _relay_make
return _relay_make.cast_like(data, dtype_like)
......
......@@ -21,7 +21,7 @@ from __future__ import absolute_import
import tvm
from tvm import relay
from .. import op as reg
from ...util import get_scalar_from_constant
from ...frontend.util import get_scalar_from_constant
#################################################
# Register the functions for different operators.
......
......@@ -24,7 +24,6 @@ from .. import expr as _expr
from .. import analysis as _analysis
from .. import op as _op
from ..op import op as _reg
from ..base import register_relay_node
from . import _quantize
from .quantize import QAnnotateKind, current_qconfig, quantize_context
from .quantize import _forward_op
......@@ -58,7 +57,7 @@ _reg.register_pattern("relay.op.annotation.simulated_quantize",
_reg.register_injective_schedule("annotation.cast_hint")
@register_relay_node
@tvm._ffi.register_object("relay.QAnnotateExpr")
class QAnnotateExpr(_expr.TempExpr):
"""A special kind of Expr for Annotating.
......
......@@ -19,7 +19,6 @@
import tvm
from .. import expr as _expr
from .. import analysis as _analysis
from ..base import register_relay_node
from ..op import op as _reg
from . import _quantize
from .quantize import _forward_op
......@@ -30,7 +29,7 @@ def register_partition_function(op_name, frewrite=None, level=10):
return _register(frewrite) if frewrite is not None else _register
@register_relay_node
@tvm._ffi.register_object("relay.QPartitionExpr")
class QPartitionExpr(_expr.TempExpr):
def __init__(self, expr):
self.__init_handle_by_constructor__(
......
......@@ -17,12 +17,12 @@
#pylint: disable=unused-argument, not-context-manager
"""Automatic quantization toolkit."""
import tvm.ir
from tvm.runtime import Object
from . import _quantize
from ._calibrate import calibrate
from .. import expr as _expr
from .. import transform as _transform
from ..base import Object, register_relay_node
class QAnnotateKind(object):
......@@ -52,7 +52,7 @@ def _forward_op(ref_call, args):
ref_call.op, args, ref_call.attrs, ref_call.type_args)
@register_relay_node("relay.quantize.QConfig")
@tvm._ffi.register_object("relay.quantize.QConfig")
class QConfig(Object):
"""Configure the quantization behavior by setting config variables.
......
......@@ -14,8 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable
"""The interface of expr function exposed from C++."""
import tvm._ffi
# pylint: disable=wildcard-import, redefined-builtin, invalid-name
"""The Relay IR namespace containing transformations."""
# transformation passes
from .transform import *
tvm._ffi._init_api("relay._expr", __name__)
from . import memory_alloc
......@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""FFI exposing the Relay type inference and checking."""
"""FFI APIs for Relay transformation passes."""
import tvm._ffi
tvm._ffi._init_api("relay._transform", __name__)
......@@ -19,12 +19,13 @@
A pass for manifesting explicit memory allocations.
"""
import numpy as np
from .expr_functor import ExprMutator
from .scope_builder import ScopeBuilder
from ..expr_functor import ExprMutator
from ..scope_builder import ScopeBuilder
from . import transform
from . import op, ty, expr
from .. import DataType, register_func
from .backend import compile_engine
from .. import op
from ... import DataType, register_func
from .. import ty, expr
from ..backend import compile_engine
def is_primitive(call):
......
......@@ -28,8 +28,7 @@ from tvm.runtime import ndarray as _nd
from tvm.ir.transform import PassInfo, PassContext, Pass, ModulePass, Sequential, module_pass
from tvm import relay
from . import _transform
from .base import register_relay_node
from . import _ffi_api
def build_config(opt_level=2,
......@@ -83,7 +82,7 @@ def build_config(opt_level=2,
disabled_pass, trace)
@register_relay_node
@tvm._ffi.register_object("relay.FunctionPass")
class FunctionPass(Pass):
"""A pass that works on each tvm.relay.Function in a module. A function
pass class should be created through `function_pass`.
......@@ -98,7 +97,7 @@ def InferType():
ret : tvm.relay.Pass
The registered type inference pass.
"""
return _transform.InferType()
return _ffi_api.InferType()
def FoldScaleAxis():
......@@ -116,7 +115,7 @@ def FoldScaleAxis():
forward_fold_scale_axis as backward folding targets the common conv->bn
pattern.
"""
return _transform.FoldScaleAxis()
return _ffi_api.FoldScaleAxis()
def BackwardFoldScaleAxis():
......@@ -133,7 +132,7 @@ def BackwardFoldScaleAxis():
before using forward_fold_scale_axis as backward folding targets the common
conv->bn pattern.
"""
return _transform.BackwardFoldScaleAxis()
return _ffi_api.BackwardFoldScaleAxis()
def RemoveUnusedFunctions(entry_functions=None):
"""Remove unused global relay functions in a relay module.
......@@ -150,7 +149,7 @@ def RemoveUnusedFunctions(entry_functions=None):
"""
if entry_functions is None:
entry_functions = ['main']
return _transform.RemoveUnusedFunctions(entry_functions)
return _ffi_api.RemoveUnusedFunctions(entry_functions)
def ForwardFoldScaleAxis():
"""Fold the scaling of axis into weights of conv2d/dense.
......@@ -166,7 +165,7 @@ def ForwardFoldScaleAxis():
before using forward_fold_scale_axis, as backward folding targets the
common conv->bn pattern.
"""
return _transform.ForwardFoldScaleAxis()
return _ffi_api.ForwardFoldScaleAxis()
def SimplifyInference():
......@@ -178,7 +177,7 @@ def SimplifyInference():
ret: tvm.relay.Pass
The registered pass to perform operator simplification.
"""
return _transform.SimplifyInference()
return _ffi_api.SimplifyInference()
def FastMath():
......@@ -189,7 +188,7 @@ def FastMath():
ret: tvm.relay.Pass
The registered pass to perform fast math operations.
"""
return _transform.FastMath()
return _ffi_api.FastMath()
def CanonicalizeOps():
......@@ -202,7 +201,7 @@ def CanonicalizeOps():
ret: tvm.relay.Pass
The registered pass performing the canonicalization.
"""
return _transform.CanonicalizeOps()
return _ffi_api.CanonicalizeOps()
def DeadCodeElimination(inline_once=False):
......@@ -218,7 +217,7 @@ def DeadCodeElimination(inline_once=False):
ret: tvm.relay.Pass
The registered pass that eliminates the dead code in a Relay program.
"""
return _transform.DeadCodeElimination(inline_once)
return _ffi_api.DeadCodeElimination(inline_once)
def FoldConstant():
......@@ -229,7 +228,7 @@ def FoldConstant():
ret : tvm.relay.Pass
The registered pass for constant folding.
"""
return _transform.FoldConstant()
return _ffi_api.FoldConstant()
def FuseOps(fuse_opt_level=-1):
......@@ -246,7 +245,7 @@ def FuseOps(fuse_opt_level=-1):
ret : tvm.relay.Pass
The registered pass for operator fusion.
"""
return _transform.FuseOps(fuse_opt_level)
return _ffi_api.FuseOps(fuse_opt_level)
def CombineParallelConv2D(min_num_branches=3):
......@@ -263,7 +262,7 @@ def CombineParallelConv2D(min_num_branches=3):
ret: tvm.relay.Pass
The registered pass that combines parallel conv2d operators.
"""
return _transform.CombineParallelConv2D(min_num_branches)
return _ffi_api.CombineParallelConv2D(min_num_branches)
def CombineParallelDense(min_num_branches=3):
......@@ -295,7 +294,7 @@ def CombineParallelDense(min_num_branches=3):
ret: tvm.relay.Pass
The registered pass that combines parallel dense operators.
"""
return _transform.CombineParallelDense(min_num_branches)
return _ffi_api.CombineParallelDense(min_num_branches)
def AlterOpLayout():
......@@ -309,7 +308,7 @@ def AlterOpLayout():
ret : tvm.relay.Pass
The registered pass that alters the layout of operators.
"""
return _transform.AlterOpLayout()
return _ffi_api.AlterOpLayout()
def ConvertLayout(desired_layout):
......@@ -337,7 +336,7 @@ def ConvertLayout(desired_layout):
pass: FunctionPass
The pass.
"""
return _transform.ConvertLayout(desired_layout)
return _ffi_api.ConvertLayout(desired_layout)
def Legalize(legalize_map_attr_name="FTVMLegalize"):
......@@ -357,7 +356,7 @@ def Legalize(legalize_map_attr_name="FTVMLegalize"):
ret : tvm.relay.Pass
The registered pass that rewrites an expr.
"""
return _transform.Legalize(legalize_map_attr_name)
return _ffi_api.Legalize(legalize_map_attr_name)
def MergeComposite(pattern_table):
......@@ -382,7 +381,7 @@ def MergeComposite(pattern_table):
pattern_names.append(pattern_name)
patterns.append(pattern)
return _transform.MergeComposite(pattern_names, patterns)
return _ffi_api.MergeComposite(pattern_names, patterns)
def RewriteAnnotatedOps(fallback_device):
......@@ -403,7 +402,7 @@ def RewriteAnnotatedOps(fallback_device):
The registered pass that rewrites an expression with annotated
`on_device` operators.
"""
return _transform.RewriteDeviceAnnotation(fallback_device)
return _ffi_api.RewriteDeviceAnnotation(fallback_device)
def ToANormalForm():
......@@ -417,7 +416,7 @@ def ToANormalForm():
ret: Union[tvm.relay.Pass, tvm.relay.Expr]
The registered pass that transforms an expression into A Normal Form.
"""
return _transform.ToANormalForm()
return _ffi_api.ToANormalForm()
def ToCPS(expr, mod=None):
......@@ -431,7 +430,7 @@ def ToCPS(expr, mod=None):
result: tvm.relay.Pass
The registered pass that transforms an expression into CPS.
"""
return _transform.to_cps(expr, mod)
return _ffi_api.to_cps(expr, mod)
def EtaExpand(expand_constructor=False, expand_global_var=False):
......@@ -450,7 +449,7 @@ def EtaExpand(expand_constructor=False, expand_global_var=False):
ret: tvm.relay.Pass
The registered pass that eta expands an expression.
"""
return _transform.EtaExpand(expand_constructor, expand_global_var)
return _ffi_api.EtaExpand(expand_constructor, expand_global_var)
def ToGraphNormalForm():
......@@ -461,7 +460,7 @@ def ToGraphNormalForm():
ret : tvm.relay.Pass
The registered pass that transforms an expression into Graph Normal Form.
"""
return _transform.ToGraphNormalForm()
return _ffi_api.ToGraphNormalForm()
def EliminateCommonSubexpr(fskip=None):
......@@ -478,7 +477,7 @@ def EliminateCommonSubexpr(fskip=None):
ret : tvm.relay.Pass
The registered pass that eliminates common subexpressions.
"""
return _transform.EliminateCommonSubexpr(fskip)
return _ffi_api.EliminateCommonSubexpr(fskip)
def PartialEvaluate():
......@@ -496,7 +495,7 @@ def PartialEvaluate():
ret: tvm.relay.Pass
The registered pass that performs partial evaluation on an expression.
"""
return _transform.PartialEvaluate()
return _ffi_api.PartialEvaluate()
def CanonicalizeCast():
......@@ -508,7 +507,7 @@ def CanonicalizeCast():
ret : tvm.relay.Pass
The registered pass that canonicalizes cast expression.
"""
return _transform.CanonicalizeCast()
return _ffi_api.CanonicalizeCast()
def LambdaLift():
......@@ -520,7 +519,7 @@ def LambdaLift():
ret : tvm.relay.Pass
The registered pass that lifts the lambda function.
"""
return _transform.LambdaLift()
return _ffi_api.LambdaLift()
def PrintIR(show_meta_data=True):
......@@ -537,7 +536,7 @@ def PrintIR(show_meta_data=True):
ret : tvm.relay.Pass
The registered pass that prints the module IR.
"""
return _transform.PrintIR(show_meta_data)
return _ffi_api.PrintIR(show_meta_data)
def PartitionGraph():
......@@ -549,7 +548,7 @@ def PartitionGraph():
ret: tvm.relay.Pass
The registered pass that partitions the Relay program.
"""
return _transform.PartitionGraph()
return _ffi_api.PartitionGraph()
......@@ -568,7 +567,7 @@ def AnnotateTarget(target):
The annotated pass that wrapps ops with subgraph_start and
subgraph_end.
"""
return _transform.AnnotateTarget(target)
return _ffi_api.AnnotateTarget(target)
def Inline():
......@@ -581,7 +580,7 @@ def Inline():
ret: tvm.relay.Pass
The registered pass that performs inlining for a Relay IR module.
"""
return _transform.Inline()
return _ffi_api.Inline()
def gradient(expr, mod=None, mode='higher_order'):
......@@ -609,9 +608,9 @@ def gradient(expr, mod=None, mode='higher_order'):
The transformed expression.
"""
if mode == 'first_order':
return _transform.first_order_gradient(expr, mod)
return _ffi_api.first_order_gradient(expr, mod)
if mode == 'higher_order':
return _transform.gradient(expr, mod)
return _ffi_api.gradient(expr, mod)
raise Exception('unknown mode')
......@@ -634,7 +633,7 @@ def to_cps(func, mod=None):
result: tvm.relay.Function
The output function.
"""
return _transform.to_cps(func, mod)
return _ffi_api.to_cps(func, mod)
def un_cps(func):
......@@ -654,7 +653,7 @@ def un_cps(func):
result: tvm.relay.Function
The output function
"""
return _transform.un_cps(func)
return _ffi_api.un_cps(func)
def _wrap_class_function_pass(pass_cls, pass_info):
......@@ -670,7 +669,7 @@ def _wrap_class_function_pass(pass_cls, pass_info):
def _pass_func(func, mod, ctx):
return inst.transform_function(func, mod, ctx)
self.__init_handle_by_constructor__(
_transform.MakeFunctionPass, _pass_func, pass_info)
_ffi_api.MakeFunctionPass, _pass_func, pass_info)
self._inst = inst
def __getattr__(self, name):
......@@ -778,7 +777,7 @@ def function_pass(pass_func=None, opt_level=None, name=None, required=None):
return _wrap_class_function_pass(pass_arg, info)
if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)):
raise TypeError("pass_func must be a callable for Module pass")
return _transform.MakeFunctionPass(pass_arg, info)
return _ffi_api.MakeFunctionPass(pass_arg, info)
if pass_func:
return create_function_pass(pass_func)
......
......@@ -20,10 +20,10 @@ from tvm.ir import Type, TypeKind, TypeVar, GlobalTypeVar
from tvm.ir import TypeConstraint, FuncType, TupleType, IncompleteType
from tvm.ir import TypeCall, TypeRelation, TensorType, RelayRefType as RefType
from .base import RelayNode, register_relay_node
from . import _make
from .base import RelayNode
from . import _ffi_api
Any = _make.Any
Any = _ffi_api.Any
def type_has_any(tensor_type):
"""Check whether type has any as a shape.
......@@ -36,7 +36,7 @@ def type_has_any(tensor_type):
has_any : bool
The check result.
"""
return _make.IsDynamic(tensor_type)
return _ffi_api.IsDynamic(tensor_type)
def ShapeVar(name):
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=wildcard-import, unused-import, unused-wildcard-import
"""Vision network related operators."""
# Re-export in a specific file name so that autodoc can pick it up
from .op.vision import *
......@@ -581,7 +581,7 @@ bool AlphaEqual(const Expr& lhs, const Expr& rhs) {
return AlphaEqualHandler(false, false).ExprEqual(lhs, rhs);
}
TVM_REGISTER_GLOBAL("relay._analysis._alpha_equal")
TVM_REGISTER_GLOBAL("relay.analysis._alpha_equal")
.set_body_typed([](ObjectRef a, ObjectRef b) {
return AlphaEqualHandler(false, false).Equal(a, b);
});
......@@ -591,18 +591,18 @@ TVM_REGISTER_GLOBAL("ir.type_alpha_equal")
return AlphaEqual(a, b);
});
TVM_REGISTER_GLOBAL("relay._analysis._assert_alpha_equal")
TVM_REGISTER_GLOBAL("relay.analysis._assert_alpha_equal")
.set_body_typed([](ObjectRef a, ObjectRef b) {
bool alpha_equal = AlphaEqualHandler(false, true).Equal(a, b);
CHECK(alpha_equal) << AsText(a, true) << " and " << AsText(b, true) << " are not alpha equal";
});
TVM_REGISTER_GLOBAL("relay._analysis._graph_equal")
TVM_REGISTER_GLOBAL("relay.analysis._graph_equal")
.set_body_typed([](ObjectRef a, ObjectRef b) {
return AlphaEqualHandler(true, false).Equal(a, b);
});
TVM_REGISTER_GLOBAL("relay._analysis._assert_graph_equal")
TVM_REGISTER_GLOBAL("relay.analysis._assert_graph_equal")
.set_body_typed([](ObjectRef a, ObjectRef b) {
bool graph_equal = AlphaEqualHandler(true, true).Equal(a, b);
CHECK(graph_equal) << AsText(a, true) << " and " << AsText(b, true) << " are not graph equal";
......
......@@ -299,24 +299,24 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "CallGraph: \n" << GetRef<CallGraph>(node);
});
TVM_REGISTER_GLOBAL("relay._analysis.CallGraph")
TVM_REGISTER_GLOBAL("relay.analysis.CallGraph")
.set_body_typed([](IRModule module) {
return CallGraph(module);
});
TVM_REGISTER_GLOBAL("relay._analysis.PrintCallGraph")
TVM_REGISTER_GLOBAL("relay.analysis.PrintCallGraph")
.set_body_typed([](CallGraph call_graph) {
std::stringstream ss;
ss << call_graph;
return ss.str();
});
TVM_REGISTER_GLOBAL("relay._analysis.GetModule")
TVM_REGISTER_GLOBAL("relay.analysis.GetModule")
.set_body_typed([](CallGraph call_graph) {
return call_graph->module;
});
TVM_REGISTER_GLOBAL("relay._analysis.PrintCallGraphGlobalVar")
TVM_REGISTER_GLOBAL("relay.analysis.PrintCallGraphGlobalVar")
.set_body_typed([](CallGraph call_graph, GlobalVar var) {
const auto* entry_node = call_graph[var];
std::stringstream ss;
......@@ -324,19 +324,19 @@ TVM_REGISTER_GLOBAL("relay._analysis.PrintCallGraphGlobalVar")
return ss.str();
});
TVM_REGISTER_GLOBAL("relay._analysis.GetRefCountGlobalVar")
TVM_REGISTER_GLOBAL("relay.analysis.GetRefCountGlobalVar")
.set_body_typed([](CallGraph call_graph, GlobalVar var) {
const auto* entry_node = call_graph[var];
return static_cast<int>(entry_node->GetRefCount());
});
TVM_REGISTER_GLOBAL("relay._analysis.GetGlobalVarCallCount")
TVM_REGISTER_GLOBAL("relay.analysis.GetGlobalVarCallCount")
.set_body_typed([](CallGraph call_graph, GlobalVar var) {
const auto* entry_node = call_graph[var];
return static_cast<int>(entry_node->size());
});
TVM_REGISTER_GLOBAL("relay._analysis.IsRecursive")
TVM_REGISTER_GLOBAL("relay.analysis.IsRecursive")
.set_body_typed([](CallGraph call_graph, GlobalVar var) {
const auto* entry_node = call_graph[var];
return entry_node->IsRecursive();
......
......@@ -74,7 +74,7 @@ Pass ExtractFusedFunctions() {
"ExtractFusedFunctions");
}
TVM_REGISTER_GLOBAL("relay._analysis.ExtractFusedFunctions").set_body_typed(ExtractFusedFunctions);
TVM_REGISTER_GLOBAL("relay.analysis.ExtractFusedFunctions").set_body_typed(ExtractFusedFunctions);
} // namespace transform
......
......@@ -104,7 +104,7 @@ Array<Integer> PyDetectFeature(const Expr& expr, const IRModule& mod) {
return static_cast<Array<Integer>>(fs);
}
TVM_REGISTER_GLOBAL("relay._analysis.detect_feature")
TVM_REGISTER_GLOBAL("relay.analysis.detect_feature")
.set_body_typed(PyDetectFeature);
} // namespace relay
......
......@@ -186,7 +186,7 @@ Kind KindCheck(const Type& t, const IRModule& mod) {
return kc.Check(t);
}
TVM_REGISTER_GLOBAL("relay._analysis.check_kind")
TVM_REGISTER_GLOBAL("relay.analysis.check_kind")
.set_body([](TVMArgs args, TVMRetValue* ret) {
if (args.size() == 1) {
*ret = KindCheck(args[0], IRModule({}, {}));
......
......@@ -206,7 +206,7 @@ int64_t GetTotalMacNumber(const Expr& expr) {
return MacCounter::GetTotalMacNumber(expr);
}
TVM_REGISTER_GLOBAL("relay._analysis.GetTotalMacNumber")
TVM_REGISTER_GLOBAL("relay.analysis.GetTotalMacNumber")
.set_body_typed(GetTotalMacNumber);
} // namespace mac_count
......
......@@ -310,7 +310,7 @@ Array<Pattern> UnmatchedCases(const Match& match, const IRModule& mod) {
}
// expose for testing only
TVM_REGISTER_GLOBAL("relay._analysis.unmatched_cases")
TVM_REGISTER_GLOBAL("relay.analysis.unmatched_cases")
.set_body_typed(
[](const Match& match, const IRModule& mod_ref) {
IRModule call_mod = mod_ref;
......
......@@ -659,7 +659,7 @@ bool TypeSolver::Solve() {
}
// Expose type solver only for debugging purposes.
TVM_REGISTER_GLOBAL("relay._analysis._test_type_solver")
TVM_REGISTER_GLOBAL("relay.analysis._test_type_solver")
.set_body([](runtime::TVMArgs args, runtime::TVMRetValue* ret) {
using runtime::PackedFunc;
using runtime::TypedPackedFunc;
......
......@@ -274,10 +274,10 @@ tvm::Array<Var> AllVars(const Expr& expr) {
return VarVisitor().All(expr);
}
TVM_REGISTER_GLOBAL("relay._analysis.free_vars")
TVM_REGISTER_GLOBAL("relay.analysis.free_vars")
.set_body_typed(FreeVars);
TVM_REGISTER_GLOBAL("relay._analysis.bound_vars")
TVM_REGISTER_GLOBAL("relay.analysis.bound_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) {
ObjectRef x = args[0];
if (x.as<ExprNode>()) {
......@@ -287,10 +287,10 @@ TVM_REGISTER_GLOBAL("relay._analysis.bound_vars")
}
});
TVM_REGISTER_GLOBAL("relay._analysis.all_vars")
TVM_REGISTER_GLOBAL("relay.analysis.all_vars")
.set_body_typed(AllVars);
TVM_REGISTER_GLOBAL("relay._analysis.free_type_vars")
TVM_REGISTER_GLOBAL("relay.analysis.free_type_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) {
ObjectRef x = args[0];
IRModule mod = args[1];
......@@ -301,7 +301,7 @@ TVM_REGISTER_GLOBAL("relay._analysis.free_type_vars")
}
});
TVM_REGISTER_GLOBAL("relay._analysis.bound_type_vars")
TVM_REGISTER_GLOBAL("relay.analysis.bound_type_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) {
ObjectRef x = args[0];
IRModule mod = args[1];
......@@ -312,7 +312,7 @@ TVM_REGISTER_GLOBAL("relay._analysis.bound_type_vars")
}
});
TVM_REGISTER_GLOBAL("relay._analysis.all_type_vars")
TVM_REGISTER_GLOBAL("relay.analysis.all_type_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) {
ObjectRef x = args[0];
IRModule mod = args[1];
......
......@@ -125,7 +125,7 @@ bool WellFormed(const Expr& e) {
return WellFormedChecker().CheckWellFormed(e);
}
TVM_REGISTER_GLOBAL("relay._analysis.well_formed")
TVM_REGISTER_GLOBAL("relay.analysis.well_formed")
.set_body_typed(WellFormed);
} // namespace relay
......
......@@ -86,7 +86,7 @@ bool IsDynamic(const Type& ty) {
}
// TODO(@jroesch): MOVE ME
TVM_REGISTER_GLOBAL("relay._make.IsDynamic")
TVM_REGISTER_GLOBAL("relay.ir.IsDynamic")
.set_body_typed(IsDynamic);
Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
......
......@@ -34,7 +34,7 @@ PatternWildcard PatternWildcardNode::make() {
TVM_REGISTER_NODE_TYPE(PatternWildcardNode);
TVM_REGISTER_GLOBAL("relay._make.PatternWildcard")
TVM_REGISTER_GLOBAL("relay.ir.PatternWildcard")
.set_body_typed(PatternWildcardNode::make);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
......@@ -50,7 +50,7 @@ PatternVar PatternVarNode::make(tvm::relay::Var var) {
TVM_REGISTER_NODE_TYPE(PatternVarNode);
TVM_REGISTER_GLOBAL("relay._make.PatternVar")
TVM_REGISTER_GLOBAL("relay.ir.PatternVar")
.set_body_typed(PatternVarNode::make);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
......@@ -69,7 +69,7 @@ PatternConstructor PatternConstructorNode::make(Constructor constructor,
TVM_REGISTER_NODE_TYPE(PatternConstructorNode);
TVM_REGISTER_GLOBAL("relay._make.PatternConstructor")
TVM_REGISTER_GLOBAL("relay.ir.PatternConstructor")
.set_body_typed(PatternConstructorNode::make);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
......@@ -87,7 +87,7 @@ PatternTuple PatternTupleNode::make(tvm::Array<Pattern> patterns) {
TVM_REGISTER_NODE_TYPE(PatternTupleNode);
TVM_REGISTER_GLOBAL("relay._make.PatternTuple")
TVM_REGISTER_GLOBAL("relay.ir.PatternTuple")
.set_body_typed(PatternTupleNode::make);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
......@@ -105,7 +105,7 @@ Clause ClauseNode::make(Pattern lhs, Expr rhs) {
TVM_REGISTER_NODE_TYPE(ClauseNode);
TVM_REGISTER_GLOBAL("relay._make.Clause")
TVM_REGISTER_GLOBAL("relay.ir.Clause")
.set_body_typed(ClauseNode::make);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
......@@ -125,7 +125,7 @@ Match MatchNode::make(Expr data, tvm::Array<Clause> clauses, bool complete) {
TVM_REGISTER_NODE_TYPE(MatchNode);
TVM_REGISTER_GLOBAL("relay._make.Match")
TVM_REGISTER_GLOBAL("relay.ir.Match")
.set_body_typed(MatchNode::make);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
......
......@@ -38,7 +38,7 @@ Constant ConstantNode::make(runtime::NDArray data) {
TVM_REGISTER_NODE_TYPE(ConstantNode);
TVM_REGISTER_GLOBAL("relay._make.Constant")
TVM_REGISTER_GLOBAL("relay.ir.Constant")
.set_body_typed(ConstantNode::make);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
......@@ -71,7 +71,7 @@ Tuple TupleNode::make(tvm::Array<relay::Expr> fields) {
TVM_REGISTER_NODE_TYPE(TupleNode);
TVM_REGISTER_GLOBAL("relay._make.Tuple")
TVM_REGISTER_GLOBAL("relay.ir.Tuple")
.set_body_typed(TupleNode::make);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
......@@ -96,7 +96,7 @@ Var VarNode::make(std::string name_hint, Type type_annotation) {
TVM_REGISTER_NODE_TYPE(VarNode);
TVM_REGISTER_GLOBAL("relay._make.Var")
TVM_REGISTER_GLOBAL("relay.ir.Var")
.set_body_typed(static_cast<Var (*)(std::string, Type)>(VarNode::make));
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
......@@ -123,7 +123,7 @@ Call CallNode::make(Expr op, Array<Expr> args, Attrs attrs,
TVM_REGISTER_NODE_TYPE(CallNode);
TVM_REGISTER_GLOBAL("relay._make.Call")
TVM_REGISTER_GLOBAL("relay.ir.Call")
.set_body_typed(CallNode::make);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
......@@ -143,7 +143,7 @@ Let LetNode::make(Var var, Expr value, Expr body) {
TVM_REGISTER_NODE_TYPE(LetNode);
TVM_REGISTER_GLOBAL("relay._make.Let")
TVM_REGISTER_GLOBAL("relay.ir.Let")
.set_body_typed(LetNode::make);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
......@@ -163,7 +163,7 @@ If IfNode::make(Expr cond, Expr true_branch, Expr false_branch) {
TVM_REGISTER_NODE_TYPE(IfNode);
TVM_REGISTER_GLOBAL("relay._make.If")
TVM_REGISTER_GLOBAL("relay.ir.If")
.set_body_typed(IfNode::make);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
......@@ -182,7 +182,7 @@ TupleGetItem TupleGetItemNode::make(Expr tuple, int index) {
TVM_REGISTER_NODE_TYPE(TupleGetItemNode);
TVM_REGISTER_GLOBAL("relay._make.TupleGetItem")
TVM_REGISTER_GLOBAL("relay.ir.TupleGetItem")
.set_body_typed(TupleGetItemNode::make);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
......@@ -199,7 +199,7 @@ RefCreate RefCreateNode::make(Expr value) {
TVM_REGISTER_NODE_TYPE(RefCreateNode);
TVM_REGISTER_GLOBAL("relay._make.RefCreate")
TVM_REGISTER_GLOBAL("relay.ir.RefCreate")
.set_body_typed(RefCreateNode::make);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
......@@ -216,7 +216,7 @@ RefRead RefReadNode::make(Expr ref) {
TVM_REGISTER_NODE_TYPE(RefReadNode);
TVM_REGISTER_GLOBAL("relay._make.RefRead")
TVM_REGISTER_GLOBAL("relay.ir.RefRead")
.set_body_typed(RefReadNode::make);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
......@@ -234,7 +234,7 @@ RefWrite RefWriteNode::make(Expr ref, Expr value) {
TVM_REGISTER_NODE_TYPE(RefWriteNode);
TVM_REGISTER_GLOBAL("relay._make.RefWrite")
TVM_REGISTER_GLOBAL("relay.ir.RefWrite")
.set_body_typed(RefWriteNode::make);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
......@@ -243,12 +243,12 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "RefWriteNode(" << node->ref << ", " << node->value << ")";
});
TVM_REGISTER_GLOBAL("relay._expr.TempExprRealize")
TVM_REGISTER_GLOBAL("relay.ir.TempExprRealize")
.set_body_typed([](TempExpr temp) {
return temp->Realize();
});
TVM_REGISTER_GLOBAL("relay._make.Any")
TVM_REGISTER_GLOBAL("relay.ir.Any")
.set_body_typed([]() { return Any::make(); });
} // namespace relay
......
......@@ -347,7 +347,7 @@ void PostOrderVisit(const Expr& e, std::function<void(const Expr&)> fvisit) {
ExprApplyVisit(fvisit).VisitExpr(e);
}
TVM_REGISTER_GLOBAL("relay._analysis.post_order_visit")
TVM_REGISTER_GLOBAL("relay.analysis.post_order_visit")
.set_body_typed([](Expr expr, PackedFunc f) {
PostOrderVisit(expr, [f](const Expr& n) {
f(n);
......@@ -443,7 +443,7 @@ Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {
}
}
TVM_REGISTER_GLOBAL("relay._expr.Bind")
TVM_REGISTER_GLOBAL("relay.ir.Bind")
.set_body([](TVMArgs args, TVMRetValue* ret) {
ObjectRef input = args[0];
if (input->IsInstance<ExprNode>()) {
......
......@@ -62,7 +62,7 @@ bool FunctionNode::UseDefaultCompiler() const {
TVM_REGISTER_NODE_TYPE(FunctionNode);
TVM_REGISTER_GLOBAL("relay._make.Function")
TVM_REGISTER_GLOBAL("relay.ir.Function")
.set_body_typed([](tvm::Array<Var> params,
Expr body,
Type ret_type,
......@@ -80,7 +80,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
<< node->attrs << ")";
});
TVM_REGISTER_GLOBAL("relay._expr.FunctionWithAttr")
TVM_REGISTER_GLOBAL("relay.ir.FunctionWithAttr")
.set_body_typed(
[](Function func, std::string name, ObjectRef ref) {
return WithAttr(std::move(func), name, ref);
......
......@@ -423,12 +423,12 @@ size_t StructuralHash::operator()(const Expr& expr) const {
return RelayHashHandler().ExprHash(expr);
}
TVM_REGISTER_GLOBAL("relay._analysis._expr_hash")
TVM_REGISTER_GLOBAL("relay.analysis._expr_hash")
.set_body_typed([](ObjectRef ref) {
return static_cast<int64_t>(RelayHashHandler().Hash(ref));
});
TVM_REGISTER_GLOBAL("relay._analysis._type_hash")
TVM_REGISTER_GLOBAL("relay.analysis._type_hash")
.set_body_typed([](Type type) {
return static_cast<int64_t>(RelayHashHandler().TypeHash(type));
});
......
......@@ -82,7 +82,7 @@ Expr MakeCast(Expr data,
return CallNode::make(op, {data}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay._make.cast")
TVM_REGISTER_GLOBAL("relay.ir.cast")
.set_body_typed(MakeCast);
RELAY_REGISTER_OP("cast")
......@@ -138,7 +138,7 @@ Expr MakeCastLike(Expr data,
}
TVM_REGISTER_GLOBAL("relay._make.cast_like")
TVM_REGISTER_GLOBAL("relay.ir.cast_like")
.set_body_typed(MakeCastLike);
RELAY_REGISTER_OP("cast_like")
......
......@@ -560,10 +560,10 @@ Map<Expr, Integer> CollectDeviceAnnotationOps(const Expr& expr) {
return AnnotatationVisitor::GetAnnotations(expr);
}
TVM_REGISTER_GLOBAL("relay._analysis.CollectDeviceInfo")
TVM_REGISTER_GLOBAL("relay.analysis.CollectDeviceInfo")
.set_body_typed(CollectDeviceInfo);
TVM_REGISTER_GLOBAL("relay._analysis.CollectDeviceAnnotationOps")
TVM_REGISTER_GLOBAL("relay.analysis.CollectDeviceAnnotationOps")
.set_body_typed(CollectDeviceAnnotationOps);
namespace transform {
......
......@@ -73,7 +73,7 @@ bool ConstantCheck(const Expr& e) {
return ConstantChecker().Check(e);
}
TVM_REGISTER_GLOBAL("relay._analysis.check_constant")
TVM_REGISTER_GLOBAL("relay.analysis.check_constant")
.set_body_typed(ConstantCheck);
// TODO(tvm-team) consider combine dead-code with constant folder.
......
......@@ -18,9 +18,8 @@
import tvm
from tvm import te
from tvm import relay
from tvm.relay.analysis import detect_feature
from tvm.relay.analysis import detect_feature, Feature
from tvm.relay.transform import gradient
from tvm.relay.feature import Feature
from tvm.relay.prelude import Prelude
from tvm.relay.testing import run_infer_type
......
......@@ -17,10 +17,8 @@
import numpy as np
import tvm
from tvm import te
from tvm import relay
from tvm.contrib import graph_runtime
from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay.op import add
from tvm.relay.testing.config import ctx_list
......
......@@ -25,7 +25,7 @@ def test_callgraph_construct():
x = relay.var("x", shape=(2, 3))
y = relay.var("y", shape=(2, 3))
mod["g1"] = relay.Function([x, y], x + y)
call_graph = relay.CallGraph(mod)
call_graph = relay.analysis.CallGraph(mod)
assert "g1" in str(call_graph)
assert relay.alpha_equal(mod, call_graph.module)
......@@ -38,7 +38,7 @@ def test_print_element():
x1 = relay.var("x1", shape=(2, 3))
y1 = relay.var("y1", shape=(2, 3))
mod["g1"] = relay.Function([x1, y1], x1 - y1)
call_graph = relay.CallGraph(mod)
call_graph = relay.analysis.CallGraph(mod)
assert "#refs = 0" in str(call_graph.print_var("g0"))
assert "#refs = 0" in str(call_graph.print_var("g1"))
......@@ -54,13 +54,13 @@ def test_global_call_count():
y1 = relay.var("y1", shape=(2, 3))
g1 = relay.GlobalVar("g1")
mod[g1] = relay.Function([x1, y1], g0(x1, y1))
call_graph = relay.CallGraph(mod)
call_graph = relay.analysis.CallGraph(mod)
p0 = relay.var("p0", shape=(2, 3))
p1 = relay.var("p1", shape=(2, 3))
func = relay.Function([p0, p1], g0(p0, p1) * g1(p0, p1))
mod["main"] = func
call_graph = relay.CallGraph(mod)
call_graph = relay.analysis.CallGraph(mod)
assert call_graph.global_call_count(g0) == 0
assert call_graph.global_call_count(g1) == 1
......@@ -77,13 +77,13 @@ def test_ref_count():
y1 = relay.var("y1", shape=(2, 3))
g1 = relay.GlobalVar("g1")
mod[g1] = relay.Function([x1, y1], x1 - y1)
call_graph = relay.CallGraph(mod)
call_graph = relay.analysis.CallGraph(mod)
p0 = relay.var("p0", shape=(2, 3))
p1 = relay.var("p1", shape=(2, 3))
func = relay.Function([p0, p1], g0(p0, p1) * g1(p0, p1))
mod["main"] = func
call_graph = relay.CallGraph(mod)
call_graph = relay.analysis.CallGraph(mod)
assert call_graph.ref_count(g0) == 1
assert call_graph.ref_count(g1) == 1
......@@ -100,13 +100,13 @@ def test_nested_ref():
y1 = relay.var("y1", shape=(2, 3))
g1 = relay.GlobalVar("g1")
mod[g1] = relay.Function([x1, y1], g0(x1, y1))
call_graph = relay.CallGraph(mod)
call_graph = relay.analysis.CallGraph(mod)
p0 = relay.var("p0", shape=(2, 3))
p1 = relay.var("p1", shape=(2, 3))
func = relay.Function([p0, p1], g0(p0, p1) * g1(p0, p1))
mod["main"] = func
call_graph = relay.CallGraph(mod)
call_graph = relay.analysis.CallGraph(mod)
assert call_graph.ref_count(g0) == 2
assert call_graph.ref_count(g1) == 1
......@@ -138,7 +138,7 @@ def test_recursive_func():
mod[sum_up] = func
iarg = relay.var('i', shape=[], dtype='int32')
mod["main"] = relay.Function([iarg], sum_up(iarg))
call_graph = relay.CallGraph(mod)
call_graph = relay.analysis.CallGraph(mod)
assert call_graph.is_recursive(sum_up)
assert call_graph.ref_count(sum_up) == 2
......
......@@ -19,7 +19,6 @@ import json
import numpy as np
import tvm
from tvm import te
from tvm import relay
from tvm.contrib import graph_runtime
from tvm.relay.expr_functor import ExprMutator
......
......@@ -18,7 +18,7 @@ import tvm
from tvm import te
import numpy as np
from tvm import relay
from tvm.relay import memory_alloc
from tvm.relay.transform import memory_alloc
def check_vm_alloc(func, check_fn):
mod = tvm.IRModule()
......
......@@ -25,7 +25,7 @@ from tvm import relay
from tvm import runtime
from tvm.relay import transform
from tvm.contrib import util
from tvm.relay.annotation import compiler_begin, compiler_end
from tvm.relay.op.annotation import compiler_begin, compiler_end
from tvm.relay.expr_functor import ExprMutator
# Leverage the pass manager to write a simple white list based annotator
......
......@@ -22,7 +22,7 @@ from tvm.relay.analysis import alpha_equal, detect_feature
from tvm.relay import op, create_executor, transform
from tvm.relay.prelude import Prelude
from tvm.relay.testing import add_nat_definitions, count
from tvm.relay.feature import Feature
from tvm.relay.analysis import Feature
def run_opt_pass(expr, passes):
......
......@@ -16,15 +16,14 @@
# under the License.
import numpy as np
import tvm
from tvm import te
from tvm import relay
from tvm.relay.analysis import alpha_equal, detect_feature
from tvm.relay.transform import to_cps, un_cps
from tvm.relay.feature import Feature
from tvm.relay.analysis import Feature
from tvm.relay.prelude import Prelude
from tvm.relay.testing import add_nat_definitions, make_nat_expr, rand, run_infer_type, run_opt_pass
from tvm.relay import create_executor
from tvm.relay import Function, transform
from tvm.relay import transform
def test_id():
......
......@@ -16,9 +16,9 @@
# under the License.
import numpy as np
import tvm
from tvm import te
from tvm import relay
from tvm.relay import op, create_executor, transform, Feature
from tvm.relay import op, create_executor, transform
from tvm.relay.analysis import Feature
from tvm.relay.analysis import detect_feature
......
......@@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import te
from tvm import relay
import pytest
......@@ -27,7 +26,7 @@ def make_rel(name, args, num_inputs=None, attrs=None):
return relay.ty.TypeRelation(func, args, num_inputs, attrs)
def make_solver():
solver = relay._analysis._test_type_solver()
solver = relay.analysis._ffi_api._test_type_solver()
solver.Solve = solver("Solve")
solver.Unify = solver("Unify")
solver.Resolve = solver("Resolve")
......
......@@ -18,7 +18,6 @@ import numpy as np
import pytest
import tvm
from tvm import te
from tvm import runtime
from tvm import relay
from tvm.relay.scope_builder import ScopeBuilder
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment