/*! * Copyright (c) 2016 by Contributors * \file bound.cc * \brief The bound inference logic. */ #include <tvm/ir_visitor.h> #include <tvm/schedule_pass.h> #include <tvm/operation.h> #include <tvm/ir_pass.h> #include <unordered_map> #include <unordered_set> #include "./graph.h" #include "./message_passing.h" #include "../runtime/thread_storage_scope.h" namespace tvm { namespace schedule { using runtime::StorageRank; using runtime::StorageScope; using runtime::ThreadScope; /*! \brief The graph context used during bound inference. */ struct GraphContext { /*! \brief The feed graph */ FeedGraph feed_graph; /*! \brief Attachment path */ AttachPath attach_path; /*! \brief The bind map */ std::unordered_map<IterVar, IterVar> bind_map; /*! \brief map from op to stage */ std::unordered_map<const Node*, Stage> op2stage_; }; bool NeedRelax(const IterVar& iv, bool found_attach, const std::unordered_map<IterVar, IterVar>& bind_map, const runtime::StorageScope& scope) { auto it = bind_map.find(iv); const std::string& tag = ( it != bind_map.end() ? it->second->thread_tag : iv->thread_tag); if (tag.length() == 0 || tag == "pipeline") { return !found_attach; } ThreadScope ts = ThreadScope::make(tag); // When there is warp memory // threadIdx.x must be set to be warp index. if (scope.rank == StorageRank::kWarp && ts.rank == 1 && ts.dim_index == 0) { return true; } return static_cast<int>(scope.rank) <= ts.rank; } // infer storage scope, if not given StorageScope InferStorageScope( const Stage& stage, const GraphContext& ctx) { if (stage->scope.length() != 0) { return StorageScope::make(stage->scope); } int max_rank = -1; for (IterVar iv : ctx.attach_path.at(stage->op)) { auto it = ctx.bind_map.find(iv); const std::string& tag = ( it != ctx.bind_map.end() ? it->second->thread_tag : iv->thread_tag); if (tag != "pipeline" && tag.length() != 0) { max_rank = std::max(max_rank, ThreadScope::make(tag).rank); } } StorageScope s; s.rank = runtime::DefaultStorageRank(max_rank); return s; } void InferRootBound(const Stage& stage, const GraphContext& ctx, std::unordered_map<IterVar, Range>* rmap) { CHECK_NE(stage->attach_type, kInline) << "call schedule.normalize before scheduleops"; if (stage->attach_type == kInlinedAlready) return; if (stage->is_output) { // verify correctness. CHECK_EQ(stage.GetAttachSpec()->attach_type, kGroupRoot) << "Output must be attached at root"; } if (stage->is_output || stage->op.as<PlaceholderOpNode>()) { for (auto iv : stage->op->root_iter_vars()) { CHECK(iv->dom.defined()); CHECK(!rmap->count(iv)); (*rmap)[iv] = iv->dom; } return; } // The tensor domain. std::unordered_map<Tensor, TensorDom> tmap; // The consumers of the op. std::unordered_set<Operation> consumers; for (int i = 0; i < stage->op->num_outputs(); ++i) { Tensor t = stage->op.output(i); tmap.emplace(t, TensorDom(static_cast<int>(t.ndim()))); auto it = ctx.feed_graph.find(t); if (it != ctx.feed_graph.end()) { for (const Operation& op : it->second) { consumers.insert(op); } } else { LOG(INFO) << "not in feed graph consumer = " << stage->op; } } // storage scope. runtime::StorageScope scope = InferStorageScope(stage, ctx); // Bound prop by other consumers. // - Compute bound by relaxation rules: NeedRelax // - For normal index, use relative location of loop nest./ // - For thread index, use the thread scope. // Array<IterVar> stage_attach = ctx.attach_path.at(stage->op); // The parent set. for (const Operation& op : consumers) { std::unordered_map<const Variable*, IntSet> relax_set; std::unordered_map<IterVar, IntSet> up_state; bool found_attach = false; CHECK(ctx.op2stage_.count(op.get())); const Stage& op_stage = ctx.op2stage_.at(op.get()); // Consumer nest for (size_t i = op_stage->leaf_iter_vars.size(); i != 0; --i) { IterVar iv = op_stage->leaf_iter_vars[i - 1]; if (stage_attach.size() != 0 && iv == stage_attach[0]) { found_attach = true; } auto it = rmap->find(iv); CHECK(it != rmap->end()); const Range& vrange = it->second; if (is_one(vrange->extent)) { up_state[iv] = IntSet::single_point(vrange->min); } else if (!NeedRelax(iv, found_attach, ctx.bind_map, scope)) { CHECK(is_zero(vrange->min)) << "InferBound requires every leaf iter var's min equals 0, " << " call schedule.normalize to achieve this. "; if (ctx.bind_map.count(iv)) { up_state[iv] = IntSet::single_point(ctx.bind_map.at(iv)->var); } else { up_state[iv] = IntSet::single_point(iv->var); } } else { up_state[iv] = IntSet::range(vrange); } } // Consumer's attach nest for (IterVar iv : ctx.attach_path.at(op)) { if (stage_attach.size() != 0 && iv == stage_attach[0]) { found_attach = true; } Range vrange = rmap->at(iv); CHECK(is_zero(vrange->min)) << "InferBound requires every leaf iter var's min equals 0, " << "call schedule.normalize to achieve this."; if (NeedRelax(iv, found_attach, ctx.bind_map, scope)) { relax_set[iv->var.get()] = IntSet::range(vrange); if (ctx.bind_map.count(iv)) { relax_set[ctx.bind_map.at(iv)->var.get()] = IntSet::range(vrange); } } } CHECK(found_attach || stage_attach.size() == 0) << "Invalid Schedule, cannot find the producer " << stage->op << " along the loop nest specified by compute_at of consumer " << op; // Get the domain of the consumer PassUpDomain(op_stage, *rmap, &up_state); // Relax if needed. std::unordered_map<const Variable*, IntSet> dom_map; for (auto iv : op->root_iter_vars()) { Range r; if (up_state.count(iv)) { r = up_state.at(iv).cover_range(iv->dom); } else { r = iv->dom; } if (relax_set.size() != 0) { dom_map[iv->var.get()] = EvalSet(r, relax_set); } else { dom_map[iv->var.get()] = IntSet::range(r); } } op->PropBoundToInputs(op, dom_map, &tmap); } stage->op->GatherBound(stage->op, tmap, rmap); } Map<IterVar, Range> InferBound(const Schedule& sch) { // Prepare context GraphContext ctx; Array<Operation> roots; for (Operation op : sch->outputs) { roots.push_back(sch->stage_map[op]->op); } ctx.feed_graph = CreateFeedGraph(CreateReadGraph(roots)); for (Stage stage : sch->stages) { for (auto kv : stage->iter_var_attrs) { if (kv.second->bind_thread.defined()) { CHECK(!ctx.bind_map.count(kv.first)); ctx.bind_map[kv.first] = kv.second->bind_thread; } } ctx.op2stage_[stage->op.get()] = stage; } ctx.attach_path = CreateAttachPath(sch); // Run inference. std::unordered_map<IterVar, Range> ret; for (size_t i = sch->stages.size(); i != 0; --i) { const Stage& stage = sch->stages[i - 1]; InferRootBound(stage, ctx, &ret); // pass down to get bound of all iter vars. PassDownDomain(stage, &ret); for (IterVar iv : stage->env_threads) { CHECK(iv->dom.defined()); ret[iv] = iv->dom; } } for (auto& p : ret) { ret[p.first] = Range::make_by_min_extent(ir::Simplify(p.second->min), ir::Simplify(p.second->extent)); } return Map<IterVar, Range>(ret.begin(), ret.end()); } } // namespace schedule } // namespace tvm