graph_hash.cc 7.36 KB
Newer Older
1 2 3 4 5
/*!
 *  Copyright (c) 2017 by Contributors
 * \file graph_deep_compare.cc
 * \brief Deep compare two graph structure
 */
6
#include <dmlc/common.h>
7 8 9
#include <nnvm/graph.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/compiler/packed_func_ext.h>
10
#include <tvm/ir.h>
11
#include <tvm/runtime/packed_func.h>
12
#include <functional>
13 14 15 16
#include <vector>
#include <algorithm>
#include "node_attr.h"
#include "graph_hash.h"
17 18 19 20

namespace nnvm {
namespace compiler {

21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
using namespace tvm;
using tvm::ir::IntImm;

size_t HashPlaceHolder(const Tensor& t) {
  size_t key = t->shape.size();
  key = dmlc::HashCombine(key, (t->dtype.code() << 8) | t->dtype.bits());
  for (Expr s : t->shape) {
    if (const IntImm* op = s.as<IntImm>()) {
      key = dmlc::HashCombine(key, op->value);
    }
  }
  return key;
}

bool PlaceHolderEqual(const Tensor& a, const Tensor& b) {
  if (a->shape.size() != b->shape.size()) return false;
  if (a->dtype != b->dtype) return false;
  for (size_t i = 0; i < a->shape.size(); ++i) {
    const IntImm* a_value = a->shape[i].as<IntImm>();
    const IntImm* b_value = b->shape[i].as<IntImm>();
    if (a_value && b_value == nullptr) return false;
    if (b_value && a_value == nullptr) return false;
    if (a_value == nullptr && b_value == nullptr) {
      continue;
    }
    if (a_value->value != b_value->value) return false;
  }
  return true;
}

size_t GraphKeyHash::Hash(const GraphKey& gkey)  {
  if (gkey->cache_hash_key_ != 0) return gkey->cache_hash_key_;
  size_t key = dmlc::HashCombine(GraphHash(gkey->graph), gkey->target);
  key = dmlc::HashCombine(key, gkey->inputs.size());
  for (size_t i = 0; i < gkey->inputs.size(); ++i) {
    key = dmlc::HashCombine(key, HashPlaceHolder(gkey->inputs[i]));
  }
  if (key == 0) key = 1;
  gkey->cache_hash_key_ = key;
  return key;
}

bool GraphKeyEqual::Equal(const GraphKey& a,
                          const GraphKey& b) {
  if (a->target != b->target) return false;
  if (a->inputs.size() != b->inputs.size()) return false;
  for (size_t i = 0; i < a->inputs.size(); ++i) {
    if (!PlaceHolderEqual(a->inputs[i], b->inputs[i])) return false;
  }
  if (GraphDeepCompare(a->graph, b->graph, false).length() != 0) return false;
  return true;
}

GraphKey GraphKeyNode::make(Graph graph,
                            tvm::Array<Tensor> inputs,
                            std::string target) {
  std::shared_ptr<GraphKeyNode> n
      = std::make_shared<GraphKeyNode>();
  n->graph = std::move(graph);
  n->inputs = inputs;
  n->target = std::move(target);
  return GraphKey(n);
}

85
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
.set_dispatch<GraphKeyNode>([](const GraphKeyNode *op, IRPrinter *p) {
    p->stream << "GraphKeyNode("<< op << ")";
});


// Run graph hash
size_t GraphHash(const Graph& graph) {
  const IndexedGraph& idx = graph.indexed_graph();
  size_t key = 0;
  // Combine a linearized sequence of ops in subgraph
  key = dmlc::HashCombine(key, idx.num_nodes());
  std::hash<std::string> str_hash;
  std::vector<size_t> hash_temp;
  for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
    const IndexedGraph::Node& inode = idx[nid];
    // Use name instad op address so it is deterministic across runs
    if (inode.source->is_variable()) continue;
103
    key = dmlc::HashCombine(key, inode.source->op()->name);
104 105 106 107 108 109 110 111 112 113 114 115 116 117
    hash_temp.clear();
    for (const auto& kv : GetAttrDict(inode.source->attrs)) {
      hash_temp.push_back(dmlc::HashCombine(str_hash(kv.first), kv.second));
    }
    // to make sure it is deterministic
    // since unordered_map is not deterministic
    std::sort(hash_temp.begin(), hash_temp.end());
    for (size_t value : hash_temp) {
      key = dmlc::HashCombine(key, value);
    }
  }
  return key;
}

118 119 120 121
// deep compare the graph structure
// not considering the graph attributes
// return non-empty error message if the graph mismatch.
// the comparator won't match name of intermediate node.
122
// compare_var_attr
123 124 125
std::string GraphDeepCompare(const Graph& a,
                             const Graph& b,
                             bool compare_variable_attr) {
126 127 128 129
  const IndexedGraph& idxa = a.indexed_graph();
  const IndexedGraph& idxb = b.indexed_graph();
  std::ostringstream err;
  if (idxa.num_nodes() != idxb.num_nodes()) {
130
    err << "Number of nodes mismatch (" <<  idxa.num_nodes() << " v.s " << idxb.num_nodes() << ")";
131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
    return err.str();
  }
  if (idxa.num_node_entries() != idxb.num_node_entries()) {
    err << "Number of node entry mismatch";
    return err.str();
  }
  if (idxa.outputs().size() != idxb.outputs().size()) {
    err << "Number of outputs mismatch";
    return err.str();
  }
  for (size_t i = 0; i < idxa.outputs().size(); ++i) {
    if (idxa.outputs()[i].node_id != idxb.outputs()[i].node_id ||
        idxa.outputs()[i].index != idxb.outputs()[i].index) {
      err << "Output entry mismatch";
      return err.str();
    }
  }
  if (idxa.input_nodes().size() != idxb.input_nodes().size()) {
    err << "Number of inputs mismatch";
    return err.str();
  }

  for (uint32_t nid = 0; nid < idxa.num_nodes(); ++nid) {
    const IndexedGraph::Node& anode = idxa[nid];
    const IndexedGraph::Node& bnode = idxb[nid];
    if (anode.source->op() != bnode.source->op()) {
      err << "Node mismatch ";
      return err.str();
    }
160 161 162 163
    if (anode.source->is_variable()) {
      CHECK(bnode.source->is_variable());
      if (!compare_variable_attr) continue;
    }
164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219
    AttrDict adict = GetAttrDict(anode.source->attrs);
    AttrDict bdict = GetAttrDict(bnode.source->attrs);

    auto fmatch = [&err, &anode](const AttrDict& adict, const AttrDict& bdict) {
      for (const auto& kv : adict) {
        auto it = bdict.find(kv.first);
        if (it != bdict.end()) {
          if (it->second != kv.second) {
            err << "Node attr mismatch, op=" << anode.source->attrs.name
                << " attr_key=" << kv.first << " " << it->second
                << " v.s. " << kv.second;
            return false;
          }
        } else {
          err << "One attr_key=" << kv.first << " is missing in another "
               << "op=" << anode.source->attrs.name;
          return false;
        }
      }
      return true;
    };
    if (!fmatch(adict, bdict)) return err.str();
    if (adict.size() != bdict.size()) {
      CHECK(!fmatch(bdict, adict));
      return err.str();
    }
    if (anode.inputs.size() != bnode.inputs.size()) {
      err << "Node input mismatch, op=" << anode.source->attrs.name;
      return err.str();
    }
    if (anode.control_deps.size() != bnode.control_deps.size()) {
      err << "Node control_deps mistach, op=" << anode.source->attrs.name;
      return err.str();
    }
    for (size_t i = 0; i < anode.inputs.size(); ++i) {
      const IndexedGraph::NodeEntry& ae = anode.inputs[i];
      const IndexedGraph::NodeEntry& be = bnode.inputs[i];
      if (ae.node_id != be.node_id ||
          ae.index != be.index ||
          ae.version != be.version) {
        err << "Node input mismatch on, op=" << anode.source->attrs.name;
        return err.str();
      }
    }
    for (size_t i = 0; i < anode.control_deps.size(); ++i) {
      if (anode.control_deps[i] != bnode.control_deps[i]) {
        err << "Node control_dep mismatch on, op=" << anode.source->attrs.name;
        return err.str();
      }
    }
  }
  return "";
}

TVM_REGISTER_GLOBAL("nnvm.graph.DeepCompare")
.set_body([](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue *rv) {
220
    *rv = GraphDeepCompare(args[0], args[1], args[2]);
221 222 223
  });
}  // namespace compiler
}  // namespace nnvm