pass.cc 1.38 KB
Newer Older
tqchen committed
1 2 3 4 5
/*!
 *  Copyright (c) 2016 by Contributors
 * \file pass.cc
 * \brief Support for pass registry.
 */
tqchen committed
6
#include <nnvm/pass.h>
tqchen committed
7 8 9 10
#include <algorithm>

namespace dmlc {
// enable registry
tqchen committed
11
DMLC_REGISTRY_ENABLE(nnvm::PassFunctionReg);
tqchen committed
12 13
}  // namespace dmlc

tqchen committed
14
namespace nnvm {
tqchen committed
15 16 17 18 19 20 21 22 23 24

const PassFunctionReg* FindPassDep(const std::string&attr_name) {
  for (auto* r : dmlc::Registry<PassFunctionReg>::List()) {
    for (auto& s : r->graph_attr_targets) {
      if (s == attr_name) return r;
    }
  }
  return nullptr;
}

25 26
Graph ApplyPasses(Graph g,
                  const std::vector<std::string>& pass) {
tqchen committed
27 28 29 30 31 32 33 34 35 36
  std::vector<const PassFunctionReg*> fpass;
  for (auto& name : pass) {
    auto* reg = dmlc::Registry<PassFunctionReg>::Find(name);
    CHECK(reg != nullptr)
        << "Cannot find pass " << name << " in the registry";
    fpass.push_back(reg);
  }

  for (auto r : fpass) {
    for (auto& dep : r->graph_attr_dependency) {
37
      if (g.attrs.count(dep) == 0) {
tqchen committed
38 39 40 41 42 43 44 45 46 47 48
        auto* pass_dep = FindPassDep(dep);
        std::string msg;
        if (pass_dep != nullptr) {
          msg = " The attribute is provided by pass " + pass_dep->name;
        }
        LOG(FATAL) << "Graph attr dependency " << dep
                   << " is required by pass " << r->name
                   << " but is not available "
                   << msg;
      }
    }
49
    g = r->body(std::move(g));
tqchen committed
50
  }
51

tqchen committed
52 53 54
  return g;
}

tqchen committed
55
}  // namespace nnvm