/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file src/tvm/relay/expr_functor.cc * \brief A wrapper around ExprFunctor which functionally updates the AST. * * ExprMutator uses memoization and self return in order to amortize * the cost of using functional updates. */ #include <tvm/relay/analysis.h> #include <tvm/relay/expr_functor.h> #include <tvm/relay/pattern_functor.h> #include "type_functor.h" namespace tvm { namespace relay { Expr ExprMutator::VisitExpr(const Expr& expr) { auto it = this->memo_.find(expr); if (it != this->memo_.end()) { return it->second; } else { Expr new_expr = ExprFunctor::VisitExpr(expr); memo_[expr] = new_expr; return new_expr; } } Expr ExprMutator::VisitExpr_(const VarNode* op) { if (op->type_annotation.defined()) { auto type = this->VisitType(op->type_annotation); if (!op->type_annotation.same_as(type)) { return VarNode::make(op->vid, type); } } // default case return self. return GetRef<Expr>(op); } Expr ExprMutator::VisitExpr_(const ConstantNode* op) { return GetRef<Expr>(op); } Expr ExprMutator::VisitExpr_(const GlobalVarNode* op) { return GetRef<Expr>(op); } Expr ExprMutator::VisitExpr_(const OpNode* op) { return GetRef<Expr>(op); } Expr ExprMutator::VisitExpr_(const TupleNode* op) { tvm::Array<Expr> fields; bool all_fields_unchanged = true; for (auto field : op->fields) { auto new_field = this->Mutate(field); fields.push_back(new_field); all_fields_unchanged &= new_field.same_as(field); } if (all_fields_unchanged) { return GetRef<Expr>(op); } else { return TupleNode::make(fields); } } Expr ExprMutator::VisitExpr_(const FunctionNode* op) { tvm::Array<TypeVar> ty_params; bool all_ty_params_unchanged = true; for (auto ty_param : op->type_params) { TypeVar new_ty_param = Downcast<TypeVar>(VisitType(ty_param)); ty_params.push_back(new_ty_param); all_ty_params_unchanged &= new_ty_param.same_as(ty_param); } tvm::Array<Var> params; bool all_params_unchanged = true; for (auto param : op->params) { Var new_param = Downcast<Var>(this->Mutate(param)); params.push_back(new_param); all_params_unchanged &= param.same_as(new_param); } auto ret_type = this->VisitType(op->ret_type); auto body = this->Mutate(op->body); if (all_ty_params_unchanged && all_params_unchanged && ret_type.same_as(op->ret_type) && body.same_as(op->body)) { return GetRef<Expr>(op); } else { return FunctionNode::make(params, body, ret_type, ty_params, op->attrs); } } Expr ExprMutator::VisitExpr_(const CallNode* call_node) { auto new_op = this->Mutate(call_node->op); bool unchanged = call_node->op.same_as(new_op); tvm::Array<Type> ty_args; for (auto ty_arg : call_node->type_args) { auto new_ty_arg = this->VisitType(ty_arg); ty_args.push_back(new_ty_arg); unchanged &= new_ty_arg.same_as(ty_arg); } tvm::Array<Expr> call_args; for (auto arg : call_node->args) { auto new_arg = this->Mutate(arg); call_args.push_back(new_arg); unchanged &= new_arg.same_as(arg); } if (unchanged) { return GetRef<Expr>(call_node); } else { return CallNode::make(new_op, call_args, call_node->attrs, ty_args); } } Expr ExprMutator::VisitExpr_(const LetNode* op) { Var var = Downcast<Var>(this->Mutate(op->var)); auto value = this->Mutate(op->value); auto body = this->Mutate(op->body); if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) { return GetRef<Expr>(op); } else { return LetNode::make(var, value, body); } } Expr ExprMutator::VisitExpr_(const IfNode* op) { auto guard = this->Mutate(op->cond); auto true_b = this->Mutate(op->true_branch); auto false_b = this->Mutate(op->false_branch); if (op->cond.same_as(guard) && op->true_branch.same_as(true_b) && op->false_branch.same_as(false_b)) { return GetRef<Expr>(op);; } else { return IfNode::make(guard, true_b, false_b); } } Expr ExprMutator::VisitExpr_(const TupleGetItemNode* g) { auto t = this->Mutate(g->tuple); if (g->tuple == t) { return GetRef<Expr>(g); } else { return TupleGetItemNode::make(t, g->index); } } Expr ExprMutator::VisitExpr_(const RefCreateNode* op) { Expr value = this->Mutate(op->value); if (value.same_as(op->value)) { return GetRef<Expr>(op); } else { return RefCreateNode::make(value); } } Expr ExprMutator::VisitExpr_(const RefReadNode* op) { Expr ref = this->Mutate(op->ref); if (ref.same_as(op->ref)) { return GetRef<Expr>(op); } else { return RefReadNode::make(ref); } } Expr ExprMutator::VisitExpr_(const RefWriteNode* op) { Expr ref = this->Mutate(op->ref); Expr value = this->Mutate(op->value); if (ref.same_as(op->ref) && value.same_as(op->value)) { return GetRef<Expr>(op); } else { return RefWriteNode::make(ref, value); } } Expr ExprMutator::VisitExpr_(const ConstructorNode* c) { return GetRef<Expr>(c); } Expr ExprMutator::VisitExpr_(const MatchNode* m) { std::vector<Clause> clauses; for (const Clause& p : m->clauses) { clauses.push_back(VisitClause(p)); } return MatchNode::make(VisitExpr(m->data), clauses, m->complete); } Clause ExprMutator::VisitClause(const Clause& c) { Pattern p = VisitPattern(c->lhs); return ClauseNode::make(p, VisitExpr(c->rhs)); } Pattern ExprMutator::VisitPattern(const Pattern& p) { return p; } Type ExprMutator::VisitType(const Type& t) { return t; } void ExprVisitor::VisitExpr(const Expr& expr) { auto it = visit_counter_.find(expr.get()); if (it != visit_counter_.end()) { ++it->second; } else { using TParent = ExprFunctor<void(const Expr&)>; TParent::VisitExpr(expr); visit_counter_.insert({expr.get(), 1}); } } void ExprVisitor::ExprVisitor::VisitExpr_(const VarNode* op) { if (op->type_annotation.defined()) { this->VisitType(op->type_annotation); } } void ExprVisitor::ExprVisitor::VisitExpr_(const GlobalVarNode* op) { } void ExprVisitor::ExprVisitor::VisitExpr_(const ConstantNode* op) { } void ExprVisitor::ExprVisitor::VisitExpr_(const TupleNode* op) { for (auto field : op->fields) { this->VisitExpr(field); } } void ExprVisitor::ExprVisitor::VisitExpr_(const FunctionNode* op) { for (auto param : op->params) { this->VisitExpr(param); } this->VisitExpr(op->body); } void ExprVisitor::VisitExpr_(const CallNode* op) { this->VisitExpr(op->op); for (auto ty_arg : op->type_args) { this->VisitType(ty_arg); } for (auto arg : op->args) { this->VisitExpr(arg); } } void ExprVisitor::VisitExpr_(const LetNode* op) { this->VisitExpr(op->value); this->VisitExpr(op->var); this->VisitExpr(op->body); } void ExprVisitor::VisitExpr_(const IfNode* op) { this->VisitExpr(op->cond); this->VisitExpr(op->true_branch); this->VisitExpr(op->false_branch); } void ExprVisitor::VisitExpr_(const OpNode* op) { return; } void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) { this->VisitExpr(op->tuple); } void ExprVisitor::ExprVisitor::VisitExpr_(const RefCreateNode* op) { this->VisitExpr(op->value); } void ExprVisitor::ExprVisitor::VisitExpr_(const RefReadNode* op) { this->VisitExpr(op->ref); } void ExprVisitor::ExprVisitor::VisitExpr_(const RefWriteNode* op) { this->VisitExpr(op->ref); this->VisitExpr(op->value); } void ExprVisitor::VisitExpr_(const ConstructorNode* op) { for (const Type& t : op->inputs) { this->VisitType(t); } this->VisitType(op->belong_to); } void ExprVisitor::VisitExpr_(const MatchNode* op) { this->VisitExpr(op->data); for (const Clause& c : op->clauses) { this->VisitClause(c); } } void ExprVisitor::VisitClause(const Clause& op) { this->VisitPattern(op->lhs); this->VisitExpr(op->rhs); } void ExprVisitor::VisitPattern(const Pattern& p) { return; } void ExprVisitor::VisitType(const Type& t) { return; } // visitor to implement apply class ExprApplyVisit : public ExprVisitor { public: explicit ExprApplyVisit(std::function<void(const Expr&)> f) : f_(f) {} void VisitExpr(const Expr& e) final { if (visited_.count(e.get()) != 0) return; visited_.insert(e.get()); ExprVisitor::VisitExpr(e); f_(e); } private: std::function<void(const Expr&)> f_; std::unordered_set<const Node*> visited_; }; void PostOrderVisit(const Expr& e, std::function<void(const Expr&)> fvisit) { ExprApplyVisit(fvisit).VisitExpr(e); } TVM_REGISTER_API("relay._analysis.post_order_visit") .set_body_typed<void(Expr, PackedFunc)>([](Expr expr, PackedFunc f) { PostOrderVisit(expr, [f](const Expr& n) { f(n); }); }); // Implement bind. class ExprBinder : public ExprMutator, PatternMutator { public: explicit ExprBinder(const tvm::Map<Var, Expr>& args_map) : args_map_(args_map) { } Expr VisitExpr_(const LetNode* op) final { CHECK(!args_map_.count(op->var)) << "Cannot bind an internel variable in let"; return ExprMutator::VisitExpr_(op); } Expr VisitExpr_(const FunctionNode* op) final { for (Var param : op->params) { CHECK(!args_map_.count(param)) << "Cannnot bind an internal function parameter"; } return ExprMutator::VisitExpr_(op); } Expr VisitExpr_(const VarNode* op) final { auto id = GetRef<Var>(op); auto it = args_map_.find(id); if (it != args_map_.end()) { return (*it).second; } else { return std::move(id); } } Pattern VisitPattern(const Pattern& p) final { return PatternMutator::VisitPattern(p); } Clause VisitClause(const Clause& c) final { Pattern pat = VisitPattern(c->lhs); return ClauseNode::make(pat, VisitExpr(c->rhs)); } Var VisitVar(const Var& v) final { CHECK(!args_map_.count(v)) << "Cannnot bind an internal pattern variable"; return v; } private: const tvm::Map<Var, Expr>& args_map_; }; Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) { if (const FunctionNode* func = expr.as<FunctionNode>()) { Expr new_body = ExprBinder(args_map).VisitExpr(func->body); Array<Var> new_params; for (Var param : func->params) { if (!args_map.count(param)) { new_params.push_back(param); } } if (new_body.same_as(func->body) && new_params.size() == func->params.size()) { return expr; } auto ret = FunctionNode::make(new_params, new_body, func->ret_type, func->type_params, func->attrs); std::unordered_set<Var, NodeHash, NodeEqual> set; for (const auto& v : FreeVars(expr)) { set.insert(v); } for (const auto& v : FreeVars(ret)) { if (set.count(v) == 0) { new_params.push_back(v); } } ret = FunctionNode::make(new_params, new_body, func->ret_type, func->type_params, func->attrs); CHECK_EQ(FreeVars(expr).size(), FreeVars(ret).size()); return std::move(ret); } else { return ExprBinder(args_map).VisitExpr(expr); } } TVM_REGISTER_API("relay._expr.Bind") .set_body([](TVMArgs args, TVMRetValue* ret) { NodeRef input = args[0]; if (input->IsInstance<ExprNode>()) { *ret = Bind(Downcast<Expr>(input), args[1]); } else { CHECK(input->IsInstance<TypeNode>()); *ret = Bind(Downcast<Type>(input), args[1]); } }); } // namespace relay } // namespace tvm