/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file src/relay/pass/pass_manager.cc * \brief Relay pass manager implementation. */ #include <dmlc/thread_local.h> #include <tvm/relay/expr_functor.h> #include <tvm/relay/transform.h> #include <tvm/runtime/device_api.h> #include <algorithm> #include <stack> #include <unordered_set> namespace tvm { namespace relay { namespace transform { using tvm::IRPrinter; struct RelayPassContextThreadLocalEntry { /*! \brief The default pass context. */ PassContext default_context; /*! \brief The current pass context. */ std::stack<PassContext> context_stack; RelayPassContextThreadLocalEntry() { default_context = PassContext(make_node<PassContextNode>()); } }; /*! \brief Thread local store to hold the pass context. */ typedef dmlc::ThreadLocalStore<RelayPassContextThreadLocalEntry> RelayPassContextThreadLocalStore; void PassContext::EnterWithScope() { RelayPassContextThreadLocalEntry* entry = RelayPassContextThreadLocalStore::Get(); entry->context_stack.push(*this); } void PassContext::ExitWithScope() { RelayPassContextThreadLocalEntry* entry = RelayPassContextThreadLocalStore::Get(); CHECK(!entry->context_stack.empty()); CHECK(entry->context_stack.top().same_as(*this)); entry->context_stack.pop(); } PassContext PassContext::Current() { RelayPassContextThreadLocalEntry* entry = RelayPassContextThreadLocalStore::Get(); if (!entry->context_stack.empty()) { return entry->context_stack.top(); } else { return entry->default_context; } } PassContext PassContext::Create() { return PassContext(make_node<PassContextNode>()); } class ModulePass; /*! * \brief Module-level passes are designed to implement global * analysis/optimizations, i.e. interprocedural optimizations (IPO), etc. Passes * at this level have the full control of a given Relay program including * addition and deletion of functions. */ class ModulePassNode : public PassNode { public: /* \brief The pass meta data.*/ PassInfo pass_info; /*! \brief The pass function sketches the real optimization. For example, * we may need to perform dead code elimination on the module level. We could * implement the algorithm in the `pass_func` and let it run on a module. It * will then remove the dead code including the unused functions in the module. */ runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func; ModulePassNode() = default; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pass_info", &pass_info); } /*! * \brief Run a module pass on given pass context. * * \param mod The module that an optimization pass is applied on. * \param mod The context that an optimization pass executes on. * * \return Return the updated module. */ Module operator()(const Module& mod, const PassContext& pass_ctx) const final; /*! * \brief Get the pass information/meta data. */ PassInfo Info() const override { return pass_info; } TVM_DLL static ModulePass make( runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func, PassInfo pass_info); static constexpr const char* _type_key = "relay.ModulePass"; TVM_DECLARE_NODE_TYPE_INFO(ModulePassNode, PassNode); }; RELAY_DEFINE_NODE_REF(ModulePass, ModulePassNode, Pass); class FunctionPass; /*! * \brief Function-level passes are used to implement various global * optimizations for a given Relay module. It fetches one function at a time * from the function list in the module for optimization. * * Note that the scope of passes at this level is a Relay function. Therefore, * we cannot add or delete a function through these passes as they are not aware * of the global information. */ class FunctionPassNode : public PassNode { public: /* \brief The pass meta data.*/ PassInfo pass_info; /*! \brief The packed pass function sketches the real optimization. For * instance, we can implement a pass that works on a Relay function as a * `pass_func` and let it run on a given module. The same `pass_func` will * then be applied on each function in the module. */ runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func; FunctionPassNode() = default; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pass_info", &pass_info); } /*! * \brief Run a function pass on given pass context. * * \param mod The module that an optimization pass is applied on. * \param mod The context that an optimization pass executes on. * * \return Return the updated module. */ Module operator()(const Module& mod, const PassContext& pass_ctx) const final; /*! * \brief Get the pass information/meta data. */ PassInfo Info() const override { return pass_info; } TVM_DLL static FunctionPass make( runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func, PassInfo pass_info); static constexpr const char* _type_key = "relay.FunctionPass"; TVM_DECLARE_NODE_TYPE_INFO(FunctionPassNode, PassNode); private: /* * \brief Check if a function should be skipped for optimization. * * \param func The target function to be checked. * * \return Return true if the function will be skipped, otherwise false. */ bool SkipFunction(const Function& func) const; }; RELAY_DEFINE_NODE_REF(FunctionPass, FunctionPassNode, Pass); /*! * \brief The SequentialNode contains a set of passes that transform Relay * programs from one AST to another semantically equivalent one. * * One example of this level of pass is that the pass manager needs to correctly * perform a host of optimizations with a given optimization level and disabled * passes. */ class SequentialNode : public PassNode { public: /* \brief The pass meta data.*/ PassInfo pass_info; /*! \brief A list of passes that used to compose a sequential pass. */ tvm::Array<Pass> passes; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pass_info", &pass_info); v->Visit("passes", &passes); } /*! * \brief Get the pass information/meta data. */ PassInfo Info() const override { return pass_info; } /*! * \brief Check if a pass is enabled. * * \param info The pass information. * * \return true if the pass is enabled. Otherwise, false. */ bool PassEnabled(const PassInfo& info) const; /*! * \brief Resolve the pass dependency. It globs all required passes by * a given pass and executes them. * * \param mod The module that an optimization pass runs on. * * \return The updated module after resolving pass dependencies. * * TODO(zhiics) Build a dependency graph among the passes using provided * metadata, i.e. required_passes. Likely, we can have a data structure, i.e. * PassInfo, to store the relevant information including the parent passes. */ void ResolveDependency(const Module& mod); /*! * \brief Perform optimizations on a series of passes. The aforementioned * typical pass manager jobs could be done by it. This function could * be overloaded to focus on different metrics, i.e. performance, * memory footprint, etc. * * \param mod The module that these passes are applied on. * \param pass_ctx The context that these passes execute on. * * \return Return the updated module. */ Module operator()(const Module& mod, const PassContext& pass_ctx) const final; static constexpr const char* _type_key = "relay.Sequential"; TVM_DECLARE_NODE_TYPE_INFO(SequentialNode, PassNode); }; PassInfo PassInfoNode::make(int opt_level, std::string name, tvm::Array<tvm::Expr> required) { auto pass_info = make_node<PassInfoNode>(); pass_info->opt_level = opt_level; pass_info->name = std::move(name); pass_info->required = std::move(required); return PassInfo(pass_info); } ModulePass ModulePassNode::make( runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func, PassInfo pass_info) { auto n = make_node<ModulePassNode>(); n->pass_func = std::move(pass_func); n->pass_info = std::move(pass_info); return ModulePass(n); } // Module -> Module optimizations. Module ModulePassNode::operator()(const Module& mod, const PassContext& pass_ctx) const { const PassInfo& pass_info = Info(); DLOG(INFO) << "Executing module pass : " << pass_info->name << " with opt level: " << pass_info->opt_level; CHECK(mod.defined()); Module updated_mod = pass_func(mod, pass_ctx); CHECK(updated_mod.defined()); return updated_mod; } FunctionPass FunctionPassNode::make( runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func, PassInfo pass_info) { auto n = make_node<FunctionPassNode>(); n->pass_func = std::move(pass_func); n->pass_info = std::move(pass_info); return FunctionPass(n); } // Perform Module -> Module optimizations at the Function level. Module FunctionPassNode::operator()(const Module& mod, const PassContext& pass_ctx) const { const PassInfo& pass_info = Info(); CHECK(mod.defined()); DLOG(INFO) << "Executing function pass : " << pass_info->name << " with opt level: " << pass_info->opt_level; // Execute the pass function and return a new module. Module updated_mod = ModuleNode::make(mod->functions, mod->type_definitions, mod->Imports()); std::vector<std::pair<GlobalVar, Function> > updates; for (const auto& it : updated_mod->functions) { auto updated_func = SkipFunction(it.second) ? it.second : pass_func(it.second, updated_mod, pass_ctx); updates.push_back({it.first, updated_func}); } for (const auto& pair : updates) { updated_mod->Add(pair.first, pair.second, true); } return updated_mod; } // TODO(zhiics) Create an enum attribute for FunctionNode // enum Attribute {kPrimitive, kSkipOptimization} bool FunctionPassNode::SkipFunction(const Function& func) const { NodeRef res = FunctionGetAttr(func, "SkipOptimization"); const ir::IntImm* pval = res.as<ir::IntImm>(); return pval && pval->value != 0; } Sequential::Sequential(tvm::Array<Pass> passes, PassInfo pass_info) { auto n = make_node<SequentialNode>(); n->passes = std::move(passes); n->pass_info = std::move(pass_info); data_ = std::move(n); } Sequential::Sequential(tvm::Array<Pass> passes, std::string name) { auto n = make_node<SequentialNode>(); n->passes = std::move(passes); PassInfo pass_info = PassInfoNode::make(2, std::move(name), {}); n->pass_info = std::move(pass_info); data_ = std::move(n); } const SequentialNode* Sequential::operator->() const { return static_cast<const SequentialNode*>(get()); } void SequentialNode::ResolveDependency(const Module& mod) { // TODO(zhiics) Implement it. // 1. Consider the required passes for each pass. // 2. Only resolve the enabled passes. // 3. Build a dependency graph. Probably we need to update the pass list. LOG(FATAL) << "Pass dependency has not been resolved yet." << "\n"; } // linearly scan the pass array to match pass_name inline bool PassArrayContains(const Array<tvm::Expr>& pass_array, const std::string& pass_name) { for (auto x : pass_array) { auto* str_name = x.as<ir::StringImm>(); CHECK(str_name) << "pass name must be str"; if (str_name->value == pass_name) return true; } return false; } bool SequentialNode::PassEnabled(const PassInfo& info) const { PassContext ctx = PassContext::Current(); if (PassArrayContains(ctx->disabled_pass, info->name)) { return false; } if (PassArrayContains(ctx->required_pass, info->name)) { return true; } return ctx->opt_level >= info->opt_level; } Pass GetPass(const std::string& pass_name) { using tvm::runtime::Registry; std::string fpass_name = "relay._transform." + pass_name; const auto* f = Registry::Get(fpass_name); CHECK(f != nullptr) << "Cannot find " << fpass_name << "to create the pass " << pass_name; return (*f)(); } // TODO(zhiics): we currenlty only sequentially execute each pass in // a Sequential without the consideration of their orders. The phase // ordering problem needs to be handled in the future. Module SequentialNode::operator()(const Module& module, const PassContext& pass_ctx) const { Module mod = module; for (const Pass& pass : passes) { CHECK(pass.defined()) << "Found undefined pass for optimization."; const PassInfo& pass_info = pass->Info(); if (!PassEnabled(pass_info)) continue; // resolve dependencies for (const auto& it : pass_info->required) { const auto* name = it.as<tvm::ir::StringImm>(); CHECK(name); mod = GetPass(name->value)(mod, pass_ctx); } mod = pass(mod, pass_ctx); } return mod; } Pass CreateModulePass( const runtime::TypedPackedFunc<Module(Module, PassContext)>& pass_func, int opt_level, const std::string& name, const tvm::Array<tvm::Expr>& required) { PassInfo pass_info = PassInfoNode::make(opt_level, name, required); return ModulePassNode::make(pass_func, pass_info); } Pass CreateFunctionPass( const runtime::TypedPackedFunc<Function(Function, Module, PassContext)>& pass_func, int opt_level, const std::string& name, const tvm::Array<tvm::Expr>& required) { PassInfo pass_info = PassInfoNode::make(opt_level, name, required); return FunctionPassNode::make(pass_func, pass_info); } TVM_REGISTER_NODE_TYPE(PassInfoNode); TVM_REGISTER_API("relay._transform.PassInfo") .set_body_typed(PassInfoNode::make); TVM_REGISTER_API("relay._transform.Info") .set_body([](TVMArgs args, TVMRetValue* ret) { Pass pass = args[0]; *ret = pass->Info(); }); TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) .set_dispatch<PassInfoNode>([](const ObjectRef& ref, tvm::IRPrinter* p) { auto* node = static_cast<const PassInfoNode*>(ref.get()); p->stream << "The meta data of the pass: "; p->stream << "pass name: " << node->name; p->stream << "opt_level: " << node->opt_level; p->stream << "required passes: [" << "\n"; for (const auto& it : node->required) { const auto* str = it.as<tvm::ir::StringImm>(); p->stream << str->value << ", "; } p->stream << "]\n"; }); TVM_REGISTER_NODE_TYPE(ModulePassNode); TVM_REGISTER_API("relay._transform.MakeModulePass") .set_body_typed(ModulePassNode::make); TVM_REGISTER_API("relay._transform.RunPass") .set_body([](TVMArgs args, TVMRetValue* ret) { Pass pass = args[0]; Module mod = args[1]; *ret = pass(mod); }); TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) .set_dispatch<ModulePassNode>([](const ObjectRef& ref, IRPrinter* p) { auto* node = static_cast<const ModulePassNode*>(ref.get()); const PassInfo info = node->Info(); p->stream << "Run Module pass: " << info->name << " at the optimization level " << info->opt_level; }); TVM_REGISTER_NODE_TYPE(FunctionPassNode); TVM_REGISTER_API("relay._transform.MakeFunctionPass") .set_body_typed(FunctionPassNode::make); TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) .set_dispatch<FunctionPassNode>([](const ObjectRef& ref, IRPrinter* p) { auto* node = static_cast<const FunctionPassNode*>(ref.get()); const PassInfo info = node->Info(); p->stream << "Run Function pass: " << info->name << " at the optimization level " << info->opt_level; }); TVM_REGISTER_NODE_TYPE(SequentialNode); TVM_REGISTER_API("relay._transform.Sequential") .set_body([](TVMArgs args, TVMRetValue* ret) { tvm::Array<Pass> passes = args[0]; int opt_level = args[1]; std::string name = args[2]; tvm::Array<tvm::Expr> required = args[3]; PassInfo pass_info = PassInfoNode::make(opt_level, name, required); *ret = Sequential(passes, pass_info); }); TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) .set_dispatch<SequentialNode>([](const ObjectRef& ref, IRPrinter* p) { auto* node = static_cast<const SequentialNode*>(ref.get()); const PassInfo info = node->Info(); p->stream << "Run Sequential pass: " << info->name << " at the optimization level " << info->opt_level << ". "; p->stream << "The passes will be executed are: ["; for (const auto& it : node->passes) { const PassInfo pass_info = it->Info(); p->stream << pass_info->name << " "; } p->stream << "]"; }); TVM_REGISTER_NODE_TYPE(PassContextNode); TVM_REGISTER_API("relay._transform.PassContext") .set_body([](TVMArgs args, TVMRetValue* ret) { auto pctx = PassContext::Create(); int opt_level = args[0]; int fallback_device = args[1]; tvm::Array<tvm::Expr> required = args[2]; tvm::Array<tvm::Expr> disabled = args[3]; pctx->opt_level = opt_level; pctx->fallback_device = fallback_device; pctx->required_pass = std::move(required); pctx->disabled_pass = std::move(disabled); *ret = pctx; }); TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) .set_dispatch<PassContextNode>([](const ObjectRef& ref, IRPrinter* p) { auto* node = static_cast<const PassContextNode*>(ref.get()); p->stream << "Pass context information: " << "\n"; p->stream << "\topt_level: " << node->opt_level << "\n"; p->stream << "\tfallback device: " << runtime::DeviceName(node->fallback_device) << "\n"; p->stream << "\trequired passes: [" << node->opt_level; for (const auto& it : node->required_pass) { p->stream << it << " "; } p->stream << "]\n"; p->stream << "\tdisabled passes: [" << node->opt_level; for (const auto& it : node->disabled_pass) { p->stream << it << " "; } p->stream << "]"; }); class PassContext::Internal { public: static void EnterScope(PassContext pass_ctx) { pass_ctx.EnterWithScope(); } static void ExitScope(PassContext pass_ctx) { pass_ctx.ExitWithScope(); } }; TVM_REGISTER_API("relay._transform.GetCurrentPassContext") .set_body_typed(PassContext::Current); TVM_REGISTER_API("relay._transform.EnterPassContext") .set_body_typed(PassContext::Internal::EnterScope); TVM_REGISTER_API("relay._transform.ExitPassContext") .set_body_typed(PassContext::Internal::ExitScope); } // namespace transform } // namespace relay } // namespace tvm