Unverified Commit 4cee98ba by Tianqi Chen Committed by GitHub

[PASS][RELAY] polish pass infra (#3319)

parent ca017a38
Subproject commit 3943914eed66470bd010df581e29e4dca4f7df6f
Subproject commit fbe142b267a8edd1f1188fa2140d88f7ae308661
......@@ -37,47 +37,6 @@ namespace transform {
using tvm::IRPrinter;
namespace {
// TODO(zhiics) Maybe we can use PackedFunc here so that parameters can be
// handled because we need to register the pass for Python invocation anyway.
Pass GetPass(const std::string& pass_name) {
if (pass_name == "InferType") {
return InferType();
} else if (pass_name == "AlterOpLayout") {
return AlterOpLayout();
} else if (pass_name == "CanonicalizeOps") {
return CanonicalizeOps();
} else if (pass_name == "CombineParallelConv2d") {
return CombineParallelConv2D();
} else if (pass_name == "DeadCodeElimination") {
return DeadCodeElimination();
} else if (pass_name == "EliminateCommonSubexpr") {
return DeadCodeElimination();
} else if (pass_name == "FoldConstant") {
return FoldConstant();
} else if (pass_name == "BackwardFoldScaleAxis") {
return FoldScaleAxis();
} else if (pass_name == "ForwardFoldScaleAxis") {
return FoldScaleAxis();
} else if (pass_name == "FoldScaleAxis") {
return FoldScaleAxis();
} else if (pass_name == "PartialEvaluate") {
return SimplifyInference();
} else if (pass_name == "SimplifyInference") {
return SimplifyInference();
} else if (pass_name == "ToANormalForm") {
return ToANormalForm();
} else if (pass_name == "ToGraphNormalForm") {
return ToGraphNormalForm();
} else {
LOG(FATAL) << pass_name << " has not been registered yet." << "\n";
return Pass(nullptr);
}
}
} // namespace
struct RelayPassContextThreadLocalEntry {
/*! \brief The default pass context. */
PassContext default_context;
......@@ -252,6 +211,7 @@ class SequentialNode : public PassNode {
/*! \brief A list of passes that used to compose a sequential pass. */
tvm::Array<Pass> passes;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("pass_info", &pass_info);
v->Visit("passes", &passes);
......@@ -263,22 +223,13 @@ class SequentialNode : public PassNode {
PassInfo Info() const { return pass_info; }
/*!
* \brief Add a pass to the pass list.
*
* \param pass The candidate pass to be added.
*/
void AddPass(const Pass& pass) {
passes.push_back(pass);
}
/*!
* \brief Check if a pass is enabled.
*
* \param pass_name The name of an optimization/analysis pass.
* \param info The pass information.
*
* \return true if the pass is enabled. Otherwise, false.
*/
bool PassEnabled(const std::string& pass_name) const;
bool PassEnabled(const PassInfo& info) const;
/*!
* \brief Resolve the pass dependency. It globs all required passes by
......@@ -294,12 +245,6 @@ class SequentialNode : public PassNode {
*/
void ResolveDependency(const Module& mod);
std::unordered_set<std::string> DisabledPasses(
const Array<tvm::Expr>& disabled) const;
std::unordered_set<std::string> RequiredPasses(
const Array<tvm::Expr>& required) const;
/*!
* \brief Perform optimizations on a series of passes. The aforementioned
* typical pass manager jobs could be done by it. This function could
......@@ -317,7 +262,8 @@ class SequentialNode : public PassNode {
TVM_DECLARE_NODE_TYPE_INFO(SequentialNode, PassNode);
};
PassInfo PassInfoNode::make(int opt_level, std::string name,
PassInfo PassInfoNode::make(int opt_level,
std::string name,
tvm::Array<tvm::Expr> required) {
auto pass_info = make_node<PassInfoNode>();
pass_info->opt_level = opt_level;
......@@ -338,23 +284,13 @@ ModulePass ModulePassNode::make(
// Module -> Module optimizations.
Module ModulePassNode::operator()(const Module& mod,
const PassContext& pass_ctx) const {
PassInfo pass_info = Info();
DLOG(INFO) << "Executing module pass : " << pass_info->name
<< " with opt level: " << pass_info->opt_level << "\n";
const PassInfo& pass_info = Info();
DLOG(INFO) << "Executing module pass : "
<< pass_info->name
<< " with opt level: "
<< pass_info->opt_level;
CHECK(mod.defined());
Module updated_mod = mod;
// Execute the required passes in a DFS way.
// TODO(zhiics) We may need to pass validation to detect the cyclic
// dependency.
for (const auto& it : pass_info->required) {
const auto* name = it.as<tvm::ir::StringImm>();
CHECK(name);
auto pass = GetPass(name->value);
updated_mod = pass(updated_mod, pass_ctx);
}
updated_mod = pass_func(updated_mod, pass_ctx);
Module updated_mod = pass_func(mod, pass_ctx);
CHECK(updated_mod.defined());
return updated_mod;
}
......@@ -369,25 +305,15 @@ FunctionPass FunctionPassNode::make(
}
// Perform Module -> Module optimizations at the Function level.
// TODO(zhiics) Check and handle the required passes.
Module FunctionPassNode::operator()(const Module& mod,
const PassContext& pass_ctx) const {
PassInfo pass_info = Info();
const PassInfo& pass_info = Info();
CHECK(mod.defined());
DLOG(INFO) << "Executing module pass : " << pass_info->name
<< " with opt level: " << pass_info->opt_level << "\n";
DLOG(INFO) << "Executing module pass : "
<< pass_info->name
<< " with opt level: "
<< pass_info->opt_level;
Module updated_mod = mod;
// Execute the required passes in a DFS way.
// TODO(zhiics) We may need to pass validation to detect the cyclic
// dependency.
for (const auto& it : pass_info->required) {
const auto* name = it.as<tvm::ir::StringImm>();
CHECK(name);
auto pass = GetPass(name->value);
updated_mod = pass(updated_mod, pass_ctx);
}
Module new_mod = ModuleNode::make({}, mod->type_definitions);
// Execute the pass function and return a new module.
for (const auto& it : mod->functions) {
......@@ -396,7 +322,6 @@ Module FunctionPassNode::operator()(const Module& mod,
: pass_func(it.second, updated_mod, pass_ctx);
new_mod->Add(it.first, updated_func);
}
return new_mod;
}
......@@ -436,47 +361,40 @@ void SequentialNode::ResolveDependency(const Module& mod) {
<< "\n";
}
std::unordered_set<std::string> SequentialNode::DisabledPasses(
const Array<tvm::Expr>& disabled) const {
std::unordered_set<std::string> ret;
for (const auto& it : disabled) {
const auto* str = it.as<tvm::ir::StringImm>();
CHECK(str) << "Disabled pass name must be string.";
ret.emplace(str->value);
}
return ret;
}
std::unordered_set<std::string> SequentialNode::RequiredPasses(
const Array<tvm::Expr>& required) const {
std::unordered_set<std::string> ret;
for (const auto& it : required) {
const auto* str = it.as<tvm::ir::StringImm>();
CHECK(str) << "Required pass name must be string.";
ret.emplace(str->value);
// linearly scan the pass array to match pass_name
inline bool PassArrayContains(const Array<tvm::Expr>& pass_array,
const std::string& pass_name) {
for (auto x : pass_array) {
auto* str_name = x.as<ir::StringImm>();
CHECK(str_name) << "pass name must be str";
if (str_name->value == pass_name) return true;
}
return ret;
return false;
}
bool SequentialNode::PassEnabled(const std::string& pass_name) const {
bool SequentialNode::PassEnabled(const PassInfo& info) const {
PassContext ctx = PassContext::Current();
auto required = RequiredPasses(ctx->required_pass);
auto disabled = DisabledPasses(ctx->disabled_pass);
if (disabled.count(pass_name)) {
if (PassArrayContains(ctx->disabled_pass, info->name)) {
return false;
}
if (required.count(pass_name)) {
if (PassArrayContains(ctx->required_pass, info->name)) {
return true;
}
const Pass pass = GetPass(pass_name);
PassInfo info = pass->Info();
return ctx->opt_level >= info->opt_level;
}
Pass GetPass(const std::string& pass_name) {
using tvm::runtime::Registry;
std::string fpass_name = "relay._transform." + pass_name;
const auto* f = Registry::Get(fpass_name);
CHECK(f != nullptr) << "Cannot find " << fpass_name
<< "to create the pass " << pass_name;
return (*f)();
}
// TODO(zhiics): we currenlty only sequentially execute each pass in
// a Sequential without the consideration of their orders. The phase
// ordering problem needs to be handled in the future.
......@@ -485,13 +403,15 @@ Module SequentialNode::operator()(const Module& module,
Module mod = module;
for (const Pass& pass : passes) {
CHECK(pass.defined()) << "Found undefined pass for optimization.";
PassInfo info = pass->Info();
const auto& pass_name = info->name;
// Execute the pass if it is enabled.
if (PassEnabled(pass_name)) {
mod = pass(mod, pass_ctx);
const PassInfo& pass_info = pass->Info();
if (!PassEnabled(pass_info)) continue;
// resolve dependencies
for (const auto& it : pass_info->required) {
const auto* name = it.as<tvm::ir::StringImm>();
CHECK(name);
mod = GetPass(name->value)(mod, pass_ctx);
}
mod = pass(mod, pass_ctx);
}
return mod;
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment