/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * Copyright (c) 2016 by Contributors * \file ir_visitor.cc */ #include <tvm/ir.h> #include <tvm/ir_visitor.h> #include <unordered_set> namespace tvm { namespace ir { // visitor to implement apply class IRApplyVisit : public IRVisitor { public: explicit IRApplyVisit(std::function<void(const NodeRef&)> f) : f_(f) {} void Visit(const NodeRef& node) final { if (visited_.count(node.get()) != 0) return; visited_.insert(node.get()); IRVisitor::Visit(node); f_(node); } private: std::function<void(const NodeRef&)> f_; std::unordered_set<const Node*> visited_; }; void PostOrderVisit(const NodeRef& node, std::function<void(const NodeRef&)> fvisit) { IRApplyVisit(fvisit).Visit(node); } IRVisitor::FVisit& IRVisitor::vtable() { // NOLINT(*) static FVisit inst; return inst; } inline void VisitArray(const Array<Expr>& arr, IRVisitor* v) { for (size_t i = 0; i < arr.size(); i++) { v->Visit(arr[i]); } } inline void VisitRDom(const Array<IterVar>& rdom, IRVisitor* v) { for (size_t i = 0; i < rdom.size(); i++) { Range r = rdom[i]->dom; v->Visit(r->min); v->Visit(r->extent); } } void IRVisitor::Visit_(const Variable* op) {} void IRVisitor::Visit_(const LetStmt *op) { this->Visit(op->value); this->Visit(op->body); } void IRVisitor::Visit_(const AttrStmt* op) { this->Visit(op->value); this->Visit(op->body); } void IRVisitor::Visit_(const For *op) { IRVisitor* v = this; v->Visit(op->min); v->Visit(op->extent); v->Visit(op->body); } void IRVisitor::Visit_(const Allocate *op) { IRVisitor* v = this; for (size_t i = 0; i < op->extents.size(); i++) { v->Visit(op->extents[i]); } v->Visit(op->body); v->Visit(op->condition); if (op->new_expr.defined()) { v->Visit(op->new_expr); } } void IRVisitor::Visit_(const Load *op) { this->Visit(op->index); this->Visit(op->predicate); } void IRVisitor::Visit_(const Store *op) { this->Visit(op->value); this->Visit(op->index); this->Visit(op->predicate); } void IRVisitor::Visit_(const IfThenElse *op) { this->Visit(op->condition); this->Visit(op->then_case); if (op->else_case.defined()) { this->Visit(op->else_case); } } void IRVisitor::Visit_(const Let *op) { this->Visit(op->value); this->Visit(op->body); } void IRVisitor::Visit_(const Free* op) {} void IRVisitor::Visit_(const Call *op) { VisitArray(op->args, this); } #define DEFINE_BINOP_VISIT_(OP) \ void IRVisitor::Visit_(const OP* op) { \ this->Visit(op->a); \ this->Visit(op->b); \ } DEFINE_BINOP_VISIT_(Add) DEFINE_BINOP_VISIT_(Sub) DEFINE_BINOP_VISIT_(Mul) DEFINE_BINOP_VISIT_(Div) DEFINE_BINOP_VISIT_(Mod) DEFINE_BINOP_VISIT_(Min) DEFINE_BINOP_VISIT_(Max) DEFINE_BINOP_VISIT_(EQ) DEFINE_BINOP_VISIT_(NE) DEFINE_BINOP_VISIT_(LT) DEFINE_BINOP_VISIT_(LE) DEFINE_BINOP_VISIT_(GT) DEFINE_BINOP_VISIT_(GE) DEFINE_BINOP_VISIT_(And) DEFINE_BINOP_VISIT_(Or) void IRVisitor::Visit_(const Reduce* op) { VisitRDom(op->axis, this); VisitArray(op->source, this); this->Visit(op->condition); } void IRVisitor::Visit_(const Cast* op) { this->Visit(op->value); } void IRVisitor::Visit_(const Not* op) { this->Visit(op->a); } void IRVisitor::Visit_(const Select* op) { this->Visit(op->condition); this->Visit(op->true_value); this->Visit(op->false_value); } void IRVisitor::Visit_(const Ramp *op) { this->Visit(op->base); this->Visit(op->stride); } void IRVisitor::Visit_(const Broadcast *op) { this->Visit(op->value); } void IRVisitor::Visit_(const AssertStmt *op) { this->Visit(op->condition); this->Visit(op->message); this->Visit(op->body); } void IRVisitor::Visit_(const ProducerConsumer *op) { this->Visit(op->body); } void IRVisitor::Visit_(const Provide *op) { VisitArray(op->args, this); this->Visit(op->value); } void IRVisitor::Visit_(const Realize *op) { for (size_t i = 0; i < op->bounds.size(); i++) { this->Visit(op->bounds[i]->min); this->Visit(op->bounds[i]->extent); } this->Visit(op->body); this->Visit(op->condition); } void IRVisitor::Visit_(const Prefetch *op) { for (size_t i = 0; i < op->bounds.size(); i++) { this->Visit(op->bounds[i]->min); this->Visit(op->bounds[i]->extent); } } void IRVisitor::Visit_(const Block *op) { this->Visit(op->first); this->Visit(op->rest); } void IRVisitor::Visit_(const Evaluate *op) { this->Visit(op->value); } #define DEFINE_OP_NO_VISIT_(OP) \ void IRVisitor::Visit_(const OP* op) {} DEFINE_OP_NO_VISIT_(IntImm) DEFINE_OP_NO_VISIT_(UIntImm) DEFINE_OP_NO_VISIT_(FloatImm) DEFINE_OP_NO_VISIT_(StringImm) #define DISPATCH_TO_VISIT(OP) \ set_dispatch<OP>([](const OP* op, IRVisitor* v) { \ v->Visit_(op); \ }) TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) .DISPATCH_TO_VISIT(Variable) .DISPATCH_TO_VISIT(LetStmt) .DISPATCH_TO_VISIT(AttrStmt) .DISPATCH_TO_VISIT(IfThenElse) .DISPATCH_TO_VISIT(For) .DISPATCH_TO_VISIT(Allocate) .DISPATCH_TO_VISIT(Load) .DISPATCH_TO_VISIT(Store) .DISPATCH_TO_VISIT(Let) .DISPATCH_TO_VISIT(Free) .DISPATCH_TO_VISIT(Call) .DISPATCH_TO_VISIT(Add) .DISPATCH_TO_VISIT(Sub) .DISPATCH_TO_VISIT(Mul) .DISPATCH_TO_VISIT(Div) .DISPATCH_TO_VISIT(Mod) .DISPATCH_TO_VISIT(Min) .DISPATCH_TO_VISIT(Max) .DISPATCH_TO_VISIT(EQ) .DISPATCH_TO_VISIT(NE) .DISPATCH_TO_VISIT(LT) .DISPATCH_TO_VISIT(LE) .DISPATCH_TO_VISIT(GT) .DISPATCH_TO_VISIT(GE) .DISPATCH_TO_VISIT(And) .DISPATCH_TO_VISIT(Or) .DISPATCH_TO_VISIT(Reduce) .DISPATCH_TO_VISIT(Cast) .DISPATCH_TO_VISIT(Not) .DISPATCH_TO_VISIT(Select) .DISPATCH_TO_VISIT(Ramp) .DISPATCH_TO_VISIT(Broadcast) .DISPATCH_TO_VISIT(AssertStmt) .DISPATCH_TO_VISIT(ProducerConsumer) .DISPATCH_TO_VISIT(Provide) .DISPATCH_TO_VISIT(Realize) .DISPATCH_TO_VISIT(Block) .DISPATCH_TO_VISIT(Evaluate) .DISPATCH_TO_VISIT(IntImm) .DISPATCH_TO_VISIT(UIntImm) .DISPATCH_TO_VISIT(FloatImm) .DISPATCH_TO_VISIT(StringImm) .DISPATCH_TO_VISIT(Prefetch); } // namespace ir } // namespace tvm