op_util.cc 7.42 KB
Newer Older
1
/*!
2
 *  Copyright (c) 2017 by Contributors
3
 * \brief Utility to make loop nest.
4
 * \file op_util.cc
5 6 7 8
 */
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/operation.h>
9 10
#include <tvm/ir_mutator.h>
#include "./op_util.h"
11
#include "../schedule/message_passing.h"
12 13 14 15 16 17 18 19 20 21 22 23 24 25
#include "../arithmetic/compute_expr.h"

namespace tvm {
namespace op {

using namespace arith;
using namespace ir;

std::vector<std::vector<Stmt> >
MakeLoopNest(const Stage& stage,
             const std::unordered_map<IterVar, Range>& dom_map,
             size_t begin_iter_pos,
             bool new_loop_var,
             const std::unordered_set<IterVar>& skip_iter,
26
             std::unordered_map<IterVar, Expr>* p_value_map,
27
             bool debug_keep_trivial_loop) {
28 29 30 31 32 33 34 35 36 37 38 39 40 41
  auto leaf_iter_vars = stage->leaf_iter_vars;
  Stmt no_op = Evaluate::make(0);
  // create the loop nest
  std::vector<std::vector<Stmt> > nest;
  nest.resize(leaf_iter_vars.size() + 1);
  std::unordered_map<IterVar, Expr>& value_map = *p_value_map;

  for (size_t i = begin_iter_pos; i < leaf_iter_vars.size(); ++i) {
    auto iv = leaf_iter_vars[i];
    if (skip_iter.count(iv) || iv->iter_type == kOpaque) {
      // skip this iteration.
      value_map[iv] = iv->var;
      continue;
    }
42 43 44 45 46 47 48
    // Bind iv could be another thread.
    IterVar bind_iv = iv;
    if (stage->iter_var_attrs.count(iv)) {
      IterVar bind_thread = stage->iter_var_attrs[iv]->bind_thread;
      if (bind_thread.defined()) bind_iv = bind_thread;
    }

49
    Range dom = dom_map.at(iv);
50

51
    // initialize the offset and loop_level
52
    Var var = bind_iv->var;
53
    if (new_loop_var) {
54
      var = Var(iv->var->name_hint + ".init", bind_iv->var.type());
55 56
    }
    // Mark the iter var in the IR, to remember the point
57
    if (bind_iv->thread_tag.length() == 0) {
58
      ForType for_type = ForType::Serial;
59
      IterVarAttr it_attr;
60
      if (stage->iter_var_attrs.count(iv)) {
61 62 63 64
        it_attr = stage->iter_var_attrs[iv];
      }
      if (it_attr.defined()) {
        switch (it_attr->iter_type) {
65 66 67
          case kUnrolled: for_type = ForType::Unrolled; break;
          case kVectorized: for_type = ForType::Vectorized; break;
          case kParallelized: for_type = ForType::Parallel; break;
68
          case kDataPar: break;
69
          case kTensorized: break;
70
          default: LOG(FATAL) << "Unknown iter type"
71
                              << it_attr->iter_type
72 73
                              << " in the iter_var_attrs";
        }
74 75 76 77 78 79 80
        CHECK_EQ(it_attr->pragma_keys.size(), it_attr->pragma_values.size());
        for (size_t k = 0; k < it_attr->pragma_keys.size(); ++k) {
          const std::string& pkey = it_attr->pragma_keys[k].as<StringImm>()->value;
          Expr pvalue = it_attr->pragma_values[k];
          if (!pvalue.defined()) {
            pvalue = make_const(Int(32), 1);
          }
81
          nest[i + 1].emplace_back(
82
              AttrStmt::make(iv, ir::attr::pragma_scope_prefix + pkey, pvalue, no_op));
83
        }
84
      }
85
      if (!debug_keep_trivial_loop && is_one(dom->extent)) {
86 87 88 89 90 91 92 93 94
        nest[i + 1].emplace_back(
            LetStmt::make(var, dom->min, no_op));
        value_map[iv] = dom->min;
      } else if (is_zero(dom->min)) {
        nest[i + 1].emplace_back(
            For::make(var, 0, dom->extent,
                      for_type, DeviceAPI::None, no_op));
        value_map[iv] = var;
      } else {
95
        Var idx(bind_iv->var->name_hint + ".idx", bind_iv->var.type());
96 97 98 99 100 101 102 103
        nest[i + 1].emplace_back(
            For::make(idx, 0, dom->extent,
                      for_type, DeviceAPI::None, no_op));
        Expr new_value = dom->min + idx;
        value_map[iv] = new_value;
        nest[i + 1].emplace_back(
            LetStmt::make(var, new_value, no_op));
      }
104 105 106 107 108
      if (it_attr.defined() && it_attr->prefetch_data.size() != 0) {
        CHECK(!is_one(dom->extent))
            << "Cannot prefetch on trivial loop with extent=1";
        CHECK_EQ(it_attr->prefetch_data.size(),
                 it_attr->prefetch_offset.size());
109
        for (size_t j = 0; j < it_attr->prefetch_data.size(); ++j) {
110
          nest[i + 1].emplace_back(
111
              AttrStmt::make(it_attr->prefetch_data[j],
112
                             ir::attr::prefetch_scope,
113
                             it_attr->prefetch_offset[j], no_op));
114 115
        }
      }
116 117
    } else if (bind_iv->thread_tag == "vthread" ||
               bind_iv->thread_tag == "cthread") {
118 119 120 121 122 123
      // virtual thread
      // Always restrict threaded IterVar to starts from 0.
      CHECK(is_zero(dom->min));
      CHECK(is_positive_const(dom->extent));
      // annotate the extent of the IterVar
      nest[i + 1].emplace_back(
124
          AttrStmt::make(bind_iv, ir::attr::virtual_thread, dom->extent, no_op));
125
      value_map[iv] = var;
126
    } else if (bind_iv->thread_tag == "pipeline") {
127 128 129 130 131
      // pipeline marker.
      CHECK(is_zero(dom->min));
      CHECK(is_one(dom->extent));
      // annotate the extent of the IterVar
      nest[i + 1].emplace_back(
132
          AttrStmt::make(bind_iv, ir::attr::pipeline_exec_scope, dom->extent, no_op));
133
      value_map[iv] = dom->min;
134 135 136 137 138
    } else {
      // Always restrict threaded IterVar to starts from 0.
      CHECK(is_zero(dom->min));
      // annotate the extent of the IterVar
      nest[i + 1].emplace_back(
139
          AttrStmt::make(bind_iv, ir::attr::thread_extent, dom->extent, no_op));
140
      if (!debug_keep_trivial_loop && is_one(dom->extent)) {
141 142 143 144
        value_map[iv] = dom->min;
      } else {
        value_map[iv] = var;
      }
145 146 147 148 149 150 151 152
    }
    // annotate the extent of the IterVar
    if (!new_loop_var) {
      nest[i + 1].emplace_back(
          AttrStmt::make(iv, attr::loop_scope, iv->var, no_op));
    }
  }
  // message passing to get offset of root iter vars.
153
  schedule::PassUpIndex(stage, dom_map, &value_map);
154 155 156
  return nest;
}

157 158 159 160 161 162
std::vector<Stmt> MakeIfNest(const std::vector<Expr>& predicates) {
  Stmt no_op = Evaluate::make(0);
  std::vector<Stmt> nest;
  for (const Expr& cond : predicates) {
    nest.emplace_back(IfThenElse::make(cond, no_op));
  }
163 164 165
  return nest;
}

166 167 168 169 170 171

// replacer to replace tensors
class TensorReplacer : public ir::IRMutator {
 public:
  explicit TensorReplacer(const std::unordered_map<Tensor, Tensor>& vmap)
      : vmap_(vmap) {}
172

173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206
  Expr Mutate_(const ir::Call* op, const Expr& e) {
    if (op->call_type == ir::Call::Halide) {
      Tensor t = Operation(op->func.node_).output(op->value_index);
      auto it = vmap_.find(t);
      if (it != vmap_.end()) {
        Expr ret = ir::Call::make(
            op->type, it->second->op->name, op->args,
            op->call_type, it->second->op, it->second->value_index);
        found = true;
        return IRMutator::Mutate_(ret.as<ir::Call>(), ret);
      }
    }
    return IRMutator::Mutate_(op, e);
  }

  // whether it is found.
  bool found{false};

 private:
  const std::unordered_map<Tensor, Tensor>& vmap_;
};

Stmt ReplaceTensor(Stmt stmt,
                   const std::unordered_map<Tensor, Tensor>& replace) {
  TensorReplacer repl(replace);
  Stmt ret = repl.Mutate(stmt);
  return repl.found ? ret : stmt;
}
Expr ReplaceTensor(Expr expr,
                   const std::unordered_map<Tensor, Tensor>& replace) {
  TensorReplacer repl(replace);
  Expr ret = repl.Mutate(expr);
  return repl.found ? ret : expr;
}
207 208 209 210 211 212 213 214 215 216 217


Stmt Substitute(Stmt s,
                const std::unordered_map<IterVar, Expr>& value_map) {
  std::unordered_map<const Variable*, Expr> init;
  for (const auto& kv : value_map) {
    init[kv.first->var.get()] = kv.second;
  }
  return ir::Substitute(s, init);
}

218 219
}  // namespace op
}  // namespace tvm