/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * * \file forward_rewrite.cc * \brief Apply rewriting rules in a forward fashion. */ #include <tvm/relay/analysis.h> #include <tvm/relay/expr_functor.h> #include <tvm/relay/op_attr_types.h> #include <tvm/relay/transform.h> #include "pass_util.h" namespace tvm { namespace relay { // Realizer class that realizes the expression // Note that we can take benefit of its internal memo // so that calling realize repeatively won't hurt perf. class TempRealizer : private MixedModeMutator { public: Expr Realize(Expr expr) { return Mutate(expr); } private: Expr DispatchVisitExpr(const Expr& expr) final { Expr res; if (const auto* temp = expr.as<TempExprNode>()) { res = temp->Realize(); } else { res = MixedModeMutator::DispatchVisitExpr(expr); } return res; } }; class ForwardRewriter : private MixedModeMutator { public: ForwardRewriter(const OpMap<FForwardRewrite>* rewrite_map, std::function<ObjectRef(const Call&)> fcontext, std::function<Expr(const Expr&)> fmulti_ref_trigger) : rewrite_map_(rewrite_map), fcontext_(fcontext), fmulti_ref_trigger_(fmulti_ref_trigger) {} ForwardRewriter(const FForwardRewrite* rewrite_func, std::function<ObjectRef(const Call&)> fcontext, std::function<Expr(const Expr&)> fmulti_ref_trigger) : rewrite_func_(rewrite_func), fcontext_(fcontext), fmulti_ref_trigger_(fmulti_ref_trigger) {} // Transform expression. Expr Rewrite(const Expr& expr) { if (fmulti_ref_trigger_ != nullptr) { ref_counter_ = GetExprRefCount(expr); } return realizer_.Realize(this->VisitExpr(expr)); } private: // The rewrite rule. const OpMap<FForwardRewrite>* rewrite_map_{nullptr}; const FForwardRewrite* rewrite_func_{nullptr}; // The context.const std::function<ObjectRef(const Call&)> fcontext_{nullptr}; // The multiple reference trigger std::function<Expr(const Expr&)> fmulti_ref_trigger_{nullptr}; // Internal ref counter std::unordered_map<const Object*, size_t> ref_counter_; // internal realizer TempRealizer realizer_; // Visit and allow non-realized version. Expr GetTempExpr(const Expr& expr, const Expr& post) { if (fmulti_ref_trigger_ != nullptr) { Expr ret = post; auto it = ref_counter_.find(expr.get()); CHECK(it != ref_counter_.end()); if (it->second > 1) { ret = fmulti_ref_trigger_(ret); } return ret; } else { return post; } } // Automatic fold TupleGetItem. Expr Rewrite_(const TupleGetItemNode* op, const Expr& post) final { Expr tuple = this->GetTempExpr(op->tuple, post.as<TupleGetItemNode>()->tuple); if (const auto* ptuple = tuple.as<TupleNode>()) { return ptuple->fields[op->index]; } else { if (tuple.same_as(op->tuple)) { return GetRef<Expr>(op); } else { return TupleGetItem(tuple, op->index); } } } Expr Rewrite_(const TupleNode* op, const Expr& post) final { tvm::Array<Expr> fields; bool all_fields_unchanged = true; const auto* post_node = post.as<TupleNode>(); for (size_t i = 0; i < op->fields.size(); ++i) { auto new_field = this->GetTempExpr(op->fields[i], post_node->fields[i]); fields.push_back(new_field); all_fields_unchanged &= new_field.same_as(op->fields[i]); } if (all_fields_unchanged) { return GetRef<Expr>(op); } else { return Tuple(fields); } } Expr Rewrite_(const CallNode* call_node, const Expr& post) final { const Call& ref_call = GetRef<Call>(call_node); PackedFunc frewrite; if (rewrite_func_) { frewrite = *rewrite_func_; } else { CHECK(rewrite_map_); frewrite = rewrite_map_->get(call_node->op, nullptr); } const auto* post_node = post.as<CallNode>(); auto new_op = post_node->op; bool unchanged = call_node->op.same_as(new_op); Array<Expr> call_args; for (size_t i = 0; i < call_node->args.size(); ++i) { Expr new_arg = this->GetTempExpr(call_node->args[i], post_node->args[i]); if (frewrite == nullptr) { new_arg = realizer_.Realize(new_arg); } unchanged &= new_arg.same_as(call_node->args[i]); call_args.push_back(new_arg); } // try to rewrite. if (frewrite != nullptr) { Expr res = frewrite( ref_call, call_args, fcontext_ != nullptr ? fcontext_(ref_call) : ObjectRef(nullptr)); if (res.defined()) return res; // abort, use old rule for (size_t i = 0; i < call_args.size(); ++i) { Expr arg = call_args[i]; Expr new_arg = realizer_.Realize(arg); if (!arg.same_as(new_arg)) { call_args.Set(i, new_arg); unchanged = false; } } } if (unchanged) return ref_call; return Call( new_op, call_args, call_node->attrs, call_node->type_args); } }; Expr ForwardRewrite(const Expr& expr, const std::string& rewrite_map_name, std::function<ObjectRef(const Call&)> fcontext, std::function<Expr(const Expr&)> fmulti_ref_trigger) { auto rewrite_map = Op::GetAttr<FForwardRewrite>(rewrite_map_name); return ForwardRewriter(&rewrite_map, fcontext, fmulti_ref_trigger).Rewrite(expr); } Expr ForwardRewrite(const Expr& expr, const FForwardRewrite& rewrite_func, std::function<ObjectRef(const Call&)> fcontext, std::function<Expr(const Expr&)> fmulti_ref_trigger) { return ForwardRewriter(&rewrite_func, fcontext, fmulti_ref_trigger).Rewrite(expr); } } // namespace relay } // namespace tvm