graph.cc 13.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25 26
/*!
 *  Copyright (c) 2016 by Contributors
 * \file graph.cc
 * \brief Utilities to get information about schedule graph.
 */
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
27
#include <tvm/operation.h>
28
#include <utility>
29
#include <unordered_set>
30
#include <unordered_map>
31
#include "graph.h"
32 33 34

namespace tvm {
namespace schedule {
35 36 37 38 39 40 41 42 43 44 45 46
// key to specific tensor dimension.
struct TensorDimKey {
  FunctionRef f;
  int value_index;
  int dim;
  TensorDimKey() {}
  TensorDimKey(const ir::Call* op, int dim)
      : f(op->func), value_index(op->value_index), dim(dim) {
  }
  TensorDimKey(const Tensor& t, int dim)
      : f(t->op), value_index(t->value_index), dim(dim) {
  }
47 48 49
  TensorDimKey(const Tensor& t, size_t dim)
      : f(t->op), value_index(t->value_index), dim(static_cast<int>(dim)) {
  }
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
  inline bool operator==(const TensorDimKey& other) const {
    return f == other.f &&
        value_index == other.value_index &&
        dim == other.dim;
  }
  inline bool operator!=(const TensorDimKey& other) const {
    return !operator==(other);
  }
};
}  // namespace schedule
}  // namespace tvm

namespace std {
template <>
struct hash<::tvm::schedule::TensorDimKey> {
  std::size_t operator()(const ::tvm::schedule::TensorDimKey& k) const {
    size_t lhs = k.f.hash();
67
    size_t rhs = static_cast<size_t>(k.value_index) << 16UL |
68 69 70 71 72 73 74 75 76 77
        static_cast<size_t>(k.dim);
    lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
    return lhs;
  }
};
}  // namespace std


namespace tvm {
namespace schedule {
78 79 80

// construct a read graph that gives readers of each operation
// that the root depend on
81
ReadGraph CreateReadGraph(const Array<Operation>& roots) {
82
  ReadGraph rmap;
83 84 85 86 87 88 89
  std::vector<Operation> stack;
  std::unordered_set<const Node*> visited;
  // initialize the roots
  for (Operation op : roots) {
    stack.push_back(op);
    visited.insert(op.get());
  }
90

91
  while (!stack.empty()) {
92
    Operation op = stack.back();
93
    stack.pop_back();
94
    Array<Tensor> deps = op->InputTensors();
95 96 97 98 99 100 101
    rmap.Set(op, deps);
    for (Tensor t : deps) {
      if (t->op.defined() && visited.count(t->op.get()) == 0) {
        visited.insert(t->op.get());
        stack.push_back(t->op);
      }
    }
102 103 104 105
  }
  return rmap;
}

106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
// Do DFS visit to get the subgraph.
// Return if op is inside the subgraph.
bool GetSubGraphByPostDFS_(
    const Operation& op,
    const std::unordered_set<const Node*>& boundary,
    bool include_bounary,
    std::unordered_map<const Node*, bool>* visited,
    Array<Operation>* result) {
  if (visited->count(op.get())) {
    return visited->at(op.get());
  }
  if (boundary.count(op.get())) {
    (*visited)[op.get()] = true;
    if (include_bounary) {
      result->push_back(op);
    }
    return true;
  }
  // mark to avoid loop
  // Not necessary for DAG.
  (*visited)[op.get()] = false;
  // check if we can reach boundary.
  bool reach_boundary = false;
  for (Tensor t : op->InputTensors()) {
    if (GetSubGraphByPostDFS_(t->op, boundary,
                              include_bounary,
                              visited, result)) {
      reach_boundary = true;
    }
  }
  (*visited)[op.get()] = reach_boundary;
  if (reach_boundary) {
    result->push_back(op);
  }
  return reach_boundary;
}

Array<Operation> GetSubGraph(const Array<Tensor>& outputs,
                             const Array<Tensor>& inputs,
                             bool include_inputs) {
  Array<Operation> result;
  std::unordered_set<const Node*> boundary;
  for (Tensor t : inputs) {
    boundary.insert(t->op.get());
  }
  std::unordered_map<const Node*, bool> visited;
  for (Tensor t : outputs) {
    GetSubGraphByPostDFS_(t->op, boundary, include_inputs,
                          &visited, &result);
  }
  return result;
}


160
void PostDFSOrder(const Operation& op,
161 162 163
                  const ReadGraph& g,
                  std::unordered_set<Operation>* visited,
                  Array<Operation>* post_order) {
164
  if (visited->count(op)) return;
165 166
  visited->insert(op);
  for (const auto& t : g.at(op)) {
167
    PostDFSOrder(t->op, g, visited, post_order);
168 169 170 171
  }
  post_order->push_back(op);
}

172
Array<Operation> PostDFSOrder(
173 174
    const Array<Operation>& roots,
    const ReadGraph& g) {
175
  std::unordered_set<Operation> visited;
176
  Array<Operation> post_order;
177 178 179
  for (Operation op : roots) {
    PostDFSOrder(op, g, &visited, &post_order);
  }
180 181 182
  return post_order;
}

183 184 185 186 187 188 189 190 191 192 193 194 195
FeedGraph CreateFeedGraph(const ReadGraph& g) {
  FeedGraph fg;
  for (auto kv : g) {
    for (Tensor t : kv.second) {
      fg[t].push_back(kv.first);
    }
  }
  return fg;
}

AttachPath CreateAttachPath(Schedule sch) {
  AttachPath ret;
  for (Stage stage : sch->stages) {
196
    std::unordered_set<const Node*> visited;
197
    Array<IterVar> path;
198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216
    for (Stage s = stage; s.defined();) {
      CHECK(!visited.count(s.get()))
          << "Find loop in compute_at attach group";
      visited.insert(s.get());
      Stage spec = s.GetAttachSpec();
      bool start_attach;
      IterVar attach_ivar;
      if (spec->attach_type == kScope) {
        attach_ivar = spec->attach_ivar;
        s = spec->attach_stage;
        start_attach = false;
        CHECK(attach_ivar.defined());
      } else if (spec->attach_type == kScanUpdate) {
        s = spec->attach_stage;
        start_attach = true;
      } else {
        break;
      }
      CHECK(s.defined());
217 218
      for (size_t i = s->leaf_iter_vars.size(); i != 0; --i) {
        IterVar iv = s->leaf_iter_vars[i - 1];
219 220 221
        if (!start_attach && iv.same_as(attach_ivar)) {
          start_attach = true;
        }
222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245
        if (start_attach) path.push_back(iv);
      }
      CHECK(start_attach)
          << "Invalid Schedule: cannot find attach point " << attach_ivar
          << " in the schedule of " << s->op;
    }
    if (!ret.count(stage->op)) {
      ret.Set(stage->op, path);
    }
  }
  return ret;
}

// graph of push reach relation of tensor dimensions
using ReachGraph = std::unordered_map<TensorDimKey, std::vector<TensorDimKey> >;

ReachGraph GetReachGraph(const Array<Operation>& ops) {
  ReachGraph reach;
  std::unordered_set<const Node*> bset;
  for (size_t i = 0; i < ops.size(); ++i) {
    bset.insert(ops[i].get());
  }

  for (Operation op : ops) {
246 247 248
    if (const auto* scan_op = op.as<ScanOpNode>()) {
      const auto& update = scan_op->update;
      const auto& init = scan_op->init;
249 250
      for (size_t i = 0; i < update.size(); ++i) {
        Tensor t = op.output(i);
251
        for (int k = 1; k < static_cast<int>(update[i]->shape.size()); ++k) {
252 253 254 255 256 257
          reach[TensorDimKey(t, k)].emplace_back(
              TensorDimKey(update[i], k));
          reach[TensorDimKey(t, k)].emplace_back(
              TensorDimKey(init[i], k));
        }
      }
258
    } else if (const auto* compute_op = op.as<ComputeOpNode>()) {
259
      std::unordered_map<const Node*, TensorDimKey> vmap;
260
      const auto& axis = compute_op->axis;
261 262 263 264 265 266 267 268 269 270
      Tensor t = op.output(0);
      for (size_t i = 0; i < axis.size(); ++i) {
        vmap[axis[i]->var.get()] = TensorDimKey(t, i);
        reach[TensorDimKey(t, i)] = {};
      }
      auto fvisit = [&vmap, &reach, &bset](const NodeRef& n) {
        const ir::Call *call = n.as<ir::Call>();
        if (call != nullptr && call->func.defined()) {
          if (!bset.count(call->func.get())) return;
          for (size_t i = 0; i < call->args.size(); ++i) {
271
            TensorDimKey dkey(call, static_cast<int>(i));
272 273 274 275 276 277 278 279 280 281 282
            auto fpush = [&dkey, &vmap, &reach](const NodeRef& node) {
              const Variable *v = node.as<Variable>();
              auto it = vmap.find(v);
              if (it != vmap.end()) {
                reach[it->second].push_back(dkey);
              }
            };
            ir::PostOrderVisit(call->args[i], fpush);
          }
        }
      };
283
      for (auto& e : compute_op->body) {
284 285
        ir::PostOrderVisit(e, fvisit);
      }
286 287 288 289 290
    }
  }
  return reach;
}

291 292 293 294
Array<Operation> ScanGetBody(const Operation& scan_op) {
  const ScanOpNode* scan = scan_op.as<ScanOpNode>();
  // Get the body.
  Array<Tensor> inputs;
295
  for (Tensor t : scan->state_placeholder) {
296
    inputs.push_back(t);
297
  }
298 299 300 301
  for (Tensor t : scan->inputs) {
    inputs.push_back(t);
  }
  return GetSubGraph(scan->update, inputs, false);
302 303
}

304
Map<IterVar, Expr> ScanFixPointAnalysis(const Operation& scan_op) {
305
  const ScanOpNode* scan = scan_op.as<ScanOpNode>();
306
  Array<Operation> body = ScanGetBody(scan_op);
307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332

  std::unordered_map<TensorDimKey, const Node*> exact_reach;
  std::unordered_set<const Node*> fail_set;

  for (size_t i = 0, sp_idx = 0; i < scan->update.size(); ++i) {
    for (size_t k = 1; k < scan->update[i]->shape.size(); ++k, ++sp_idx) {
      TensorDimKey key(scan->state_placeholder[i], k);
      exact_reach[key] = scan->spatial_axis_[sp_idx].get();
    }
  }
  // merge exact reach
  auto f_merge_key = [&exact_reach, &fail_set](
      const TensorDimKey& dst, const TensorDimKey& src) {
    auto sit = exact_reach.find(src);
    if (sit == exact_reach.end()) return;
    auto dit = exact_reach.find(dst);
    if (dit == exact_reach.end()) {
      exact_reach[dst] = sit->second;
    } else {
      if (dit->second != sit->second) {
        fail_set.insert(dit->second);
        fail_set.insert(sit->second);
      }
    }
  };
  // prop exact reach back.
333 334
  for (size_t i = 0; i < body.size(); ++i) {
    const Operation& op = body[i];
335 336 337
    if (const auto* scan_op = op.as<ScanOpNode>()) {
      const auto& update = scan_op->update;
      const auto& init = scan_op->init;
338 339
      for (size_t i = 0; i < update.size(); ++i) {
        Tensor t = op.output(i);
340
        for (size_t k = 1; k < update[i]->shape.size(); ++k) {
341 342 343 344
          f_merge_key(TensorDimKey(t, k), TensorDimKey(update[i], k));
          f_merge_key(TensorDimKey(t, k), TensorDimKey(init[i], k));
        }
      }
345
    } else if (const auto* compute_op = op.as<ComputeOpNode>()) {
346
      std::unordered_map<const Node*, std::vector<TensorDimKey> > vmap;
347
      const auto& axis = compute_op->axis;
348
      for (size_t i = 0; i < axis.size(); ++i) {
349 350 351 352 353
        std::vector<TensorDimKey> keys;
        for (int j = 0; j < op->num_outputs(); ++j) {
          keys.emplace_back(op.output(j), i);
        }
        vmap[axis[i]->var.get()] = std::move(keys);
354 355 356 357 358 359 360
      }
      auto fvisit = [&vmap, &f_merge_key, &exact_reach, &fail_set](
          const NodeRef& n) {
        const ir::Call *call = n.as<ir::Call>();
        if (call != nullptr && call->func.defined()) {
          for (size_t i = 0; i < call->args.size(); ++i) {
            auto it = vmap.find(call->args[i].get());
361
            TensorDimKey src(call, static_cast<int>(i));
362
            if (it != vmap.end()) {
363 364 365 366
              const std::vector<TensorDimKey>& keys = it->second;
              for (const auto& key : keys) {
                f_merge_key(key, src);
              }
367 368 369 370 371 372 373 374
            } else {
              if (exact_reach.count(src)) {
                fail_set.insert(exact_reach.at(src));
              }
            }
          }
        }
      };
375
      for (auto& e : compute_op->body) {
376 377
        ir::PostOrderVisit(e, fvisit);
      }
378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432
    }
  }
  ReachGraph reach;
  Map<IterVar, Expr> ret;
  std::unordered_set<TensorDimKey> place_holder_ref;
  for (size_t i = 0; i < scan->state_placeholder.size(); ++i) {
    for (size_t k = 0; k < scan->state_placeholder[i]->shape.size(); ++k) {
      place_holder_ref.insert(TensorDimKey(scan->state_placeholder[i], k));
    }
  }

  for (size_t i = 0, sp_idx = 0; i < scan->update.size(); ++i) {
    for (size_t k = 1; k < scan->update[i]->shape.size(); ++k, ++sp_idx) {
      TensorDimKey key(scan->update[i], k);
      TensorDimKey target(scan->state_placeholder[i], k);
      IterVar sp_iv = scan->spatial_axis_[sp_idx];
      if (fail_set.count(sp_iv.get()) ||
          !exact_reach.count(key) ||
          exact_reach.at(key) != sp_iv.get()) {
        ret.Set(sp_iv, make_const(Int(32), 0));
      } else {
        // now we proved exact match, need to prove no interference with other graph.
        if (reach.size() == 0) reach = GetReachGraph(body);
        // do a DFS
        std::unordered_set<TensorDimKey> visited;
        std::vector<TensorDimKey> stack{key};
        visited.insert(key);
        while (!stack.empty()) {
          TensorDimKey k = stack.back();
          if (k != target && place_holder_ref.count(k)) break;
          stack.pop_back();
          if (!reach.count(k)) {
            LOG(FATAL) << "cannot find reach of " << k.f << "-" << k.dim;
          }

          for (TensorDimKey kk : reach.at(k)) {
            if (visited.count(kk)) {
              continue;
            }
            visited.insert(kk);
            stack.push_back(kk);
          }
        }
        if (!stack.empty()) {
          // failed the prove.
          ret.Set(sp_iv, make_const(Int(32), 0));
        } else {
          ret.Set(sp_iv, make_const(Int(32), 1));
        }
      }
    }
  }
  return ret;
}

433 434
}  // namespace schedule
}  // namespace tvm