Commit c93235d7 by Zhi Committed by Tianqi Chen

[relay][pass manager] Open transform namespace (#3226)

parent 3272e6cb
...@@ -20,46 +20,12 @@ ...@@ -20,46 +20,12 @@
/*! /*!
* \file tvm/relay/pass.h * \file tvm/relay/pass.h
* \brief The set of Relay passes written in C++. * \brief The set of Relay passes written in C++.
* */
* This file also implements a pass manager. The pass manager manages a sequence
* of Relay-to-Relay transformation passes over a particlar unit of AST. The
* design is largely inspired from LLVM's pass manager and modern deep learning
* frameworks that perform tensor->tensor transformations.
*
* The responsibilities of a traditional compiler pass manager usually involves:
* - Organizing the execution order of optimization passes though not
* necessarily in the optimal sequence.
* - Collecting required analysis information and keep them up-to-date.
* - Reducing the effort required to implement new passes for compiler
* developers, etc.
*
* Similar to LLVM's pass manager, we designed the Relay pass manager to work
* different granularity, i.e. module level, function level, and even sequential
* passe that contains a host of passes.
*
* However, we also extend the functionality of the traditional pass manager
* with the consideration of requirements/convention from deep learning
* frameworks, such as Pytorch and Gluon, etc. Each pass in the Relay pass
* manager performs the Relay.Module -> Relay.Module transformation. All
* different types of passes, including the sequential-level pass object, are
* essentially pass objects. This design, therefore, effectively provides users
* a consistent and convenient interface, i.e. Pass, to play with. It offers a
* means to ease the development and testing of Relay passes. For example, with
* the pass manager, external users will be able to have custom passes correctly
* scheduled without having to modify a single handcrafted pass order.
*
* In the future we need to describe constraints between passes. For example,
* we may want to preserve dependencies between different passes and validate
* them on the completion of a certain pass.
*
* We also need to store side information and import the error reporting system.
*/
#ifndef TVM_RELAY_PASS_H_ #ifndef TVM_RELAY_PASS_H_
#define TVM_RELAY_PASS_H_ #define TVM_RELAY_PASS_H_
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/packed_func_ext.h> #include <tvm/packed_func_ext.h>
#include <tvm/relay/error.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/module.h> #include <tvm/relay/module.h>
#include <tvm/relay/op_attr_types.h> #include <tvm/relay/op_attr_types.h>
...@@ -72,174 +38,6 @@ ...@@ -72,174 +38,6 @@
namespace tvm { namespace tvm {
namespace relay { namespace relay {
namespace pass {
/*
* \brief The context of pass.
*/
class PassContext;
/*!
* \brief PassContextNode contains the information that a pass can rely on, such as
* analysis results.
*/
class PassContextNode : public RelayNode {
public:
/*!
* \brief The error reporter used to notify users why an optimization fails.
*/
ErrorReporter err_reporter;
PassContextNode() = default;
void VisitAttrs(tvm::AttrVisitor* v) final {
}
TVM_DLL static PassContext make();
static constexpr const char* _type_key = "relay.PassContext";
TVM_DECLARE_NODE_TYPE_INFO(PassContextNode, RelayNode);
};
TVM_DEFINE_NODE_REF(PassContext, PassContextNode)
/*
* \brief The meta data of a pass.
*
* PassInfo can be extended conveniently in the future if more meta information
* is needed.
*/
class PassInfo;
/*!
* \brief PassInfoNode contains meta data that will be used to help optimization
* and analysis.
*/
class PassInfoNode : public RelayNode {
public:
/*! \brief The minimal optimization level that this pass will be enabled. */
int opt_level;
/*! \brief The name of an optimization/analysis pass. */
std::string name;
/*! \brief The passes that are required to perform the current pass. */
tvm::Array<tvm::Expr> required;
PassInfoNode() = default;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("opt_level", &opt_level);
v->Visit("name", &name);
v->Visit("required", &required);
}
TVM_DLL static PassInfo make(int opt_level, std::string name,
tvm::Array<tvm::Expr> required);
static constexpr const char* _type_key = "relay.PassInfo";
TVM_DECLARE_NODE_TYPE_INFO(PassInfoNode, RelayNode);
};
TVM_DEFINE_NODE_REF(PassInfo, PassInfoNode)
class Pass;
/*!
* \brief PassNode is the base type of differnt types of optimization passes.
* It is designed as a pure class and implemented by different pass subclasses
* at different granularity of Relay nodes.
*/
class PassNode : public RelayNode {
public:
/*
* \brief Get the pass information/meta data. */
virtual PassInfo Info() const = 0;
/*!
* \brief Set the context information for a pass.
*
* \param pass_ctx The context information for a certain pass.
*/
virtual void SetContext(const PassContext& pass_ctx) = 0;
/*!
* \brief Execute the optimization pass using a functor.
*
* \param mod The module that an optimization pass runs on.
*
* \return The updated module.
*/
virtual Module operator()(const Module& mod) const = 0;
void VisitAttrs(tvm::AttrVisitor* v) override {}
static constexpr const char* _type_key = "relay.Pass";
TVM_DECLARE_BASE_NODE_INFO(PassNode, RelayNode);
};
class Pass : public NodeRef {
public:
Pass() = default;
explicit Pass(NodePtr<tvm::Node> p) : NodeRef(p) {}
PassNode* operator->() const {
return static_cast<PassNode*>(this->node_.get());
}
using ContainerType = PassNode;
};
/*
* \brief Create a module pass.
*
* \param pass_func The packed function that contains the optimization.
* \param opt_level The optimization level of the module pass.
* \param name The name of the module pass.
* \param required The list of the passes that the module pass is dependent on.
*
* \return The created module pass.
*/
Pass CreateModulePass(
const runtime::TypedPackedFunc<Module(Module, PassContext)>& pass_func,
int opt_level,
const std::string& name,
const tvm::Array<tvm::Expr>& required);
/*
* \brief Create a function pass.
*
* \param pass_func The packed function that contains the optimization.
* \param opt_level The optimization level of the function pass.
* \param name The name of the function pass.
* \param required The list of the passes that the function pass is dependent on.
*
* \return The created function pass.
*/
Pass CreateFunctionPass(
const runtime::TypedPackedFunc<Function(Function, PassContext)>& pass_func,
int opt_level,
const std::string& name,
const tvm::Array<tvm::Expr>& required);
/*
* \brief Create a sequential pass.
*
* \param passes The optimization passes will be performed.
* \param opt_level The optimization level of the sequential pass.
* \param name The name of the sequential pass.
* \param required The list of the passes that the sequential pass is dependent on.
* \param disabled The disabled passes.
*
* \return The created sequential pass.
*/
Pass CreateSequentialPass(const tvm::Array<Pass>& passes,
int opt_level,
const std::string& name,
const tvm::Array<tvm::Expr>& required,
const tvm::Array<tvm::Expr>& disabled);
} // namespace pass
/*! /*!
* \brief Infer the type of an expression. * \brief Infer the type of an expression.
* *
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/relay/transform.h
*
* This file implements a pass manager. The pass manager manages a sequence
* of Relay-to-Relay transformation passes over a particlar unit of AST. The
* design is largely inspired from LLVM's pass manager and modern deep learning
* frameworks that perform tensor->tensor transformations.
*
* The responsibilities of a traditional compiler pass manager usually involves:
* - Organizing the execution order of optimization passes though not
* necessarily in the optimal sequence.
* - Collecting required analysis information and keep them up-to-date.
* - Reducing the effort required to implement new passes for compiler
* developers, etc.
*
* Similar to LLVM's pass manager, we designed the Relay pass manager to work
* different granularity, i.e. module level, function level, and even sequential
* passe that contains a host of passes.
*
* However, we also extend the functionality of the traditional pass manager
* with the consideration of requirements/convention from deep learning
* frameworks, such as Pytorch and Gluon, etc. Each pass in the Relay pass
* manager performs the Relay.Module -> Relay.Module transformation. All
* different types of passes, including the sequential-level pass object, are
* essentially pass objects. This design, therefore, effectively provides users
* a consistent and convenient interface, i.e. Pass, to play with. It offers a
* means to ease the development and testing of Relay passes. For example, with
* the pass manager, external users will be able to have custom passes correctly
* scheduled without having to modify a single handcrafted pass order.
*
* In the future we need to describe constraints between passes. For example,
* we may want to preserve dependencies between different passes and validate
* them on the completion of a certain pass.
*
* We also need to store side information and import the error reporting system.
*/
#ifndef TVM_RELAY_TRANSFORM_H_
#define TVM_RELAY_TRANSFORM_H_
#include <tvm/packed_func_ext.h>
#include <tvm/relay/error.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/module.h>
#include <string>
#include <vector>
namespace tvm {
namespace relay {
namespace transform {
/*
* \brief The context of pass.
*/
class PassContext;
/*!
* \brief PassContextNode contains the information that a pass can rely on, such as
* analysis results.
*/
class PassContextNode : public RelayNode {
public:
/*!
* \brief The error reporter used to notify users why an optimization fails.
*/
ErrorReporter err_reporter;
PassContextNode() = default;
void VisitAttrs(tvm::AttrVisitor* v) final {
}
TVM_DLL static PassContext make();
static constexpr const char* _type_key = "relay.PassContext";
TVM_DECLARE_NODE_TYPE_INFO(PassContextNode, RelayNode);
};
TVM_DEFINE_NODE_REF(PassContext, PassContextNode)
/*
* \brief The meta data of a pass.
*
* PassInfo can be extended conveniently in the future if more meta information
* is needed.
*/
class PassInfo;
/*!
* \brief PassInfoNode contains meta data that will be used to help optimization
* and analysis.
*/
class PassInfoNode : public RelayNode {
public:
/*! \brief The minimal optimization level that this pass will be enabled. */
int opt_level;
/*! \brief The name of an optimization/analysis pass. */
std::string name;
/*! \brief The passes that are required to perform the current pass. */
tvm::Array<tvm::Expr> required;
PassInfoNode() = default;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("opt_level", &opt_level);
v->Visit("name", &name);
v->Visit("required", &required);
}
TVM_DLL static PassInfo make(int opt_level, std::string name,
tvm::Array<tvm::Expr> required);
static constexpr const char* _type_key = "relay.PassInfo";
TVM_DECLARE_NODE_TYPE_INFO(PassInfoNode, RelayNode);
};
TVM_DEFINE_NODE_REF(PassInfo, PassInfoNode)
class Pass;
/*!
* \brief PassNode is the base type of differnt types of optimization passes.
* It is designed as a pure class and implemented by different pass subclasses
* at different granularity of Relay nodes.
*/
class PassNode : public RelayNode {
public:
/*
* \brief Get the pass information/meta data. */
virtual PassInfo Info() const = 0;
/*!
* \brief Set the context information for a pass.
*
* \param pass_ctx The context information for a certain pass.
*/
virtual void SetContext(const PassContext& pass_ctx) = 0;
/*!
* \brief Execute the optimization pass using a functor.
*
* \param mod The module that an optimization pass runs on.
*
* \return The updated module.
*/
virtual Module operator()(const Module& mod) const = 0;
void VisitAttrs(tvm::AttrVisitor* v) override {}
static constexpr const char* _type_key = "relay.Pass";
TVM_DECLARE_BASE_NODE_INFO(PassNode, RelayNode);
};
class Pass : public NodeRef {
public:
Pass() = default;
explicit Pass(NodePtr<tvm::Node> p) : NodeRef(p) {}
PassNode* operator->() const {
return static_cast<PassNode*>(this->node_.get());
}
using ContainerType = PassNode;
};
class SequentialNode;
class Sequential : public Pass {
public:
/*!
* \brief The constructor of `Sequential`.
* \param passes The passes to apply.
* \param pass_info The pass metadata.
* \param disabled The passes that will not be applied.
*/
TVM_DLL Sequential(tvm::Array<Pass> passes,
PassInfo pass_info,
tvm::Array<tvm::Expr> disabled);
Sequential() = default;
explicit Sequential(tvm::NodePtr<::tvm::Node> n) : Pass(n) {}
const SequentialNode* operator->() const;
using ContainerType = Sequential;
};
/*
* \brief Create a module pass.
*
* \param pass_func The packed function that contains the optimization.
* \param opt_level The optimization level of the module pass.
* \param name The name of the module pass.
* \param required The list of the passes that the module pass is dependent on.
*
* \return The created module pass.
*/
Pass CreateModulePass(
const runtime::TypedPackedFunc<Module(Module, PassContext)>& pass_func,
int opt_level,
const std::string& name,
const tvm::Array<tvm::Expr>& required);
/*
* \brief Create a function pass.
*
* \param pass_func The packed function that contains the optimization.
* \param opt_level The optimization level of the function pass.
* \param name The name of the function pass.
* \param required The list of the passes that the function pass is dependent on.
*
* \return The created function pass.
*/
Pass CreateFunctionPass(
const runtime::TypedPackedFunc<Function(Function, PassContext)>& pass_func,
int opt_level,
const std::string& name,
const tvm::Array<tvm::Expr>& required);
} // namespace transform
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_TRANSFORM_H_
...@@ -25,6 +25,7 @@ from . import expr_functor ...@@ -25,6 +25,7 @@ from . import expr_functor
from . import module from . import module
from . import adt from . import adt
from . import ir_pass from . import ir_pass
from . import transform
from .build_module import build, build_config, create_executor from .build_module import build, build_config, create_executor
from . import prelude from . import prelude
from . import parser from . import parser
...@@ -97,9 +98,8 @@ Match = adt.Match ...@@ -97,9 +98,8 @@ Match = adt.Match
var = expr.var var = expr.var
const = expr.const const = expr.const
bind = expr.bind bind = expr.bind
module_pass = ir_pass.module_pass module_pass = transform.module_pass
function_pass = ir_pass.function_pass function_pass = transform.function_pass
sequential_pass = ir_pass.sequential_pass
# ExprFunctor # ExprFunctor
ExprFunctor = expr_functor.ExprFunctor ExprFunctor = expr_functor.ExprFunctor
...@@ -114,9 +114,9 @@ save_param_dict = param_dict.save_param_dict ...@@ -114,9 +114,9 @@ save_param_dict = param_dict.save_param_dict
load_param_dict = param_dict.load_param_dict load_param_dict = param_dict.load_param_dict
# Pass manager # Pass manager
PassInfo = ir_pass.PassInfo PassInfo = transform.PassInfo
PassContext = ir_pass.PassContext PassContext = transform.PassContext
Pass = ir_pass.Pass Pass = transform.Pass
ModulePass = ir_pass.ModulePass ModulePass = transform.ModulePass
FunctionPass = ir_pass.FunctionPass FunctionPass = transform.FunctionPass
SequentialPass = ir_pass.SequentialPass Sequential = transform.Sequential
...@@ -17,62 +17,8 @@ ...@@ -17,62 +17,8 @@
import tvm import tvm
from . import ir from . import ir
from .base import NodeBase
from .env import Module from .env import Module
class PassContext(NodeBase):
def __init__(self):
...
class PassInfo(NodeBase):
name = ... # type: str
opt_level = ... # type: int
required = ... # type: list
def __init__(self, name, opt_level, required)
# type: (str, int, list) -> None
class Pass(NodeBase):
def __init__(self):
...
class ModulePass(Pass):
name = ... # type: str
opt_level = ... # type: int
pass_func = ... # type: Callable
required = ... # type: list
def __init__(self, name, opt_level, pass_func, required):
# type: (str, int, Callable, list) -> None
...
class FunctionPass(Pass):
name = ... # type: str
opt_level = ... # type: int
pass_func = ... # type: Callable
required = ... # type: list
def __init__(self, name, opt_level, pass_func, required):
# type: (str, int, Callable, list) -> None
...
class SequentialPass(Pass):
name = ... # type: str
opt_level = ... # type: int
passes = ... # type: list
required = ... # type: list
disabled = ... # type: list
def __init__(self, name, opt_level, passes, required, disabled):
# type: (str, int, list, list, list) -> None
...
def check_expr(env: Module, expr: ir.Expr) -> ir.Type: ... def check_expr(env: Module, expr: ir.Expr) -> ir.Type: ...
def generalize(env: Module, expr: ir.Expr) -> ir.Expr: ... def generalize(env: Module, expr: ir.Expr) -> ir.Expr: ...
def _get_checked_type(expr: ir.Expr) -> ir.Type: ... def _get_checked_type(expr: ir.Expr) -> ir.Type: ...
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""FFI exposing the Relay type inference and checking."""
from tvm._ffi.function import _init_api
_init_api("relay._transform", __name__)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from .base import NodeBase
class PassContext(NodeBase):
def __init__(self):
...
class PassInfo(NodeBase):
name = ... # type: str
opt_level = ... # type: int
required = ... # type: list
def __init__(self, name, opt_level, required)
# type: (str, int, list) -> None
class Pass(NodeBase):
def __init__(self):
...
class ModulePass(Pass):
name = ... # type: str
opt_level = ... # type: int
pass_func = ... # type: Callable
required = ... # type: list
def __init__(self, name, opt_level, pass_func, required):
# type: (str, int, Callable, list) -> None
...
class FunctionPass(Pass):
name = ... # type: str
opt_level = ... # type: int
pass_func = ... # type: Callable
required = ... # type: list
def __init__(self, name, opt_level, pass_func, required):
# type: (str, int, Callable, list) -> None
...
class Sequential(Pass):
name = ... # type: str
opt_level = ... # type: int
passes = ... # type: list
required = ... # type: list
disabled = ... # type: list
def __init__(self, name, opt_level, passes, required, disabled):
# type: (str, int, list, list, list) -> None
...
...@@ -23,11 +23,11 @@ ...@@ -23,11 +23,11 @@
* \brief Relay pass manager implementation. * \brief Relay pass manager implementation.
*/ */
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/pass.h> #include <tvm/relay/transform.h>
namespace tvm { namespace tvm {
namespace relay { namespace relay {
namespace pass { namespace transform {
using tvm::IRPrinter; using tvm::IRPrinter;
...@@ -169,17 +169,15 @@ class FunctionPassNode : public PassNode { ...@@ -169,17 +169,15 @@ class FunctionPassNode : public PassNode {
RELAY_DEFINE_NODE_REF(FunctionPass, FunctionPassNode, Pass); RELAY_DEFINE_NODE_REF(FunctionPass, FunctionPassNode, Pass);
class SequentialPass;
/*! /*!
* \brief The SequentialPassNode contains a set of passes that transform Relay * \brief The SequentialNode contains a set of passes that transform Relay
* programs from one AST to another semantically equivalent one. * programs from one AST to another semantically equivalent one.
* *
* One example of this level of pass is that the pass manager needs to correctly * One example of this level of pass is that the pass manager needs to correctly
* perform a host of optimizations with a given optimization level and disabled * perform a host of optimizations with a given optimization level and disabled
* passes. * passes.
*/ */
class SequentialPassNode : public PassNode { class SequentialNode : public PassNode {
public: public:
/* \brief The pass meta data.*/ /* \brief The pass meta data.*/
PassInfo pass_info; PassInfo pass_info;
...@@ -212,10 +210,6 @@ class SequentialPassNode : public PassNode { ...@@ -212,10 +210,6 @@ class SequentialPassNode : public PassNode {
passes.push_back(pass); passes.push_back(pass);
} }
TVM_DLL static SequentialPass make(tvm::Array<Pass> passes,
PassInfo pass_info,
tvm::Array<tvm::Expr> disabled);
/*! /*!
* \brief Resolve the pass dependency. It globs all required passes by * \brief Resolve the pass dependency. It globs all required passes by
* a given pass and executes them. * a given pass and executes them.
...@@ -251,8 +245,8 @@ class SequentialPassNode : public PassNode { ...@@ -251,8 +245,8 @@ class SequentialPassNode : public PassNode {
*/ */
void SetContext(const PassContext& pass_ctx) final; void SetContext(const PassContext& pass_ctx) final;
static constexpr const char* _type_key = "relay.SequentialPass"; static constexpr const char* _type_key = "relay.Sequential";
TVM_DECLARE_NODE_TYPE_INFO(SequentialPassNode, PassNode); TVM_DECLARE_NODE_TYPE_INFO(SequentialNode, PassNode);
private: private:
/*! /*!
...@@ -261,8 +255,6 @@ class SequentialPassNode : public PassNode { ...@@ -261,8 +255,6 @@ class SequentialPassNode : public PassNode {
PassContext pass_ctx_; PassContext pass_ctx_;
}; };
RELAY_DEFINE_NODE_REF(SequentialPass, SequentialPassNode, Pass);
PassInfo PassInfoNode::make(int opt_level, std::string name, PassInfo PassInfoNode::make(int opt_level, std::string name,
tvm::Array<tvm::Expr> required) { tvm::Array<tvm::Expr> required) {
auto pass_info = make_node<PassInfoNode>(); auto pass_info = make_node<PassInfoNode>();
...@@ -350,20 +342,24 @@ bool FunctionPassNode::SkipFunction(const Function& func) const { ...@@ -350,20 +342,24 @@ bool FunctionPassNode::SkipFunction(const Function& func) const {
return pval && pval->value != 0; return pval && pval->value != 0;
} }
SequentialPass SequentialPassNode::make(tvm::Array<Pass> passes, Sequential::Sequential(tvm::Array<Pass> passes,
PassInfo pass_info, PassInfo pass_info,
tvm::Array<tvm::Expr> disabled) { tvm::Array<tvm::Expr> disabled) {
auto n = make_node<SequentialPassNode>(); auto n = make_node<SequentialNode>();
n->passes = std::move(passes); n->passes = std::move(passes);
n->pass_info = std::move(pass_info); n->pass_info = std::move(pass_info);
n->disabled = std::move(disabled); n->disabled = std::move(disabled);
return SequentialPass(n); node_ = std::move(n);
}
const SequentialNode* Sequential::operator->() const {
return static_cast<const SequentialNode*>(this->node_.get());
} }
// TODO(jroesch, zhiics): we currenlty only sequentially execute each pass in // TODO(jroesch, zhiics): we currenlty only sequentially execute each pass in
// a SequentialPass without the consideration of their orders. The phase // a Sequential without the consideration of their orders. The phase
// ordering problem needed to be handled in the future. // ordering problem needed to be handled in the future.
Module SequentialPassNode::operator()(const Module& module) const { Module SequentialNode::operator()(const Module& module) const {
Module mod = module; Module mod = module;
for (const Pass& pass : passes) { for (const Pass& pass : passes) {
CHECK(pass.defined()) << "Found undefined pass for optimization."; CHECK(pass.defined()) << "Found undefined pass for optimization.";
...@@ -373,7 +369,7 @@ Module SequentialPassNode::operator()(const Module& module) const { ...@@ -373,7 +369,7 @@ Module SequentialPassNode::operator()(const Module& module) const {
return mod; return mod;
} }
void SequentialPassNode::ResolveDependency(const Module& mod) { void SequentialNode::ResolveDependency(const Module& mod) {
// TODO(zhiics) Implement it. // TODO(zhiics) Implement it.
// 1. Consider the required passes for each pass. // 1. Consider the required passes for each pass.
// 2. Only resolve the enabled passes. // 2. Only resolve the enabled passes.
...@@ -382,7 +378,7 @@ void SequentialPassNode::ResolveDependency(const Module& mod) { ...@@ -382,7 +378,7 @@ void SequentialPassNode::ResolveDependency(const Module& mod) {
<< "\n"; << "\n";
} }
std::vector<std::string> SequentialPassNode::DisabledPasses() const { std::vector<std::string> SequentialNode::DisabledPasses() const {
std::vector<std::string> ret; std::vector<std::string> ret;
for (const auto& it : disabled) { for (const auto& it : disabled) {
const auto* str = it.as<tvm::ir::StringImm>(); const auto* str = it.as<tvm::ir::StringImm>();
...@@ -392,7 +388,7 @@ std::vector<std::string> SequentialPassNode::DisabledPasses() const { ...@@ -392,7 +388,7 @@ std::vector<std::string> SequentialPassNode::DisabledPasses() const {
return ret; return ret;
} }
void SequentialPassNode::SetContext(const PassContext& pass_ctx) { void SequentialNode::SetContext(const PassContext& pass_ctx) {
pass_ctx_ = pass_ctx; pass_ctx_ = pass_ctx;
} }
...@@ -414,21 +410,12 @@ Pass CreateFunctionPass( ...@@ -414,21 +410,12 @@ Pass CreateFunctionPass(
return FunctionPassNode::make(pass_func, pass_info); return FunctionPassNode::make(pass_func, pass_info);
} }
Pass CreateSequentialPass(const tvm::Array<Pass>& passes,
int opt_level,
const std::string& name,
const tvm::Array<tvm::Expr>& required,
const tvm::Array<tvm::Expr>& disabled) {
PassInfo pass_info = PassInfoNode::make(opt_level, name, required);
return SequentialPassNode::make(passes, pass_info, disabled);
}
TVM_REGISTER_NODE_TYPE(PassInfoNode); TVM_REGISTER_NODE_TYPE(PassInfoNode);
TVM_REGISTER_API("relay._ir_pass.PassInfo") TVM_REGISTER_API("relay._transform.PassInfo")
.set_body_typed(PassInfoNode::make); .set_body_typed(PassInfoNode::make);
TVM_REGISTER_API("relay._ir_pass.Info") TVM_REGISTER_API("relay._transform.Info")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
Pass pass = args[0]; Pass pass = args[0];
*ret = pass->Info(); *ret = pass->Info();
...@@ -450,10 +437,10 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -450,10 +437,10 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
TVM_REGISTER_NODE_TYPE(ModulePassNode); TVM_REGISTER_NODE_TYPE(ModulePassNode);
TVM_REGISTER_API("relay._ir_pass.CreateModulePass") TVM_REGISTER_API("relay._transform.CreateModulePass")
.set_body_typed(CreateModulePass); .set_body_typed(CreateModulePass);
TVM_REGISTER_API("relay._ir_pass.RunPass") TVM_REGISTER_API("relay._transform.RunPass")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
Pass pass = args[0]; Pass pass = args[0];
Module mod = args[1]; Module mod = args[1];
...@@ -475,7 +462,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -475,7 +462,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
TVM_REGISTER_NODE_TYPE(FunctionPassNode); TVM_REGISTER_NODE_TYPE(FunctionPassNode);
TVM_REGISTER_API("relay._ir_pass.CreateFunctionPass") TVM_REGISTER_API("relay._transform.CreateFunctionPass")
.set_body_typed(CreateFunctionPass); .set_body_typed(CreateFunctionPass);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
...@@ -486,9 +473,9 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -486,9 +473,9 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
<< " at the optimization level " << pn->opt_level; << " at the optimization level " << pn->opt_level;
}); });
TVM_REGISTER_NODE_TYPE(SequentialPassNode); TVM_REGISTER_NODE_TYPE(SequentialNode);
TVM_REGISTER_API("relay._ir_pass.CreateSequentialPass") TVM_REGISTER_API("relay._transform.Sequential")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
tvm::Array<Pass> passes = args[0]; tvm::Array<Pass> passes = args[0];
int opt_level = args[1]; int opt_level = args[1];
...@@ -496,14 +483,14 @@ TVM_REGISTER_API("relay._ir_pass.CreateSequentialPass") ...@@ -496,14 +483,14 @@ TVM_REGISTER_API("relay._ir_pass.CreateSequentialPass")
tvm::Array<tvm::Expr> required = args[3]; tvm::Array<tvm::Expr> required = args[3];
tvm::Array<tvm::Expr> disabled = args[4]; tvm::Array<tvm::Expr> disabled = args[4];
PassInfo pass_info = PassInfoNode::make(opt_level, name, required); PassInfo pass_info = PassInfoNode::make(opt_level, name, required);
*ret = SequentialPassNode::make(passes, pass_info, disabled); *ret = Sequential(passes, pass_info, disabled);
}); });
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<SequentialPassNode>([](const SequentialPassNode* node, .set_dispatch<SequentialNode>([](const SequentialNode* node,
tvm::IRPrinter* p) { tvm::IRPrinter* p) {
const PassInfoNode* seq_pn = node->Info().operator->(); const PassInfoNode* seq_pn = node->Info().operator->();
p->stream << "Run SequentialPass pass: " << seq_pn->name p->stream << "Run Sequential pass: " << seq_pn->name
<< " at the optimization level. " << seq_pn->opt_level; << " at the optimization level. " << seq_pn->opt_level;
p->stream << "The passes will be executed are: ["; p->stream << "The passes will be executed are: [";
for (const auto& it : node->passes) { for (const auto& it : node->passes) {
...@@ -514,7 +501,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -514,7 +501,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
p->stream << "]"; p->stream << "]";
}); });
TVM_REGISTER_API("relay._ir_pass.SetContext") TVM_REGISTER_API("relay._transform.SetContext")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
Pass pass = args[0]; Pass pass = args[0];
PassContext pass_ctx = args[1]; PassContext pass_ctx = args[1];
...@@ -523,7 +510,7 @@ TVM_REGISTER_API("relay._ir_pass.SetContext") ...@@ -523,7 +510,7 @@ TVM_REGISTER_API("relay._ir_pass.SetContext")
TVM_REGISTER_NODE_TYPE(PassContextNode); TVM_REGISTER_NODE_TYPE(PassContextNode);
TVM_REGISTER_API("relay._ir_pass.PassContext") TVM_REGISTER_API("relay._transform.PassContext")
.set_body_typed(PassContextNode::make); .set_body_typed(PassContextNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
...@@ -534,6 +521,6 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -534,6 +521,6 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
<< "\n"; << "\n";
}); });
} // namespace pass } // namespace transform
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -22,6 +22,7 @@ from tvm import relay ...@@ -22,6 +22,7 @@ from tvm import relay
from tvm.relay import ExprFunctor from tvm.relay import ExprFunctor
from tvm.relay import Function, Call from tvm.relay import Function, Call
from tvm.relay import ir_pass from tvm.relay import ir_pass
from tvm.relay import transform as _transform
from tvm.relay.testing import ctx_list from tvm.relay.testing import ctx_list
...@@ -126,13 +127,13 @@ def test_module_pass(): ...@@ -126,13 +127,13 @@ def test_module_pass():
opt_tester = OptTester(mod) opt_tester = OptTester(mod)
pass_ctx = None pass_ctx = None
@ir_pass.module_pass(opt_level=opt_level, name=pass_name) @_transform.module_pass(opt_level=opt_level, name=pass_name)
def transform(expr, ctx): def transform(expr, ctx):
return opt_tester.transform(expr, ctx) return opt_tester.transform(expr, ctx)
def test_pass_registration(): def test_pass_registration():
mod_pass = transform mod_pass = transform
assert isinstance(mod_pass, ir_pass.ModulePass) assert isinstance(mod_pass, _transform.ModulePass)
pass_info = mod_pass.info pass_info = mod_pass.info
assert pass_info.name == pass_name assert pass_info.name == pass_name
assert pass_info.opt_level == opt_level assert pass_info.opt_level == opt_level
...@@ -140,8 +141,8 @@ def test_module_pass(): ...@@ -140,8 +141,8 @@ def test_module_pass():
def test_pass_registration_no_decorator(): def test_pass_registration_no_decorator():
def direct_transform(expr, ctx): def direct_transform(expr, ctx):
return opt_tester.transform(expr, ctx) return opt_tester.transform(expr, ctx)
mod_pass = ir_pass.module_pass(direct_transform, opt_level=3) mod_pass = _transform.module_pass(direct_transform, opt_level=3)
assert isinstance(mod_pass, ir_pass.ModulePass) assert isinstance(mod_pass, _transform.ModulePass)
pass_info = mod_pass.info pass_info = mod_pass.info
assert pass_info.name == "direct_transform" assert pass_info.name == "direct_transform"
assert pass_info.opt_level == 3 assert pass_info.opt_level == 3
...@@ -202,7 +203,7 @@ def test_function_pass(): ...@@ -202,7 +203,7 @@ def test_function_pass():
opt_tester = OptTester(mod) opt_tester = OptTester(mod)
pass_ctx = None pass_ctx = None
@ir_pass.function_pass(opt_level=opt_level, name=pass_name) @_transform.function_pass(opt_level=opt_level, name=pass_name)
def transform(expr, ctx): def transform(expr, ctx):
return opt_tester.transform(expr, ctx) return opt_tester.transform(expr, ctx)
...@@ -212,7 +213,7 @@ def test_function_pass(): ...@@ -212,7 +213,7 @@ def test_function_pass():
def test_pass_registration(): def test_pass_registration():
function_pass = transform function_pass = transform
assert isinstance(function_pass, ir_pass.FunctionPass) assert isinstance(function_pass, _transform.FunctionPass)
pass_info = function_pass.info pass_info = function_pass.info
assert pass_info.name == pass_name assert pass_info.name == pass_name
assert pass_info.opt_level == opt_level assert pass_info.opt_level == opt_level
...@@ -220,8 +221,8 @@ def test_function_pass(): ...@@ -220,8 +221,8 @@ def test_function_pass():
def test_pass_registration_no_decorator(): def test_pass_registration_no_decorator():
def direct_transform(expr, ctx): def direct_transform(expr, ctx):
return opt_tester.transform(expr, ctx) return opt_tester.transform(expr, ctx)
mod_pass = ir_pass.function_pass(direct_transform, opt_level=0) mod_pass = _transform.function_pass(direct_transform, opt_level=0)
assert isinstance(mod_pass, ir_pass.FunctionPass) assert isinstance(mod_pass, _transform.FunctionPass)
pass_info = mod_pass.info pass_info = mod_pass.info
assert pass_info.name == "direct_transform" assert pass_info.name == "direct_transform"
assert pass_info.opt_level == 0 assert pass_info.opt_level == 0
...@@ -294,14 +295,14 @@ def test_sequential_pass(): ...@@ -294,14 +295,14 @@ def test_sequential_pass():
opt_tester = OptTester(mod) opt_tester = OptTester(mod)
pass_ctx = None pass_ctx = None
@ir_pass.module_pass(opt_level=1) @_transform.module_pass(opt_level=1)
def mod_transform(expr, ctx): def mod_transform(expr, ctx):
return opt_tester.transform(expr, ctx) return opt_tester.transform(expr, ctx)
module_pass = mod_transform module_pass = mod_transform
# Register a function pass. # Register a function pass.
@ir_pass.function_pass(opt_level=1) @_transform.function_pass(opt_level=1)
def func_transform(expr, ctx): def func_transform(expr, ctx):
return opt_tester.transform(expr, ctx) return opt_tester.transform(expr, ctx)
...@@ -310,25 +311,23 @@ def test_sequential_pass(): ...@@ -310,25 +311,23 @@ def test_sequential_pass():
def test_pass_registration(): def test_pass_registration():
passes = [module_pass, function_pass] passes = [module_pass, function_pass]
opt_level = 2 opt_level = 2
pass_name = "sequential_pass" pass_name = "sequential"
sequential_pass = ir_pass.sequential_pass(passes=passes, sequential = _transform.Sequential(passes=passes, opt_level=opt_level)
opt_level=opt_level) pass_info = sequential.info
assert isinstance(sequential_pass, ir_pass.SequentialPass)
pass_info = sequential_pass.info
assert pass_info.name == pass_name assert pass_info.name == pass_name
assert pass_info.opt_level == opt_level assert pass_info.opt_level == opt_level
def test_no_pass(): def test_no_pass():
passes = [] passes = []
sequential_pass = ir_pass.sequential_pass(opt_level=1, passes=passes) sequential = _transform.Sequential(opt_level=1, passes=passes)
ret_mod = sequential_pass(mod) ret_mod = sequential(mod)
mod_func = ret_mod[v_sub] mod_func = ret_mod[v_sub]
check_func(sub, mod_func) check_func(sub, mod_func)
def test_only_module_pass(): def test_only_module_pass():
passes = [module_pass] passes = [module_pass]
sequential_pass = ir_pass.sequential_pass(opt_level=1, passes=passes) sequential = _transform.Sequential(opt_level=1, passes=passes)
ret_mod = sequential_pass(mod) ret_mod = sequential(mod)
# Check the subtract function. # Check the subtract function.
sub_var, new_sub = extract_var_func(ret_mod, v_sub.name_hint) sub_var, new_sub = extract_var_func(ret_mod, v_sub.name_hint)
check_func(new_sub, sub) check_func(new_sub, sub)
...@@ -341,8 +340,8 @@ def test_sequential_pass(): ...@@ -341,8 +340,8 @@ def test_sequential_pass():
def test_only_function_pass(): def test_only_function_pass():
# Check the subtract function. # Check the subtract function.
passes = [function_pass] passes = [function_pass]
sequential_pass = ir_pass.sequential_pass(opt_level=1, passes=passes) sequential = _transform.Sequential(opt_level=1, passes=passes)
ret_mod = sequential_pass(mod) ret_mod = sequential(mod)
_, new_sub = extract_var_func(ret_mod, v_sub.name_hint) _, new_sub = extract_var_func(ret_mod, v_sub.name_hint)
check_func(new_sub, get_ref_sub()) check_func(new_sub, get_ref_sub())
...@@ -355,8 +354,8 @@ def test_sequential_pass(): ...@@ -355,8 +354,8 @@ def test_sequential_pass():
# function pass. # function pass.
mod = relay.Module({v_sub: sub, v_log: log}) mod = relay.Module({v_sub: sub, v_log: log})
passes = [module_pass, function_pass] passes = [module_pass, function_pass]
sequential_pass = ir_pass.sequential_pass(opt_level=1, passes=passes) sequential = _transform.Sequential(opt_level=1, passes=passes)
ret_mod = sequential_pass(mod) ret_mod = sequential(mod)
# Check the abs function is added. # Check the abs function is added.
abs_var, abs_func = get_var_func() abs_var, abs_func = get_var_func()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment