bound.cc 8.6 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23
/*!
 * \file bound.cc
 * \brief The bound inference logic.
 */
tqchen committed
24
#include <tvm/ir_visitor.h>
25
#include <tvm/schedule_pass.h>
26
#include <tvm/operation.h>
27
#include <tvm/ir_pass.h>
28 29
#include <unordered_map>
#include <unordered_set>
30 31
#include "graph.h"
#include "message_passing.h"
32
#include "../runtime/thread_storage_scope.h"
33 34

namespace tvm {
35
namespace schedule {
36

37
using runtime::StorageRank;
38
using runtime::StorageScope;
39
using runtime::ThreadScope;
40

41 42 43 44
/*! \brief The graph context used during bound inference. */
struct GraphContext {
  /*! \brief The feed graph */
  FeedGraph feed_graph;
45 46 47 48 49 50
  /*! \brief Attachment path */
  AttachPath attach_path;
  /*! \brief The bind map */
  std::unordered_map<IterVar, IterVar> bind_map;
  /*! \brief map from op to stage */
  std::unordered_map<const Node*, Stage> op2stage_;
51 52
};

53 54 55 56 57 58 59 60 61
bool NeedRelax(const IterVar& iv,
               bool found_attach,
               const std::unordered_map<IterVar, IterVar>& bind_map,
               const runtime::StorageScope& scope) {
  auto it = bind_map.find(iv);
  const std::string& tag = (
      it != bind_map.end() ? it->second->thread_tag : iv->thread_tag);
  if (tag.length() == 0 || tag == "pipeline") {
    return !found_attach;
62
  }
63 64 65 66 67 68 69 70 71 72
  ThreadScope ts = ThreadScope::make(tag);

  // When there is warp memory
  // threadIdx.x must be set to be warp index.
  if (scope.rank == StorageRank::kWarp &&
      ts.rank == 1 &&
      ts.dim_index == 0) {
    return true;
  }
  return static_cast<int>(scope.rank) <= ts.rank;
73
}
74

75 76 77 78 79 80
// infer storage scope, if not given
StorageScope InferStorageScope(
    const Stage& stage, const GraphContext& ctx) {
  if (stage->scope.length() != 0) {
    return StorageScope::make(stage->scope);
  }
81
  int max_rank = -1;
82 83 84 85 86
  for (IterVar iv : ctx.attach_path.at(stage->op)) {
    auto it = ctx.bind_map.find(iv);
    const std::string& tag = (
        it != ctx.bind_map.end() ? it->second->thread_tag : iv->thread_tag);
    if (tag != "pipeline" && tag.length() != 0) {
87
      max_rank = std::max(max_rank, ThreadScope::make(tag).rank);
88 89
    }
  }
90 91
  StorageScope s;
  s.rank = runtime::DefaultStorageRank(max_rank);
92
  return s;
93 94
}

95

96
void InferRootBound(const Stage& stage,
97
                    const GraphContext& ctx,
98
                    std::unordered_map<IterVar, Range>* rmap) {
99 100 101
  CHECK_NE(stage->attach_type, kInline)
      << "call schedule.normalize before scheduleops";
  if (stage->attach_type == kInlinedAlready) return;
102 103 104 105 106
  if (stage->is_output) {
    // verify correctness.
    CHECK_EQ(stage.GetAttachSpec()->attach_type, kGroupRoot)
          << "Output must be attached at root";
  }
107
  if (stage->is_output || stage->op.as<PlaceholderOpNode>()) {
108
    for (auto iv :  stage->op->root_iter_vars()) {
109 110 111
      CHECK(iv->dom.defined());
      CHECK(!rmap->count(iv));
      (*rmap)[iv] = iv->dom;
tqchen committed
112
    }
113
    return;
tqchen committed
114
  }
115 116
  // The tensor domain.
  std::unordered_map<Tensor, TensorDom> tmap;
117
  // The consumers of the op.
118 119 120
  std::unordered_set<Operation> consumers;
  for (int i = 0; i < stage->op->num_outputs(); ++i) {
    Tensor t = stage->op.output(i);
121
    tmap.emplace(t, TensorDom(static_cast<int>(t.ndim())));
122 123
    auto it = ctx.feed_graph.find(t);
    if (it != ctx.feed_graph.end()) {
124
      for (const Operation& op : it->second) {
125
        consumers.insert(op);
126
      }
127 128
    } else {
      LOG(INFO) << "not in feed graph consumer = " << stage->op;
129 130
    }
  }
131 132 133 134 135 136 137 138 139 140 141
  // storage scope.
  runtime::StorageScope scope = InferStorageScope(stage, ctx);
  // Bound prop by other consumers.
  // - Compute bound by relaxation rules: NeedRelax
  //   - For normal index, use relative location of loop nest./
  //   - For thread index, use the thread scope.
  //
  Array<IterVar> stage_attach = ctx.attach_path.at(stage->op);
  // The parent set.
  for (const Operation& op : consumers) {
    std::unordered_map<const Variable*, IntSet> relax_set;
142
    std::unordered_map<IterVar, IntSet> up_state;
143 144 145 146 147 148 149 150 151
    bool found_attach = false;
    CHECK(ctx.op2stage_.count(op.get()));
    const Stage& op_stage = ctx.op2stage_.at(op.get());
    // Consumer nest
    for (size_t i = op_stage->leaf_iter_vars.size(); i != 0; --i) {
      IterVar iv = op_stage->leaf_iter_vars[i - 1];
      if (stage_attach.size() != 0 && iv == stage_attach[0]) {
        found_attach = true;
      }
152 153
      auto it = rmap->find(iv);
      CHECK(it != rmap->end());
154
      const Range& vrange = it->second;
155 156
      if (is_one(vrange->extent)) {
        up_state[iv] = IntSet::single_point(vrange->min);
157 158 159 160
      } else if (!NeedRelax(iv, found_attach, ctx.bind_map, scope)) {
        CHECK(is_zero(vrange->min))
            << "InferBound requires every leaf iter var's min equals 0, "
            << " call schedule.normalize to achieve this. ";
161 162 163 164 165
        if (ctx.bind_map.count(iv)) {
          up_state[iv] = IntSet::single_point(ctx.bind_map.at(iv)->var);
        } else {
          up_state[iv] = IntSet::single_point(iv->var);
        }
166
      } else {
167
        up_state[iv] = IntSet::range(vrange);
168
      }
169 170 171 172 173 174 175 176 177 178 179 180
    }
    // Consumer's attach nest
    for (IterVar iv : ctx.attach_path.at(op)) {
      if (stage_attach.size() != 0 && iv == stage_attach[0]) {
        found_attach = true;
      }
      Range vrange = rmap->at(iv);
      CHECK(is_zero(vrange->min))
          << "InferBound requires every leaf iter var's min equals 0, "
          << "call schedule.normalize to achieve this.";
      if (NeedRelax(iv, found_attach, ctx.bind_map, scope)) {
        relax_set[iv->var.get()] = IntSet::range(vrange);
181 182 183
        if (ctx.bind_map.count(iv)) {
          relax_set[ctx.bind_map.at(iv)->var.get()] = IntSet::range(vrange);
        }
184 185
      }
    }
186 187 188 189 190 191
    CHECK(found_attach || stage_attach.size() == 0)
        << "Invalid Schedule, cannot find the producer " << stage->op
        << " along the loop nest specified by compute_at of consumer " << op;
    // Get the domain of the consumer
    PassUpDomain(op_stage, *rmap, &up_state);
    // Relax if needed.
192
    std::unordered_map<const Variable*, IntSet> dom_map;
193
    arith::Analyzer analyzer;
194
    for (auto iv : op->root_iter_vars()) {
195 196 197 198 199 200
      Range r;
      if (up_state.count(iv)) {
        r = up_state.at(iv).cover_range(iv->dom);
      } else {
        r = iv->dom;
      }
201 202 203 204 205
      if (relax_set.size() != 0) {
        dom_map[iv->var.get()] = EvalSet(r, relax_set);
      } else {
        dom_map[iv->var.get()] = IntSet::range(r);
      }
206
      analyzer.Bind(iv->var, r);
207
    }
208
    op->PropBoundToInputs(op, &analyzer, dom_map, &tmap);
209
  }
210
  stage->op->GatherBound(stage->op, tmap, rmap);
211
}
212

213
Map<IterVar, Range> InferBound(const Schedule& sch) {
214 215
  // Prepare context
  GraphContext ctx;
216
  Array<Operation> roots;
217 218
  arith::Analyzer analyzer;

219 220 221
  for (Operation op : sch->outputs) {
    roots.push_back(sch->stage_map[op]->op);
  }
222 223
  ctx.feed_graph = CreateFeedGraph(CreateReadGraph(roots));

224 225 226
  for (Stage stage : sch->stages) {
    for (auto kv : stage->iter_var_attrs) {
      if (kv.second->bind_thread.defined()) {
227 228
        CHECK(!ctx.bind_map.count(kv.first));
        ctx.bind_map[kv.first] = kv.second->bind_thread;
229 230
      }
    }
231
    ctx.op2stage_[stage->op.get()] = stage;
232
  }
233 234
  ctx.attach_path = CreateAttachPath(sch);
  // Run inference.
235
  std::unordered_map<IterVar, Range> ret;
236
  for (size_t i = sch->stages.size(); i != 0; --i) {
237
    const Stage& stage = sch->stages[i - 1];
238
    InferRootBound(stage, ctx, &ret);
239 240 241 242 243 244 245 246 247

    // bind bound of root iter vars.
    for (auto iv :  stage->op->root_iter_vars()) {
      auto it = ret.find(iv);
      if (it != ret.end()) {
        analyzer.Bind(iv->var, it->second);
      }
    }

248
    // pass down to get bound of all iter vars.
249
    PassDownDomain(stage, &ret, &analyzer);
250
    for (IterVar iv : stage->env_threads) {
251 252 253
      CHECK(iv->dom.defined());
      ret[iv] = iv->dom;
    }
254
  }
255
  for (auto& p : ret) {
256 257 258
    ret[p.first] = Range::make_by_min_extent(
        analyzer.Simplify(p.second->min),
        analyzer.Simplify(p.second->extent));
259
  }
260
  return Map<IterVar, Range>(ret.begin(), ret.end());
261 262 263
}

}  // namespace schedule
264
}  // namespace tvm