Unverified Commit 4cee98ba by Tianqi Chen Committed by GitHub

[PASS][RELAY] polish pass infra (#3319)

parent ca017a38
Subproject commit 3943914eed66470bd010df581e29e4dca4f7df6f Subproject commit fbe142b267a8edd1f1188fa2140d88f7ae308661
...@@ -37,47 +37,6 @@ namespace transform { ...@@ -37,47 +37,6 @@ namespace transform {
using tvm::IRPrinter; using tvm::IRPrinter;
namespace {
// TODO(zhiics) Maybe we can use PackedFunc here so that parameters can be
// handled because we need to register the pass for Python invocation anyway.
Pass GetPass(const std::string& pass_name) {
if (pass_name == "InferType") {
return InferType();
} else if (pass_name == "AlterOpLayout") {
return AlterOpLayout();
} else if (pass_name == "CanonicalizeOps") {
return CanonicalizeOps();
} else if (pass_name == "CombineParallelConv2d") {
return CombineParallelConv2D();
} else if (pass_name == "DeadCodeElimination") {
return DeadCodeElimination();
} else if (pass_name == "EliminateCommonSubexpr") {
return DeadCodeElimination();
} else if (pass_name == "FoldConstant") {
return FoldConstant();
} else if (pass_name == "BackwardFoldScaleAxis") {
return FoldScaleAxis();
} else if (pass_name == "ForwardFoldScaleAxis") {
return FoldScaleAxis();
} else if (pass_name == "FoldScaleAxis") {
return FoldScaleAxis();
} else if (pass_name == "PartialEvaluate") {
return SimplifyInference();
} else if (pass_name == "SimplifyInference") {
return SimplifyInference();
} else if (pass_name == "ToANormalForm") {
return ToANormalForm();
} else if (pass_name == "ToGraphNormalForm") {
return ToGraphNormalForm();
} else {
LOG(FATAL) << pass_name << " has not been registered yet." << "\n";
return Pass(nullptr);
}
}
} // namespace
struct RelayPassContextThreadLocalEntry { struct RelayPassContextThreadLocalEntry {
/*! \brief The default pass context. */ /*! \brief The default pass context. */
PassContext default_context; PassContext default_context;
...@@ -252,6 +211,7 @@ class SequentialNode : public PassNode { ...@@ -252,6 +211,7 @@ class SequentialNode : public PassNode {
/*! \brief A list of passes that used to compose a sequential pass. */ /*! \brief A list of passes that used to compose a sequential pass. */
tvm::Array<Pass> passes; tvm::Array<Pass> passes;
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("pass_info", &pass_info); v->Visit("pass_info", &pass_info);
v->Visit("passes", &passes); v->Visit("passes", &passes);
...@@ -263,22 +223,13 @@ class SequentialNode : public PassNode { ...@@ -263,22 +223,13 @@ class SequentialNode : public PassNode {
PassInfo Info() const { return pass_info; } PassInfo Info() const { return pass_info; }
/*! /*!
* \brief Add a pass to the pass list.
*
* \param pass The candidate pass to be added.
*/
void AddPass(const Pass& pass) {
passes.push_back(pass);
}
/*!
* \brief Check if a pass is enabled. * \brief Check if a pass is enabled.
* *
* \param pass_name The name of an optimization/analysis pass. * \param info The pass information.
* *
* \return true if the pass is enabled. Otherwise, false. * \return true if the pass is enabled. Otherwise, false.
*/ */
bool PassEnabled(const std::string& pass_name) const; bool PassEnabled(const PassInfo& info) const;
/*! /*!
* \brief Resolve the pass dependency. It globs all required passes by * \brief Resolve the pass dependency. It globs all required passes by
...@@ -294,12 +245,6 @@ class SequentialNode : public PassNode { ...@@ -294,12 +245,6 @@ class SequentialNode : public PassNode {
*/ */
void ResolveDependency(const Module& mod); void ResolveDependency(const Module& mod);
std::unordered_set<std::string> DisabledPasses(
const Array<tvm::Expr>& disabled) const;
std::unordered_set<std::string> RequiredPasses(
const Array<tvm::Expr>& required) const;
/*! /*!
* \brief Perform optimizations on a series of passes. The aforementioned * \brief Perform optimizations on a series of passes. The aforementioned
* typical pass manager jobs could be done by it. This function could * typical pass manager jobs could be done by it. This function could
...@@ -317,7 +262,8 @@ class SequentialNode : public PassNode { ...@@ -317,7 +262,8 @@ class SequentialNode : public PassNode {
TVM_DECLARE_NODE_TYPE_INFO(SequentialNode, PassNode); TVM_DECLARE_NODE_TYPE_INFO(SequentialNode, PassNode);
}; };
PassInfo PassInfoNode::make(int opt_level, std::string name, PassInfo PassInfoNode::make(int opt_level,
std::string name,
tvm::Array<tvm::Expr> required) { tvm::Array<tvm::Expr> required) {
auto pass_info = make_node<PassInfoNode>(); auto pass_info = make_node<PassInfoNode>();
pass_info->opt_level = opt_level; pass_info->opt_level = opt_level;
...@@ -338,23 +284,13 @@ ModulePass ModulePassNode::make( ...@@ -338,23 +284,13 @@ ModulePass ModulePassNode::make(
// Module -> Module optimizations. // Module -> Module optimizations.
Module ModulePassNode::operator()(const Module& mod, Module ModulePassNode::operator()(const Module& mod,
const PassContext& pass_ctx) const { const PassContext& pass_ctx) const {
PassInfo pass_info = Info(); const PassInfo& pass_info = Info();
DLOG(INFO) << "Executing module pass : " << pass_info->name DLOG(INFO) << "Executing module pass : "
<< " with opt level: " << pass_info->opt_level << "\n"; << pass_info->name
<< " with opt level: "
<< pass_info->opt_level;
CHECK(mod.defined()); CHECK(mod.defined());
Module updated_mod = mod; Module updated_mod = pass_func(mod, pass_ctx);
// Execute the required passes in a DFS way.
// TODO(zhiics) We may need to pass validation to detect the cyclic
// dependency.
for (const auto& it : pass_info->required) {
const auto* name = it.as<tvm::ir::StringImm>();
CHECK(name);
auto pass = GetPass(name->value);
updated_mod = pass(updated_mod, pass_ctx);
}
updated_mod = pass_func(updated_mod, pass_ctx);
CHECK(updated_mod.defined()); CHECK(updated_mod.defined());
return updated_mod; return updated_mod;
} }
...@@ -369,25 +305,15 @@ FunctionPass FunctionPassNode::make( ...@@ -369,25 +305,15 @@ FunctionPass FunctionPassNode::make(
} }
// Perform Module -> Module optimizations at the Function level. // Perform Module -> Module optimizations at the Function level.
// TODO(zhiics) Check and handle the required passes.
Module FunctionPassNode::operator()(const Module& mod, Module FunctionPassNode::operator()(const Module& mod,
const PassContext& pass_ctx) const { const PassContext& pass_ctx) const {
PassInfo pass_info = Info(); const PassInfo& pass_info = Info();
CHECK(mod.defined()); CHECK(mod.defined());
DLOG(INFO) << "Executing module pass : " << pass_info->name DLOG(INFO) << "Executing module pass : "
<< " with opt level: " << pass_info->opt_level << "\n"; << pass_info->name
<< " with opt level: "
<< pass_info->opt_level;
Module updated_mod = mod; Module updated_mod = mod;
// Execute the required passes in a DFS way.
// TODO(zhiics) We may need to pass validation to detect the cyclic
// dependency.
for (const auto& it : pass_info->required) {
const auto* name = it.as<tvm::ir::StringImm>();
CHECK(name);
auto pass = GetPass(name->value);
updated_mod = pass(updated_mod, pass_ctx);
}
Module new_mod = ModuleNode::make({}, mod->type_definitions); Module new_mod = ModuleNode::make({}, mod->type_definitions);
// Execute the pass function and return a new module. // Execute the pass function and return a new module.
for (const auto& it : mod->functions) { for (const auto& it : mod->functions) {
...@@ -396,7 +322,6 @@ Module FunctionPassNode::operator()(const Module& mod, ...@@ -396,7 +322,6 @@ Module FunctionPassNode::operator()(const Module& mod,
: pass_func(it.second, updated_mod, pass_ctx); : pass_func(it.second, updated_mod, pass_ctx);
new_mod->Add(it.first, updated_func); new_mod->Add(it.first, updated_func);
} }
return new_mod; return new_mod;
} }
...@@ -436,47 +361,40 @@ void SequentialNode::ResolveDependency(const Module& mod) { ...@@ -436,47 +361,40 @@ void SequentialNode::ResolveDependency(const Module& mod) {
<< "\n"; << "\n";
} }
std::unordered_set<std::string> SequentialNode::DisabledPasses( // linearly scan the pass array to match pass_name
const Array<tvm::Expr>& disabled) const { inline bool PassArrayContains(const Array<tvm::Expr>& pass_array,
std::unordered_set<std::string> ret; const std::string& pass_name) {
for (const auto& it : disabled) { for (auto x : pass_array) {
const auto* str = it.as<tvm::ir::StringImm>(); auto* str_name = x.as<ir::StringImm>();
CHECK(str) << "Disabled pass name must be string."; CHECK(str_name) << "pass name must be str";
ret.emplace(str->value); if (str_name->value == pass_name) return true;
}
return ret;
}
std::unordered_set<std::string> SequentialNode::RequiredPasses(
const Array<tvm::Expr>& required) const {
std::unordered_set<std::string> ret;
for (const auto& it : required) {
const auto* str = it.as<tvm::ir::StringImm>();
CHECK(str) << "Required pass name must be string.";
ret.emplace(str->value);
} }
return ret; return false;
} }
bool SequentialNode::PassEnabled(const std::string& pass_name) const { bool SequentialNode::PassEnabled(const PassInfo& info) const {
PassContext ctx = PassContext::Current(); PassContext ctx = PassContext::Current();
auto required = RequiredPasses(ctx->required_pass); if (PassArrayContains(ctx->disabled_pass, info->name)) {
auto disabled = DisabledPasses(ctx->disabled_pass);
if (disabled.count(pass_name)) {
return false; return false;
} }
if (required.count(pass_name)) { if (PassArrayContains(ctx->required_pass, info->name)) {
return true; return true;
} }
const Pass pass = GetPass(pass_name);
PassInfo info = pass->Info();
return ctx->opt_level >= info->opt_level; return ctx->opt_level >= info->opt_level;
} }
Pass GetPass(const std::string& pass_name) {
using tvm::runtime::Registry;
std::string fpass_name = "relay._transform." + pass_name;
const auto* f = Registry::Get(fpass_name);
CHECK(f != nullptr) << "Cannot find " << fpass_name
<< "to create the pass " << pass_name;
return (*f)();
}
// TODO(zhiics): we currenlty only sequentially execute each pass in // TODO(zhiics): we currenlty only sequentially execute each pass in
// a Sequential without the consideration of their orders. The phase // a Sequential without the consideration of their orders. The phase
// ordering problem needs to be handled in the future. // ordering problem needs to be handled in the future.
...@@ -485,13 +403,15 @@ Module SequentialNode::operator()(const Module& module, ...@@ -485,13 +403,15 @@ Module SequentialNode::operator()(const Module& module,
Module mod = module; Module mod = module;
for (const Pass& pass : passes) { for (const Pass& pass : passes) {
CHECK(pass.defined()) << "Found undefined pass for optimization."; CHECK(pass.defined()) << "Found undefined pass for optimization.";
const PassInfo& pass_info = pass->Info();
PassInfo info = pass->Info(); if (!PassEnabled(pass_info)) continue;
const auto& pass_name = info->name; // resolve dependencies
// Execute the pass if it is enabled. for (const auto& it : pass_info->required) {
if (PassEnabled(pass_name)) { const auto* name = it.as<tvm::ir::StringImm>();
mod = pass(mod, pass_ctx); CHECK(name);
mod = GetPass(name->value)(mod, pass_ctx);
} }
mod = pass(mod, pass_ctx);
} }
return mod; return mod;
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment