ir_mutator_with_analyzer.cc 6.41 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
21
 * \file tvm/arith/ir_mutator_with_analyzer.cc
22
 */
23 24
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/op.h>
25 26 27 28 29
#include "ir_mutator_with_analyzer.h"

namespace tvm {
namespace arith {

30
using namespace tir;
31 32

Stmt IRMutatorWithAnalyzer::
33
VisitStmt_(const ForNode* op) {
34
  analyzer_->Bind(op->loop_var,
35 36
                  Range::make_by_min_extent(op->min, op->extent));
  return StmtExprMutator::VisitStmt_(op);
37 38 39
}

Stmt IRMutatorWithAnalyzer::
40
VisitStmt_(const LetStmtNode* op) {
41
  PrimExpr value = this->VisitExpr(op->value);
42
  if (!tir::HasSideEffect(value)) {
43 44
    analyzer_->Bind(op->var, value);
  }
45 46
  // We keep the let-binding here
  // as sub-class may or maynot choose to replace it.
47
  Stmt body = this->VisitStmt(op->body);
48 49
  if (value.same_as(op->value) &&
      body.same_as(op->body)) {
50
    return GetRef<Stmt>(op);
51
  } else {
52 53 54 55
    auto n = this->CopyOnWrite(op);
    n->value = std::move(value);
    n->body = std::move(body);
    return Stmt(n);
56 57 58 59
  }
}

Stmt IRMutatorWithAnalyzer::
60
VisitStmt_(const IfThenElseNode* op) {
61
  PrimExpr condition = this->VisitExpr(op->condition);
62 63 64
  Stmt then_case, else_case;
  {
    With<ConstraintContext> ctx(analyzer_, condition);
65
    then_case = this->VisitStmt(op->then_case);
66 67 68
  }
  if (op->else_case.defined()) {
      With<ConstraintContext> ctx(analyzer_,
69
                                  analyzer_->rewrite_simplify(NotNode::make(condition)));
70
      else_case = this->VisitStmt(op->else_case);
71 72 73 74 75 76
  }
  if (is_one(condition)) return then_case;
  if (is_zero(condition)) {
    if (else_case.defined()) {
      return else_case;
    }
77
    return EvaluateNode::make(0);
78 79 80 81 82
  }

  if (condition.same_as(op->condition) &&
      then_case.same_as(op->then_case) &&
      else_case.same_as(op->else_case)) {
83
    return GetRef<Stmt>(op);
84
  } else {
85 86 87 88 89
    auto n = this->CopyOnWrite(op);
    n->condition = std::move(condition);
    n->then_case = std::move(then_case);
    n->else_case = std::move(else_case);
    return Stmt(n);
90 91 92 93
  }
}

Stmt IRMutatorWithAnalyzer::
94
VisitStmt_(const AttrStmtNode* op) {
95 96
  if (op->attr_key == tir::attr::thread_extent ||
      op->attr_key == tir::attr::virtual_thread) {
97
    IterVar iv = Downcast<IterVar>(op->node);
98 99 100
    CHECK_NE(iv->thread_tag.length(), 0U);
    analyzer_->Bind(iv->var,
                    Range::make_by_min_extent(0, op->value));
101
    Stmt stmt = StmtExprMutator::VisitStmt_(op);
102 103
    return stmt;
  } else {
104
    return StmtExprMutator::VisitStmt_(op);
105 106 107 108
  }
}

Stmt IRMutatorWithAnalyzer::
109
VisitStmt_(const AssertStmtNode* op) {
110 111
  PrimExpr condition = this->VisitExpr(op->condition);
  PrimExpr message = this->VisitExpr(op->message);
112
  With<ConstraintContext> ctx(analyzer_, condition);
113
  Stmt body = this->VisitStmt(op->body);
114 115 116 117

  if (condition.same_as(op->condition) &&
      message.same_as(op->message) &&
      body.same_as(op->body)) {
118
    return GetRef<Stmt>(op);
119
  } else {
120 121 122 123 124
    auto n = this->CopyOnWrite(op);
    n->condition = std::move(condition);
    n->message = std::move(message);
    n->body = std::move(body);
    return Stmt(n);
125 126 127
  }
}

128
PrimExpr IRMutatorWithAnalyzer::
129
VisitExpr_(const CallNode* op) {
130
  // add condition context to if_then_else
131
  if (op->is_intrinsic(tir::intrinsic::tvm_if_then_else)) {
132 133
    PrimExpr cond = this->VisitExpr(op->args[0]);
    PrimExpr true_value, false_value;
134 135
    {
      With<ConstraintContext> constraint(analyzer_, cond);
136
      true_value = this->VisitExpr(op->args[1]);
137 138 139
    }
    {
      With<ConstraintContext> constraint(analyzer_,
140
                                         analyzer_->rewrite_simplify(NotNode::make(cond)));
141
      false_value = this->VisitExpr(op->args[2]);
142 143 144 145 146 147 148 149 150 151
    }
    if (is_zero(cond)) {
      return false_value;
    }
    if (is_one(cond)) {
      return true_value;
    }
    if (cond.same_as(op->args[0]) &&
        true_value.same_as(op->args[1]) &&
        false_value.same_as(op->args[2])) {
152
      return GetRef<PrimExpr>(op);
153
    } else {
154
      return CallNode::make(op->dtype, op->name,
155 156 157 158
                        {cond, true_value, false_value},
                        op->call_type);
    }
  }
159
  return StmtExprMutator::VisitExpr_(op);
160 161
}

162
PrimExpr IRMutatorWithAnalyzer::
163
VisitExpr_(const LetNode* op) {
164
  PrimExpr value = this->VisitExpr(op->value);
165
  if (!tir::HasSideEffect(value)) {
166 167
    analyzer_->Bind(op->var, value);
  }
168 169
  // We keep the let-binding here
  // as sub-class may or maynot choose to replace it.
170
  PrimExpr body = this->VisitExpr(op->body);
171 172
  if (value.same_as(op->value) &&
      body.same_as(op->body)) {
173
    return GetRef<PrimExpr>(op);
174
  } else {
175
    return LetNode::make(op->var, value, body);
176 177 178
  }
}

179
PrimExpr IRMutatorWithAnalyzer::
180
VisitExpr_(const SelectNode* op) {
181 182
  PrimExpr cond = this->VisitExpr(op->condition);
  PrimExpr true_value, false_value;
183 184
  {
    With<ConstraintContext> constraint(analyzer_, cond);
185
    true_value = VisitExpr(op->true_value);
186 187 188
  }
  {
    With<ConstraintContext> constraint(analyzer_,
189
                                       analyzer_->rewrite_simplify(NotNode::make(cond)));
190
    false_value = VisitExpr(op->false_value);
191 192 193 194 195 196 197 198 199 200 201
  }
  if (is_zero(cond)) {
    return false_value;
  }
  if (is_one(cond)) {
    return true_value;
  }
  // normal path
  if (cond.same_as(op->condition) &&
      true_value.same_as(op->true_value) &&
      false_value.same_as(op->false_value)) {
202
    return GetRef<PrimExpr>(op);
203
  } else {
204
    return SelectNode::make(cond, true_value, false_value);
205 206 207
  }
}

208
PrimExpr IRMutatorWithAnalyzer::
209
VisitExpr_(const ReduceNode* op) {
210 211 212 213 214
  // Setup the domain information before simplification.
  for (const IterVar& iv : op->axis) {
    analyzer_->Bind(iv->var, iv->dom);
  }
  // Recursively call simplification when necessary.
215
  return StmtExprMutator::VisitExpr_(op);
216 217 218 219
}

}  // namespace arith
}  // namespace tvm