/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file src/relay/transforms/merge_composite.cc * \brief Merges expressions matching patterns into functions marked * as 'composite'. This is primarily intended to be used alongside the * external codegen infrastructure to support the case where multiple * Relay operators map to a single external operator. */ #include <tvm/te/operation.h> #include <tvm/relay/analysis.h> #include <tvm/relay/expr_functor.h> #include <tvm/relay/op_attr_types.h> #include <tvm/relay/transform.h> namespace tvm { namespace relay { namespace merge_composite { class MergeCompositeWrapper : public ExprMutator { public: explicit MergeCompositeWrapper(const std::string& pattern_name, const Expr& pattern) : pattern_name_(pattern_name), pattern_(pattern) {} Expr ExtractPattern(const Var& pattern, const Expr& root, Map<std::string, Array<Expr>>* var_map) { if (var_map->find(pattern->name_hint()) == var_map->end()) { // if we haven't encountered this var yet, make a new free var and associate // it with the value at 'root' auto free_var = VarNode::make(pattern->name_hint(), Type()); var_map->Set(pattern->name_hint(), Array<Expr>({free_var, root})); return std::move(free_var); } else { // if we have encountered this var already, return the free var that was created auto vars = (*var_map)[pattern->name_hint()]; auto free_var = vars[0]; auto graph_expr = vars[1]; // make sure to first check they both map to the same node in the graph if (graph_expr != root) { return Expr(); } return (*var_map)[pattern->name_hint()][0]; } } Expr ExtractPattern(const Constant& pattern, const Expr& root, Map<std::string, Array<Expr>>* var_map) { return root; } /*! * \brief Try and extract a given pattern from a graph as a subgraph. * \param pattern The pattern to extract. * \param root The graph to extract from. * \param var_map A map between free vars in the subgraph and nodes in the graph. * \return The extracted subgraph. * * \note How does this work? * * A pattern consists of Relay expression containing only operator call nodes, constants * and free variables. The free variables indicate where the pattern can 'attach' in your * graph. This function takes the final call node of the pattern and the call node currently * being traversed in the Relay graph. It traverses through the pattern in lockstep with call node * from the graph (referred to as the 'root' node here) to check they're identical. If at any point * they differ, an empty expression is returned to signify the extract failed. If a free var is * reached in the pattern, the corresponding value in the root is associated with the name of the * free var (via the var_map) so that when we construct the composite function, the inputs match * up correctly with the rest of the graph. The return value of this function when successful is * a new Relay expression ready to be wrapped into a composite function. */ Expr ExtractPattern(const Call& pattern, const Call& root, Map<std::string, Array<Expr>>* var_map, Map<Expr, Expr>* call_map) { // check to make sure both calls are to operators (not functions) if (!pattern->op->IsInstance<OpNode>() || !root->op->IsInstance<OpNode>()) return Expr(); if (pattern->op.as<OpNode>()->name != root->op.as<OpNode>()->name) return Expr(); unsigned int i = 0; Array<Expr> new_args; for (const auto& arg : pattern->args) { Expr new_arg; if (arg->IsInstance<CallNode>()) { // if we've already processed this call node, return the previous result if (call_map->find(arg) != call_map->end()) { new_arg = (*call_map)[arg]; } else { // fail if the root argument is not also a call node if (!root->args[i]->IsInstance<CallNode>()) { return Expr(); } // if it's a call node, recursively call this function new_arg = ExtractPattern(Downcast<Call>(arg), Downcast<Call>(root->args[i]), var_map, call_map); call_map->Set(arg, new_arg); } } else if (arg->IsInstance<VarNode>()) { // if there's a var in the pattern, it must be a free var // so call the function to update the var_map new_arg = ExtractPattern(Downcast<Var>(arg), root->args[i], var_map); } else if (arg->IsInstance<ConstantNode>()) { // if there's a constant, simply get the corresponding // value of the constant from the root new_arg = ExtractPattern(Downcast<Constant>(arg), root->args[i], var_map); } if (!new_arg.defined()) { return Expr(); } new_args.push_back(new_arg); i++; } return CallNode::make(root->op, new_args, root->attrs); } Expr VisitExpr_(const CallNode* cn) { Call call = GetRef<Call>(cn); if (call->op->IsInstance<FunctionNode>()) { Function func = Downcast<Function>(call->op); CHECK(func.defined()); const auto name_node = func->GetAttr<tir::StringImm>(attr::kComposite); // don't step into existing composite functions if (name_node.defined() && name_node->value != "") { tvm::Array<tvm::relay::Expr> new_args; for (const auto& arg : call->args) { auto new_e = this->Mutate(arg); new_args.push_back(new_e); } return CallNode::make(call->op, new_args, call->attrs); } } Expr expr = ExprMutator::VisitExpr_(cn); call = Downcast<Call>(expr); if (!call->op->IsInstance<OpNode>()) return std::move(call); // only call patterns are supported Call pattern = Downcast<Call>(pattern_); CHECK(pattern.defined()); Map<std::string, Array<Expr>> args_map; Map<Expr, Expr> call_map; auto extract = ExtractPattern(pattern, call, &args_map, &call_map); if (extract.defined()) { auto free_vars = FreeVars(extract); // make the composite function auto f = Function(free_vars, extract, call->checked_type_, {}, DictAttrs()); f = WithAttr(std::move(f), attr::kComposite, tir::StringImmNode::make(pattern_name_)); // find the expressions associated with the free vars using the args_map // this tells us which expressions should be given as inputs to the composite function Array<Expr> args; for (const auto& free_var : free_vars) { args.push_back(args_map[free_var->name_hint()][1]); } auto new_call = CallNode::make(f, args); return std::move(new_call); } return std::move(call); } private: /*! \brief The name of the pattern to match */ std::string pattern_name_; /*! \brief The pattern to match */ Expr pattern_; }; Expr MergeComposite(const Expr& expr, const Array<tir::StringImm>& pattern_names, const Array<Expr>& patterns) { CHECK_EQ(pattern_names.size(), patterns.size()); Expr merged_expr = expr; // merge the patterns one-by-one in order for (size_t i = 0; i < patterns.size(); i++) { std::string pattern_name = pattern_names[i]->value; Expr pattern = patterns[i]; merged_expr = MergeCompositeWrapper(pattern_name, pattern).Mutate(merged_expr); } return merged_expr; } } // namespace merge_composite namespace transform { Pass MergeComposite(const tvm::Array<tir::StringImm>& pattern_names, const tvm::Array<Expr>& patterns) { runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = [=](Function f, IRModule m, PassContext pc) { return Downcast<Function>( relay::merge_composite::MergeComposite(f, pattern_names, patterns)); }; auto func_pass = CreateFunctionPass(pass_func, 0, "MergeComposite", {}); return func_pass; } TVM_REGISTER_GLOBAL("relay._transform.MergeComposite") .set_body_typed(MergeComposite); } // namespace transform } // namespace relay } // namespace tvm