/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file tvm/arithmetic/ir_mutator_with_analyzer.cc */ #include <tvm/ir_pass.h> #include <tvm/expr_operator.h> #include "ir_mutator_with_analyzer.h" namespace tvm { namespace arith { using namespace ir; Stmt IRMutatorWithAnalyzer:: Mutate_(const For* op, const Stmt& s) { analyzer_->Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent)); return IRMutator::Mutate_(op, s); } Stmt IRMutatorWithAnalyzer:: Mutate_(const LetStmt* op, const Stmt& s) { Expr value = this->Mutate(op->value); if (!ir::HasSideEffect(value)) { analyzer_->Bind(op->var, value); } // We keep the let-binding here // as sub-class may or maynot choose to replace it. Stmt body = this->Mutate(op->body); if (value.same_as(op->value) && body.same_as(op->body)) { return s; } else { return LetStmt::make(op->var, value, body); } } Stmt IRMutatorWithAnalyzer:: Mutate_(const IfThenElse* op, const Stmt& s) { Expr condition = this->Mutate(op->condition); Stmt then_case, else_case; { With<ConstraintContext> ctx(analyzer_, condition); then_case = this->Mutate(op->then_case); } if (op->else_case.defined()) { With<ConstraintContext> ctx(analyzer_, analyzer_->rewrite_simplify(Not::make(condition))); else_case = this->Mutate(op->else_case); } if (is_one(condition)) return then_case; if (is_zero(condition)) { if (else_case.defined()) { return else_case; } return Evaluate::make(0); } if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) { return s; } else { return IfThenElse::make(condition, then_case, else_case); } } Stmt IRMutatorWithAnalyzer:: Mutate_(const AttrStmt* op, const Stmt& s) { if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { IterVar iv = Downcast<IterVar>(op->node); CHECK_NE(iv->thread_tag.length(), 0U); analyzer_->Bind(iv->var, Range::make_by_min_extent(0, op->value)); Stmt stmt = IRMutator::Mutate_(op, s); return stmt; } else { return IRMutator::Mutate_(op, s); } } Stmt IRMutatorWithAnalyzer:: Mutate_(const AssertStmt* op, const Stmt& s) { Expr condition = this->Mutate(op->condition); Expr message = this->Mutate(op->message); With<ConstraintContext> ctx(analyzer_, condition); Stmt body = this->Mutate(op->body); if (condition.same_as(op->condition) && message.same_as(op->message) && body.same_as(op->body)) { return s; } else { return AssertStmt::make(condition, message, body); } } Expr IRMutatorWithAnalyzer:: Mutate_(const Call* op, const Expr& self) { // add condition context to if_then_else if (op->is_intrinsic(ir::intrinsic::tvm_if_then_else)) { Expr cond = Mutate(op->args[0]); Expr true_value, false_value; { With<ConstraintContext> constraint(analyzer_, cond); true_value = Mutate(op->args[1]); } { With<ConstraintContext> constraint(analyzer_, analyzer_->rewrite_simplify(Not::make(cond))); false_value = Mutate(op->args[2]); } if (is_zero(cond)) { return false_value; } if (is_one(cond)) { return true_value; } if (cond.same_as(op->args[0]) && true_value.same_as(op->args[1]) && false_value.same_as(op->args[2])) { return self; } else { return Call::make(op->type, op->name, {cond, true_value, false_value}, op->call_type); } } return IRMutator::Mutate_(op, self); } Expr IRMutatorWithAnalyzer:: Mutate_(const Let* op, const Expr& self) { Expr value = this->Mutate(op->value); if (!ir::HasSideEffect(value)) { analyzer_->Bind(op->var, value); } // We keep the let-binding here // as sub-class may or maynot choose to replace it. Expr body = this->Mutate(op->body); if (value.same_as(op->value) && body.same_as(op->body)) { return self; } else { return Let::make(op->var, value, body); } } Expr IRMutatorWithAnalyzer:: Mutate_(const Select* op, const Expr& self) { Expr cond = Mutate(op->condition); Expr true_value, false_value; { With<ConstraintContext> constraint(analyzer_, cond); true_value = Mutate(op->true_value); } { With<ConstraintContext> constraint(analyzer_, analyzer_->rewrite_simplify(Not::make(cond))); false_value = Mutate(op->false_value); } if (is_zero(cond)) { return false_value; } if (is_one(cond)) { return true_value; } // normal path if (cond.same_as(op->condition) && true_value.same_as(op->true_value) && false_value.same_as(op->false_value)) { return self; } else { return Select::make(cond, true_value, false_value); } } Expr IRMutatorWithAnalyzer:: Mutate_(const Reduce* op, const Expr& self) { // Setup the domain information before simplification. for (const IterVar& iv : op->axis) { analyzer_->Bind(iv->var, iv->dom); } // Recursively call simplification when necessary. return IRMutator::Mutate_(op, self); } } // namespace arith } // namespace tvm