pass.cc 2.17 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

tqchen committed
20 21 22 23 24
/*!
 *  Copyright (c) 2016 by Contributors
 * \file pass.cc
 * \brief Support for pass registry.
 */
tqchen committed
25
#include <nnvm/pass.h>
tqchen committed
26 27 28 29
#include <algorithm>

namespace dmlc {
// enable registry
tqchen committed
30
DMLC_REGISTRY_ENABLE(nnvm::PassFunctionReg);
tqchen committed
31 32
}  // namespace dmlc

tqchen committed
33
namespace nnvm {
tqchen committed
34 35 36 37 38 39 40 41 42 43

const PassFunctionReg* FindPassDep(const std::string&attr_name) {
  for (auto* r : dmlc::Registry<PassFunctionReg>::List()) {
    for (auto& s : r->graph_attr_targets) {
      if (s == attr_name) return r;
    }
  }
  return nullptr;
}

44 45
Graph ApplyPasses(Graph g,
                  const std::vector<std::string>& pass) {
tqchen committed
46 47 48 49 50 51 52 53 54 55
  std::vector<const PassFunctionReg*> fpass;
  for (auto& name : pass) {
    auto* reg = dmlc::Registry<PassFunctionReg>::Find(name);
    CHECK(reg != nullptr)
        << "Cannot find pass " << name << " in the registry";
    fpass.push_back(reg);
  }

  for (auto r : fpass) {
    for (auto& dep : r->graph_attr_dependency) {
56
      if (g.attrs.count(dep) == 0) {
tqchen committed
57 58 59 60 61 62 63 64 65 66 67
        auto* pass_dep = FindPassDep(dep);
        std::string msg;
        if (pass_dep != nullptr) {
          msg = " The attribute is provided by pass " + pass_dep->name;
        }
        LOG(FATAL) << "Graph attr dependency " << dep
                   << " is required by pass " << r->name
                   << " but is not available "
                   << msg;
      }
    }
68
    g = r->body(std::move(g));
tqchen committed
69
  }
70

tqchen committed
71 72 73
  return g;
}

tqchen committed
74
}  // namespace nnvm