Unverified Commit 953ca1f6 by Tianqi Chen Committed by GitHub

[C++] Cleanup transform API nits (#3253)

parent a8275bdb
......@@ -76,8 +76,8 @@ namespace transform {
class PassContext;
/*!
* \brief PassContextNode contains the information that a pass can rely on, such as
* analysis results.
* \brief PassContextNode contains the information that a pass can rely on,
* such as analysis results.
*/
class PassContextNode : public RelayNode {
public:
......@@ -110,32 +110,51 @@ class PassContextNode : public RelayNode {
TVM_DECLARE_NODE_TYPE_INFO(PassContextNode, RelayNode);
};
/*!
* \brief PassContext that is used to configure the pass behavior.
*
* \code
*
* auto new_ctx = PassContext::Create();
* ctx->opt_level = 2;
* ctx->fallback_device = kDLCPU;
* With<PassContext> scope(ctx);
* // pass context in effect.
*
* \endcode
*/
class PassContext : public NodeRef {
public:
PassContext() {}
explicit PassContext(tvm::NodePtr<Node> n) : NodeRef(n) {}
/*
* \brief Constructor of a `PassContext` object.
*
* \param opt_level The optimization level that will be applied.
* \param fallback_device The fallback device used for heterogeneous
* execution.
* \param required_pass The passes that are required for a context to execute
* other passes.
* \param required_pass The passes that will be disabled during the
* optimization under a context.
explicit PassContext(NodePtr<::tvm::Node> n) : NodeRef(n) {}
/*!
* \brief const accessor.
* \return const access pointer.
*/
const PassContextNode* operator->() const {
CHECK(node_.get() != nullptr);
return static_cast<const PassContextNode*>(node_.get());
}
/*!
* \brief mutable accessor.
* \return mutable access pointer.
*/
PassContextNode* operator->() {
CHECK(node_.get() != nullptr);
return static_cast<PassContextNode*>(node_.get());
}
/*!
* \brief Construct a PassContext containing the default configurations.
* \return The new PassContext.
*/
TVM_DLL static PassContext Create();
/*!
* \brief Get the default pass context in the current scope.
* \return The pass context.
*/
TVM_DLL PassContext(int opt_level,
int fallback_device,
tvm::Array<tvm::Expr> required_pass,
tvm::Array<tvm::Expr> disabled_pass);
// Get the currently used pass context.
TVM_DLL static PassContext Current();
const PassContextNode* operator->() const;
// accessor.
using ContainerType = PassContextNode;
class Internal;
......@@ -204,25 +223,23 @@ class PassNode : public RelayNode {
virtual PassInfo Info() const = 0;
/*!
* \brief Execute the optimization pass using a functor. This functor
* internally uses a current pass context.
* \brief Transform mod using the default PassContext in the current scope.
*
* \param mod The module that an optimization pass runs on.
*
* \return The updated module.
* \return The transformed module.
*/
Module operator()(const Module& mod) const {
return this->operator()(mod, PassContext::Current());
}
/*!
* \brief Execute the optimization pass using a functor under a given pass context.
* \brief Transform mod using a functor under a given pass context.
*
* \param mod The module that an optimization pass runs on.
* \param pass_ctx The pass context that will be used to help the execution of
* optimizations.
* \param pass_ctx The pass context that can provide information for the optimization.
*
* \return The updated module.
* \return The transformed module.
*/
virtual Module operator()(const Module& mod,
const PassContext& pass_ctx) const = 0;
......@@ -235,14 +252,34 @@ class PassNode : public RelayNode {
class Pass : public NodeRef {
public:
Pass() = default;
explicit Pass(NodePtr<tvm::Node> p) : NodeRef(p) {}
PassNode* operator->() const {
return static_cast<PassNode*>(this->node_.get());
/*!
* \brief Transform mod using the default PassContext in the current scope.
*
* \param mod The module that an optimization pass runs on.
*
* \return The transformed module.
*/
Module operator()(const Module& mod) const {
const PassNode* node = operator->();
CHECK(node != nullptr);
return node->operator()(mod);
}
/*!
* \brief Transform mod using a functor under a given pass context.
*
* \param mod The module that an optimization pass runs on.
* \param pass_ctx The pass context that can provide information for the optimization.
*
* \return The transformed module.
*/
Module operator()(const Module& mod,
const PassContext& pass_ctx) const {
const PassNode* node = operator->();
CHECK(node != nullptr);
return node->operator()(mod, pass_ctx);
}
using ContainerType = PassNode;
TVM_DEFINE_NODE_REF_METHODS(Pass, NodeRef, PassNode);
};
class SequentialNode;
......
......@@ -74,21 +74,6 @@ class OptPassLevel {
}
};
PassContext::PassContext(int opt_level, int fallback_device,
tvm::Array<tvm::Expr> required_pass,
tvm::Array<tvm::Expr> disabled_pass) {
auto ctx = make_node<PassContextNode>();
ctx->opt_level = opt_level;
ctx->fallback_device = fallback_device;
ctx->required_pass = std::move(required_pass);
ctx->disabled_pass = std::move(disabled_pass);
node_ = std::move(ctx);
}
const PassContextNode* PassContext::operator->() const {
return static_cast<const PassContextNode*>(node_.get());
}
struct RelayPassContextThreadLocalEntry {
/*! \brief The default pass context. */
PassContext default_context;
......@@ -129,6 +114,10 @@ PassContext PassContext::Current() {
}
}
PassContext PassContext::Create() {
return PassContext(make_node<PassContextNode>());
}
class ModulePass;
/*!
......@@ -291,7 +280,7 @@ class SequentialNode : public PassNode {
*
* \return true if the pass is enabled. Otherwise, false.
*/
bool pass_enabled(const std::string& pass_name) const;
bool PassEnabled(const std::string& pass_name) const;
/*!
* \brief Resolve the pass dependency. It globs all required passes by
......@@ -353,9 +342,8 @@ ModulePass ModulePassNode::make(
Module ModulePassNode::operator()(const Module& mod,
const PassContext& pass_ctx) const {
PassInfo pass_info = Info();
LOG(INFO) << "Executing module pass : " << pass_info.operator->()->name
<< " with opt level: " << pass_info.operator->()->opt_level << "\n";
DLOG(INFO) << "Executing module pass : " << pass_info->name
<< " with opt level: " << pass_info->opt_level << "\n";
CHECK(mod.defined());
auto updated_mod = pass_func(mod, pass_ctx);
CHECK(updated_mod.defined());
......@@ -376,11 +364,10 @@ FunctionPass FunctionPassNode::make(
Module FunctionPassNode::operator()(const Module& mod,
const PassContext& pass_ctx) const {
PassInfo pass_info = Info();
LOG(INFO) << "Executing function pass : " << pass_info.operator->()->name
<< " with opt level: " << pass_info.operator->()->opt_level << "\n";
CHECK(mod.defined());
Module new_mod = ModuleNode::make({}, mod->type_definitions);
DLOG(INFO) << "Executing module pass : " << pass_info->name
<< " with opt level: " << pass_info->opt_level << "\n";
// Execute the pass function and return a new module.
for (const auto& it : mod->functions) {
auto updated_func = SkipFunction(it.second) ? it.second : pass_func(it.second, mod, pass_ctx);
......@@ -448,12 +435,11 @@ std::unordered_set<std::string> SequentialNode::RequiredPasses(
return ret;
}
bool SequentialNode::pass_enabled(const std::string& pass_name) const {
bool SequentialNode::PassEnabled(const std::string& pass_name) const {
PassContext ctx = PassContext::Current();
const PassContextNode* ctx_node = ctx.operator->();
auto required = RequiredPasses(ctx_node->required_pass);
auto disabled = DisabledPasses(ctx_node->required_pass);
auto required = RequiredPasses(ctx->required_pass);
auto disabled = DisabledPasses(ctx->required_pass);
if (disabled.count(pass_name)) {
return false;
......@@ -462,7 +448,7 @@ bool SequentialNode::pass_enabled(const std::string& pass_name) const {
if (required.count(pass_name)) {
return true;
}
return ctx_node->opt_level >= opt_pass_level[pass_name];
return ctx->opt_level >= opt_pass_level[pass_name];
}
// TODO(zhiics): we currenlty only sequentially execute each pass in
......@@ -470,15 +456,14 @@ bool SequentialNode::pass_enabled(const std::string& pass_name) const {
// ordering problem needed to be handled in the future.
Module SequentialNode::operator()(const Module& module,
const PassContext& pass_ctx) const {
const auto* ctx_node = pass_ctx.operator->();
int opt_level = ctx_node->opt_level;
auto disabled = DisabledPasses(ctx_node->disabled_pass);
int opt_level = pass_ctx->opt_level;
auto disabled = DisabledPasses(pass_ctx->disabled_pass);
Module mod = module;
for (const Pass& pass : passes) {
CHECK(pass.defined()) << "Found undefined pass for optimization.";
PassInfo info = pass->Info();
const auto& pass_name = info.operator->()->name;
const auto& pass_opt_level = info.operator->()->opt_level;
const auto& pass_name = info->name;
const auto& pass_opt_level = info->opt_level;
// Skip the pass if its optimization level is higher that the one of in the
// pass context or if this pass is disabled.
if (pass_opt_level > opt_level || disabled.count(pass_name)) {
......@@ -540,14 +525,7 @@ TVM_REGISTER_API("relay._transform.CreateModulePass")
TVM_REGISTER_API("relay._transform.RunPass")
.set_body([](TVMArgs args, TVMRetValue* ret) {
Pass pass = args[0];
Module mod = args[1];
CHECK(pass.defined())
<< "Running an undefined pass is not allowed."
<< "\n";
const auto* pn = pass.operator->();
*ret = (*pn)(mod);
*ret = args[0].operator Pass()(args[1]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
......@@ -602,11 +580,16 @@ TVM_REGISTER_NODE_TYPE(PassContextNode);
TVM_REGISTER_API("relay._transform.PassContext")
.set_body([](TVMArgs args, TVMRetValue* ret) {
auto pctx = PassContext::Create();
int opt_level = args[0];
int fallback_device = args[1];
tvm::Array<tvm::Expr> required = args[2];
tvm::Array<tvm::Expr> disabled = args[3];
*ret = PassContext(opt_level, fallback_device, required, disabled);
pctx->opt_level = opt_level;
pctx->fallback_device = fallback_device;
pctx->required_pass = std::move(required);
pctx->disabled_pass = std::move(disabled);
*ret = pctx;
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment