op_util.cc 8.36 KB
Newer Older
1
/*!
2
 *  Copyright (c) 2017 by Contributors
3
 * \brief Utility to make loop nest.
4
 * \file op_util.cc
5 6 7 8
 */
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/operation.h>
9
#include <tvm/ir_mutator.h>
10 11
#include <string>
#include "op_util.h"
12
#include "../schedule/message_passing.h"
13 14 15 16 17 18 19 20 21 22 23 24 25 26
#include "../arithmetic/compute_expr.h"

namespace tvm {
namespace op {

using namespace arith;
using namespace ir;

std::vector<std::vector<Stmt> >
MakeLoopNest(const Stage& stage,
             const std::unordered_map<IterVar, Range>& dom_map,
             size_t begin_iter_pos,
             bool new_loop_var,
             const std::unordered_set<IterVar>& skip_iter,
27
             std::unordered_map<IterVar, Expr>* p_value_map,
28
             bool debug_keep_trivial_loop) {
29 30 31 32 33 34 35 36 37 38 39 40 41 42
  auto leaf_iter_vars = stage->leaf_iter_vars;
  Stmt no_op = Evaluate::make(0);
  // create the loop nest
  std::vector<std::vector<Stmt> > nest;
  nest.resize(leaf_iter_vars.size() + 1);
  std::unordered_map<IterVar, Expr>& value_map = *p_value_map;

  for (size_t i = begin_iter_pos; i < leaf_iter_vars.size(); ++i) {
    auto iv = leaf_iter_vars[i];
    if (skip_iter.count(iv) || iv->iter_type == kOpaque) {
      // skip this iteration.
      value_map[iv] = iv->var;
      continue;
    }
43 44 45 46 47 48 49
    // Bind iv could be another thread.
    IterVar bind_iv = iv;
    if (stage->iter_var_attrs.count(iv)) {
      IterVar bind_thread = stage->iter_var_attrs[iv]->bind_thread;
      if (bind_thread.defined()) bind_iv = bind_thread;
    }

50
    Range dom = dom_map.at(iv);
51

52
    // initialize the offset and loop_level
53
    Var var = bind_iv->var;
54
    if (new_loop_var) {
55
      var = Var(iv->var->name_hint + ".init", bind_iv->var.type());
56 57
    }
    // Mark the iter var in the IR, to remember the point
58
    if (bind_iv->thread_tag.length() == 0) {
59
      ForType for_type = ForType::Serial;
60
      IterVarAttr it_attr;
61
      if (stage->iter_var_attrs.count(iv)) {
62 63 64 65
        it_attr = stage->iter_var_attrs[iv];
      }
      if (it_attr.defined()) {
        switch (it_attr->iter_type) {
66 67 68
          case kUnrolled: for_type = ForType::Unrolled; break;
          case kVectorized: for_type = ForType::Vectorized; break;
          case kParallelized: for_type = ForType::Parallel; break;
69
          case kDataPar: break;
70
          case kTensorized: break;
71
          default: LOG(FATAL) << "Unknown iter type"
72
                              << it_attr->iter_type
73 74
                              << " in the iter_var_attrs";
        }
75 76 77 78 79 80 81
        CHECK_EQ(it_attr->pragma_keys.size(), it_attr->pragma_values.size());
        for (size_t k = 0; k < it_attr->pragma_keys.size(); ++k) {
          const std::string& pkey = it_attr->pragma_keys[k].as<StringImm>()->value;
          Expr pvalue = it_attr->pragma_values[k];
          if (!pvalue.defined()) {
            pvalue = make_const(Int(32), 1);
          }
82
          nest[i + 1].emplace_back(
83
              AttrStmt::make(iv, ir::attr::pragma_scope_prefix + pkey, pvalue, no_op));
84
        }
85
      }
86
      if (!debug_keep_trivial_loop && is_one(dom->extent)) {
87 88 89 90 91 92 93 94 95
        nest[i + 1].emplace_back(
            LetStmt::make(var, dom->min, no_op));
        value_map[iv] = dom->min;
      } else if (is_zero(dom->min)) {
        nest[i + 1].emplace_back(
            For::make(var, 0, dom->extent,
                      for_type, DeviceAPI::None, no_op));
        value_map[iv] = var;
      } else {
96
        Var idx(bind_iv->var->name_hint + ".idx", bind_iv->var.type());
97 98 99 100 101 102 103 104
        nest[i + 1].emplace_back(
            For::make(idx, 0, dom->extent,
                      for_type, DeviceAPI::None, no_op));
        Expr new_value = dom->min + idx;
        value_map[iv] = new_value;
        nest[i + 1].emplace_back(
            LetStmt::make(var, new_value, no_op));
      }
105 106 107 108 109
      if (it_attr.defined() && it_attr->prefetch_data.size() != 0) {
        CHECK(!is_one(dom->extent))
            << "Cannot prefetch on trivial loop with extent=1";
        CHECK_EQ(it_attr->prefetch_data.size(),
                 it_attr->prefetch_offset.size());
110
        for (size_t j = 0; j < it_attr->prefetch_data.size(); ++j) {
111
          nest[i + 1].emplace_back(
112
              AttrStmt::make(it_attr->prefetch_data[j],
113
                             ir::attr::prefetch_scope,
114
                             it_attr->prefetch_offset[j], no_op));
115 116
        }
      }
117 118
    } else if (bind_iv->thread_tag == "vthread" ||
               bind_iv->thread_tag == "cthread") {
119 120 121 122 123 124
      // virtual thread
      // Always restrict threaded IterVar to starts from 0.
      CHECK(is_zero(dom->min));
      CHECK(is_positive_const(dom->extent));
      // annotate the extent of the IterVar
      nest[i + 1].emplace_back(
125
          AttrStmt::make(bind_iv, ir::attr::virtual_thread, dom->extent, no_op));
126
      value_map[iv] = var;
127
    } else if (bind_iv->thread_tag == "pipeline") {
128 129 130 131 132
      // pipeline marker.
      CHECK(is_zero(dom->min));
      CHECK(is_one(dom->extent));
      // annotate the extent of the IterVar
      nest[i + 1].emplace_back(
133
          AttrStmt::make(bind_iv, ir::attr::pipeline_exec_scope, dom->extent, no_op));
134
      value_map[iv] = dom->min;
135 136 137 138 139
    } else {
      // Always restrict threaded IterVar to starts from 0.
      CHECK(is_zero(dom->min));
      // annotate the extent of the IterVar
      nest[i + 1].emplace_back(
140
          AttrStmt::make(bind_iv, ir::attr::thread_extent, dom->extent, no_op));
141
      if (!debug_keep_trivial_loop && is_one(dom->extent)) {
142 143 144 145
        value_map[iv] = dom->min;
      } else {
        value_map[iv] = var;
      }
146 147 148 149 150 151 152 153
    }
    // annotate the extent of the IterVar
    if (!new_loop_var) {
      nest[i + 1].emplace_back(
          AttrStmt::make(iv, attr::loop_scope, iv->var, no_op));
    }
  }
  // message passing to get offset of root iter vars.
154
  schedule::PassUpIndex(stage, dom_map, &value_map);
155 156 157
  return nest;
}

158 159 160 161 162 163
std::vector<Stmt> MakeIfNest(const std::vector<Expr>& predicates) {
  Stmt no_op = Evaluate::make(0);
  std::vector<Stmt> nest;
  for (const Expr& cond : predicates) {
    nest.emplace_back(IfThenElse::make(cond, no_op));
  }
164 165 166
  return nest;
}

167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197
// replacer to replace tensors' usage in Provide
class ProviderReplacer : public ir::IRMutator {
 public:
  explicit ProviderReplacer(const std::unordered_map<Tensor, Tensor>& vmap)
      : vmap_(vmap) {}

  Stmt Mutate_(const ir::Provide* op, const Stmt& s) {
    Tensor t = Operation(op->func.node_).output(op->value_index);
    auto it = vmap_.find(t);
    if (it != vmap_.end()) {
      Stmt ret = ir::Provide::make(
        it->second->op, it->second->value_index, op->value, op->args);
      found = true;
      return IRMutator::Mutate_(ret.as<ir::Provide>(), ret);
    }
    return IRMutator::Mutate_(op, s);
  }

  // whether it is found.
  bool found{false};

 private:
  const std::unordered_map<Tensor, Tensor>& vmap_;
};

Stmt ReplaceProvideTensor(Stmt stmt,
                   const std::unordered_map<Tensor, Tensor>& replace) {
  ProviderReplacer repl(replace);
  Stmt ret = repl.Mutate(stmt);
  return repl.found ? ret : stmt;
}
198 199 200 201 202 203

// replacer to replace tensors
class TensorReplacer : public ir::IRMutator {
 public:
  explicit TensorReplacer(const std::unordered_map<Tensor, Tensor>& vmap)
      : vmap_(vmap) {}
204

205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238
  Expr Mutate_(const ir::Call* op, const Expr& e) {
    if (op->call_type == ir::Call::Halide) {
      Tensor t = Operation(op->func.node_).output(op->value_index);
      auto it = vmap_.find(t);
      if (it != vmap_.end()) {
        Expr ret = ir::Call::make(
            op->type, it->second->op->name, op->args,
            op->call_type, it->second->op, it->second->value_index);
        found = true;
        return IRMutator::Mutate_(ret.as<ir::Call>(), ret);
      }
    }
    return IRMutator::Mutate_(op, e);
  }

  // whether it is found.
  bool found{false};

 private:
  const std::unordered_map<Tensor, Tensor>& vmap_;
};

Stmt ReplaceTensor(Stmt stmt,
                   const std::unordered_map<Tensor, Tensor>& replace) {
  TensorReplacer repl(replace);
  Stmt ret = repl.Mutate(stmt);
  return repl.found ? ret : stmt;
}
Expr ReplaceTensor(Expr expr,
                   const std::unordered_map<Tensor, Tensor>& replace) {
  TensorReplacer repl(replace);
  Expr ret = repl.Mutate(expr);
  return repl.found ? ret : expr;
}
239 240 241 242 243 244 245 246 247 248 249


Stmt Substitute(Stmt s,
                const std::unordered_map<IterVar, Expr>& value_map) {
  std::unordered_map<const Variable*, Expr> init;
  for (const auto& kv : value_map) {
    init[kv.first->var.get()] = kv.second;
  }
  return ir::Substitute(s, init);
}

250 251
}  // namespace op
}  // namespace tvm