Unverified Commit d7d2a9b3 by Tianqi Chen Committed by GitHub

[REFACTOR][IR] Initialize Unified IR Pass Infra. (#4702)

Move the relay's pass Infra to ir.
Keep FunctionPass in relay as it is local to the dialect.
parent edc3674d
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/ir/transform.h
*
* This file implements a pass manager. The pass manager manages a sequence
* of IRModule -> IRModule transformation passes over a particlar unit of AST. The
* design is largely inspired from LLVM's pass manager and modern deep learning
* frameworks that perform tensor->tensor transformations.
*
* The responsibilities of a traditional compiler pass manager usually involves:
* - Organizing the execution order of optimization passes though not
* necessarily in the optimal sequence.
* - Collecting required analysis information and keep them up-to-date.
* - Reducing the effort required to implement new passes for compiler
* developers, etc.
*
* Similar to LLVM's pass manager, we designed the Relay pass manager to work
* different granularity, i.e. module level, function level, and even sequential
* passe that contains a host of passes.
*
* However, we also extend the functionality of the traditional pass manager
* with the consideration of requirements/convention from deep learning
* frameworks, such as Pytorch and Gluon, etc. Each pass in the Relay pass
* manager performs the IRModule -> IRModule transformation. All
* different types of passes, including the sequential-level pass object, are
* essentially pass objects. This design, therefore, effectively provides users
* a consistent and convenient interface, i.e. Pass, to play with. It offers a
* means to ease the development and testing of Relay passes. For example, with
* the pass manager, external users will be able to have custom passes correctly
* scheduled without having to modify a single handcrafted pass order.
*
* In the future we need to describe constraints between passes. For example,
* we may want to preserve dependencies between different passes and validate
* them on the completion of a certain pass.
*
* We also need to store side information and import the error reporting system.
*/
#ifndef TVM_IR_TRANSFORM_H_
#define TVM_IR_TRANSFORM_H_
#include <tvm/base.h>
#include <tvm/node/container.h>
#include <tvm/ir/error.h>
#include <tvm/ir/module.h>
#include <string>
namespace tvm {
namespace transform {
/*!
* \brief PassContextNode contains the information that a pass can rely on,
* such as analysis results.
* \sa PassContext
*/
class PassContextNode : public Object {
public:
/*!
* \brief The error reporter used to notify users why an optimization fails.
*/
ErrorReporter err_reporter;
/*! \brief The default optimization level. */
int opt_level{2};
/*! \brief CPU is the default fallback device for heterogeneous execution. */
int fallback_device{static_cast<int>(kDLCPU)};
/*! \brief The list of required passes. */
tvm::Array<tvm::PrimExpr> required_pass;
/*! \brief The list of disabled passes. */
tvm::Array<tvm::PrimExpr> disabled_pass;
PassContextNode() = default;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("opt_level", &opt_level);
v->Visit("fallback_device", &fallback_device);
v->Visit("required_pass", &required_pass);
v->Visit("disabled_pass", &disabled_pass);
}
static constexpr const char* _type_key = "relay.PassContext";
TVM_DECLARE_FINAL_OBJECT_INFO(PassContextNode, Object);
};
/*!
* \brief PassContext that is used to configure the pass behavior.
*
* \code
*
* auto new_ctx = PassContext::Create();
* ctx->opt_level = 2;
* ctx->fallback_device = kDLCPU;
* With<PassContext> scope(ctx);
* // pass context in effect.
*
* \endcode
* \sa PassContextNode
*/
class PassContext : public ObjectRef {
public:
PassContext() {}
explicit PassContext(ObjectPtr<::tvm::Object> n) : ObjectRef(n) {}
/*!
* \brief const accessor.
* \return const access pointer.
*/
const PassContextNode* operator->() const {
CHECK(get() != nullptr);
return static_cast<const PassContextNode*>(get());
}
/*!
* \brief mutable accessor.
* \return mutable access pointer.
*/
PassContextNode* operator->() {
CHECK(get() != nullptr);
return static_cast<PassContextNode*>(get_mutable());
}
/*!
* \brief Construct a PassContext containing the default configurations.
* \return The new PassContext.
*/
TVM_DLL static PassContext Create();
/*!
* \brief Get the default pass context in the current scope.
* \return The pass context.
*/
TVM_DLL static PassContext Current();
// accessor.
using ContainerType = PassContextNode;
class Internal;
private:
// The entry of a pass context scope.
TVM_DLL void EnterWithScope();
// The exit of a pass context scope.
TVM_DLL void ExitWithScope();
// Classes to get the Python `with` like syntax.
friend class Internal;
friend class tvm::With<PassContext>;
};
/*!
* \brief Meta data that will be used to help optimization and analysis.
* \sa PassInfo
*/
class PassInfoNode : public Object {
public:
/*! \brief The minimal optimization level that this pass will be enabled. */
int opt_level;
/*! \brief The name of an optimization/analysis pass. */
std::string name;
/*! \brief The passes that are required to perform the current pass. */
tvm::Array<tvm::PrimExpr> required;
PassInfoNode() = default;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("opt_level", &opt_level);
v->Visit("name", &name);
v->Visit("required", &required);
}
static constexpr const char* _type_key = "relay.PassInfo";
TVM_DECLARE_FINAL_OBJECT_INFO(PassInfoNode, Object);
};
/*
* \brief Managed reference class for PassInfoNode
* \sa PassInfoNode
*/
class PassInfo : public ObjectRef {
public:
/*!
* \brief Constructor
* \param opt_level The optimization level
* \param name Name of the pass.
* \param required The passes that are required to perform the current pass.
*/
TVM_DLL PassInfo(int opt_level,
std::string name,
tvm::Array<tvm::PrimExpr> required);
TVM_DEFINE_OBJECT_REF_METHODS(PassInfo, ObjectRef, PassInfoNode);
};
/*!
* \brief PassNode is the base type of differnt types of optimization passes.
* It is designed as a pure class and implemented by different pass subclasses
* at different granularity of Relay nodes.
*/
class PassNode : public Object {
public:
virtual ~PassNode() {}
/*!
* \brief Get the pass information/meta data. */
virtual PassInfo Info() const = 0;
/*!
* \brief Transform mod using the default PassContext in the current scope.
*
* \param mod The module that an optimization pass runs on.
*
* \return The transformed module.
*/
IRModule operator()(const IRModule& mod) const {
return this->operator()(mod, PassContext::Current());
}
/*!
* \brief Transform mod using a functor under a given pass context.
*
* \param mod The module that an optimization pass runs on.
* \param pass_ctx The pass context that can provide information for the optimization.
*
* \return The transformed module.
*/
virtual IRModule operator()(const IRModule& mod,
const PassContext& pass_ctx) const = 0;
void VisitAttrs(tvm::AttrVisitor* v) {}
static constexpr const char* _type_key = "relay.Pass";
TVM_DECLARE_BASE_OBJECT_INFO(PassNode, Object);
};
class Pass : public ObjectRef {
public:
/*!
* \brief Transform mod using the default PassContext in the current scope.
*
* \param mod The module that an optimization pass runs on.
*
* \return The transformed module.
*/
IRModule operator()(const IRModule& mod) const {
const PassNode* node = operator->();
CHECK(node != nullptr);
return node->operator()(mod);
}
/*!
* \brief Transform mod using a functor under a given pass context.
*
* \param mod The module that an optimization pass runs on.
* \param pass_ctx The pass context that can provide information for the optimization.
*
* \return The transformed module.
*/
IRModule operator()(const IRModule& mod,
const PassContext& pass_ctx) const {
const PassNode* node = operator->();
CHECK(node != nullptr);
return node->operator()(mod, pass_ctx);
}
TVM_DEFINE_OBJECT_REF_METHODS(Pass, ObjectRef, PassNode);
};
class SequentialNode;
class Sequential : public Pass {
public:
/*!
* \brief The constructor of `Sequential`.
*
* \param passes The passes to apply.
* \param pass_info The pass metadata.
*/
TVM_DLL Sequential(tvm::Array<Pass> passes, PassInfo pass_info);
/*!
* \brief The constructor of `Sequential`.
*
* \param passes The passes to apply.
* \param name The name of a sequential pass. It's defaulted to "sequential".
* This allows users to only provide a list of passes and execute them
* under a given context.
*/
TVM_DLL Sequential(tvm::Array<Pass> passes, std::string name = "sequential");
Sequential() = default;
explicit Sequential(tvm::ObjectPtr<::tvm::Object> n) : Pass(n) {}
const SequentialNode* operator->() const;
using ContainerType = Sequential;
};
/*
* \brief Create a module pass.
*
* \param pass_func The packed function that contains the optimization.
* \param opt_level The optimization level of the module pass.
* \param name The name of the module pass.
* \param required The list of the passes that the module pass is dependent on.
*
* \return The created module pass.
*/
Pass CreateModulePass(
const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func,
int opt_level,
const std::string& name,
const tvm::Array<tvm::PrimExpr>& required);
} // namespace transform
} // namespace tvm
#endif // TVM_IR_TRANSFORM_H_
......@@ -19,320 +19,31 @@
/*!
* \file tvm/relay/transform.h
*
* This file implements a pass manager. The pass manager manages a sequence
* of Relay-to-Relay transformation passes over a particlar unit of AST. The
* design is largely inspired from LLVM's pass manager and modern deep learning
* frameworks that perform tensor->tensor transformations.
*
* The responsibilities of a traditional compiler pass manager usually involves:
* - Organizing the execution order of optimization passes though not
* necessarily in the optimal sequence.
* - Collecting required analysis information and keep them up-to-date.
* - Reducing the effort required to implement new passes for compiler
* developers, etc.
*
* Similar to LLVM's pass manager, we designed the Relay pass manager to work
* different granularity, i.e. module level, function level, and even sequential
* passe that contains a host of passes.
*
* However, we also extend the functionality of the traditional pass manager
* with the consideration of requirements/convention from deep learning
* frameworks, such as Pytorch and Gluon, etc. Each pass in the Relay pass
* manager performs the Relay.Module -> Relay.Module transformation. All
* different types of passes, including the sequential-level pass object, are
* essentially pass objects. This design, therefore, effectively provides users
* a consistent and convenient interface, i.e. Pass, to play with. It offers a
* means to ease the development and testing of Relay passes. For example, with
* the pass manager, external users will be able to have custom passes correctly
* scheduled without having to modify a single handcrafted pass order.
*
* In the future we need to describe constraints between passes. For example,
* we may want to preserve dependencies between different passes and validate
* them on the completion of a certain pass.
*
* We also need to store side information and import the error reporting system.
* \brief Relay specific transformation passes.
*/
#ifndef TVM_RELAY_TRANSFORM_H_
#define TVM_RELAY_TRANSFORM_H_
#include <tvm/base.h>
#include <tvm/packed_func_ext.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/ir/error.h>
#include <tvm/ir/transform.h>
#include <tvm/relay/expr.h>
#include <tvm/ir/module.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/op.h>
#include <string>
#include <unordered_map>
#include <vector>
namespace tvm {
namespace relay {
namespace transform {
/*
* \brief The context of pass.
*/
class PassContext;
/*!
* \brief PassContextNode contains the information that a pass can rely on,
* such as analysis results.
*/
class PassContextNode : public RelayNode {
public:
/*!
* \brief The error reporter used to notify users why an optimization fails.
*/
ErrorReporter err_reporter;
/*! \brief The default optimization level. */
int opt_level{2};
/*! \brief CPU is the default fallback device for heterogeneous execution. */
int fallback_device{static_cast<int>(kDLCPU)};
/*! \brief The list of required passes. */
tvm::Array<tvm::PrimExpr> required_pass;
/*! \brief The list of disabled passes. */
tvm::Array<tvm::PrimExpr> disabled_pass;
PassContextNode() = default;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("opt_level", &opt_level);
v->Visit("fallback_device", &fallback_device);
v->Visit("required_pass", &required_pass);
v->Visit("disabled_pass", &disabled_pass);
}
static constexpr const char* _type_key = "relay.PassContext";
TVM_DECLARE_FINAL_OBJECT_INFO(PassContextNode, RelayNode);
};
/*!
* \brief PassContext that is used to configure the pass behavior.
*
* \code
*
* auto new_ctx = PassContext::Create();
* ctx->opt_level = 2;
* ctx->fallback_device = kDLCPU;
* With<PassContext> scope(ctx);
* // pass context in effect.
*
* \endcode
*/
class PassContext : public ObjectRef {
public:
PassContext() {}
explicit PassContext(ObjectPtr<::tvm::Object> n) : ObjectRef(n) {}
/*!
* \brief const accessor.
* \return const access pointer.
*/
const PassContextNode* operator->() const {
CHECK(get() != nullptr);
return static_cast<const PassContextNode*>(get());
}
/*!
* \brief mutable accessor.
* \return mutable access pointer.
*/
PassContextNode* operator->() {
CHECK(get() != nullptr);
return static_cast<PassContextNode*>(get_mutable());
}
/*!
* \brief Construct a PassContext containing the default configurations.
* \return The new PassContext.
*/
TVM_DLL static PassContext Create();
/*!
* \brief Get the default pass context in the current scope.
* \return The pass context.
*/
TVM_DLL static PassContext Current();
// accessor.
using ContainerType = PassContextNode;
class Internal;
private:
// The entry of a pass context scope.
TVM_DLL void EnterWithScope();
// The exit of a pass context scope.
TVM_DLL void ExitWithScope();
// Classes to get the Python `with` like syntax.
friend class Internal;
friend class tvm::With<PassContext>;
};
/*
* \brief The meta data of a pass.
*
* PassInfo can be extended conveniently in the future if more meta information
* is needed.
*/
class PassInfo;
/*!
* \brief PassInfoNode contains meta data that will be used to help optimization
* and analysis.
*/
class PassInfoNode : public RelayNode {
public:
/*! \brief The minimal optimization level that this pass will be enabled. */
int opt_level;
/*! \brief The name of an optimization/analysis pass. */
std::string name;
/*! \brief The passes that are required to perform the current pass. */
tvm::Array<tvm::PrimExpr> required;
PassInfoNode() = default;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("opt_level", &opt_level);
v->Visit("name", &name);
v->Visit("required", &required);
}
TVM_DLL static PassInfo make(int opt_level,
std::string name,
tvm::Array<tvm::PrimExpr> required);
static constexpr const char* _type_key = "relay.PassInfo";
TVM_DECLARE_FINAL_OBJECT_INFO(PassInfoNode, RelayNode);
};
class PassInfo : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(PassInfo, ObjectRef, PassInfoNode);
};
class Pass;
/*!
* \brief PassNode is the base type of differnt types of optimization passes.
* It is designed as a pure class and implemented by different pass subclasses
* at different granularity of Relay nodes.
*/
class PassNode : public RelayNode {
public:
virtual ~PassNode() {}
/*!
* \brief Get the pass information/meta data. */
virtual PassInfo Info() const = 0;
/*!
* \brief Transform mod using the default PassContext in the current scope.
*
* \param mod The module that an optimization pass runs on.
*
* \return The transformed module.
*/
IRModule operator()(const IRModule& mod) const {
return this->operator()(mod, PassContext::Current());
}
/*!
* \brief Transform mod using a functor under a given pass context.
*
* \param mod The module that an optimization pass runs on.
* \param pass_ctx The pass context that can provide information for the optimization.
*
* \return The transformed module.
*/
virtual IRModule operator()(const IRModule& mod,
const PassContext& pass_ctx) const = 0;
void VisitAttrs(tvm::AttrVisitor* v) {}
static constexpr const char* _type_key = "relay.Pass";
TVM_DECLARE_BASE_OBJECT_INFO(PassNode, RelayNode);
};
class Pass : public ObjectRef {
public:
/*!
* \brief Transform mod using the default PassContext in the current scope.
*
* \param mod The module that an optimization pass runs on.
*
* \return The transformed module.
*/
IRModule operator()(const IRModule& mod) const {
const PassNode* node = operator->();
CHECK(node != nullptr);
return node->operator()(mod);
}
/*!
* \brief Transform mod using a functor under a given pass context.
*
* \param mod The module that an optimization pass runs on.
* \param pass_ctx The pass context that can provide information for the optimization.
*
* \return The transformed module.
*/
IRModule operator()(const IRModule& mod,
const PassContext& pass_ctx) const {
const PassNode* node = operator->();
CHECK(node != nullptr);
return node->operator()(mod, pass_ctx);
}
TVM_DEFINE_OBJECT_REF_METHODS(Pass, ObjectRef, PassNode);
};
class SequentialNode;
class Sequential : public Pass {
public:
/*!
* \brief The constructor of `Sequential`.
*
* \param passes The passes to apply.
* \param pass_info The pass metadata.
*/
TVM_DLL Sequential(tvm::Array<Pass> passes, PassInfo pass_info);
/*!
* \brief The constructor of `Sequential`.
*
* \param passes The passes to apply.
* \param name The name of a sequential pass. It's defaulted to "sequential".
* This allows users to only provide a list of passes and execute them
* under a given context.
*/
TVM_DLL Sequential(tvm::Array<Pass> passes, std::string name = "sequential");
Sequential() = default;
explicit Sequential(tvm::ObjectPtr<::tvm::Object> n) : Pass(n) {}
const SequentialNode* operator->() const;
using ContainerType = Sequential;
};
/*
* \brief Create a module pass.
*
* \param pass_func The packed function that contains the optimization.
* \param opt_level The optimization level of the module pass.
* \param name The name of the module pass.
* \param required The list of the passes that the module pass is dependent on.
*
* \return The created module pass.
*/
Pass CreateModulePass(
const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func,
int opt_level,
const std::string& name,
const tvm::Array<tvm::PrimExpr>& required);
using Pass = tvm::transform::Pass;
using PassNode = tvm::transform::PassNode;
using PassInfo = tvm::transform::PassInfo;
using PassInfoNode = tvm::transform::PassInfoNode;
using PassContext = tvm::transform::PassContext;
using PassContextNode = tvm::transform::PassContextNode;
using Sequential = tvm::transform::Sequential;
/*
* \brief Create a function pass.
......
......@@ -18,48 +18,52 @@
*/
/*!
* \file src/relay/pass/pass_manager.cc
* \brief Relay pass manager implementation.
* \file src/ir/transform.cc
* \brief Infrastructure for transformation passes.
*/
#include <dmlc/thread_local.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/device_api.h>
#include <tvm/node/printer.h>
#include <tvm/ir/transform.h>
// TODO(tqchen): Update to use String container after it is merged.
#include <tvm/ir.h>
#include <algorithm>
#include <stack>
#include <unordered_set>
namespace tvm {
namespace relay {
namespace transform {
using tvm::runtime::TVMArgs;
using tvm::runtime::TVMRetValue;
using tvm::NodePrinter;
struct RelayPassContextThreadLocalEntry {
struct PassContextThreadLocalEntry {
/*! \brief The default pass context. */
PassContext default_context;
/*! \brief The current pass context. */
std::stack<PassContext> context_stack;
RelayPassContextThreadLocalEntry() {
PassContextThreadLocalEntry() {
default_context = PassContext(make_object<PassContextNode>());
}
};
/*! \brief Thread local store to hold the pass context. */
typedef dmlc::ThreadLocalStore<RelayPassContextThreadLocalEntry>
typedef dmlc::ThreadLocalStore<PassContextThreadLocalEntry>
RelayPassContextThreadLocalStore;
void PassContext::EnterWithScope() {
RelayPassContextThreadLocalEntry* entry =
PassContextThreadLocalEntry* entry =
RelayPassContextThreadLocalStore::Get();
entry->context_stack.push(*this);
}
void PassContext::ExitWithScope() {
RelayPassContextThreadLocalEntry* entry =
PassContextThreadLocalEntry* entry =
RelayPassContextThreadLocalStore::Get();
CHECK(!entry->context_stack.empty());
CHECK(entry->context_stack.top().same_as(*this));
......@@ -67,7 +71,7 @@ void PassContext::ExitWithScope() {
}
PassContext PassContext::Current() {
RelayPassContextThreadLocalEntry* entry =
PassContextThreadLocalEntry* entry =
RelayPassContextThreadLocalStore::Get();
if (!entry->context_stack.empty()) {
return entry->context_stack.top();
......@@ -121,84 +125,16 @@ class ModulePassNode : public PassNode {
*/
PassInfo Info() const override { return pass_info; }
TVM_DLL static ModulePass make(
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func,
PassInfo pass_info);
static constexpr const char* _type_key = "relay.ModulePass";
TVM_DECLARE_FINAL_OBJECT_INFO(ModulePassNode, PassNode);
};
class ModulePass : public Pass {
public:
TVM_DEFINE_OBJECT_REF_METHODS(ModulePass, Pass, ModulePassNode);
};
class FunctionPass;
/*!
* \brief Function-level passes are used to implement various global
* optimizations for a given Relay module. It fetches one function at a time
* from the function list in the module for optimization.
*
* Note that the scope of passes at this level is a Relay function. Therefore,
* we cannot add or delete a function through these passes as they are not aware
* of the global information.
*/
class FunctionPassNode : public PassNode {
public:
/* \brief The pass meta data.*/
PassInfo pass_info;
/*! \brief The packed pass function sketches the real optimization. For
* instance, we can implement a pass that works on a Relay function as a
* `pass_func` and let it run on a given module. The same `pass_func` will
* then be applied on each function in the module.
*/
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func;
FunctionPassNode() = default;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("pass_info", &pass_info);
}
/*!
* \brief Run a function pass on given pass context.
*
* \param mod The module that an optimization pass is applied on.
* \param mod The context that an optimization pass executes on.
*
* \return Return the updated module.
*/
IRModule operator()(const IRModule& mod, const PassContext& pass_ctx) const final;
/*!
* \brief Get the pass information/meta data.
*/
PassInfo Info() const override { return pass_info; }
TVM_DLL static FunctionPass make(
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func,
ModulePass(runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func,
PassInfo pass_info);
static constexpr const char* _type_key = "relay.FunctionPass";
TVM_DECLARE_FINAL_OBJECT_INFO(FunctionPassNode, PassNode);
private:
/*
* \brief Check if a function should be skipped for optimization.
*
* \param func The target function to be checked.
*
* \return Return true if the function will be skipped, otherwise false.
*/
bool SkipFunction(const Function& func) const;
};
class FunctionPass : public Pass {
public:
TVM_DEFINE_OBJECT_REF_METHODS(FunctionPass, Pass, FunctionPassNode);
TVM_DEFINE_OBJECT_REF_METHODS(ModulePass, Pass, ModulePassNode);
};
/*!
......@@ -267,23 +203,23 @@ class SequentialNode : public PassNode {
TVM_DECLARE_FINAL_OBJECT_INFO(SequentialNode, PassNode);
};
PassInfo PassInfoNode::make(int opt_level,
PassInfo::PassInfo(int opt_level,
std::string name,
tvm::Array<tvm::PrimExpr> required) {
auto pass_info = make_object<PassInfoNode>();
pass_info->opt_level = opt_level;
pass_info->name = std::move(name);
pass_info->required = std::move(required);
return PassInfo(pass_info);
data_ = std::move(pass_info);
}
ModulePass ModulePassNode::make(
ModulePass::ModulePass(
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func,
PassInfo pass_info) {
auto n = make_object<ModulePassNode>();
n->pass_func = std::move(pass_func);
n->pass_info = std::move(pass_info);
return ModulePass(n);
data_ = std::move(n);
}
// Module -> Module optimizations.
......@@ -300,51 +236,6 @@ IRModule ModulePassNode::operator()(const IRModule& mod,
return updated_mod;
}
FunctionPass FunctionPassNode::make(
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func,
PassInfo pass_info) {
auto n = make_object<FunctionPassNode>();
n->pass_func = std::move(pass_func);
n->pass_info = std::move(pass_info);
return FunctionPass(n);
}
// Perform Module -> Module optimizations at the Function level.
IRModule FunctionPassNode::operator()(const IRModule& mod,
const PassContext& pass_ctx) const {
const PassInfo& pass_info = Info();
CHECK(mod.defined());
DLOG(INFO) << "Executing function pass : "
<< pass_info->name
<< " with opt level: "
<< pass_info->opt_level;
// Execute the pass function and return a new module.
IRModule updated_mod = IRModule(mod->functions, mod->type_definitions, mod->Imports());
std::vector<std::pair<GlobalVar, Function> > updates;
for (const auto& it : updated_mod->functions) {
// only picks up relay::Function
if (auto* n = it.second.as<FunctionNode>()) {
Function func = GetRef<Function>(n);
auto updated_func = SkipFunction(func)
? func
: pass_func(func, updated_mod, pass_ctx);
updates.push_back({it.first, updated_func});
}
}
for (const auto& pair : updates) {
updated_mod->Add(pair.first, pair.second, true);
}
return updated_mod;
}
bool FunctionPassNode::SkipFunction(const Function& func) const {
ObjectRef skip_opt = FunctionGetAttr(func, attr::kSkipOptimization);
const ir::IntImmNode* pval = skip_opt.as<ir::IntImmNode>();
return (pval && pval->value != 0) || (!func->UseDefaultCompiler());
}
Sequential::Sequential(tvm::Array<Pass> passes, PassInfo pass_info) {
auto n = make_object<SequentialNode>();
n->passes = std::move(passes);
......@@ -355,7 +246,7 @@ Sequential::Sequential(tvm::Array<Pass> passes, PassInfo pass_info) {
Sequential::Sequential(tvm::Array<Pass> passes, std::string name) {
auto n = make_object<SequentialNode>();
n->passes = std::move(passes);
PassInfo pass_info = PassInfoNode::make(2, std::move(name), {});
PassInfo pass_info = PassInfo(2, std::move(name), {});
n->pass_info = std::move(pass_info);
data_ = std::move(n);
}
......@@ -433,23 +324,16 @@ Pass CreateModulePass(
int opt_level,
const std::string& name,
const tvm::Array<tvm::PrimExpr>& required) {
PassInfo pass_info = PassInfoNode::make(opt_level, name, required);
return ModulePassNode::make(pass_func, pass_info);
}
Pass CreateFunctionPass(
const runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>& pass_func,
int opt_level,
const std::string& name,
const tvm::Array<tvm::PrimExpr>& required) {
PassInfo pass_info = PassInfoNode::make(opt_level, name, required);
return FunctionPassNode::make(pass_func, pass_info);
PassInfo pass_info = PassInfo(opt_level, name, required);
return ModulePass(pass_func, pass_info);
}
TVM_REGISTER_NODE_TYPE(PassInfoNode);
TVM_REGISTER_GLOBAL("relay._transform.PassInfo")
.set_body_typed(PassInfoNode::make);
.set_body_typed([](int opt_level, std::string name, tvm::Array<PrimExpr> required) {
return PassInfo(opt_level, name, required);
});
TVM_REGISTER_GLOBAL("relay._transform.Info")
.set_body([](TVMArgs args, TVMRetValue* ret) {
......@@ -474,7 +358,11 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
TVM_REGISTER_NODE_TYPE(ModulePassNode);
TVM_REGISTER_GLOBAL("relay._transform.MakeModulePass")
.set_body_typed(ModulePassNode::make);
.set_body_typed(
[](runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func,
PassInfo pass_info) {
return ModulePass(pass_func, pass_info);
});
TVM_REGISTER_GLOBAL("relay._transform.RunPass")
.set_body([](TVMArgs args, TVMRetValue* ret) {
......@@ -491,19 +379,6 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
<< " at the optimization level " << info->opt_level;
});
TVM_REGISTER_NODE_TYPE(FunctionPassNode);
TVM_REGISTER_GLOBAL("relay._transform.MakeFunctionPass")
.set_body_typed(FunctionPassNode::make);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<FunctionPassNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const FunctionPassNode*>(ref.get());
const PassInfo info = node->Info();
p->stream << "Run Function pass: " << info->name
<< " at the optimization level " << info->opt_level;
});
TVM_REGISTER_NODE_TYPE(SequentialNode);
TVM_REGISTER_GLOBAL("relay._transform.Sequential")
......@@ -512,7 +387,7 @@ TVM_REGISTER_GLOBAL("relay._transform.Sequential")
int opt_level = args[1];
std::string name = args[2];
tvm::Array<tvm::PrimExpr> required = args[3];
PassInfo pass_info = PassInfoNode::make(opt_level, name, required);
PassInfo pass_info = PassInfo(opt_level, name, required);
*ret = Sequential(passes, pass_info);
});
......@@ -589,5 +464,4 @@ TVM_REGISTER_GLOBAL("relay._transform.ExitPassContext")
.set_body_typed(PassContext::Internal::ExitScope);
} // namespace transform
} // namespace relay
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file relay/ir/transform.cc
* \brief Relay specific transformation passes.
*/
#include <dmlc/thread_local.h>
#include <tvm/runtime/registry.h>
#include <tvm/node/printer.h>
#include <tvm/relay/transform.h>
namespace tvm {
namespace relay {
namespace transform {
class FunctionPass;
/*!
* \brief Function-level passes are used to implement various global
* optimizations for a given Relay module. It fetches one function at a time
* from the function list in the module for optimization.
*
* Note that the scope of passes at this level is a Relay function. Therefore,
* we cannot add or delete a function through these passes as they are not aware
* of the global information.
*/
class FunctionPassNode : public PassNode {
public:
/* \brief The pass meta data.*/
PassInfo pass_info;
/*! \brief The packed pass function sketches the real optimization. For
* instance, we can implement a pass that works on a Relay function as a
* `pass_func` and let it run on a given module. The same `pass_func` will
* then be applied on each function in the module.
*/
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func;
FunctionPassNode() = default;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("pass_info", &pass_info);
}
/*!
* \brief Run a function pass on given pass context.
*
* \param mod The module that an optimization pass is applied on.
* \param mod The context that an optimization pass executes on.
*
* \return Return the updated module.
*/
IRModule operator()(const IRModule& mod, const PassContext& pass_ctx) const final;
/*!
* \brief Get the pass information/meta data.
*/
PassInfo Info() const override { return pass_info; }
TVM_DLL static FunctionPass make(
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func,
PassInfo pass_info);
static constexpr const char* _type_key = "relay.FunctionPass";
TVM_DECLARE_FINAL_OBJECT_INFO(FunctionPassNode, PassNode);
private:
/*
* \brief Check if a function should be skipped for optimization.
*
* \param func The target function to be checked.
*
* \return Return true if the function will be skipped, otherwise false.
*/
bool SkipFunction(const Function& func) const;
};
class FunctionPass : public Pass {
public:
TVM_DEFINE_OBJECT_REF_METHODS(FunctionPass, Pass, FunctionPassNode);
};
FunctionPass FunctionPassNode::make(
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func,
PassInfo pass_info) {
auto n = make_object<FunctionPassNode>();
n->pass_func = std::move(pass_func);
n->pass_info = std::move(pass_info);
return FunctionPass(n);
}
// Perform Module -> Module optimizations at the Function level.
IRModule FunctionPassNode::operator()(const IRModule& mod,
const PassContext& pass_ctx) const {
const PassInfo& pass_info = Info();
CHECK(mod.defined());
DLOG(INFO) << "Executing function pass : "
<< pass_info->name
<< " with opt level: "
<< pass_info->opt_level;
// Execute the pass function and return a new module.
IRModule updated_mod = IRModule(mod->functions, mod->type_definitions, mod->Imports());
std::vector<std::pair<GlobalVar, Function> > updates;
for (const auto& it : updated_mod->functions) {
// only picks up relay::Function
if (auto* n = it.second.as<FunctionNode>()) {
Function func = GetRef<Function>(n);
auto updated_func = SkipFunction(func)
? func
: pass_func(func, updated_mod, pass_ctx);
updates.push_back({it.first, updated_func});
}
}
for (const auto& pair : updates) {
updated_mod->Add(pair.first, pair.second, true);
}
return updated_mod;
}
bool FunctionPassNode::SkipFunction(const Function& func) const {
ObjectRef skip_opt = FunctionGetAttr(func, attr::kSkipOptimization);
const ir::IntImmNode* pval = skip_opt.as<ir::IntImmNode>();
return (pval && pval->value != 0) || (!func->UseDefaultCompiler());
}
Pass CreateFunctionPass(
const runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>& pass_func,
int opt_level,
const std::string& name,
const tvm::Array<tvm::PrimExpr>& required) {
PassInfo pass_info = PassInfo(opt_level, name, required);
return FunctionPassNode::make(pass_func, pass_info);
}
TVM_REGISTER_NODE_TYPE(FunctionPassNode);
TVM_REGISTER_GLOBAL("relay._transform.MakeFunctionPass")
.set_body_typed(FunctionPassNode::make);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<FunctionPassNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const FunctionPassNode*>(ref.get());
const PassInfo info = node->Info();
p->stream << "Run Function pass: " << info->name
<< " at the optimization level " << info->opt_level;
});
} // namespace transform
} // namespace relay
} // namespace tvm
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment