Commit c93235d7 by Zhi Committed by Tianqi Chen

[relay][pass manager] Open transform namespace (#3226)

parent 3272e6cb
...@@ -20,46 +20,12 @@ ...@@ -20,46 +20,12 @@
/*! /*!
* \file tvm/relay/pass.h * \file tvm/relay/pass.h
* \brief The set of Relay passes written in C++. * \brief The set of Relay passes written in C++.
*
* This file also implements a pass manager. The pass manager manages a sequence
* of Relay-to-Relay transformation passes over a particlar unit of AST. The
* design is largely inspired from LLVM's pass manager and modern deep learning
* frameworks that perform tensor->tensor transformations.
*
* The responsibilities of a traditional compiler pass manager usually involves:
* - Organizing the execution order of optimization passes though not
* necessarily in the optimal sequence.
* - Collecting required analysis information and keep them up-to-date.
* - Reducing the effort required to implement new passes for compiler
* developers, etc.
*
* Similar to LLVM's pass manager, we designed the Relay pass manager to work
* different granularity, i.e. module level, function level, and even sequential
* passe that contains a host of passes.
*
* However, we also extend the functionality of the traditional pass manager
* with the consideration of requirements/convention from deep learning
* frameworks, such as Pytorch and Gluon, etc. Each pass in the Relay pass
* manager performs the Relay.Module -> Relay.Module transformation. All
* different types of passes, including the sequential-level pass object, are
* essentially pass objects. This design, therefore, effectively provides users
* a consistent and convenient interface, i.e. Pass, to play with. It offers a
* means to ease the development and testing of Relay passes. For example, with
* the pass manager, external users will be able to have custom passes correctly
* scheduled without having to modify a single handcrafted pass order.
*
* In the future we need to describe constraints between passes. For example,
* we may want to preserve dependencies between different passes and validate
* them on the completion of a certain pass.
*
* We also need to store side information and import the error reporting system.
*/ */
#ifndef TVM_RELAY_PASS_H_ #ifndef TVM_RELAY_PASS_H_
#define TVM_RELAY_PASS_H_ #define TVM_RELAY_PASS_H_
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/packed_func_ext.h> #include <tvm/packed_func_ext.h>
#include <tvm/relay/error.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/module.h> #include <tvm/relay/module.h>
#include <tvm/relay/op_attr_types.h> #include <tvm/relay/op_attr_types.h>
...@@ -72,174 +38,6 @@ ...@@ -72,174 +38,6 @@
namespace tvm { namespace tvm {
namespace relay { namespace relay {
namespace pass {
/*
* \brief The context of pass.
*/
class PassContext;
/*!
* \brief PassContextNode contains the information that a pass can rely on, such as
* analysis results.
*/
class PassContextNode : public RelayNode {
public:
/*!
* \brief The error reporter used to notify users why an optimization fails.
*/
ErrorReporter err_reporter;
PassContextNode() = default;
void VisitAttrs(tvm::AttrVisitor* v) final {
}
TVM_DLL static PassContext make();
static constexpr const char* _type_key = "relay.PassContext";
TVM_DECLARE_NODE_TYPE_INFO(PassContextNode, RelayNode);
};
TVM_DEFINE_NODE_REF(PassContext, PassContextNode)
/*
* \brief The meta data of a pass.
*
* PassInfo can be extended conveniently in the future if more meta information
* is needed.
*/
class PassInfo;
/*!
* \brief PassInfoNode contains meta data that will be used to help optimization
* and analysis.
*/
class PassInfoNode : public RelayNode {
public:
/*! \brief The minimal optimization level that this pass will be enabled. */
int opt_level;
/*! \brief The name of an optimization/analysis pass. */
std::string name;
/*! \brief The passes that are required to perform the current pass. */
tvm::Array<tvm::Expr> required;
PassInfoNode() = default;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("opt_level", &opt_level);
v->Visit("name", &name);
v->Visit("required", &required);
}
TVM_DLL static PassInfo make(int opt_level, std::string name,
tvm::Array<tvm::Expr> required);
static constexpr const char* _type_key = "relay.PassInfo";
TVM_DECLARE_NODE_TYPE_INFO(PassInfoNode, RelayNode);
};
TVM_DEFINE_NODE_REF(PassInfo, PassInfoNode)
class Pass;
/*!
* \brief PassNode is the base type of differnt types of optimization passes.
* It is designed as a pure class and implemented by different pass subclasses
* at different granularity of Relay nodes.
*/
class PassNode : public RelayNode {
public:
/*
* \brief Get the pass information/meta data. */
virtual PassInfo Info() const = 0;
/*!
* \brief Set the context information for a pass.
*
* \param pass_ctx The context information for a certain pass.
*/
virtual void SetContext(const PassContext& pass_ctx) = 0;
/*!
* \brief Execute the optimization pass using a functor.
*
* \param mod The module that an optimization pass runs on.
*
* \return The updated module.
*/
virtual Module operator()(const Module& mod) const = 0;
void VisitAttrs(tvm::AttrVisitor* v) override {}
static constexpr const char* _type_key = "relay.Pass";
TVM_DECLARE_BASE_NODE_INFO(PassNode, RelayNode);
};
class Pass : public NodeRef {
public:
Pass() = default;
explicit Pass(NodePtr<tvm::Node> p) : NodeRef(p) {}
PassNode* operator->() const {
return static_cast<PassNode*>(this->node_.get());
}
using ContainerType = PassNode;
};
/*
* \brief Create a module pass.
*
* \param pass_func The packed function that contains the optimization.
* \param opt_level The optimization level of the module pass.
* \param name The name of the module pass.
* \param required The list of the passes that the module pass is dependent on.
*
* \return The created module pass.
*/
Pass CreateModulePass(
const runtime::TypedPackedFunc<Module(Module, PassContext)>& pass_func,
int opt_level,
const std::string& name,
const tvm::Array<tvm::Expr>& required);
/*
* \brief Create a function pass.
*
* \param pass_func The packed function that contains the optimization.
* \param opt_level The optimization level of the function pass.
* \param name The name of the function pass.
* \param required The list of the passes that the function pass is dependent on.
*
* \return The created function pass.
*/
Pass CreateFunctionPass(
const runtime::TypedPackedFunc<Function(Function, PassContext)>& pass_func,
int opt_level,
const std::string& name,
const tvm::Array<tvm::Expr>& required);
/*
* \brief Create a sequential pass.
*
* \param passes The optimization passes will be performed.
* \param opt_level The optimization level of the sequential pass.
* \param name The name of the sequential pass.
* \param required The list of the passes that the sequential pass is dependent on.
* \param disabled The disabled passes.
*
* \return The created sequential pass.
*/
Pass CreateSequentialPass(const tvm::Array<Pass>& passes,
int opt_level,
const std::string& name,
const tvm::Array<tvm::Expr>& required,
const tvm::Array<tvm::Expr>& disabled);
} // namespace pass
/*! /*!
* \brief Infer the type of an expression. * \brief Infer the type of an expression.
* *
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/relay/transform.h
*
* This file implements a pass manager. The pass manager manages a sequence
* of Relay-to-Relay transformation passes over a particlar unit of AST. The
* design is largely inspired from LLVM's pass manager and modern deep learning
* frameworks that perform tensor->tensor transformations.
*
* The responsibilities of a traditional compiler pass manager usually involves:
* - Organizing the execution order of optimization passes though not
* necessarily in the optimal sequence.
* - Collecting required analysis information and keep them up-to-date.
* - Reducing the effort required to implement new passes for compiler
* developers, etc.
*
* Similar to LLVM's pass manager, we designed the Relay pass manager to work
* different granularity, i.e. module level, function level, and even sequential
* passe that contains a host of passes.
*
* However, we also extend the functionality of the traditional pass manager
* with the consideration of requirements/convention from deep learning
* frameworks, such as Pytorch and Gluon, etc. Each pass in the Relay pass
* manager performs the Relay.Module -> Relay.Module transformation. All
* different types of passes, including the sequential-level pass object, are
* essentially pass objects. This design, therefore, effectively provides users
* a consistent and convenient interface, i.e. Pass, to play with. It offers a
* means to ease the development and testing of Relay passes. For example, with
* the pass manager, external users will be able to have custom passes correctly
* scheduled without having to modify a single handcrafted pass order.
*
* In the future we need to describe constraints between passes. For example,
* we may want to preserve dependencies between different passes and validate
* them on the completion of a certain pass.
*
* We also need to store side information and import the error reporting system.
*/
#ifndef TVM_RELAY_TRANSFORM_H_
#define TVM_RELAY_TRANSFORM_H_
#include <tvm/packed_func_ext.h>
#include <tvm/relay/error.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/module.h>
#include <string>
#include <vector>
namespace tvm {
namespace relay {
namespace transform {
/*
* \brief The context of pass.
*/
class PassContext;
/*!
* \brief PassContextNode contains the information that a pass can rely on, such as
* analysis results.
*/
class PassContextNode : public RelayNode {
public:
/*!
* \brief The error reporter used to notify users why an optimization fails.
*/
ErrorReporter err_reporter;
PassContextNode() = default;
void VisitAttrs(tvm::AttrVisitor* v) final {
}
TVM_DLL static PassContext make();
static constexpr const char* _type_key = "relay.PassContext";
TVM_DECLARE_NODE_TYPE_INFO(PassContextNode, RelayNode);
};
TVM_DEFINE_NODE_REF(PassContext, PassContextNode)
/*
* \brief The meta data of a pass.
*
* PassInfo can be extended conveniently in the future if more meta information
* is needed.
*/
class PassInfo;
/*!
* \brief PassInfoNode contains meta data that will be used to help optimization
* and analysis.
*/
class PassInfoNode : public RelayNode {
public:
/*! \brief The minimal optimization level that this pass will be enabled. */
int opt_level;
/*! \brief The name of an optimization/analysis pass. */
std::string name;
/*! \brief The passes that are required to perform the current pass. */
tvm::Array<tvm::Expr> required;
PassInfoNode() = default;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("opt_level", &opt_level);
v->Visit("name", &name);
v->Visit("required", &required);
}
TVM_DLL static PassInfo make(int opt_level, std::string name,
tvm::Array<tvm::Expr> required);
static constexpr const char* _type_key = "relay.PassInfo";
TVM_DECLARE_NODE_TYPE_INFO(PassInfoNode, RelayNode);
};
TVM_DEFINE_NODE_REF(PassInfo, PassInfoNode)
class Pass;
/*!
* \brief PassNode is the base type of differnt types of optimization passes.
* It is designed as a pure class and implemented by different pass subclasses
* at different granularity of Relay nodes.
*/
class PassNode : public RelayNode {
public:
/*
* \brief Get the pass information/meta data. */
virtual PassInfo Info() const = 0;
/*!
* \brief Set the context information for a pass.
*
* \param pass_ctx The context information for a certain pass.
*/
virtual void SetContext(const PassContext& pass_ctx) = 0;
/*!
* \brief Execute the optimization pass using a functor.
*
* \param mod The module that an optimization pass runs on.
*
* \return The updated module.
*/
virtual Module operator()(const Module& mod) const = 0;
void VisitAttrs(tvm::AttrVisitor* v) override {}
static constexpr const char* _type_key = "relay.Pass";
TVM_DECLARE_BASE_NODE_INFO(PassNode, RelayNode);
};
class Pass : public NodeRef {
public:
Pass() = default;
explicit Pass(NodePtr<tvm::Node> p) : NodeRef(p) {}
PassNode* operator->() const {
return static_cast<PassNode*>(this->node_.get());
}
using ContainerType = PassNode;
};
class SequentialNode;
class Sequential : public Pass {
public:
/*!
* \brief The constructor of `Sequential`.
* \param passes The passes to apply.
* \param pass_info The pass metadata.
* \param disabled The passes that will not be applied.
*/
TVM_DLL Sequential(tvm::Array<Pass> passes,
PassInfo pass_info,
tvm::Array<tvm::Expr> disabled);
Sequential() = default;
explicit Sequential(tvm::NodePtr<::tvm::Node> n) : Pass(n) {}
const SequentialNode* operator->() const;
using ContainerType = Sequential;
};
/*
* \brief Create a module pass.
*
* \param pass_func The packed function that contains the optimization.
* \param opt_level The optimization level of the module pass.
* \param name The name of the module pass.
* \param required The list of the passes that the module pass is dependent on.
*
* \return The created module pass.
*/
Pass CreateModulePass(
const runtime::TypedPackedFunc<Module(Module, PassContext)>& pass_func,
int opt_level,
const std::string& name,
const tvm::Array<tvm::Expr>& required);
/*
* \brief Create a function pass.
*
* \param pass_func The packed function that contains the optimization.
* \param opt_level The optimization level of the function pass.
* \param name The name of the function pass.
* \param required The list of the passes that the function pass is dependent on.
*
* \return The created function pass.
*/
Pass CreateFunctionPass(
const runtime::TypedPackedFunc<Function(Function, PassContext)>& pass_func,
int opt_level,
const std::string& name,
const tvm::Array<tvm::Expr>& required);
} // namespace transform
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_TRANSFORM_H_
...@@ -25,6 +25,7 @@ from . import expr_functor ...@@ -25,6 +25,7 @@ from . import expr_functor
from . import module from . import module
from . import adt from . import adt
from . import ir_pass from . import ir_pass
from . import transform
from .build_module import build, build_config, create_executor from .build_module import build, build_config, create_executor
from . import prelude from . import prelude
from . import parser from . import parser
...@@ -97,9 +98,8 @@ Match = adt.Match ...@@ -97,9 +98,8 @@ Match = adt.Match
var = expr.var var = expr.var
const = expr.const const = expr.const
bind = expr.bind bind = expr.bind
module_pass = ir_pass.module_pass module_pass = transform.module_pass
function_pass = ir_pass.function_pass function_pass = transform.function_pass
sequential_pass = ir_pass.sequential_pass
# ExprFunctor # ExprFunctor
ExprFunctor = expr_functor.ExprFunctor ExprFunctor = expr_functor.ExprFunctor
...@@ -114,9 +114,9 @@ save_param_dict = param_dict.save_param_dict ...@@ -114,9 +114,9 @@ save_param_dict = param_dict.save_param_dict
load_param_dict = param_dict.load_param_dict load_param_dict = param_dict.load_param_dict
# Pass manager # Pass manager
PassInfo = ir_pass.PassInfo PassInfo = transform.PassInfo
PassContext = ir_pass.PassContext PassContext = transform.PassContext
Pass = ir_pass.Pass Pass = transform.Pass
ModulePass = ir_pass.ModulePass ModulePass = transform.ModulePass
FunctionPass = ir_pass.FunctionPass FunctionPass = transform.FunctionPass
SequentialPass = ir_pass.SequentialPass Sequential = transform.Sequential
...@@ -17,62 +17,8 @@ ...@@ -17,62 +17,8 @@
import tvm import tvm
from . import ir from . import ir
from .base import NodeBase
from .env import Module from .env import Module
class PassContext(NodeBase):
def __init__(self):
...
class PassInfo(NodeBase):
name = ... # type: str
opt_level = ... # type: int
required = ... # type: list
def __init__(self, name, opt_level, required)
# type: (str, int, list) -> None
class Pass(NodeBase):
def __init__(self):
...
class ModulePass(Pass):
name = ... # type: str
opt_level = ... # type: int
pass_func = ... # type: Callable
required = ... # type: list
def __init__(self, name, opt_level, pass_func, required):
# type: (str, int, Callable, list) -> None
...
class FunctionPass(Pass):
name = ... # type: str
opt_level = ... # type: int
pass_func = ... # type: Callable
required = ... # type: list
def __init__(self, name, opt_level, pass_func, required):
# type: (str, int, Callable, list) -> None
...
class SequentialPass(Pass):
name = ... # type: str
opt_level = ... # type: int
passes = ... # type: list
required = ... # type: list
disabled = ... # type: list
def __init__(self, name, opt_level, passes, required, disabled):
# type: (str, int, list, list, list) -> None
...
def check_expr(env: Module, expr: ir.Expr) -> ir.Type: ... def check_expr(env: Module, expr: ir.Expr) -> ir.Type: ...
def generalize(env: Module, expr: ir.Expr) -> ir.Expr: ... def generalize(env: Module, expr: ir.Expr) -> ir.Expr: ...
def _get_checked_type(expr: ir.Expr) -> ir.Type: ... def _get_checked_type(expr: ir.Expr) -> ir.Type: ...
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""FFI exposing the Relay type inference and checking."""
from tvm._ffi.function import _init_api
_init_api("relay._transform", __name__)
...@@ -17,324 +17,16 @@ ...@@ -17,324 +17,16 @@
# pylint: disable=no-else-return # pylint: disable=no-else-return
# pylint: disable=unidiomatic-typecheck # pylint: disable=unidiomatic-typecheck
""" """
This file contains: This file contains the set of passes for Relay, which exposes an interface for
1. The set of passes for Relay, which exposes an interface for configuring the configuring the passes and scripting them in Python.
passes and scripting them in Python.
2. The pass manager for Relay which exposes different granularity of interfaces
for users to implement and use passes more conveniently.
""" """
import types
from . import _ir_pass from . import _ir_pass
from . import _make from . import _make
from .expr import Expr from .expr import Expr
from .ty import Type from .ty import Type
from .base import RelayNode, register_relay_node
from .module import Module from .module import Module
@register_relay_node
class PassInfo(RelayNode):
"""The class that contains the meta data required by a pass. It is the
container of information needed by running an optimization or analysis.
This class can be extended by adding new members when more meta data is
needed.
Parameters
----------
name : str
The pass name.
opt_level : int
The optimization level of this pass.
required : List[str]
The list of passes that are required by a certain pass.
"""
def __init__(self, name, opt_level, required=None):
self.__init_handle_by_constructor__(_ir_pass.PassInfo, name, opt_level,
required)
@register_relay_node
class PassContext(RelayNode):
"""The basis where a Relay optimization/analysis runs on.
Each pass context contains a number of auxiliary information that is used
to help an optimization pass. Such information includes the error reporter
to record the errors of during the optimization, etc.
"""
def __init__(self):
self.__init_handle_by_constructor__(_ir_pass.PassContext)
@register_relay_node
class Pass(RelayNode):
"""The base class of all passes. All methods here are just simple wrappers
that are implemented in the backend. They are defined for users to
conveniently interact with the base class.
"""
def set_pass_context(self, pass_ctx):
"""Setup the pass context for analysis and optimizations. This context
could be shared by different passes for sequential passes.
Parameters
----------
pass_ctx : PassContext
The context that is used to help perform a certain pass or a series
of passes.
"""
if not isinstance(pass_ctx, PassContext):
raise TypeError("pass_ctx is expected to be the PassContext type")
_ir_pass.SetContext(self, pass_ctx)
@property
def info(self):
"""Get the pass meta."""
return _ir_pass.Info(self)
def __call__(self, mod):
"""Execute the pass. Note that for sequential pass, the dependency among
different passes will be resolved in the backend.
Parameters
----------
mod : tvm.relay.Module
The module that a certain optimization is performed on.
Returns
-------
mod : tvm.relay.Module
The updated module after applying this pass.
"""
return _ir_pass.RunPass(self, mod)
@register_relay_node
class ModulePass(Pass):
"""A pass that works on tvm.relay.Module. Users don't need to interact with
this class directly. Instead, a module pass should be created through
`module_pass`, because the design of the `module_pass` API is flexible
enough to handle the creation of a module pass in different manners. In
addition, all members of a module pass can be accessed from the base class.
The same rule applies to FunctionPass and SequentialPass as well.
"""
@register_relay_node
class FunctionPass(Pass):
"""A pass that works on each tvm.relay.Function in a module. A function
pass class should be created through `function_pass`.
"""
@register_relay_node
class SequentialPass(Pass):
"""A pass that works on a sequence of pass objects. A sequential pass class
should be created through `sequential_pass`.
"""
def module_pass(pass_func=None, opt_level=None, name=None, required=None):
"""Create a module pass. This function returns a callback when pass_func
is provided. Otherwise, it returns the created module level pass using the
given optimization function.
Parameters
----------
pass_func : Optional[Callable[(Module/Function, PassContext) ->
Module/Function]]
The implemented optimization pass.
opt_level : int
The optimization level of this module pass.
name : Optional[str]
The name of the module pass. The name could be empty. In this case, the
name of the optimization function will be used as the pass name.
required : Optional[List[str]]
The list of passes that the module pass is dependent on.
Returns
-------
create_module_pass : Union[Callable, ModulePass]
The callable that will create a module pass is returned when
pass_func is not passed in. Otherwise, a ModulePass object will be
directly created.
Examples
--------
The following code creates a module level pass and adds an abs function to
the module.
.. code-block:: python
@relay.ir_pass.module_pass(opt_level=2)
def transform(mod, ctx):
tp = relay.TensorType((10,), "float32")
x = relay.var("x", tp)
gv = relay.GlobalVar("var")
func = relay.Function([x], relay.abs(x))
new_mod = relay.Module({gv: func})
new_mod.update(mod)
return new_mod
module_pass = transform
assert isinstance(module_pass, ir_pass.ModulePass)
assert module_pass.info.opt_level == 2
# Given a module m, the optimization could be invoked as the follwoing:
updated_mod = module_pass(m)
# Now a function abs should be added to the module m.
"""
if opt_level is None:
raise ValueError("Please provide opt_level for the module pass.")
required = required if required else []
if not isinstance(required, (list, tuple)):
raise TypeError("Required is expected to be the type of " +
"list/tuple.")
def create_module_pass(pass_func):
"""Internal function that creates a module pass"""
if not isinstance(pass_func, (types.FunctionType, types.LambdaType)):
raise TypeError("pass_func must be a callable for Module pass")
return _ir_pass.CreateModulePass(pass_func, opt_level,
name if name else pass_func.__name__,
required)
if pass_func:
return create_module_pass(pass_func)
return create_module_pass
def function_pass(pass_func=None, opt_level=None, name=None, required=None):
"""Create a function pass. This function returns a callback when pass_func
is provided. Otherwise, it returns the created function pass using the
given optimization function.
Parameters
----------
pass_func : Optional[Callable[(Module/Function, PassContext) ->
Module/Function]]
The implemented optimization pass.
opt_level : int
The optimization level of this module pass.
name : Optional[str]
The name of the function pass. The name could be empty. In this case, the
name of the optimization function will be used as the pass name.
required : Optional[List[str]]
The list of passes that the module pass is dependent on.
Returns
-------
create_function_pass : Union[Callable, FunctionPass]
The callable that will create a function pass is returned when
pass_func is not passed in. Otherwise, a FunctionPass object will be
created.
Examples
--------
The following code creates a function level pass that performs constant
folding.
.. code-block:: python
@relay.ir_pass.function_pass(opt_level=2)
def transform(func, ctx):
return ir_pass.fold_constant(func)
function_pass = transform
assert isinstance(function_pass, ir_pass.FunctionPass)
assert function_pass.info.opt_level == 2
# Given a module m, the optimization could be invoked as the follwoing:
updated_mod = function_pass(m)
# Now constant folding should have been applied to every function in
# the provided module m. And the updated module will be returned.
"""
if opt_level is None:
raise ValueError("Please provide opt_level for the funtion pass.")
required = required if required else []
if not isinstance(required, (list, tuple)):
raise TypeError("Required is expected to be the type of " +
"list/tuple.")
def create_function_pass(pass_func):
"""Internal function that creates a function pass"""
if not isinstance(pass_func, (types.FunctionType, types.LambdaType)):
raise TypeError("pass_func must be a callable for Module pass")
return _ir_pass.CreateFunctionPass(pass_func, opt_level,
name if name else pass_func.__name__,
required)
if pass_func:
return create_function_pass(pass_func)
return create_function_pass
def sequential_pass(passes=None, opt_level=2, name="sequential_pass",
required=None, disabled=None):
"""Create a sequential pass using a defined optimization function from
Python. Some typical usage of the sequential pass are:
1. Users provide a list of passes for optimization.
2. Only an optimization level is provided so that the backend system has
to glob all passes at this level and below to perform the optimizations.
Note that users can also provide a series of passes that they don't want to
apply when running a sequential pass. Pass dependency will be resolved in
the backend as well.
Parameters
----------
passes : Optional[List[Pass]]
A sequence of passes candidate for optimization.
opt_level : Optional[int]
The optimization level of this sequential pass.
name : Optional[str]
The name of the sequential pass.
required : Optional[List[str]]
The list of passes that the sequential pass is dependent on.
disabled : Optional[List[str]]
A list of disabled passes.
Returns
-------
ret : Pass
A sequential pass built through pass_func.
"""
passes = passes if passes else []
if not isinstance(passes, (list, tuple)):
raise TypeError("passes must be a list of Pass objects.")
disabled = disabled if disabled else []
if not isinstance(disabled, (list, tuple)):
raise TypeError("disabled must be a list or tuple of pass names")
required = required if required else []
if not isinstance(required, (list, tuple)):
raise TypeError("Required is expected to be the type of list/tuple.")
return _ir_pass.CreateSequentialPass(passes, opt_level, name, required,
disabled)
def post_order_visit(expr, fvisit): def post_order_visit(expr, fvisit):
"""Recursively visit the ir in post DFS order node, """Recursively visit the ir in post DFS order node,
apply fvisit. Each node is guaranteed to be visited apply fvisit. Each node is guaranteed to be visited
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return
# pylint: disable=unidiomatic-typecheck
"""
This file contains the pass manager for Relay which exposes different
granularity of interfaces for users to implement and use passes more
conveniently.
"""
import types
from . import _transform
from .base import RelayNode, register_relay_node
@register_relay_node
class PassInfo(RelayNode):
"""The class that contains the meta data required by a pass. It is the
container of information needed by running an optimization or analysis.
This class can be extended by adding new members when more meta data is
needed.
Parameters
----------
name : str
The pass name.
opt_level : int
The optimization level of this pass.
required : List[str]
The list of passes that are required by a certain pass.
"""
def __init__(self, name, opt_level, required=None):
self.__init_handle_by_constructor__(_transform.PassInfo, name, opt_level,
required)
@register_relay_node
class PassContext(RelayNode):
"""The basis where a Relay optimization/analysis runs on.
Each pass context contains a number of auxiliary information that is used
to help an optimization pass. Such information includes the error reporter
to record the errors of during the optimization, etc.
"""
def __init__(self):
self.__init_handle_by_constructor__(_transform.PassContext)
@register_relay_node
class Pass(RelayNode):
"""The base class of all passes. All methods here are just simple wrappers
that are implemented in the backend. They are defined for users to
conveniently interact with the base class.
"""
def set_pass_context(self, pass_ctx):
"""Setup the pass context for analysis and optimizations. This context
could be shared by different passes for sequential passes.
Parameters
----------
pass_ctx : PassContext
The context that is used to help perform a certain pass or a series
of passes.
"""
if not isinstance(pass_ctx, PassContext):
raise TypeError("pass_ctx is expected to be the PassContext type")
_transform.SetContext(self, pass_ctx)
@property
def info(self):
"""Get the pass meta."""
return _transform.Info(self)
def __call__(self, mod):
"""Execute the pass. Note that for sequential pass, the dependency among
different passes will be resolved in the backend.
Parameters
----------
mod : tvm.relay.Module
The module that a certain optimization is performed on.
Returns
-------
mod : tvm.relay.Module
The updated module after applying this pass.
"""
return _transform.RunPass(self, mod)
@register_relay_node
class ModulePass(Pass):
"""A pass that works on tvm.relay.Module. Users don't need to interact with
this class directly. Instead, a module pass should be created through
`module_pass`, because the design of the `module_pass` API is flexible
enough to handle the creation of a module pass in different manners. In
addition, all members of a module pass can be accessed from the base class.
The same rule applies to FunctionPass and Sequential as well.
"""
@register_relay_node
class FunctionPass(Pass):
"""A pass that works on each tvm.relay.Function in a module. A function
pass class should be created through `function_pass`.
"""
@register_relay_node
class Sequential(Pass):
"""A pass that works on a sequence of pass objects. Multiple passes can be
executed sequentially using this class.
Some typical usage of the sequential pass are:
1. Users provide a list of passes for optimization.
2. Only an optimization level is provided so that the backend system has
to glob all passes at this level and below to perform the optimizations.
Note that users can also provide a series of passes that they don't want to
apply when running a sequential pass. Pass dependency will be resolved in
the backend as well.
Parameters
----------
passes : Optional[List[Pass]]
A sequence of passes candidate for optimization.
opt_level : Optional[int]
The optimization level of this sequential pass.
name : Optional[str]
The name of the sequential pass.
required : Optional[List[str]]
The list of passes that the sequential pass is dependent on.
disabled : Optional[List[str]]
A list of disabled passes.
"""
def __init__(self,
passes=None,
opt_level=2,
name="sequential",
required=None,
disabled=None):
passes = passes if passes else []
if not isinstance(passes, (list, tuple)):
raise TypeError("passes must be a list of Pass objects.")
disabled = disabled if disabled else []
if not isinstance(disabled, (list, tuple)):
raise TypeError("disabled must be a list or tuple of pass names")
required = required if required else []
if not isinstance(required, (list, tuple)):
raise TypeError("Required is expected to be the type of list/tuple.")
self.__init_handle_by_constructor__(_transform.Sequential,
passes, opt_level, name, required,
disabled)
def module_pass(pass_func=None, opt_level=None, name=None, required=None):
"""Create a module pass. This function returns a callback when pass_func
is provided. Otherwise, it returns the created module level pass using the
given optimization function.
Parameters
----------
pass_func : Optional[Callable[(Module/Function, PassContext) ->
Module/Function]]
The implemented optimization pass.
opt_level : int
The optimization level of this module pass.
name : Optional[str]
The name of the module pass. The name could be empty. In this case, the
name of the optimization function will be used as the pass name.
required : Optional[List[str]]
The list of passes that the module pass is dependent on.
Returns
-------
create_module_pass : Union[Callable, ModulePass]
The callable that will create a module pass is returned when
pass_func is not passed in. Otherwise, a ModulePass object will be
directly created.
Examples
--------
The following code creates a module level pass and adds an abs function to
the module.
.. code-block:: python
@relay.transform.module_pass(opt_level=2)
def transform(mod, ctx):
tp = relay.TensorType((10,), "float32")
x = relay.var("x", tp)
gv = relay.GlobalVar("var")
func = relay.Function([x], relay.abs(x))
new_mod = relay.Module({gv: func})
new_mod.update(mod)
return new_mod
module_pass = transform
assert isinstance(module_pass, transform.ModulePass)
assert module_pass.info.opt_level == 2
# Given a module m, the optimization could be invoked as the follwoing:
updated_mod = module_pass(m)
# Now a function abs should be added to the module m.
"""
if opt_level is None:
raise ValueError("Please provide opt_level for the module pass.")
required = required if required else []
if not isinstance(required, (list, tuple)):
raise TypeError("Required is expected to be the type of " +
"list/tuple.")
def create_module_pass(pass_func):
"""Internal function that creates a module pass"""
if not isinstance(pass_func, (types.FunctionType, types.LambdaType)):
raise TypeError("pass_func must be a callable for Module pass")
return _transform.CreateModulePass(
pass_func, opt_level, name if name else pass_func.__name__,
required)
if pass_func:
return create_module_pass(pass_func)
return create_module_pass
def function_pass(pass_func=None, opt_level=None, name=None, required=None):
"""Create a function pass. This function returns a callback when pass_func
is provided. Otherwise, it returns the created function pass using the
given optimization function.
Parameters
----------
pass_func : Optional[Callable[(Module/Function, PassContext) ->
Module/Function]]
The implemented optimization pass.
opt_level : int
The optimization level of this module pass.
name : Optional[str]
The name of the function pass. The name could be empty. In this case, the
name of the optimization function will be used as the pass name.
required : Optional[List[str]]
The list of passes that the module pass is dependent on.
Returns
-------
create_function_pass : Union[Callable, FunctionPass]
The callable that will create a function pass is returned when
pass_func is not passed in. Otherwise, a FunctionPass object will be
created.
Examples
--------
The following code creates a function level pass that performs constant
folding.
.. code-block:: python
@relay.transform.function_pass(opt_level=2)
def transform(func, ctx):
return ir_pass.fold_constant(func)
function_pass = transform
assert isinstance(function_pass, transform.FunctionPass)
assert function_pass.info.opt_level == 2
# Given a module m, the optimization could be invoked as the follwoing:
updated_mod = function_pass(m)
# Now constant folding should have been applied to every function in
# the provided module m. And the updated module will be returned.
"""
if opt_level is None:
raise ValueError("Please provide opt_level for the funtion pass.")
required = required if required else []
if not isinstance(required, (list, tuple)):
raise TypeError("Required is expected to be the type of " +
"list/tuple.")
def create_function_pass(pass_func):
"""Internal function that creates a function pass"""
if not isinstance(pass_func, (types.FunctionType, types.LambdaType)):
raise TypeError("pass_func must be a callable for Module pass")
return _transform.CreateFunctionPass(
pass_func, opt_level, name if name else pass_func.__name__,
required)
if pass_func:
return create_function_pass(pass_func)
return create_function_pass
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from .base import NodeBase
class PassContext(NodeBase):
def __init__(self):
...
class PassInfo(NodeBase):
name = ... # type: str
opt_level = ... # type: int
required = ... # type: list
def __init__(self, name, opt_level, required)
# type: (str, int, list) -> None
class Pass(NodeBase):
def __init__(self):
...
class ModulePass(Pass):
name = ... # type: str
opt_level = ... # type: int
pass_func = ... # type: Callable
required = ... # type: list
def __init__(self, name, opt_level, pass_func, required):
# type: (str, int, Callable, list) -> None
...
class FunctionPass(Pass):
name = ... # type: str
opt_level = ... # type: int
pass_func = ... # type: Callable
required = ... # type: list
def __init__(self, name, opt_level, pass_func, required):
# type: (str, int, Callable, list) -> None
...
class Sequential(Pass):
name = ... # type: str
opt_level = ... # type: int
passes = ... # type: list
required = ... # type: list
disabled = ... # type: list
def __init__(self, name, opt_level, passes, required, disabled):
# type: (str, int, list, list, list) -> None
...
...@@ -23,11 +23,11 @@ ...@@ -23,11 +23,11 @@
* \brief Relay pass manager implementation. * \brief Relay pass manager implementation.
*/ */
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/pass.h> #include <tvm/relay/transform.h>
namespace tvm { namespace tvm {
namespace relay { namespace relay {
namespace pass { namespace transform {
using tvm::IRPrinter; using tvm::IRPrinter;
...@@ -169,17 +169,15 @@ class FunctionPassNode : public PassNode { ...@@ -169,17 +169,15 @@ class FunctionPassNode : public PassNode {
RELAY_DEFINE_NODE_REF(FunctionPass, FunctionPassNode, Pass); RELAY_DEFINE_NODE_REF(FunctionPass, FunctionPassNode, Pass);
class SequentialPass;
/*! /*!
* \brief The SequentialPassNode contains a set of passes that transform Relay * \brief The SequentialNode contains a set of passes that transform Relay
* programs from one AST to another semantically equivalent one. * programs from one AST to another semantically equivalent one.
* *
* One example of this level of pass is that the pass manager needs to correctly * One example of this level of pass is that the pass manager needs to correctly
* perform a host of optimizations with a given optimization level and disabled * perform a host of optimizations with a given optimization level and disabled
* passes. * passes.
*/ */
class SequentialPassNode : public PassNode { class SequentialNode : public PassNode {
public: public:
/* \brief The pass meta data.*/ /* \brief The pass meta data.*/
PassInfo pass_info; PassInfo pass_info;
...@@ -212,10 +210,6 @@ class SequentialPassNode : public PassNode { ...@@ -212,10 +210,6 @@ class SequentialPassNode : public PassNode {
passes.push_back(pass); passes.push_back(pass);
} }
TVM_DLL static SequentialPass make(tvm::Array<Pass> passes,
PassInfo pass_info,
tvm::Array<tvm::Expr> disabled);
/*! /*!
* \brief Resolve the pass dependency. It globs all required passes by * \brief Resolve the pass dependency. It globs all required passes by
* a given pass and executes them. * a given pass and executes them.
...@@ -251,8 +245,8 @@ class SequentialPassNode : public PassNode { ...@@ -251,8 +245,8 @@ class SequentialPassNode : public PassNode {
*/ */
void SetContext(const PassContext& pass_ctx) final; void SetContext(const PassContext& pass_ctx) final;
static constexpr const char* _type_key = "relay.SequentialPass"; static constexpr const char* _type_key = "relay.Sequential";
TVM_DECLARE_NODE_TYPE_INFO(SequentialPassNode, PassNode); TVM_DECLARE_NODE_TYPE_INFO(SequentialNode, PassNode);
private: private:
/*! /*!
...@@ -261,8 +255,6 @@ class SequentialPassNode : public PassNode { ...@@ -261,8 +255,6 @@ class SequentialPassNode : public PassNode {
PassContext pass_ctx_; PassContext pass_ctx_;
}; };
RELAY_DEFINE_NODE_REF(SequentialPass, SequentialPassNode, Pass);
PassInfo PassInfoNode::make(int opt_level, std::string name, PassInfo PassInfoNode::make(int opt_level, std::string name,
tvm::Array<tvm::Expr> required) { tvm::Array<tvm::Expr> required) {
auto pass_info = make_node<PassInfoNode>(); auto pass_info = make_node<PassInfoNode>();
...@@ -350,20 +342,24 @@ bool FunctionPassNode::SkipFunction(const Function& func) const { ...@@ -350,20 +342,24 @@ bool FunctionPassNode::SkipFunction(const Function& func) const {
return pval && pval->value != 0; return pval && pval->value != 0;
} }
SequentialPass SequentialPassNode::make(tvm::Array<Pass> passes, Sequential::Sequential(tvm::Array<Pass> passes,
PassInfo pass_info, PassInfo pass_info,
tvm::Array<tvm::Expr> disabled) { tvm::Array<tvm::Expr> disabled) {
auto n = make_node<SequentialPassNode>(); auto n = make_node<SequentialNode>();
n->passes = std::move(passes); n->passes = std::move(passes);
n->pass_info = std::move(pass_info); n->pass_info = std::move(pass_info);
n->disabled = std::move(disabled); n->disabled = std::move(disabled);
return SequentialPass(n); node_ = std::move(n);
}
const SequentialNode* Sequential::operator->() const {
return static_cast<const SequentialNode*>(this->node_.get());
} }
// TODO(jroesch, zhiics): we currenlty only sequentially execute each pass in // TODO(jroesch, zhiics): we currenlty only sequentially execute each pass in
// a SequentialPass without the consideration of their orders. The phase // a Sequential without the consideration of their orders. The phase
// ordering problem needed to be handled in the future. // ordering problem needed to be handled in the future.
Module SequentialPassNode::operator()(const Module& module) const { Module SequentialNode::operator()(const Module& module) const {
Module mod = module; Module mod = module;
for (const Pass& pass : passes) { for (const Pass& pass : passes) {
CHECK(pass.defined()) << "Found undefined pass for optimization."; CHECK(pass.defined()) << "Found undefined pass for optimization.";
...@@ -373,7 +369,7 @@ Module SequentialPassNode::operator()(const Module& module) const { ...@@ -373,7 +369,7 @@ Module SequentialPassNode::operator()(const Module& module) const {
return mod; return mod;
} }
void SequentialPassNode::ResolveDependency(const Module& mod) { void SequentialNode::ResolveDependency(const Module& mod) {
// TODO(zhiics) Implement it. // TODO(zhiics) Implement it.
// 1. Consider the required passes for each pass. // 1. Consider the required passes for each pass.
// 2. Only resolve the enabled passes. // 2. Only resolve the enabled passes.
...@@ -382,7 +378,7 @@ void SequentialPassNode::ResolveDependency(const Module& mod) { ...@@ -382,7 +378,7 @@ void SequentialPassNode::ResolveDependency(const Module& mod) {
<< "\n"; << "\n";
} }
std::vector<std::string> SequentialPassNode::DisabledPasses() const { std::vector<std::string> SequentialNode::DisabledPasses() const {
std::vector<std::string> ret; std::vector<std::string> ret;
for (const auto& it : disabled) { for (const auto& it : disabled) {
const auto* str = it.as<tvm::ir::StringImm>(); const auto* str = it.as<tvm::ir::StringImm>();
...@@ -392,7 +388,7 @@ std::vector<std::string> SequentialPassNode::DisabledPasses() const { ...@@ -392,7 +388,7 @@ std::vector<std::string> SequentialPassNode::DisabledPasses() const {
return ret; return ret;
} }
void SequentialPassNode::SetContext(const PassContext& pass_ctx) { void SequentialNode::SetContext(const PassContext& pass_ctx) {
pass_ctx_ = pass_ctx; pass_ctx_ = pass_ctx;
} }
...@@ -414,21 +410,12 @@ Pass CreateFunctionPass( ...@@ -414,21 +410,12 @@ Pass CreateFunctionPass(
return FunctionPassNode::make(pass_func, pass_info); return FunctionPassNode::make(pass_func, pass_info);
} }
Pass CreateSequentialPass(const tvm::Array<Pass>& passes,
int opt_level,
const std::string& name,
const tvm::Array<tvm::Expr>& required,
const tvm::Array<tvm::Expr>& disabled) {
PassInfo pass_info = PassInfoNode::make(opt_level, name, required);
return SequentialPassNode::make(passes, pass_info, disabled);
}
TVM_REGISTER_NODE_TYPE(PassInfoNode); TVM_REGISTER_NODE_TYPE(PassInfoNode);
TVM_REGISTER_API("relay._ir_pass.PassInfo") TVM_REGISTER_API("relay._transform.PassInfo")
.set_body_typed(PassInfoNode::make); .set_body_typed(PassInfoNode::make);
TVM_REGISTER_API("relay._ir_pass.Info") TVM_REGISTER_API("relay._transform.Info")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
Pass pass = args[0]; Pass pass = args[0];
*ret = pass->Info(); *ret = pass->Info();
...@@ -450,10 +437,10 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -450,10 +437,10 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
TVM_REGISTER_NODE_TYPE(ModulePassNode); TVM_REGISTER_NODE_TYPE(ModulePassNode);
TVM_REGISTER_API("relay._ir_pass.CreateModulePass") TVM_REGISTER_API("relay._transform.CreateModulePass")
.set_body_typed(CreateModulePass); .set_body_typed(CreateModulePass);
TVM_REGISTER_API("relay._ir_pass.RunPass") TVM_REGISTER_API("relay._transform.RunPass")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
Pass pass = args[0]; Pass pass = args[0];
Module mod = args[1]; Module mod = args[1];
...@@ -475,7 +462,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -475,7 +462,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
TVM_REGISTER_NODE_TYPE(FunctionPassNode); TVM_REGISTER_NODE_TYPE(FunctionPassNode);
TVM_REGISTER_API("relay._ir_pass.CreateFunctionPass") TVM_REGISTER_API("relay._transform.CreateFunctionPass")
.set_body_typed(CreateFunctionPass); .set_body_typed(CreateFunctionPass);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
...@@ -486,9 +473,9 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -486,9 +473,9 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
<< " at the optimization level " << pn->opt_level; << " at the optimization level " << pn->opt_level;
}); });
TVM_REGISTER_NODE_TYPE(SequentialPassNode); TVM_REGISTER_NODE_TYPE(SequentialNode);
TVM_REGISTER_API("relay._ir_pass.CreateSequentialPass") TVM_REGISTER_API("relay._transform.Sequential")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
tvm::Array<Pass> passes = args[0]; tvm::Array<Pass> passes = args[0];
int opt_level = args[1]; int opt_level = args[1];
...@@ -496,14 +483,14 @@ TVM_REGISTER_API("relay._ir_pass.CreateSequentialPass") ...@@ -496,14 +483,14 @@ TVM_REGISTER_API("relay._ir_pass.CreateSequentialPass")
tvm::Array<tvm::Expr> required = args[3]; tvm::Array<tvm::Expr> required = args[3];
tvm::Array<tvm::Expr> disabled = args[4]; tvm::Array<tvm::Expr> disabled = args[4];
PassInfo pass_info = PassInfoNode::make(opt_level, name, required); PassInfo pass_info = PassInfoNode::make(opt_level, name, required);
*ret = SequentialPassNode::make(passes, pass_info, disabled); *ret = Sequential(passes, pass_info, disabled);
}); });
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<SequentialPassNode>([](const SequentialPassNode* node, .set_dispatch<SequentialNode>([](const SequentialNode* node,
tvm::IRPrinter* p) { tvm::IRPrinter* p) {
const PassInfoNode* seq_pn = node->Info().operator->(); const PassInfoNode* seq_pn = node->Info().operator->();
p->stream << "Run SequentialPass pass: " << seq_pn->name p->stream << "Run Sequential pass: " << seq_pn->name
<< " at the optimization level. " << seq_pn->opt_level; << " at the optimization level. " << seq_pn->opt_level;
p->stream << "The passes will be executed are: ["; p->stream << "The passes will be executed are: [";
for (const auto& it : node->passes) { for (const auto& it : node->passes) {
...@@ -514,7 +501,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -514,7 +501,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
p->stream << "]"; p->stream << "]";
}); });
TVM_REGISTER_API("relay._ir_pass.SetContext") TVM_REGISTER_API("relay._transform.SetContext")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
Pass pass = args[0]; Pass pass = args[0];
PassContext pass_ctx = args[1]; PassContext pass_ctx = args[1];
...@@ -523,7 +510,7 @@ TVM_REGISTER_API("relay._ir_pass.SetContext") ...@@ -523,7 +510,7 @@ TVM_REGISTER_API("relay._ir_pass.SetContext")
TVM_REGISTER_NODE_TYPE(PassContextNode); TVM_REGISTER_NODE_TYPE(PassContextNode);
TVM_REGISTER_API("relay._ir_pass.PassContext") TVM_REGISTER_API("relay._transform.PassContext")
.set_body_typed(PassContextNode::make); .set_body_typed(PassContextNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
...@@ -534,6 +521,6 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -534,6 +521,6 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
<< "\n"; << "\n";
}); });
} // namespace pass } // namespace transform
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -22,6 +22,7 @@ from tvm import relay ...@@ -22,6 +22,7 @@ from tvm import relay
from tvm.relay import ExprFunctor from tvm.relay import ExprFunctor
from tvm.relay import Function, Call from tvm.relay import Function, Call
from tvm.relay import ir_pass from tvm.relay import ir_pass
from tvm.relay import transform as _transform
from tvm.relay.testing import ctx_list from tvm.relay.testing import ctx_list
...@@ -126,13 +127,13 @@ def test_module_pass(): ...@@ -126,13 +127,13 @@ def test_module_pass():
opt_tester = OptTester(mod) opt_tester = OptTester(mod)
pass_ctx = None pass_ctx = None
@ir_pass.module_pass(opt_level=opt_level, name=pass_name) @_transform.module_pass(opt_level=opt_level, name=pass_name)
def transform(expr, ctx): def transform(expr, ctx):
return opt_tester.transform(expr, ctx) return opt_tester.transform(expr, ctx)
def test_pass_registration(): def test_pass_registration():
mod_pass = transform mod_pass = transform
assert isinstance(mod_pass, ir_pass.ModulePass) assert isinstance(mod_pass, _transform.ModulePass)
pass_info = mod_pass.info pass_info = mod_pass.info
assert pass_info.name == pass_name assert pass_info.name == pass_name
assert pass_info.opt_level == opt_level assert pass_info.opt_level == opt_level
...@@ -140,8 +141,8 @@ def test_module_pass(): ...@@ -140,8 +141,8 @@ def test_module_pass():
def test_pass_registration_no_decorator(): def test_pass_registration_no_decorator():
def direct_transform(expr, ctx): def direct_transform(expr, ctx):
return opt_tester.transform(expr, ctx) return opt_tester.transform(expr, ctx)
mod_pass = ir_pass.module_pass(direct_transform, opt_level=3) mod_pass = _transform.module_pass(direct_transform, opt_level=3)
assert isinstance(mod_pass, ir_pass.ModulePass) assert isinstance(mod_pass, _transform.ModulePass)
pass_info = mod_pass.info pass_info = mod_pass.info
assert pass_info.name == "direct_transform" assert pass_info.name == "direct_transform"
assert pass_info.opt_level == 3 assert pass_info.opt_level == 3
...@@ -202,7 +203,7 @@ def test_function_pass(): ...@@ -202,7 +203,7 @@ def test_function_pass():
opt_tester = OptTester(mod) opt_tester = OptTester(mod)
pass_ctx = None pass_ctx = None
@ir_pass.function_pass(opt_level=opt_level, name=pass_name) @_transform.function_pass(opt_level=opt_level, name=pass_name)
def transform(expr, ctx): def transform(expr, ctx):
return opt_tester.transform(expr, ctx) return opt_tester.transform(expr, ctx)
...@@ -212,7 +213,7 @@ def test_function_pass(): ...@@ -212,7 +213,7 @@ def test_function_pass():
def test_pass_registration(): def test_pass_registration():
function_pass = transform function_pass = transform
assert isinstance(function_pass, ir_pass.FunctionPass) assert isinstance(function_pass, _transform.FunctionPass)
pass_info = function_pass.info pass_info = function_pass.info
assert pass_info.name == pass_name assert pass_info.name == pass_name
assert pass_info.opt_level == opt_level assert pass_info.opt_level == opt_level
...@@ -220,8 +221,8 @@ def test_function_pass(): ...@@ -220,8 +221,8 @@ def test_function_pass():
def test_pass_registration_no_decorator(): def test_pass_registration_no_decorator():
def direct_transform(expr, ctx): def direct_transform(expr, ctx):
return opt_tester.transform(expr, ctx) return opt_tester.transform(expr, ctx)
mod_pass = ir_pass.function_pass(direct_transform, opt_level=0) mod_pass = _transform.function_pass(direct_transform, opt_level=0)
assert isinstance(mod_pass, ir_pass.FunctionPass) assert isinstance(mod_pass, _transform.FunctionPass)
pass_info = mod_pass.info pass_info = mod_pass.info
assert pass_info.name == "direct_transform" assert pass_info.name == "direct_transform"
assert pass_info.opt_level == 0 assert pass_info.opt_level == 0
...@@ -294,14 +295,14 @@ def test_sequential_pass(): ...@@ -294,14 +295,14 @@ def test_sequential_pass():
opt_tester = OptTester(mod) opt_tester = OptTester(mod)
pass_ctx = None pass_ctx = None
@ir_pass.module_pass(opt_level=1) @_transform.module_pass(opt_level=1)
def mod_transform(expr, ctx): def mod_transform(expr, ctx):
return opt_tester.transform(expr, ctx) return opt_tester.transform(expr, ctx)
module_pass = mod_transform module_pass = mod_transform
# Register a function pass. # Register a function pass.
@ir_pass.function_pass(opt_level=1) @_transform.function_pass(opt_level=1)
def func_transform(expr, ctx): def func_transform(expr, ctx):
return opt_tester.transform(expr, ctx) return opt_tester.transform(expr, ctx)
...@@ -310,25 +311,23 @@ def test_sequential_pass(): ...@@ -310,25 +311,23 @@ def test_sequential_pass():
def test_pass_registration(): def test_pass_registration():
passes = [module_pass, function_pass] passes = [module_pass, function_pass]
opt_level = 2 opt_level = 2
pass_name = "sequential_pass" pass_name = "sequential"
sequential_pass = ir_pass.sequential_pass(passes=passes, sequential = _transform.Sequential(passes=passes, opt_level=opt_level)
opt_level=opt_level) pass_info = sequential.info
assert isinstance(sequential_pass, ir_pass.SequentialPass)
pass_info = sequential_pass.info
assert pass_info.name == pass_name assert pass_info.name == pass_name
assert pass_info.opt_level == opt_level assert pass_info.opt_level == opt_level
def test_no_pass(): def test_no_pass():
passes = [] passes = []
sequential_pass = ir_pass.sequential_pass(opt_level=1, passes=passes) sequential = _transform.Sequential(opt_level=1, passes=passes)
ret_mod = sequential_pass(mod) ret_mod = sequential(mod)
mod_func = ret_mod[v_sub] mod_func = ret_mod[v_sub]
check_func(sub, mod_func) check_func(sub, mod_func)
def test_only_module_pass(): def test_only_module_pass():
passes = [module_pass] passes = [module_pass]
sequential_pass = ir_pass.sequential_pass(opt_level=1, passes=passes) sequential = _transform.Sequential(opt_level=1, passes=passes)
ret_mod = sequential_pass(mod) ret_mod = sequential(mod)
# Check the subtract function. # Check the subtract function.
sub_var, new_sub = extract_var_func(ret_mod, v_sub.name_hint) sub_var, new_sub = extract_var_func(ret_mod, v_sub.name_hint)
check_func(new_sub, sub) check_func(new_sub, sub)
...@@ -341,8 +340,8 @@ def test_sequential_pass(): ...@@ -341,8 +340,8 @@ def test_sequential_pass():
def test_only_function_pass(): def test_only_function_pass():
# Check the subtract function. # Check the subtract function.
passes = [function_pass] passes = [function_pass]
sequential_pass = ir_pass.sequential_pass(opt_level=1, passes=passes) sequential = _transform.Sequential(opt_level=1, passes=passes)
ret_mod = sequential_pass(mod) ret_mod = sequential(mod)
_, new_sub = extract_var_func(ret_mod, v_sub.name_hint) _, new_sub = extract_var_func(ret_mod, v_sub.name_hint)
check_func(new_sub, get_ref_sub()) check_func(new_sub, get_ref_sub())
...@@ -355,8 +354,8 @@ def test_sequential_pass(): ...@@ -355,8 +354,8 @@ def test_sequential_pass():
# function pass. # function pass.
mod = relay.Module({v_sub: sub, v_log: log}) mod = relay.Module({v_sub: sub, v_log: log})
passes = [module_pass, function_pass] passes = [module_pass, function_pass]
sequential_pass = ir_pass.sequential_pass(opt_level=1, passes=passes) sequential = _transform.Sequential(opt_level=1, passes=passes)
ret_mod = sequential_pass(mod) ret_mod = sequential(mod)
# Check the abs function is added. # Check the abs function is added.
abs_var, abs_func = get_var_func() abs_var, abs_func = get_var_func()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment