compute_op.cc 11.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11
/*!
 *  Copyright (c) 2017 by Contributors
 * \brief Compute Op.
 * \file compute_op.cc
 */
#include <tvm/operation.h>
#include <tvm/arithmetic.h>
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_pass.h>
#include <unordered_set>
12
#include "./op_util.h"
13
#include "../schedule/message_passing.h"
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67

namespace tvm {

using namespace ir;

TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<ComputeOpNode>([](const ComputeOpNode *op, IRPrinter *p) {
    p->stream << "compute(" << op->name << ", " << op << ")";
});

TVM_REGISTER_NODE_TYPE(ComputeOpNode);

int ComputeOpNode::num_outputs() const {
  return 1;
}

Array<IterVar> ComputeOpNode::root_iter_vars() const {
  if (reduce_axis.size() == 0) return axis;
  Array<IterVar> ret = axis;
  for (IterVar iv : reduce_axis) {
    ret.push_back(iv);
  }
  return ret;
}

Type ComputeOpNode::output_dtype(size_t i) const {
  CHECK_EQ(i, 0U);
  return body.type();
}

Array<Expr> ComputeOpNode::output_shape(size_t i) const {
  CHECK_EQ(i, 0U);
  std::vector<Expr> shape;
  for (size_t i = 0; i < axis.size(); ++i) {
    const Range& r = axis[i]->dom;
    shape.push_back(r->extent);
  }
  return Array<Expr>(shape);
}

Tensor compute(Array<Expr> shape, FCompute fcompute, std::string name) {
  auto op_node = std::make_shared<ComputeOpNode>();
  // compute dimension.
  size_t ndim = shape.size();
  std::vector<IterVar> axis;
  std::vector<Var> args;
  for (size_t i = 0; i < ndim; ++i) {
    std::ostringstream os;
    os << "ax" << i;
    axis.emplace_back(IterVarNode::make(
        Range(0, shape[i]), Var(os.str(), shape[i].type()), kDataPar));
    args.push_back(axis.back()->var);
  }

68
  return ComputeOpNode::make(name, axis, fcompute(args)).output(0);
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
}

Operation ComputeOpNode::make(std::string name,
                              Array<IterVar> axis,
                              Expr body) {
  auto n = std::make_shared<ComputeOpNode>();
  n->name = name;
  n->axis = axis;
  n->body = body;
  if (n->body->is_type<ir::Reduce>()) {
    n->reduce_axis = n->body.as<ir::Reduce>()->axis;
  }
  return Operation(n);
}

// The schedule related logics
Array<Tensor> ComputeOpNode::InputTensors() const {
  Array<Tensor> ret;
  std::unordered_set<Tensor> visited;
  ir::PostOrderVisit(body, [&ret, &visited](const NodeRef& n) {
      const ir::Call *call = n.as<ir::Call>();
      if (call != nullptr && call->func.defined()) {
        Tensor t = Operation(call->func.node_).output(call->value_index);
        if (!visited.count(t)) {
          ret.push_back(t);
          visited.insert(t);
        }
      }
    });
  return ret;
}

Operation ComputeOpNode::ReplaceInputs(
    const Operation& self,
    const std::unordered_map<Tensor, Tensor>& rmap) const {
  CHECK_EQ(self.operator->(), this);
105 106
  Expr new_body = op::ReplaceTensor(this->body, rmap);
  if (!new_body.same_as(this->body)) {
107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175
    return ComputeOpNode::make(name, axis, new_body);
  } else {
    return self;
  }
}

void ComputeOpNode::PropBoundToInputs(
    const Operation& self,
    const std::unordered_map<const Variable*, IntSet>& dom_map,
    std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
  CHECK_EQ(self.operator->(), this);
  auto fvisit = [&dom_map, out_dom_map](const NodeRef& n) {
    auto *call = n.as<ir::Call>();
    if (call != nullptr && call->func.defined()) {
      Tensor t = Operation(call->func.node_).output(call->value_index);
      if (t->op.defined() && out_dom_map->count(t)) {
        TensorDom& dom = out_dom_map->at(t);
        for (size_t i = 0; i < t.ndim(); ++i) {
          dom.data[i].push_back(EvalSet(call->args[i], dom_map));
        }
      }
    }
  };
  ir::PostOrderVisit(body, fvisit);
}

void ComputeOpNode::GatherBound(
    const Operation& self,
    const std::unordered_map<Tensor, TensorDom>& tensor_dom,
    std::unordered_map<IterVar, Range>* out_dom_map) const {
  const TensorDom& tdom = tensor_dom.at(self.output(0));
  for (size_t i = 0; i < this->axis.size(); ++i) {
    Range r = arith::Union(tdom.data.at(i)).cover_range(this->axis[i]->dom);
    CHECK(!out_dom_map->count(this->axis[i]));
    (*out_dom_map)[this->axis[i]] = r;
  }
  for (size_t i = 0; i < this->reduce_axis.size(); ++i) {
    CHECK(!out_dom_map->count(this->reduce_axis[i]));
    (*out_dom_map)[this->reduce_axis[i]] = this->reduce_axis[i]->dom;
  }
}

Stmt ComputeOpNode::BuildRealize(
    const Operation& self,
    const std::unordered_map<IterVar, Range>& realize_map,
    const Stmt& realize_body) const {
  CHECK_EQ(self.operator->(), this);
  Tensor t = self.output(0);
  Halide::Internal::Region bounds;
  for (IterVar iv : this->axis) {
    bounds.push_back(realize_map.at(iv));
  }
  return ir::Realize::make(t->op, t->value_index, t->dtype,
                           bounds, const_true(), realize_body);
}

// Build a reduction body.
void MakeReduction(const ComputeOpNode* op,
                   const Tensor& t,
                   Stmt* init,
                   Stmt* provide) {
  Stmt no_op = Evaluate::make(0);
  std::vector<Stmt> nest;
  Array<Expr>  args;
  for (IterVar iv : op->axis) {
    args.push_back(iv->var);
  }
  const Reduce* reduce = op->body.as<Reduce>();
  CHECK(reduce);
ziheng committed
176 177 178 179
  const CommReducerNode* combiner = reduce->combiner.as<CommReducerNode>();
  CHECK(combiner);
  Expr init_value = combiner->identity_element;
  Expr update_value = (*combiner)(t(args), reduce->source);
180 181
  *init = Provide::make(t->op, t->value_index, init_value, args);
  *provide = Provide::make(t->op, t->value_index, update_value, args);
182 183 184
  if (!is_one(reduce->condition)) {
    *provide = IfThenElse::make(reduce->condition, *provide);
  }
185 186 187 188 189 190 191 192 193 194 195
}

Stmt Substitute(Stmt s,
                const std::unordered_map<IterVar, Expr>& value_map) {
  Map<Var, Expr> temp;
  for (const auto& kv : value_map) {
    temp.Set(kv.first->var, kv.second);
  }
  return ir::Substitute(s, temp);
}

196 197 198 199
// Cross Thread reduction marker.
bool IsCrossThreadReduction(const ComputeOpNode* self,
                            const Stage& stage) {
  // Verify correctness of leaf nest.
200
  int normal_red = 0, thread_red = 0;
201 202
  for (IterVar iv : stage->leaf_iter_vars) {
    if (iv->iter_type == kCommReduce) {
203 204 205 206 207 208 209
      auto it = stage->iter_var_attrs.find(iv);
      if (it != stage->iter_var_attrs.end() &&
          (*it).second->bind_thread.defined()) {
        ++thread_red;
      } else {
        ++normal_red;
      }
210
    } else {
211
      CHECK_EQ(thread_red, 0)
212 213 214
          << "Cross thread reduce cannot swap with normal data axis";
    }
  }
215 216 217
  CHECK(normal_red == 0 || thread_red == 0)
      << "Cannot mix normal reduction with thread reduce";
  return thread_red != 0;
218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244
}

Stmt MakeCrossThreadReduction(
    const ComputeOpNode* self,
    const Stage& stage,
    const std::unordered_map<IterVar, Range>& dom_map) {
  Array<Expr>  args;
  for (IterVar iv : self->axis) {
    args.push_back(iv->var);
  }
  const Reduce* reduce = self->body.as<Reduce>();
  CHECK(reduce);
  std::unordered_map<IterVar, Expr> value_map;
  auto nest = op::MakeLoopNest(
      stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map);
  auto conds = op::MakeBoundCheck(
      stage, dom_map, false,
      std::unordered_set<IterVar>(), value_map);
  Expr cond = reduce->condition;
  for (Expr v : conds) {
    cond = cond && v;
  }
  Var res_handle("reduce_temp", Handle());
  Array<Expr> freduce_args;
  freduce_args.push_back(reduce->source);
  freduce_args.push_back(cond);

245 246 247 248 249 250 251
  for (IterVar iv : stage->leaf_iter_vars) {
    if (iv->iter_type == kCommReduce) {
      auto it = stage->iter_var_attrs.find(iv);
      if (it != stage->iter_var_attrs.end() &&
          (*it).second->bind_thread.defined()) {
        IterVar tv = (*it).second->bind_thread;
        freduce_args.push_back(tv->var);
252 253 254
      }
    }
  }
255 256 257 258 259
  // Checks for the thread.
  std::vector<Expr> thread_head_check;
  if (stage->store_predicate.defined()) {
    thread_head_check.emplace_back(stage->store_predicate);
  }
260 261
  Type t = reduce->type;
  Expr pred = const_true(t.lanes());
ziheng committed
262 263 264 265 266
  Stmt reduce_body = Store::make(res_handle,
    Call::make(
      reduce->type,
      ir::intrinsic::tvm_thread_allreduce,
      freduce_args, Call::Intrinsic),
267
     0, pred);
ziheng committed
268 269 270 271 272
  reduce_body = AttrStmt::make(
      reduce->combiner,
      attr::reduce_scope,
      make_zero(reduce->type),
      reduce_body);
273
  Stmt assign_body = Provide::make(
274
      stage->op, 0, Load::make(reduce->type, res_handle, 0, pred), args);
275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294
  assign_body = MergeNest(op::MakeIfNest(thread_head_check), assign_body);
  assign_body = MergeNest(op::MakeIfNest(conds), assign_body);
  Stmt body = Allocate::make(
      res_handle, reduce->type, {1}, const_true(),
      Block::make(reduce_body, assign_body));
  body = AttrStmt::make(
      res_handle, attr::storage_scope, StringImm::make("local"), body);
  body = Substitute(body, value_map);
  return MergeNest(nest, body);
}

Stmt MakeProvide(const ComputeOpNode* op,
                 const Tensor& t) {
  Array<Expr> args;
  for (IterVar iv : op->axis) {
    args.push_back(iv->var);
  }
  return Provide::make(t->op, t->value_index, op->body, args);
}

295 296 297 298 299
Stmt ComputeOpNode::BuildProvide(
    const Stage& stage,
    const std::unordered_map<IterVar, Range>& dom_map) const {
  CHECK_EQ(stage->op.operator->(), this);

300 301 302 303
  if (IsCrossThreadReduction(this, stage)) {
    // specially handle cross thread reduction.
    return MakeCrossThreadReduction(this, stage, dom_map);
  }
304 305 306 307 308 309 310 311 312 313
  Stmt init, provide;
  if (this->reduce_axis.size() == 0) {
    provide = MakeProvide(this, stage->op.output(0));
  } else {
    MakeReduction(this, stage->op.output(0), &init, &provide);
  }
  // make loop nest
  std::unordered_map<IterVar, Expr> value_map;
  auto nest = op::MakeLoopNest(
      stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map);
314 315 316 317
  auto preds = op::MakeBoundCheck(stage, dom_map, false,
      std::unordered_set<IterVar>(), value_map);
  for (auto& e : preds) e = likely(e);
  nest.push_back(op::MakeIfNest(preds));
318 319 320
  if (stage->store_predicate.defined()) {
    nest.emplace_back(op::MakeIfNest({stage->store_predicate}));
  }
321 322 323 324 325 326 327 328 329 330 331 332 333
  provide = Substitute(provide, value_map);

  if (init.defined()) {
    // try to find the location to insert the initialization.
    // Fuse the initialization and provide loop when possible.
    std::unordered_map<IterVar, int> update_state;
    for (IterVar iv : this->reduce_axis) {
      update_state[iv] = 2;
    }
    for (IterVar iv : this->axis) {
      update_state[iv] = 1;
    }
    // find which iter var is related to reduction and which is related to axis.
334
    schedule::PassDownBitMaskOr(stage, &update_state);
335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355
    auto leaf_iter_vars = stage->leaf_iter_vars;
    std::unordered_map<IterVar, Expr> init_value_map;
    // first first loop that is related to reduction.
    size_t begin_loop = leaf_iter_vars.size();
    for (size_t i = 0; i < leaf_iter_vars.size(); ++i) {
      auto iv = leaf_iter_vars[i];
      int flag = update_state.at(iv);
      if ((flag & 2) != 0) {
        begin_loop = i; break;
      }
      init_value_map[iv] = value_map.at(iv);
    }
    // skip loops that does not relates to axis.
    std::unordered_set<IterVar> skip_iter;
    for (auto kv : update_state) {
      int flag = kv.second;
      if ((flag & 1) == 0) skip_iter.insert(kv.first);
    }
    auto init_nest = op::MakeLoopNest(
        stage, dom_map, begin_loop, true,
        skip_iter, &init_value_map);
356 357 358
    auto preds = op::MakeBoundCheck(stage, dom_map, true, skip_iter, init_value_map);
    for (auto& e : preds) e = likely(e);
    init_nest.push_back(op::MakeIfNest(preds));
359 360 361 362 363 364 365 366 367 368 369 370
    init = Substitute(init, init_value_map);
    init  = MergeNest(init_nest, init);
    // common nest
    std::vector<std::vector<Stmt> > common(nest.begin(), nest.begin() + begin_loop + 1);
    std::vector<std::vector<Stmt> > reduce(nest.begin() + begin_loop + 1, nest.end());
    provide = MergeNest(reduce, provide);
    return MergeNest(common, Block::make(init, provide));
  } else {
    return MergeNest(nest, provide);
  }
}
}  // namespace tvm