/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * Copyright (c) 2017 by Contributors * \file graph_hash.h * \brief The graph hashing function. */ #ifndef NNVM_COMPILER_GRAPH_HASH_H_ #define NNVM_COMPILER_GRAPH_HASH_H_ #include <dmlc/common.h> #include <nnvm/graph.h> #include <tvm/operation.h> #include <string> #include <utility> namespace nnvm { namespace compiler { class GraphKey; /*! \brief Key to a graph compiler cache */ struct GraphKeyNode : public tvm::Node { /*! \brief The graph structure */ Graph graph; /* \brief The inputs to the function */ tvm::Array<Tensor> inputs; /*! \brief The target */ std::string target; // Cached internal hash key, invisible to the user. // The graph hash key is ensured always not to be 0 mutable size_t cache_hash_key_{0}; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("inputs", &inputs); v->Visit("target", &target); } static GraphKey make(Graph graph, tvm::Array<Tensor> inputs, std::string target); static constexpr const char* _type_key = "GraphKey"; TVM_DECLARE_NODE_TYPE_INFO(GraphKeyNode, tvm::Node); }; TVM_DEFINE_NODE_REF(GraphKey, GraphKeyNode); /*! \brief Hashing function for graph key */ struct GraphKeyHash { size_t operator()(const GraphKey& gkey) const { return Hash(gkey); } static size_t Hash(const GraphKey& gkey); }; /*! \brief function for graph key */ struct GraphKeyEqual { bool operator()(const GraphKey& a, const GraphKey& b) const { return Equal(a, b); } static bool Equal(const GraphKey& a, const GraphKey& b); }; /*! * \brief Create a hash code for a given graph. * \return The hash code of the graph. */ size_t GraphHash(const Graph& graph); /*! * \brief Compare two graphs * return empty string if they are equal * otherwise return error message * \param a The first graph. * \param b The second graph. * \return empty string if they are equal, otherwise return error message. */ std::string GraphDeepCompare(const Graph& a, const Graph& b, bool compare_variable_attr); } // namespace compiler } // namespace nnvm #endif // NNVM_COMPILER_GRAPH_HASH_H_