ir_util.cc 2.08 KB
Newer Older
1 2 3 4 5
/*!
 *  Copyright (c) 2017 by Contributors
 * \file ir_util.cc
 * \brief Helper functions to construct and compose IR nodes.
 */
6
#include "ir_util.h"
7 8 9 10 11 12 13 14

namespace tvm {
namespace ir {

Stmt MergeNest(const std::vector<Stmt>& nest, Stmt body) {
  // use reverse iteration
  for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) {
    Stmt s = *ri;
15 16
    if (const auto* for_ = s.as<For>()) {
      auto n = make_node<For>(*for_);
17 18 19
      CHECK(is_no_op(n->body));
      n->body = body;
      body = Stmt(n);
20 21
    } else if (const auto* let = s.as<LetStmt>()) {
      auto n = make_node<LetStmt>(*let);
22 23 24
      CHECK(is_no_op(n->body));
      n->body = body;
      body = Stmt(n);
25 26
    } else if (const auto* attr = s.as<AttrStmt>()) {
      auto n = make_node<AttrStmt>(*attr);
27 28 29
      CHECK(is_no_op(n->body));
      n->body = body;
      body = Stmt(n);
30 31
    } else if (const auto* ite = s.as<IfThenElse>()) {
      auto n = make_node<IfThenElse>(*ite);
32 33 34 35
      CHECK(is_no_op(n->then_case));
      CHECK(!n->else_case.defined());
      n->then_case = body;
      body = Stmt(n);
36 37
    } else if (const auto* block = s.as<Block>()) {
      auto n = make_node<Block>(*block);
38 39 40
      CHECK(is_no_op(n->rest));
      n->rest = body;
      body = Stmt(n);
41 42
    } else if (const auto* assert_ = s.as<AssertStmt>()) {
      auto n = make_node<AssertStmt>(*assert_);
43 44 45
      CHECK(is_no_op(n->body));
      n->body = body;
      body = Stmt(n);
46 47
    } else if (const auto* alloc = s.as<Allocate>()) {
      auto n = make_node<Allocate>(*alloc);
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
      CHECK(is_no_op(n->body));
      n->body = body;
      body = Stmt(n);
    } else {
      LOG(FATAL) << "not supported nest type";
    }
  }
  return body;
}

Stmt MergeNest(const std::vector<std::vector<Stmt> >& nest, Stmt body) {
  for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) {
    body = MergeNest(*ri, body);
  }
  return body;
}

Stmt MergeSeq(const std::vector<Stmt>& seq) {
  if (seq.size() == 0) return Evaluate::make(0);
  Stmt body = seq[0];
  for (size_t i = 1; i < seq.size(); ++i) {
    body = Block::make(body, seq[i]);
  }
  return body;
}

}  // namespace ir
}  // namespace tvm