bound.cc 8.35 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24
/*!
 *  Copyright (c) 2016 by Contributors
 * \file bound.cc
 * \brief The bound inference logic.
 */
tqchen committed
25
#include <tvm/ir_visitor.h>
26
#include <tvm/schedule_pass.h>
27
#include <tvm/operation.h>
28
#include <tvm/ir_pass.h>
29 30
#include <unordered_map>
#include <unordered_set>
31 32
#include "graph.h"
#include "message_passing.h"
33
#include "../runtime/thread_storage_scope.h"
34 35

namespace tvm {
36
namespace schedule {
37

38
using runtime::StorageRank;
39
using runtime::StorageScope;
40
using runtime::ThreadScope;
41

42 43 44 45
/*! \brief The graph context used during bound inference. */
struct GraphContext {
  /*! \brief The feed graph */
  FeedGraph feed_graph;
46 47 48 49 50 51
  /*! \brief Attachment path */
  AttachPath attach_path;
  /*! \brief The bind map */
  std::unordered_map<IterVar, IterVar> bind_map;
  /*! \brief map from op to stage */
  std::unordered_map<const Node*, Stage> op2stage_;
52 53
};

54 55 56 57 58 59 60 61 62
bool NeedRelax(const IterVar& iv,
               bool found_attach,
               const std::unordered_map<IterVar, IterVar>& bind_map,
               const runtime::StorageScope& scope) {
  auto it = bind_map.find(iv);
  const std::string& tag = (
      it != bind_map.end() ? it->second->thread_tag : iv->thread_tag);
  if (tag.length() == 0 || tag == "pipeline") {
    return !found_attach;
63
  }
64 65 66 67 68 69 70 71 72 73
  ThreadScope ts = ThreadScope::make(tag);

  // When there is warp memory
  // threadIdx.x must be set to be warp index.
  if (scope.rank == StorageRank::kWarp &&
      ts.rank == 1 &&
      ts.dim_index == 0) {
    return true;
  }
  return static_cast<int>(scope.rank) <= ts.rank;
74
}
75

76 77 78 79 80 81
// infer storage scope, if not given
StorageScope InferStorageScope(
    const Stage& stage, const GraphContext& ctx) {
  if (stage->scope.length() != 0) {
    return StorageScope::make(stage->scope);
  }
82
  int max_rank = -1;
83 84 85 86 87
  for (IterVar iv : ctx.attach_path.at(stage->op)) {
    auto it = ctx.bind_map.find(iv);
    const std::string& tag = (
        it != ctx.bind_map.end() ? it->second->thread_tag : iv->thread_tag);
    if (tag != "pipeline" && tag.length() != 0) {
88
      max_rank = std::max(max_rank, ThreadScope::make(tag).rank);
89 90
    }
  }
91 92
  StorageScope s;
  s.rank = runtime::DefaultStorageRank(max_rank);
93
  return s;
94 95
}

96

97
void InferRootBound(const Stage& stage,
98
                    const GraphContext& ctx,
99
                    std::unordered_map<IterVar, Range>* rmap) {
100 101 102
  CHECK_NE(stage->attach_type, kInline)
      << "call schedule.normalize before scheduleops";
  if (stage->attach_type == kInlinedAlready) return;
103 104 105 106 107
  if (stage->is_output) {
    // verify correctness.
    CHECK_EQ(stage.GetAttachSpec()->attach_type, kGroupRoot)
          << "Output must be attached at root";
  }
108
  if (stage->is_output || stage->op.as<PlaceholderOpNode>()) {
109
    for (auto iv :  stage->op->root_iter_vars()) {
110 111 112
      CHECK(iv->dom.defined());
      CHECK(!rmap->count(iv));
      (*rmap)[iv] = iv->dom;
tqchen committed
113
    }
114
    return;
tqchen committed
115
  }
116 117
  // The tensor domain.
  std::unordered_map<Tensor, TensorDom> tmap;
118
  // The consumers of the op.
119 120 121
  std::unordered_set<Operation> consumers;
  for (int i = 0; i < stage->op->num_outputs(); ++i) {
    Tensor t = stage->op.output(i);
122
    tmap.emplace(t, TensorDom(static_cast<int>(t.ndim())));
123 124
    auto it = ctx.feed_graph.find(t);
    if (it != ctx.feed_graph.end()) {
125
      for (const Operation& op : it->second) {
126
        consumers.insert(op);
127
      }
128 129
    } else {
      LOG(INFO) << "not in feed graph consumer = " << stage->op;
130 131
    }
  }
132 133 134 135 136 137 138 139 140 141 142
  // storage scope.
  runtime::StorageScope scope = InferStorageScope(stage, ctx);
  // Bound prop by other consumers.
  // - Compute bound by relaxation rules: NeedRelax
  //   - For normal index, use relative location of loop nest./
  //   - For thread index, use the thread scope.
  //
  Array<IterVar> stage_attach = ctx.attach_path.at(stage->op);
  // The parent set.
  for (const Operation& op : consumers) {
    std::unordered_map<const Variable*, IntSet> relax_set;
143
    std::unordered_map<IterVar, IntSet> up_state;
144 145 146 147 148 149 150 151 152
    bool found_attach = false;
    CHECK(ctx.op2stage_.count(op.get()));
    const Stage& op_stage = ctx.op2stage_.at(op.get());
    // Consumer nest
    for (size_t i = op_stage->leaf_iter_vars.size(); i != 0; --i) {
      IterVar iv = op_stage->leaf_iter_vars[i - 1];
      if (stage_attach.size() != 0 && iv == stage_attach[0]) {
        found_attach = true;
      }
153 154
      auto it = rmap->find(iv);
      CHECK(it != rmap->end());
155
      const Range& vrange = it->second;
156 157
      if (is_one(vrange->extent)) {
        up_state[iv] = IntSet::single_point(vrange->min);
158 159 160 161
      } else if (!NeedRelax(iv, found_attach, ctx.bind_map, scope)) {
        CHECK(is_zero(vrange->min))
            << "InferBound requires every leaf iter var's min equals 0, "
            << " call schedule.normalize to achieve this. ";
162 163 164 165 166
        if (ctx.bind_map.count(iv)) {
          up_state[iv] = IntSet::single_point(ctx.bind_map.at(iv)->var);
        } else {
          up_state[iv] = IntSet::single_point(iv->var);
        }
167
      } else {
168
        up_state[iv] = IntSet::range(vrange);
169
      }
170 171 172 173 174 175 176 177 178 179 180 181
    }
    // Consumer's attach nest
    for (IterVar iv : ctx.attach_path.at(op)) {
      if (stage_attach.size() != 0 && iv == stage_attach[0]) {
        found_attach = true;
      }
      Range vrange = rmap->at(iv);
      CHECK(is_zero(vrange->min))
          << "InferBound requires every leaf iter var's min equals 0, "
          << "call schedule.normalize to achieve this.";
      if (NeedRelax(iv, found_attach, ctx.bind_map, scope)) {
        relax_set[iv->var.get()] = IntSet::range(vrange);
182 183 184
        if (ctx.bind_map.count(iv)) {
          relax_set[ctx.bind_map.at(iv)->var.get()] = IntSet::range(vrange);
        }
185 186
      }
    }
187 188 189 190 191 192
    CHECK(found_attach || stage_attach.size() == 0)
        << "Invalid Schedule, cannot find the producer " << stage->op
        << " along the loop nest specified by compute_at of consumer " << op;
    // Get the domain of the consumer
    PassUpDomain(op_stage, *rmap, &up_state);
    // Relax if needed.
193
    std::unordered_map<const Variable*, IntSet> dom_map;
194
    for (auto iv : op->root_iter_vars()) {
195 196 197 198 199 200
      Range r;
      if (up_state.count(iv)) {
        r = up_state.at(iv).cover_range(iv->dom);
      } else {
        r = iv->dom;
      }
201 202 203 204 205
      if (relax_set.size() != 0) {
        dom_map[iv->var.get()] = EvalSet(r, relax_set);
      } else {
        dom_map[iv->var.get()] = IntSet::range(r);
      }
206
    }
207
    op->PropBoundToInputs(op, dom_map, &tmap);
208
  }
209
  stage->op->GatherBound(stage->op, tmap, rmap);
210
}
211

212
Map<IterVar, Range> InferBound(const Schedule& sch) {
213 214
  // Prepare context
  GraphContext ctx;
215 216 217 218
  Array<Operation> roots;
  for (Operation op : sch->outputs) {
    roots.push_back(sch->stage_map[op]->op);
  }
219 220
  ctx.feed_graph = CreateFeedGraph(CreateReadGraph(roots));

221 222 223
  for (Stage stage : sch->stages) {
    for (auto kv : stage->iter_var_attrs) {
      if (kv.second->bind_thread.defined()) {
224 225
        CHECK(!ctx.bind_map.count(kv.first));
        ctx.bind_map[kv.first] = kv.second->bind_thread;
226 227
      }
    }
228
    ctx.op2stage_[stage->op.get()] = stage;
229
  }
230 231
  ctx.attach_path = CreateAttachPath(sch);
  // Run inference.
232
  std::unordered_map<IterVar, Range> ret;
233
  for (size_t i = sch->stages.size(); i != 0; --i) {
234
    const Stage& stage = sch->stages[i - 1];
235
    InferRootBound(stage, ctx, &ret);
236
    // pass down to get bound of all iter vars.
237
    PassDownDomain(stage, &ret);
238
    for (IterVar iv : stage->env_threads) {
239 240 241
      CHECK(iv->dom.defined());
      ret[iv] = iv->dom;
    }
242
  }
243 244 245 246
  for (auto& p : ret) {
    ret[p.first] = Range::make_by_min_extent(ir::Simplify(p.second->min),
                                             ir::Simplify(p.second->extent));
  }
247
  return Map<IterVar, Range>(ret.begin(), ret.end());
248 249 250
}

}  // namespace schedule
251
}  // namespace tvm