/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * Copyright (c) 2017 by Contributors * \file graph_deep_compare.cc * \brief Deep compare two graph structure */ #include <dmlc/common.h> #include <nnvm/graph.h> #include <nnvm/op_attr_types.h> #include <nnvm/compiler/packed_func_ext.h> #include <tvm/ir.h> #include <tvm/runtime/packed_func.h> #include <functional> #include <vector> #include <utility> #include <algorithm> #include "node_attr.h" #include "graph_hash.h" namespace nnvm { namespace compiler { using namespace tvm; using tvm::ir::IntImm; size_t HashPlaceHolder(const Tensor& t) { size_t key = t->shape.size(); key = dmlc::HashCombine(key, (t->dtype.code() << 8) | t->dtype.bits()); for (Expr s : t->shape) { if (const IntImm* op = s.as<IntImm>()) { key = dmlc::HashCombine(key, op->value); } } return key; } bool PlaceHolderEqual(const Tensor& a, const Tensor& b) { if (a->shape.size() != b->shape.size()) return false; if (a->dtype != b->dtype) return false; for (size_t i = 0; i < a->shape.size(); ++i) { const IntImm* a_value = a->shape[i].as<IntImm>(); const IntImm* b_value = b->shape[i].as<IntImm>(); if (a_value && b_value == nullptr) return false; if (b_value && a_value == nullptr) return false; if (a_value == nullptr && b_value == nullptr) { continue; } if (a_value->value != b_value->value) return false; } return true; } size_t GraphKeyHash::Hash(const GraphKey& gkey) { if (gkey->cache_hash_key_ != 0) return gkey->cache_hash_key_; size_t key = dmlc::HashCombine(GraphHash(gkey->graph), gkey->target); key = dmlc::HashCombine(key, gkey->inputs.size()); for (size_t i = 0; i < gkey->inputs.size(); ++i) { key = dmlc::HashCombine(key, HashPlaceHolder(gkey->inputs[i])); } if (key == 0) key = 1; gkey->cache_hash_key_ = key; return key; } bool GraphKeyEqual::Equal(const GraphKey& a, const GraphKey& b) { if (a->target != b->target) return false; if (a->inputs.size() != b->inputs.size()) return false; for (size_t i = 0; i < a->inputs.size(); ++i) { if (!PlaceHolderEqual(a->inputs[i], b->inputs[i])) return false; } if (GraphDeepCompare(a->graph, b->graph, false).length() != 0) return false; return true; } GraphKey GraphKeyNode::make(Graph graph, tvm::Array<Tensor> inputs, std::string target) { auto n = tvm::make_node<GraphKeyNode>(); n->graph = std::move(graph); n->inputs = inputs; n->target = std::move(target); return GraphKey(n); } TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch<GraphKeyNode>([](const GraphKeyNode *op, IRPrinter *p) { p->stream << "GraphKeyNode("<< op << ")"; }); // Run graph hash size_t GraphHash(const Graph& graph) { const IndexedGraph& idx = graph.indexed_graph(); size_t key = 0; // Combine a linearized sequence of ops in subgraph key = dmlc::HashCombine(key, idx.num_nodes()); std::hash<std::string> str_hash; std::vector<size_t> hash_temp; for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { const IndexedGraph::Node& inode = idx[nid]; // Use name instad op address so it is deterministic across runs if (inode.source->is_variable()) continue; key = dmlc::HashCombine(key, inode.source->op()->name); hash_temp.clear(); for (const auto& kv : GetAttrDict(inode.source->attrs)) { hash_temp.push_back(dmlc::HashCombine(str_hash(kv.first), kv.second)); } // to make sure it is deterministic // since unordered_map is not deterministic std::sort(hash_temp.begin(), hash_temp.end()); for (size_t value : hash_temp) { key = dmlc::HashCombine(key, value); } } return key; } // deep compare the graph structure // not considering the graph attributes // return non-empty error message if the graph mismatch. // the comparator won't match name of intermediate node. // compare_var_attr std::string GraphDeepCompare(const Graph& a, const Graph& b, bool compare_variable_attr) { const IndexedGraph& idxa = a.indexed_graph(); const IndexedGraph& idxb = b.indexed_graph(); std::ostringstream err; if (idxa.num_nodes() != idxb.num_nodes()) { err << "Number of nodes mismatch (" << idxa.num_nodes() << " v.s " << idxb.num_nodes() << ")"; return err.str(); } if (idxa.num_node_entries() != idxb.num_node_entries()) { err << "Number of node entry mismatch"; return err.str(); } if (idxa.outputs().size() != idxb.outputs().size()) { err << "Number of outputs mismatch"; return err.str(); } for (size_t i = 0; i < idxa.outputs().size(); ++i) { if (idxa.outputs()[i].node_id != idxb.outputs()[i].node_id || idxa.outputs()[i].index != idxb.outputs()[i].index) { err << "Output entry mismatch"; return err.str(); } } if (idxa.input_nodes().size() != idxb.input_nodes().size()) { err << "Number of inputs mismatch"; return err.str(); } for (uint32_t nid = 0; nid < idxa.num_nodes(); ++nid) { const IndexedGraph::Node& anode = idxa[nid]; const IndexedGraph::Node& bnode = idxb[nid]; if (anode.source->op() != bnode.source->op()) { err << "Node mismatch "; return err.str(); } if (anode.source->is_variable()) { CHECK(bnode.source->is_variable()); if (!compare_variable_attr) continue; } AttrDict adict = GetAttrDict(anode.source->attrs); AttrDict bdict = GetAttrDict(bnode.source->attrs); auto fmatch = [&err, &anode](const AttrDict& adict, const AttrDict& bdict) { for (const auto& kv : adict) { auto it = bdict.find(kv.first); if (it != bdict.end()) { if (it->second != kv.second) { err << "Node attr mismatch, op=" << anode.source->attrs.name << " attr_key=" << kv.first << " " << it->second << " v.s. " << kv.second; return false; } } else { err << "One attr_key=" << kv.first << " is missing in another " << "op=" << anode.source->attrs.name; return false; } } return true; }; if (!fmatch(adict, bdict)) return err.str(); if (adict.size() != bdict.size()) { CHECK(!fmatch(bdict, adict)); return err.str(); } if (anode.inputs.size() != bnode.inputs.size()) { err << "Node input mismatch, op=" << anode.source->attrs.name; return err.str(); } if (anode.control_deps.size() != bnode.control_deps.size()) { err << "Node control_deps mistach, op=" << anode.source->attrs.name; return err.str(); } for (size_t i = 0; i < anode.inputs.size(); ++i) { const IndexedGraph::NodeEntry& ae = anode.inputs[i]; const IndexedGraph::NodeEntry& be = bnode.inputs[i]; if (ae.node_id != be.node_id || ae.index != be.index || ae.version != be.version) { err << "Node input mismatch on, op=" << anode.source->attrs.name; return err.str(); } } for (size_t i = 0; i < anode.control_deps.size(); ++i) { if (anode.control_deps[i] != bnode.control_deps[i]) { err << "Node control_dep mismatch on, op=" << anode.source->attrs.name; return err.str(); } } } return ""; } TVM_REGISTER_GLOBAL("nnvm.graph.DeepCompare") .set_body_typed(GraphDeepCompare); } // namespace compiler } // namespace nnvm