Unverified Commit 953ca1f6 by Tianqi Chen Committed by GitHub

[C++] Cleanup transform API nits (#3253)

parent a8275bdb
...@@ -76,8 +76,8 @@ namespace transform { ...@@ -76,8 +76,8 @@ namespace transform {
class PassContext; class PassContext;
/*! /*!
* \brief PassContextNode contains the information that a pass can rely on, such as * \brief PassContextNode contains the information that a pass can rely on,
* analysis results. * such as analysis results.
*/ */
class PassContextNode : public RelayNode { class PassContextNode : public RelayNode {
public: public:
...@@ -110,32 +110,51 @@ class PassContextNode : public RelayNode { ...@@ -110,32 +110,51 @@ class PassContextNode : public RelayNode {
TVM_DECLARE_NODE_TYPE_INFO(PassContextNode, RelayNode); TVM_DECLARE_NODE_TYPE_INFO(PassContextNode, RelayNode);
}; };
/*!
* \brief PassContext that is used to configure the pass behavior.
*
* \code
*
* auto new_ctx = PassContext::Create();
* ctx->opt_level = 2;
* ctx->fallback_device = kDLCPU;
* With<PassContext> scope(ctx);
* // pass context in effect.
*
* \endcode
*/
class PassContext : public NodeRef { class PassContext : public NodeRef {
public: public:
PassContext() {} PassContext() {}
explicit PassContext(tvm::NodePtr<Node> n) : NodeRef(n) {} explicit PassContext(NodePtr<::tvm::Node> n) : NodeRef(n) {}
/*!
/* * \brief const accessor.
* \brief Constructor of a `PassContext` object. * \return const access pointer.
* */
* \param opt_level The optimization level that will be applied. const PassContextNode* operator->() const {
* \param fallback_device The fallback device used for heterogeneous CHECK(node_.get() != nullptr);
* execution. return static_cast<const PassContextNode*>(node_.get());
* \param required_pass The passes that are required for a context to execute }
* other passes. /*!
* \param required_pass The passes that will be disabled during the * \brief mutable accessor.
* optimization under a context. * \return mutable access pointer.
*/
PassContextNode* operator->() {
CHECK(node_.get() != nullptr);
return static_cast<PassContextNode*>(node_.get());
}
/*!
* \brief Construct a PassContext containing the default configurations.
* \return The new PassContext.
*/
TVM_DLL static PassContext Create();
/*!
* \brief Get the default pass context in the current scope.
* \return The pass context.
*/ */
TVM_DLL PassContext(int opt_level,
int fallback_device,
tvm::Array<tvm::Expr> required_pass,
tvm::Array<tvm::Expr> disabled_pass);
// Get the currently used pass context.
TVM_DLL static PassContext Current(); TVM_DLL static PassContext Current();
const PassContextNode* operator->() const; // accessor.
using ContainerType = PassContextNode; using ContainerType = PassContextNode;
class Internal; class Internal;
...@@ -204,25 +223,23 @@ class PassNode : public RelayNode { ...@@ -204,25 +223,23 @@ class PassNode : public RelayNode {
virtual PassInfo Info() const = 0; virtual PassInfo Info() const = 0;
/*! /*!
* \brief Execute the optimization pass using a functor. This functor * \brief Transform mod using the default PassContext in the current scope.
* internally uses a current pass context.
* *
* \param mod The module that an optimization pass runs on. * \param mod The module that an optimization pass runs on.
* *
* \return The updated module. * \return The transformed module.
*/ */
Module operator()(const Module& mod) const { Module operator()(const Module& mod) const {
return this->operator()(mod, PassContext::Current()); return this->operator()(mod, PassContext::Current());
} }
/*! /*!
* \brief Execute the optimization pass using a functor under a given pass context. * \brief Transform mod using a functor under a given pass context.
* *
* \param mod The module that an optimization pass runs on. * \param mod The module that an optimization pass runs on.
* \param pass_ctx The pass context that will be used to help the execution of * \param pass_ctx The pass context that can provide information for the optimization.
* optimizations.
* *
* \return The updated module. * \return The transformed module.
*/ */
virtual Module operator()(const Module& mod, virtual Module operator()(const Module& mod,
const PassContext& pass_ctx) const = 0; const PassContext& pass_ctx) const = 0;
...@@ -235,14 +252,34 @@ class PassNode : public RelayNode { ...@@ -235,14 +252,34 @@ class PassNode : public RelayNode {
class Pass : public NodeRef { class Pass : public NodeRef {
public: public:
Pass() = default; /*!
explicit Pass(NodePtr<tvm::Node> p) : NodeRef(p) {} * \brief Transform mod using the default PassContext in the current scope.
*
PassNode* operator->() const { * \param mod The module that an optimization pass runs on.
return static_cast<PassNode*>(this->node_.get()); *
* \return The transformed module.
*/
Module operator()(const Module& mod) const {
const PassNode* node = operator->();
CHECK(node != nullptr);
return node->operator()(mod);
}
/*!
* \brief Transform mod using a functor under a given pass context.
*
* \param mod The module that an optimization pass runs on.
* \param pass_ctx The pass context that can provide information for the optimization.
*
* \return The transformed module.
*/
Module operator()(const Module& mod,
const PassContext& pass_ctx) const {
const PassNode* node = operator->();
CHECK(node != nullptr);
return node->operator()(mod, pass_ctx);
} }
using ContainerType = PassNode; TVM_DEFINE_NODE_REF_METHODS(Pass, NodeRef, PassNode);
}; };
class SequentialNode; class SequentialNode;
......
...@@ -74,21 +74,6 @@ class OptPassLevel { ...@@ -74,21 +74,6 @@ class OptPassLevel {
} }
}; };
PassContext::PassContext(int opt_level, int fallback_device,
tvm::Array<tvm::Expr> required_pass,
tvm::Array<tvm::Expr> disabled_pass) {
auto ctx = make_node<PassContextNode>();
ctx->opt_level = opt_level;
ctx->fallback_device = fallback_device;
ctx->required_pass = std::move(required_pass);
ctx->disabled_pass = std::move(disabled_pass);
node_ = std::move(ctx);
}
const PassContextNode* PassContext::operator->() const {
return static_cast<const PassContextNode*>(node_.get());
}
struct RelayPassContextThreadLocalEntry { struct RelayPassContextThreadLocalEntry {
/*! \brief The default pass context. */ /*! \brief The default pass context. */
PassContext default_context; PassContext default_context;
...@@ -129,6 +114,10 @@ PassContext PassContext::Current() { ...@@ -129,6 +114,10 @@ PassContext PassContext::Current() {
} }
} }
PassContext PassContext::Create() {
return PassContext(make_node<PassContextNode>());
}
class ModulePass; class ModulePass;
/*! /*!
...@@ -291,7 +280,7 @@ class SequentialNode : public PassNode { ...@@ -291,7 +280,7 @@ class SequentialNode : public PassNode {
* *
* \return true if the pass is enabled. Otherwise, false. * \return true if the pass is enabled. Otherwise, false.
*/ */
bool pass_enabled(const std::string& pass_name) const; bool PassEnabled(const std::string& pass_name) const;
/*! /*!
* \brief Resolve the pass dependency. It globs all required passes by * \brief Resolve the pass dependency. It globs all required passes by
...@@ -353,9 +342,8 @@ ModulePass ModulePassNode::make( ...@@ -353,9 +342,8 @@ ModulePass ModulePassNode::make(
Module ModulePassNode::operator()(const Module& mod, Module ModulePassNode::operator()(const Module& mod,
const PassContext& pass_ctx) const { const PassContext& pass_ctx) const {
PassInfo pass_info = Info(); PassInfo pass_info = Info();
LOG(INFO) << "Executing module pass : " << pass_info.operator->()->name DLOG(INFO) << "Executing module pass : " << pass_info->name
<< " with opt level: " << pass_info.operator->()->opt_level << "\n"; << " with opt level: " << pass_info->opt_level << "\n";
CHECK(mod.defined()); CHECK(mod.defined());
auto updated_mod = pass_func(mod, pass_ctx); auto updated_mod = pass_func(mod, pass_ctx);
CHECK(updated_mod.defined()); CHECK(updated_mod.defined());
...@@ -376,11 +364,10 @@ FunctionPass FunctionPassNode::make( ...@@ -376,11 +364,10 @@ FunctionPass FunctionPassNode::make(
Module FunctionPassNode::operator()(const Module& mod, Module FunctionPassNode::operator()(const Module& mod,
const PassContext& pass_ctx) const { const PassContext& pass_ctx) const {
PassInfo pass_info = Info(); PassInfo pass_info = Info();
LOG(INFO) << "Executing function pass : " << pass_info.operator->()->name
<< " with opt level: " << pass_info.operator->()->opt_level << "\n";
CHECK(mod.defined()); CHECK(mod.defined());
Module new_mod = ModuleNode::make({}, mod->type_definitions); Module new_mod = ModuleNode::make({}, mod->type_definitions);
DLOG(INFO) << "Executing module pass : " << pass_info->name
<< " with opt level: " << pass_info->opt_level << "\n";
// Execute the pass function and return a new module. // Execute the pass function and return a new module.
for (const auto& it : mod->functions) { for (const auto& it : mod->functions) {
auto updated_func = SkipFunction(it.second) ? it.second : pass_func(it.second, mod, pass_ctx); auto updated_func = SkipFunction(it.second) ? it.second : pass_func(it.second, mod, pass_ctx);
...@@ -448,12 +435,11 @@ std::unordered_set<std::string> SequentialNode::RequiredPasses( ...@@ -448,12 +435,11 @@ std::unordered_set<std::string> SequentialNode::RequiredPasses(
return ret; return ret;
} }
bool SequentialNode::pass_enabled(const std::string& pass_name) const { bool SequentialNode::PassEnabled(const std::string& pass_name) const {
PassContext ctx = PassContext::Current(); PassContext ctx = PassContext::Current();
const PassContextNode* ctx_node = ctx.operator->(); auto required = RequiredPasses(ctx->required_pass);
auto required = RequiredPasses(ctx_node->required_pass); auto disabled = DisabledPasses(ctx->required_pass);
auto disabled = DisabledPasses(ctx_node->required_pass);
if (disabled.count(pass_name)) { if (disabled.count(pass_name)) {
return false; return false;
...@@ -462,7 +448,7 @@ bool SequentialNode::pass_enabled(const std::string& pass_name) const { ...@@ -462,7 +448,7 @@ bool SequentialNode::pass_enabled(const std::string& pass_name) const {
if (required.count(pass_name)) { if (required.count(pass_name)) {
return true; return true;
} }
return ctx_node->opt_level >= opt_pass_level[pass_name]; return ctx->opt_level >= opt_pass_level[pass_name];
} }
// TODO(zhiics): we currenlty only sequentially execute each pass in // TODO(zhiics): we currenlty only sequentially execute each pass in
...@@ -470,15 +456,14 @@ bool SequentialNode::pass_enabled(const std::string& pass_name) const { ...@@ -470,15 +456,14 @@ bool SequentialNode::pass_enabled(const std::string& pass_name) const {
// ordering problem needed to be handled in the future. // ordering problem needed to be handled in the future.
Module SequentialNode::operator()(const Module& module, Module SequentialNode::operator()(const Module& module,
const PassContext& pass_ctx) const { const PassContext& pass_ctx) const {
const auto* ctx_node = pass_ctx.operator->(); int opt_level = pass_ctx->opt_level;
int opt_level = ctx_node->opt_level; auto disabled = DisabledPasses(pass_ctx->disabled_pass);
auto disabled = DisabledPasses(ctx_node->disabled_pass);
Module mod = module; Module mod = module;
for (const Pass& pass : passes) { for (const Pass& pass : passes) {
CHECK(pass.defined()) << "Found undefined pass for optimization."; CHECK(pass.defined()) << "Found undefined pass for optimization.";
PassInfo info = pass->Info(); PassInfo info = pass->Info();
const auto& pass_name = info.operator->()->name; const auto& pass_name = info->name;
const auto& pass_opt_level = info.operator->()->opt_level; const auto& pass_opt_level = info->opt_level;
// Skip the pass if its optimization level is higher that the one of in the // Skip the pass if its optimization level is higher that the one of in the
// pass context or if this pass is disabled. // pass context or if this pass is disabled.
if (pass_opt_level > opt_level || disabled.count(pass_name)) { if (pass_opt_level > opt_level || disabled.count(pass_name)) {
...@@ -540,14 +525,7 @@ TVM_REGISTER_API("relay._transform.CreateModulePass") ...@@ -540,14 +525,7 @@ TVM_REGISTER_API("relay._transform.CreateModulePass")
TVM_REGISTER_API("relay._transform.RunPass") TVM_REGISTER_API("relay._transform.RunPass")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
Pass pass = args[0]; *ret = args[0].operator Pass()(args[1]);
Module mod = args[1];
CHECK(pass.defined())
<< "Running an undefined pass is not allowed."
<< "\n";
const auto* pn = pass.operator->();
*ret = (*pn)(mod);
}); });
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
...@@ -602,11 +580,16 @@ TVM_REGISTER_NODE_TYPE(PassContextNode); ...@@ -602,11 +580,16 @@ TVM_REGISTER_NODE_TYPE(PassContextNode);
TVM_REGISTER_API("relay._transform.PassContext") TVM_REGISTER_API("relay._transform.PassContext")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
auto pctx = PassContext::Create();
int opt_level = args[0]; int opt_level = args[0];
int fallback_device = args[1]; int fallback_device = args[1];
tvm::Array<tvm::Expr> required = args[2]; tvm::Array<tvm::Expr> required = args[2];
tvm::Array<tvm::Expr> disabled = args[3]; tvm::Array<tvm::Expr> disabled = args[3];
*ret = PassContext(opt_level, fallback_device, required, disabled); pctx->opt_level = opt_level;
pctx->fallback_device = fallback_device;
pctx->required_pass = std::move(required);
pctx->disabled_pass = std::move(disabled);
*ret = pctx;
}); });
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment