compute_op.cc 19.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25 26 27 28 29 30
/*!
 *  Copyright (c) 2017 by Contributors
 * \brief Compute Op.
 * \file compute_op.cc
 */
#include <tvm/operation.h>
#include <tvm/arithmetic.h>
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_pass.h>
#include <unordered_set>
31
#include <string>
32
#include <utility>
33 34
#include "compute_op.h"
#include "op_util.h"
35
#include "../schedule/message_passing.h"
36
#include "../arithmetic/compute_expr.h"
37 38 39 40 41 42 43 44 45 46 47 48

namespace tvm {

using namespace ir;

TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ComputeOpNode>([](const ComputeOpNode *op, IRPrinter *p) {
    p->stream << "compute(" << op->name << ", " << op << ")";
});

TVM_REGISTER_NODE_TYPE(ComputeOpNode);

49 50 51
/// Verify if ComputeOp is valid with respect to Reduce operations.
static void VerifyComputeOp(const ComputeOpNode *op);

52 53 54 55 56 57 58
inline bool ReduceEqual(const ir::Reduce* a, const ir::Reduce* b) {
  return (a->combiner.same_as(b->combiner)) &&
         (a->source.same_as(b->source)) &&
         (a->axis.same_as(b->axis)) &&
         (a->condition.same_as(b->condition));
}

59
int ComputeOpNode::num_outputs() const {
60
  return body.size();
61 62
}

63
Array<IterVar> BaseComputeOpNode::root_iter_vars() const {
64 65 66 67 68 69 70 71
  if (reduce_axis.size() == 0) return axis;
  Array<IterVar> ret = axis;
  for (IterVar iv : reduce_axis) {
    ret.push_back(iv);
  }
  return ret;
}

72 73 74
Type ComputeOpNode::output_dtype(size_t idx) const {
  CHECK_LT(idx, num_outputs());
  return body[idx].type();
75 76
}

77
Array<Expr> BaseComputeOpNode::output_shape(size_t idx) const {
78
  CHECK_LT(idx, num_outputs());
79 80 81 82
  // for now, all outputs of a BaseComputeOp have the same shape
  Array<Expr> shape;
  for (const auto& ivar : this->axis) {
    const Range& r = ivar->dom;
83 84
    shape.push_back(r->extent);
  }
85
  return shape;
86 87
}

88 89 90
Tensor compute(Array<Expr> shape,
               FCompute fcompute,
               std::string name,
91 92
               std::string tag,
               Map<std::string, NodeRef> attrs) {
93
  auto op_node = make_node<ComputeOpNode>();
94 95 96 97 98 99 100 101 102 103 104 105
  // compute dimension.
  size_t ndim = shape.size();
  std::vector<IterVar> axis;
  std::vector<Var> args;
  for (size_t i = 0; i < ndim; ++i) {
    std::ostringstream os;
    os << "ax" << i;
    axis.emplace_back(IterVarNode::make(
        Range(0, shape[i]), Var(os.str(), shape[i].type()), kDataPar));
    args.push_back(axis.back()->var);
  }

106 107
  return ComputeOpNode::make(
      name, tag, attrs, axis, {fcompute(args)}).output(0);
108 109
}

110 111 112
Array<Tensor> compute(Array<Expr> shape,
                      FBatchCompute fcompute,
                      std::string name,
113 114
                      std::string tag,
                      Map<std::string, NodeRef> attrs) {
115
  auto op_node = make_node<ComputeOpNode>();
116 117 118 119 120 121 122 123 124 125 126 127
  // compute dimension.
  size_t ndim = shape.size();
  std::vector<IterVar> axis;
  std::vector<Var> args;
  for (size_t i = 0; i < ndim; ++i) {
    std::ostringstream os;
    os << "ax" << i;
    axis.emplace_back(IterVarNode::make(
        Range(0, shape[i]), Var(os.str(), shape[i].type()), kDataPar));
    args.push_back(axis.back()->var);
  }

128
  Operation op = ComputeOpNode::make(name, tag, attrs, axis, fcompute(args));
129 130 131 132 133 134 135
  Array<Tensor> outputs;
  for (int idx = 0; idx < op->num_outputs(); ++idx) {
    outputs.push_back(op.output(idx));
  }
  return outputs;
}

136
Operation ComputeOpNode::make(std::string name,
137
                              std::string tag,
138
                              Map<std::string, NodeRef> attrs,
139
                              Array<IterVar> axis,
140
                              Array<Expr> body) {
141 142 143
  if (!attrs.defined()) {
    attrs = Map<std::string, NodeRef>();
  }
144
  auto n = make_node<ComputeOpNode>();
145 146 147 148 149
  n->name = std::move(name);
  n->tag = std::move(tag);
  n->attrs = std::move(attrs);
  n->axis = std::move(axis);
  n->body = std::move(body);
150 151 152
  if (n->body[0]->is_type<ir::Reduce>()) {
    const ir::Reduce* reduce = n->body[0].as<ir::Reduce>();
    n->reduce_axis = reduce->axis;
153
  }
154
  VerifyComputeOp(n.get());
155 156 157 158 159 160 161
  return Operation(n);
}

// The schedule related logics
Array<Tensor> ComputeOpNode::InputTensors() const {
  Array<Tensor> ret;
  std::unordered_set<Tensor> visited;
162 163 164 165 166 167 168 169 170
  for (auto& e : body) {
    ir::PostOrderVisit(e, [&ret, &visited](const NodeRef& n) {
        const ir::Call *call = n.as<ir::Call>();
        if (call != nullptr && call->func.defined()) {
          Tensor t = Operation(call->func.node_).output(call->value_index);
          if (!visited.count(t)) {
            ret.push_back(t);
            visited.insert(t);
          }
171
        }
172 173
      });
  }
174 175 176 177 178 179 180
  return ret;
}

Operation ComputeOpNode::ReplaceInputs(
    const Operation& self,
    const std::unordered_map<Tensor, Tensor>& rmap) const {
  CHECK_EQ(self.operator->(), this);
181
  VerifyComputeOp(this);
182 183 184 185 186 187 188 189
  Array<Expr> arr;
  if (this->body[0]->is_type<ir::Reduce>()) {
    // Specially handle reduce so the replaced op
    // still share all the components
    Expr new_reduce = op::ReplaceTensor(this->body[0], rmap);
    if (!new_reduce.same_as(this->body[0])) {
      const ir::Reduce* r = new_reduce.as<ir::Reduce>();
      for (size_t k = 0; k < this->body.size(); ++k) {
190
        auto n = make_node<ir::Reduce>(*r);
191 192 193 194 195 196 197 198 199 200 201 202
        n->value_index = static_cast<int>(k);
        n->type = r->source[k].type();
        arr.push_back(Expr(n));
      }
    } else {
      arr = this->body;
    }
  } else {
    arr = UpdateArray(this->body, [&rmap] (const Expr& e) {
        return op::ReplaceTensor(e, rmap);
      });
  }
203
  if (!arr.same_as(this->body)) {
204 205
    return ComputeOpNode::make(
        this->name, this->tag, this->attrs, this->axis, arr);
206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227
  } else {
    return self;
  }
}

void ComputeOpNode::PropBoundToInputs(
    const Operation& self,
    const std::unordered_map<const Variable*, IntSet>& dom_map,
    std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
  CHECK_EQ(self.operator->(), this);
  auto fvisit = [&dom_map, out_dom_map](const NodeRef& n) {
    auto *call = n.as<ir::Call>();
    if (call != nullptr && call->func.defined()) {
      Tensor t = Operation(call->func.node_).output(call->value_index);
      if (t->op.defined() && out_dom_map->count(t)) {
        TensorDom& dom = out_dom_map->at(t);
        for (size_t i = 0; i < t.ndim(); ++i) {
          dom.data[i].push_back(EvalSet(call->args[i], dom_map));
        }
      }
    }
  };
228
  for (auto& e : body) ir::PostOrderVisit(e, fvisit);
229 230
}

231
void BaseComputeOpNode::GatherBound(
232 233 234
    const Operation& self,
    const std::unordered_map<Tensor, TensorDom>& tensor_dom,
    std::unordered_map<IterVar, Range>* out_dom_map) const {
235
  CHECK_EQ(self.operator->(), this);
236 237 238 239 240 241 242 243 244 245 246 247
  const TensorDom& tdom = tensor_dom.at(self.output(0));
  for (size_t i = 0; i < this->axis.size(); ++i) {
    Range r = arith::Union(tdom.data.at(i)).cover_range(this->axis[i]->dom);
    CHECK(!out_dom_map->count(this->axis[i]));
    (*out_dom_map)[this->axis[i]] = r;
  }
  for (size_t i = 0; i < this->reduce_axis.size(); ++i) {
    CHECK(!out_dom_map->count(this->reduce_axis[i]));
    (*out_dom_map)[this->reduce_axis[i]] = this->reduce_axis[i]->dom;
  }
}

248
Stmt BaseComputeOpNode::BuildRealize(
249
    const Stage& stage,
250
    const std::unordered_map<IterVar, Range>& realize_map,
251
    const Stmt& body) const {
252
  CHECK_EQ(stage->op.get(), this);
253
  HalideIR::Internal::Region bounds;
254 255 256
  for (IterVar iv : this->axis) {
    bounds.push_back(realize_map.at(iv));
  }
257
  Stmt realize = body;
258 259
  for (int i = this->num_outputs(); i > 0; --i) {
    Tensor t = stage->op.output(i-1);
260 261
    realize = ir::Realize::make(t->op, t->value_index,
      t->dtype, bounds, const_true(), realize);
262
    // alignment requirement, only useful for compute
263
    for (size_t i = 0; i < num_schedulable_dims(); ++i) {
264 265 266 267 268 269 270 271 272 273 274 275 276 277
      auto it = stage->iter_var_attrs.find(this->axis[i]);
      if (it != stage->iter_var_attrs.end()) {
        IterVarAttr attr = (*it).second;
        if (attr->dim_align_factor != 0) {
          Array<Expr> tuple = {static_cast<int>(i),
                               attr->dim_align_factor,
                               attr->dim_align_offset};
          realize = ir::AttrStmt::make(
              t, ir::attr::buffer_dim_align,
              Call::make(Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic),
              realize);
        }
      }
    }
278 279
  }
  return realize;
280 281
}

282 283 284 285
size_t ComputeOpNode::num_schedulable_dims() const {
  return axis.size();
}

286 287
// Build a reduction body.
void MakeReduction(const ComputeOpNode* op,
288
                   const Array<Tensor>& tensors,
289 290 291 292 293 294
                   Stmt* init,
                   Stmt* provide) {
  Array<Expr>  args;
  for (IterVar iv : op->axis) {
    args.push_back(iv->var);
  }
295 296 297 298
  std::vector<Stmt> inits, provides;

  size_t size = op->body.size();
  const Reduce* reduce = op->body[0].as<Reduce>();
299
  CHECK(reduce);
ziheng committed
300 301
  const CommReducerNode* combiner = reduce->combiner.as<CommReducerNode>();
  CHECK(combiner);
302 303 304 305 306 307 308 309 310 311 312 313 314 315 316
  Array<Expr> lhs;
  for (size_t i = 0; i < size; ++i) {
    lhs.push_back(tensors[i](args));
  }
  Array<Expr> init_value = combiner->identity_element;
  Array<Expr> update_value = (*combiner)(lhs, reduce->source);
  for (size_t i = 0; i < size; ++i) {
    Tensor t = tensors[i];
    inits.emplace_back(Provide::make(
          t->op, t->value_index, init_value[i], args));
    provides.emplace_back(Provide::make(
          t->op, t->value_index, update_value[i], args));
  }
  *init = Block::make(inits);
  *provide = Block::make(provides);
317 318 319
  if (!is_one(reduce->condition)) {
    *provide = IfThenElse::make(reduce->condition, *provide);
  }
320 321
}

322
// Normal computation.
323 324 325 326 327 328
Stmt MakeProvide(const ComputeOpNode* op,
                 const Tensor& t) {
  Array<Expr> args;
  for (IterVar iv : op->axis) {
    args.push_back(iv->var);
  }
329
  return Provide::make(t->op, t->value_index, op->body[t->value_index], args);
330 331
}

332 333
Stmt MakeComputeStmt(const ComputeOpNode* self,
                     const Stage& stage,
334
                     const std::unordered_map<IterVar, Range>& dom_map,
335
                     bool debug_keep_trivial_loop) {
336
  // grab the nest structure
337
  ComputeLoopNest n = ComputeLoopNest::make(self, stage, dom_map, debug_keep_trivial_loop);
338 339 340 341 342 343 344 345 346 347 348 349
  // Normal loop structure
  n.init_nest.emplace_back(op::MakeIfNest(n.init_predicates));
  n.main_nest.emplace_back(op::MakeIfNest(n.main_predicates));
  if (self->reduce_axis.size() != 0) {
    // make reduction.
    Stmt init, provide;
    Array<Tensor> source;
    for (size_t i = 0; i < self->body.size(); ++i) {
      source.push_back(stage->op.output(i));
    }
    MakeReduction(self, source, &init, &provide);
    init = MergeNest(n.init_nest, init);
350
    init = op::Substitute(init, n.init_vmap);
351 352 353 354 355 356
    // common nest
    std::vector<std::vector<Stmt> > common(
        n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1);
    std::vector<std::vector<Stmt> > reduce(
        n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.end());
    provide = MergeNest(reduce, provide);
357
    if (debug_keep_trivial_loop) {
358
      provide = MergeNest(common, provide);
359
    } else {
360
      provide = MergeNest(common, Block::make(init, provide));
361
    }
362 363 364
    // run substitution in the on the full nest, because  loop condition
    // could depend on outer loops.
    return op::Substitute(provide, n.main_vmap);
365 366 367 368 369
  } else {
    std::vector<Stmt> provides;
    for (size_t i = 0; i < self->body.size(); ++i) {
      provides.emplace_back(MakeProvide(self, stage->op.output(i)));
    }
370 371 372 373 374
    Stmt provide = Block::make(provides);
    provide = MergeNest(n.main_nest, provide);
    // run substitution in the on the full nest, because  loop condition
    // could depend on outer loops.
    return op::Substitute(provide, n.main_vmap);
375 376
  }
}
377

378 379 380 381 382 383 384
enum class ComputeType {
  kNormal,
  kCrossThreadReduction,
  kTensorize
};

ComputeType DetectComputeType(const ComputeOpNode* self,
385
                              const Stage& stage) {
386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422
  // Verify correctness of leaf nest.
  int normal_red = 0, thread_red = 0, tensorize = 0;

  for (IterVar iv : stage->leaf_iter_vars) {
    IterVarAttr attr;
    auto it = stage->iter_var_attrs.find(iv);
    if (it != stage->iter_var_attrs.end()) {
      attr = (*it).second;
    }
    if (attr.defined() && attr->iter_type == kTensorized) {
      ++tensorize;
    }
    if (iv->iter_type == kCommReduce) {
      if (attr.defined() && attr->bind_thread.defined()) {
        ++thread_red;
      } else {
        ++normal_red;
      }
    } else {
      CHECK_EQ(thread_red, 0)
          << "Cross thread reduce cannot swap with normal data axis";
    }
  }
  if (tensorize != 0) {
    CHECK(thread_red == 0)
        << "Cannot mix cross thread reduction with Tensorize";
    return ComputeType::kTensorize;
  }
  CHECK(normal_red == 0 || thread_red == 0)
      << "Cannot mix normal reduction with thread reduce";
  if (thread_red != 0) {
    return ComputeType::kCrossThreadReduction;
  } else {
    return ComputeType::kNormal;
  }
}

423 424 425
// implement the provide utility.
Stmt ComputeOpNode::BuildProvide(
    const Stage& stage,
426
    const std::unordered_map<IterVar, Range>& dom_map,
427
    bool debug_keep_trivial_loop) const {
428
  CHECK_EQ(stage->op.operator->(), this);
429 430
  ComputeType ctype = DetectComputeType(this, stage);
  if (ctype == ComputeType::kCrossThreadReduction) {
431
    // specially handle cross thread reduction.
432
    return MakeCrossThreadReduction(this, stage, dom_map, debug_keep_trivial_loop);
433
  } else if (ctype == ComputeType::kTensorize) {
434
    return MakeTensorize(this, stage, dom_map, debug_keep_trivial_loop);
435
  } else {
436
    return MakeComputeStmt(this, stage, dom_map, debug_keep_trivial_loop);
437 438 439 440
  }
}

ComputeLoopNest ComputeLoopNest::make(
441
    const BaseComputeOpNode* self,
442
    const Stage& stage,
443
    const std::unordered_map<IterVar, Range>& dom_map,
444
    bool debug_keep_trivial_loop) {
445 446 447 448
  CHECK_EQ(stage->op.operator->(), self);
  ComputeLoopNest ret;
  // make main loop nest
  ret.main_nest = op::MakeLoopNest(
449 450
      stage, dom_map, 0, false, std::unordered_set<IterVar>(), &ret.main_vmap,
      debug_keep_trivial_loop);
451 452 453
  ret.main_predicates = schedule::MakeBoundCheck(
      stage, dom_map, ret.main_vmap, false,
      std::unordered_set<IterVar>());
454 455
  for (auto& e : ret.main_predicates) {
    e = likely(e);
456
  }
457
  if (stage->store_predicate.defined()) {
458
    ret.main_predicates.push_back(stage->store_predicate);
459
  }
460
  if (self->reduce_axis.size() != 0) {
461 462 463
    // try to find the location to insert the initialization.
    // Fuse the initialization and provide loop when possible.
    std::unordered_map<IterVar, int> update_state;
464
    for (IterVar iv : self->reduce_axis) {
465 466
      update_state[iv] = 2;
    }
467 468
    for (size_t i = 0; i < self->num_schedulable_dims(); ++i) {
      update_state[self->axis[i]] = 1;
469 470
    }
    // find which iter var is related to reduction and which is related to axis.
471
    schedule::PassDownBitMaskOr(stage, &update_state);
472 473 474 475 476 477 478 479 480
    auto leaf_iter_vars = stage->leaf_iter_vars;
    // first first loop that is related to reduction.
    size_t begin_loop = leaf_iter_vars.size();
    for (size_t i = 0; i < leaf_iter_vars.size(); ++i) {
      auto iv = leaf_iter_vars[i];
      int flag = update_state.at(iv);
      if ((flag & 2) != 0) {
        begin_loop = i; break;
      }
481
      ret.init_vmap[iv] = ret.main_vmap.at(iv);
482
    }
483
    ret.num_common_loop = begin_loop;
484
    // skip loops that are related to reduction and are unrelated to axis.
485 486 487
    std::unordered_set<IterVar> skip_iter;
    for (auto kv : update_state) {
      int flag = kv.second;
488
      if (flag == 2) skip_iter.insert(kv.first);
489
    }
490
    ret.init_nest = op::MakeLoopNest(
491
        stage, dom_map, begin_loop, true,
492
        skip_iter, &(ret.init_vmap), debug_keep_trivial_loop);
493 494
    ret.init_predicates = schedule::MakeBoundCheck(
        stage, dom_map, ret.init_vmap, true, skip_iter);
495 496 497 498
    for (auto& e : ret.init_predicates) {
      e = likely(e);
    }
  } else {
499 500
    CHECK_EQ(ret.main_nest.size(), stage->leaf_iter_vars.size() + 1);
    ret.num_common_loop = stage->leaf_iter_vars.size();
501 502 503 504
  }
  // copy elison here.
  return ret;
}
505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578

namespace {
/*!
 * \brief Verify if ComputeOp is valid with respect to Reduce operations.
 *
 *  The following two properties are verified:
 *  (1) All Reduce operations must exist at top level.
 *  (2) For a list of operations, if one is Reduce, then the others
 *      must be Reduce as well; and their inputs should have the
 *      same attribute except value_index.
 */
class ComputeVerifier final : protected ir::IRVisitor {
 public:
  /// Special member functions
  //@{
  explicit ComputeVerifier(const ComputeOpNode* compute)
      : compute_(compute), reduce_(compute->body[0].as<ir::Reduce>()) {}
  virtual ~ComputeVerifier() = default;
  ComputeVerifier(const ComputeVerifier&) = delete;
  ComputeVerifier(ComputeVerifier&&) = delete;
  ComputeVerifier& operator=(const ComputeVerifier&) = delete;
  ComputeVerifier& operator=(ComputeVerifier&&) = delete;
  //@}

  /// Interface to perform compute verification
  void Run() {
    for (const Expr e : compute_->body) {
      // Check for consistency of top level reductions
      const ir::Reduce* reduce = e.as<ir::Reduce>();
      CHECK((reduce && reduce_) || (!reduce && !reduce_))
          << "All ComputeOp should be consistent "
          << "with being Reduce operation or not.";

      if (reduce && reduce_) {
        CHECK(ReduceEqual(reduce, reduce_))
            << "The Reduce inputs of ComputeOp should "
            << "have the same attribute except value_index";
      }

      level_ = 0;
      ir::IRVisitor::Visit(e);
    }
  }

 protected:
  /// Visitor implementation
  //@{
  void Visit(const NodeRef& n) final {
    ++level_;
    ir::IRVisitor::Visit(n);
    --level_;
  }

  void Visit_(const ir::Reduce* op) final {
    // Check for non top level reductions
    CHECK(0 == level_)
        << "Reductions are only allowed at the top level of compute. "
        << "Please create another tensor for further composition.";
  }
  //@}

 private:
  const ComputeOpNode* compute_{nullptr};  ///< ComputeOpNode to verify
  const ir::Reduce* reduce_{nullptr};      ///< Top level Reduce operation
  int level_{0};                           ///< Level of op being processed
};
}  // namespace

/// Verify if ComputeOp is valid with respect to Reduce operations.
static void VerifyComputeOp(const ComputeOpNode* op) {
  ComputeVerifier v(op);
  v.Run();
}

579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612
Stmt TransformUpdate(const Stage& stage,
                     const std::unordered_map<IterVar, Range>& dom_map,
                     const ComputeLoopNest& n,
                     Stmt body,
                     Stmt update) {
  Array<Expr> conds;
  std::unordered_set<const Variable*> banned;
  for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) {
    IterVar iv = stage->leaf_iter_vars[i];
    auto iit = stage->iter_var_attrs.find(iv);
    if (iit != stage->iter_var_attrs.end()) {
      const IterVarAttr& attr = (*iit).second;
      if (attr->iter_type == kTensorized) {
        break;
      }
    }
    if (iv->iter_type == kCommReduce) {
      auto vit = dom_map.find(iv);
      CHECK(vit != dom_map.end());
      const Range& vrange = vit->second;
      conds.push_back(likely(iv->var > vrange->min));
      banned.insert(iv->var.get());
    }
  }
  for (const Expr& pred : n.main_predicates) {
    if (ir::ExprUseVar(pred, banned)) {
      LOG(FATAL) << "Tensorize update transform failed, the condition "
                 << pred << " has a conflict with the reset condition";
    }
  }

  return IfThenElse::make(arith::ComputeReduce<ir::Or>(conds, const_true(1)),
                          update, body);
}
613
}  // namespace tvm