Unverified Commit a2edd01b by Zhi Committed by GitHub

relay::StructuralHash to tvm::StructuralHash (#5166)

parent 919ae889
...@@ -225,28 +225,6 @@ TVM_DLL Map<Expr, Integer> CollectDeviceAnnotationOps(const Expr& expr); ...@@ -225,28 +225,6 @@ TVM_DLL Map<Expr, Integer> CollectDeviceAnnotationOps(const Expr& expr);
*/ */
TVM_DLL Array<Pattern> UnmatchedCases(const Match& match, const IRModule& mod); TVM_DLL Array<Pattern> UnmatchedCases(const Match& match, const IRModule& mod);
/*! \brief A hashing structure in the style of std::hash. */
struct StructuralHash {
/*! \brief Hash a Relay type.
*
* Implements structural hashing of a Relay type.
*
* \param type the type to hash.
*
* \return the hash value.
*/
size_t operator()(const Type& type) const;
/*! \brief Hash a Relay expression.
*
* Implements structural hashing of a Relay expression.
*
* \param expr the expression to hash.
*
* \return the hash value.
*/
size_t operator()(const Expr& expr) const;
};
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
......
...@@ -20,11 +20,10 @@ ...@@ -20,11 +20,10 @@
This file contains the set of passes for Relay, which exposes an interface for This file contains the set of passes for Relay, which exposes an interface for
configuring the passes and scripting them in Python. configuring the passes and scripting them in Python.
""" """
from tvm.ir import RelayExpr, IRModule from tvm.ir import IRModule
from . import _ffi_api from . import _ffi_api
from .feature import Feature from .feature import Feature
from ..ty import Type
def post_order_visit(expr, fvisit): def post_order_visit(expr, fvisit):
...@@ -314,29 +313,6 @@ def detect_feature(a, b=None): ...@@ -314,29 +313,6 @@ def detect_feature(a, b=None):
return {Feature(int(x)) for x in _ffi_api.detect_feature(a, b)} return {Feature(int(x)) for x in _ffi_api.detect_feature(a, b)}
def structural_hash(value):
"""Hash a Relay expression structurally.
Parameters
----------
expr : Union[tvm.relay.Expr, tvm.relay.Type]
The expression to hash.
Returns
-------
result : int
The hash value
"""
if isinstance(value, RelayExpr):
return int(_ffi_api._expr_hash(value))
elif isinstance(value, Type):
return int(_ffi_api._type_hash(value))
else:
msg = ("found value of type {0} expected" +
"relay.Expr or relay.Type").format(type(value))
raise TypeError(msg)
def extract_fused_functions(mod): def extract_fused_functions(mod):
"""Pass to extract IRModule of only fused primitive functions. """Pass to extract IRModule of only fused primitive functions.
......
...@@ -27,7 +27,7 @@ import tvm ...@@ -27,7 +27,7 @@ import tvm
from tvm.ir import IRModule from tvm.ir import IRModule
from tvm.relay.prelude import Prelude from tvm.relay.prelude import Prelude
from tvm.relay.analysis import structural_hash as s_hash from tvm.ir import structural_hash as s_hash
from .. import analysis from .. import analysis
from .. import expr as _expr from .. import expr as _expr
......
...@@ -238,7 +238,7 @@ class PythonConverter(ExprFunctor): ...@@ -238,7 +238,7 @@ class PythonConverter(ExprFunctor):
# compile the function and register globally # compile the function and register globally
cc_key = compile_engine.CCacheKey(op, self.tgt) cc_key = compile_engine.CCacheKey(op, self.tgt)
func_hash = relay.analysis.structural_hash(op) func_hash = tvm.ir.structural_hash(op)
op_name = '_lowered_op_{}'.format(func_hash) op_name = '_lowered_op_{}'.format(func_hash)
if not tvm.get_global_func(op_name, allow_missing=True): if not tvm.get_global_func(op_name, allow_missing=True):
jitted = self.engine.jit(cc_key, self.tgt) jitted = self.engine.jit(cc_key, self.tgt)
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
* \file extract_fused_functions.cc * \file extract_fused_functions.cc
* \brief Apply fusion and extract fused primitive functions from an IRModule * \brief Apply fusion and extract fused primitive functions from an IRModule
*/ */
#include <tvm/node/structural_hash.h>
#include <tvm/relay/analysis.h> #include <tvm/relay/analysis.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
...@@ -55,7 +56,7 @@ class FusedFunctionExtractorWrapper : private ExprVisitor { ...@@ -55,7 +56,7 @@ class FusedFunctionExtractorWrapper : private ExprVisitor {
if (n->HasNonzeroAttr(attr::kPrimitive)) { if (n->HasNonzeroAttr(attr::kPrimitive)) {
// Add function to functions, keyed by function hash string // Add function to functions, keyed by function hash string
Function func = Function(n->params, n->body, n->ret_type, n->type_params, n->attrs); Function func = Function(n->params, n->body, n->ret_type, n->type_params, n->attrs);
size_t hash_ = StructuralHash()(func); size_t hash_ = tvm::StructuralHash()(func);
this->functions.Set(std::to_string(hash_), func); this->functions.Set(std::to_string(hash_), func);
} }
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#define TVM_RELAY_BACKEND_COMPILE_ENGINE_H_ #define TVM_RELAY_BACKEND_COMPILE_ENGINE_H_
#include <tvm/node/structural_equal.h> #include <tvm/node/structural_equal.h>
#include <tvm/node/structural_hash.h>
#include <tvm/tir/lowered_func.h> #include <tvm/tir/lowered_func.h>
#include <tvm/runtime/module.h> #include <tvm/runtime/module.h>
#include <tvm/relay/analysis.h> #include <tvm/relay/analysis.h>
...@@ -258,7 +259,7 @@ bool IsDynamic(const Type& ty); ...@@ -258,7 +259,7 @@ bool IsDynamic(const Type& ty);
inline size_t CCacheKeyNode::Hash() const { inline size_t CCacheKeyNode::Hash() const {
if (hash_ != 0) return hash_; if (hash_ != 0) return hash_;
// do structral hash, avoid 0. // do structral hash, avoid 0.
hash_ = StructuralHash()(this->source_func); hash_ = tvm::StructuralHash()(this->source_func);
hash_ = dmlc::HashCombine( hash_ = dmlc::HashCombine(
hash_, std::hash<std::string>()(target->str())); hash_, std::hash<std::string>()(target->str()));
if (hash_ == 0) hash_ = 1; if (hash_ == 0) hash_ = 1;
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
*/ */
#include <tvm/node/structural_equal.h> #include <tvm/node/structural_equal.h>
#include <tvm/node/structural_hash.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/support/logging.h> #include <tvm/support/logging.h>
...@@ -39,7 +40,7 @@ namespace relay { ...@@ -39,7 +40,7 @@ namespace relay {
namespace vm { namespace vm {
inline std::string GenerateName(const Function& func) { inline std::string GenerateName(const Function& func) {
size_t hash = StructuralHash()(func); size_t hash = tvm::StructuralHash()(func);
return std::string("lifted_name") + std::to_string(hash); return std::string("lifted_name") + std::to_string(hash);
} }
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/relay/ir/hash.cc
* \brief Hash functions for Relay types and expressions.
*/
#include <tvm/ir/type_functor.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/relay/analysis.h>
#include <tvm/ir/attrs.h>
#include "../../ir/attr_functor.h"
namespace tvm {
namespace relay {
// Hash handler for Relay.
class RelayHashHandler:
public AttrsHashHandler,
public TypeFunctor<size_t(const Type&)>,
public ExprFunctor<size_t(const Expr&)>,
public PatternFunctor<size_t(const Pattern&)> {
public:
explicit RelayHashHandler() {}
/*!
* Compute hash of a node.
* \param ref The node to hash.
* \return the hash value.
*/
size_t Hash(const ObjectRef& ref) {
if (!ref.defined()) return ObjectHash()(ref);
if (ref->IsInstance<TypeNode>()) {
return TypeHash(Downcast<Type>(ref));
}
if (ref->IsInstance<ExprNode>()) {
return ExprHash(Downcast<Expr>(ref));
}
return AttrHash(ref);
}
/*!
* Compute hash of the attributes.
* \param ref The attributes.
* \return the hash value
*/
size_t AttrHash(const ObjectRef& ref) {
if (!ref.defined()) {
return ObjectHash()(ref);
}
return AttrsHashHandler::Hash(ref);
}
/*!
* Compute hash of a Relay type.
* \param ref The type to hash.
* \param rhs The right hand operand.
* \return the hash value.
*/
size_t TypeHash(const Type& type) {
if (!type.defined()) {
return ObjectHash()(type);
}
auto found = hash_map_.find(type);
if (found != hash_map_.end()) {
return found->second;
} else {
auto hash = this->VisitType(type);
hash_map_.insert({type, hash});
return hash;
}
}
/*!
* Compute the hash of an expression.
*
* \note We run graph structural equality checking when comparing two Exprs.
* This means that AlphaEqualHandler can only be used once for each pair.
* The equality checker checks data-flow equvalence of the Expr DAG.
* This function also runs faster as it memomizes equal_map.
*
* \param expr The expression to hash.
* \return the hash value.
*/
size_t ExprHash(const Expr& expr) {
if (!expr.defined()) {
return ObjectHash()(expr);
}
auto found = hash_map_.find(expr);
if (found != hash_map_.end()) {
return found->second;
} else {
auto hash = this->VisitExpr(expr);
hash_map_.insert({expr, hash});
return hash;
}
}
protected:
/*!
* \brief Hash a DataType.
* \param dtype The dtype to hash.
* \return the hash value.
*/
size_t DataTypeHash(const DataType& dtype) {
return ::tvm::AttrsHash()(dtype);
}
using AttrsHashHandler::VisitAttr_;
size_t VisitAttr_(const tvm::tir::VarNode* var) final {
size_t hash = std::hash<std::string>()(VarNode::_type_key);
auto it = hash_map_.find(GetRef<tvm::tir::Var>(var));
if (it != hash_map_.end()) {
return it->second;
}
return Combine(hash, std::hash<std::string>()(var->name_hint));
}
// Type hashing
size_t VisitType_(const TensorTypeNode* tensor_type) final {
size_t hash = std::hash<std::string>()(TensorTypeNode::_type_key);
hash = Combine(hash, DataTypeHash(tensor_type->dtype));
hash = Combine(hash, Hash(tensor_type->shape));
return hash;
}
size_t VisitType_(const IncompleteTypeNode* incomplete) final {
size_t hash = std::hash<std::string>()(IncompleteTypeNode::_type_key);
return Combine(hash, std::hash<int>()(incomplete->kind));
}
size_t VisitType_(const TypeVarNode* tyvar) final {
/*
TypeVar/Var/Variable have two locations where they are hashed:
The declaration site of a function, let, or function type.
The first occurence in the term.
We will only reach this code if the TypeVar itself is unbound, we assign
a free variable index to it, meaning this hashing function implements
structural equality for both open (i.e graph equality) and closed terms
(i.e alpha_equality).
*/
return BindVar(GetRef<TypeVar>(tyvar));
}
size_t VisitType_(const FuncTypeNode* func_type) final {
size_t hash = std::hash<std::string>()(FuncTypeNode::_type_key);
for (auto type_param : func_type->type_params) {
hash = Combine(hash, BindVar(type_param));
}
for (auto arg : func_type->arg_types) {
hash = Combine(hash, TypeHash(arg));
}
hash = Combine(hash, TypeHash(func_type->ret_type));
for (auto cs : func_type->type_constraints) {
hash = Combine(hash, TypeHash(cs));
}
return hash;
}
size_t VisitType_(const TypeRelationNode* type_rel) final {
size_t hash = std::hash<std::string>()(TypeRelationNode::_type_key);
hash = Combine(hash, std::hash<std::string>()(type_rel->func->name));
hash = Combine(hash, AttrHash(type_rel->attrs));
for (auto arg : type_rel->args) {
hash = Combine(hash, TypeHash(arg));
}
return hash;
}
size_t VisitType_(const TupleTypeNode* tuple_type) final {
size_t hash = std::hash<std::string>()(TupleTypeNode::_type_key);
for (size_t i = 0; i < tuple_type->fields.size(); i++) {
hash = Combine(hash, TypeHash(tuple_type->fields[i]));
}
return hash;
}
size_t VisitType_(const RelayRefTypeNode* rtn) final {
size_t hash = std::hash<std::string>()(RelayRefTypeNode::_type_key);
hash = Combine(hash, TypeHash(rtn->value));
return hash;
}
// Expr hashing.
size_t NDArrayHash(const runtime::NDArray& array) {
size_t hash = std::hash<uint8_t>()(array->dtype.code);
hash = Combine(hash, std::hash<uint8_t>()(array->dtype.bits));
hash = Combine(hash, std::hash<uint16_t>()(array->dtype.lanes));
CHECK_EQ(array->ctx.device_type, kDLCPU) << "can only compare CPU tensor";
size_t data_size = runtime::GetDataSize(*array.operator->());
uint8_t * data = reinterpret_cast<uint8_t*>(array->data);
for (size_t i = 0; i < data_size; i++) {
hash = Combine(hash, std::hash<uint8_t>()(data[i]));
}
return hash;
}
size_t BindVar(const ObjectRef& var) {
size_t hash = std::hash<int>()(var_counter++);
CHECK_EQ(hash_map_.count(var), 0);
if (auto var_node = var.as<VarNode>()) {
hash = Combine(hash, TypeHash(var_node->type_annotation));
}
hash_map_[var] = hash;
return hash;
}
size_t VisitExpr_(const VarNode* var) final {
// hash free variable
size_t name_hash = std::hash<const Object*>()(var->vid.get());
return Combine(name_hash, TypeHash(var->type_annotation));
}
size_t VisitExpr_(const GlobalVarNode* global) final {
return std::hash<std::string>()(global->name_hint);
}
size_t VisitExpr_(const TupleNode* tuple) final {
size_t hash = std::hash<std::string>()(TupleNode::_type_key);
for (size_t i = 0; i < tuple->fields.size(); i++) {
hash = Combine(hash, ExprHash(tuple->fields[i]));
}
return hash;
}
size_t VisitExpr_(const FunctionNode* func) final {
size_t hash = std::hash<std::string>()(FunctionNode::_type_key);
for (auto type_param : func->type_params) {
hash = Combine(hash, BindVar(type_param));
}
for (auto param : func->params) {
hash = Combine(hash, BindVar(param));
}
hash = Combine(hash, TypeHash(func->ret_type));
hash = Combine(hash, ExprHash(func->body));
hash = Combine(hash, AttrHash(func->attrs));
return hash;
}
size_t VisitExpr_(const CallNode* call) final {
size_t hash = std::hash<std::string>()(CallNode::_type_key);
hash = Combine(hash, ExprHash(call->op));
for (auto arg : call->args) {
hash = Combine(hash, ExprHash(arg));
}
for (auto t : call->type_args) {
CHECK(t.defined());
hash = Combine(hash, TypeHash(t));
}
hash = Combine(hash, AttrHash(call->attrs));
return hash;
}
size_t VisitExpr_(const LetNode* let) final {
size_t hash = std::hash<std::string>()(LetNode::_type_key);
hash = Combine(hash, BindVar(let->var));
hash = Combine(hash, ExprHash(let->value));
hash = Combine(hash, ExprHash(let->body));
return hash;
}
size_t VisitExpr_(const IfNode* ite) final {
size_t key = std::hash<std::string>()(IfNode::_type_key);
size_t hash = key;
hash = Combine(hash, ExprHash(ite->cond));
hash = Combine(hash, ExprHash(ite->true_branch));
hash = Combine(hash, ExprHash(ite->false_branch));
return hash;
}
size_t VisitExpr_(const OpNode* op) final {
return ObjectHash()(GetRef<Op>(op));
}
size_t VisitExpr_(const ConstantNode* rconst) final {
return NDArrayHash(rconst->data);
}
size_t VisitExpr_(const TupleGetItemNode* get_item) final {
size_t hash = std::hash<std::string>()(TupleGetItemNode::_type_key);
hash = Combine(hash, ExprHash(get_item->tuple));
hash = Combine(hash, std::hash<int>()(get_item->index));
return hash;
}
size_t VisitExpr_(const RefCreateNode* rn) final {
size_t hash = std::hash<std::string>()(RefCreateNode::_type_key);
hash = Combine(hash, ExprHash(rn->value));
return hash;
}
size_t VisitExpr_(const RefReadNode* rn) final {
size_t hash = std::hash<std::string>()(RefReadNode::_type_key);
hash = Combine(hash, ExprHash(rn->ref));
return hash;
}
size_t VisitExpr_(const RefWriteNode* rn) final {
size_t hash = std::hash<std::string>()(RefWriteNode::_type_key);
hash = Combine(hash, ExprHash(rn->ref));
hash = Combine(hash, ExprHash(rn->value));
return hash;
}
size_t VisitExpr_(const MatchNode* mn) final {
size_t hash = std::hash<std::string>()(MatchNode::_type_key);
hash = Combine(hash, ExprHash(mn->data));
for (const auto& c : mn->clauses) {
hash = Combine(hash, PatternHash(c->lhs));
hash = Combine(hash, ExprHash(c->rhs));
}
hash = Combine(hash, std::hash<bool>()(mn->complete));
return hash;
}
size_t VisitExpr_(const ConstructorNode* cn) final {
size_t hash = std::hash<std::string>()(ConstructorNode::_type_key);
hash = Combine(hash, std::hash<std::string>()(cn->name_hint));
return hash;
}
size_t VisitType_(const TypeCallNode* tcn) final {
size_t hash = std::hash<std::string>()(TypeCallNode::_type_key);
hash = Combine(hash, TypeHash(tcn->func));
for (const auto& t : tcn->args) {
hash = Combine(hash, TypeHash(t));
}
return hash;
}
size_t VisitType_(const TypeDataNode* tdn) final {
size_t hash = std::hash<std::string>()(TypeDataNode::_type_key);
hash = Combine(hash, TypeHash(tdn->header));
for (const auto& tv : tdn->type_vars) {
hash = Combine(hash, TypeHash(tv));
}
for (const auto& cn : tdn->constructors) {
hash = Combine(hash, ExprHash(cn));
}
return hash;
}
size_t VisitType_(const GlobalTypeVarNode* tvn) final {
return BindVar(GetRef<GlobalTypeVar>(tvn));
}
size_t PatternHash(const Pattern& p) {
return VisitPattern(p);
}
size_t VisitPattern_(const PatternConstructorNode* pcn) final {
size_t hash = std::hash<std::string>()(PatternConstructorNode::_type_key);
hash = Combine(hash, ExprHash(pcn->constructor));
for (const auto& p : pcn->patterns) {
hash = Combine(hash, PatternHash(p));
}
return hash;
}
size_t VisitPattern_(const PatternTupleNode* ptn) final {
size_t hash = std::hash<std::string>()(PatternTupleNode::_type_key);
for (const auto& p : ptn->patterns) {
hash = Combine(hash, PatternHash(p));
}
return hash;
}
size_t VisitPattern_(const PatternVarNode* pvn) final {
size_t hash = std::hash<std::string>()(PatternVarNode::_type_key);
hash = Combine(hash, BindVar(pvn->var));
return hash;
}
size_t VisitPattern_(const PatternWildcardNode* pwn) final {
size_t hash = std::hash<std::string>()(PatternWildcardNode::_type_key);
return hash;
}
private:
// renaming of NodeRef to indicate two nodes equals to each other
std::unordered_map<ObjectRef, size_t, ObjectHash, ObjectEqual> hash_map_;
int var_counter = 0;
};
size_t StructuralHash::operator()(const Type& type) const {
return RelayHashHandler().TypeHash(type);
}
size_t StructuralHash::operator()(const Expr& expr) const {
return RelayHashHandler().ExprHash(expr);
}
TVM_REGISTER_GLOBAL("relay.analysis._expr_hash")
.set_body_typed([](ObjectRef ref) {
return static_cast<int64_t>(RelayHashHandler().Hash(ref));
});
TVM_REGISTER_GLOBAL("relay.analysis._type_hash")
.set_body_typed([](Type type) {
return static_cast<int64_t>(RelayHashHandler().TypeHash(type));
});
} // namespace relay
} // namespace tvm
...@@ -31,7 +31,8 @@ def alpha_equal(x, y): ...@@ -31,7 +31,8 @@ def alpha_equal(x, y):
""" """
x = x['main'] x = x['main']
y = y['main'] y = y['main']
return tvm.ir.structural_equal(x, y) and analysis.structural_hash(x) == analysis.structural_hash(y) return tvm.ir.structural_equal(x, y) and \
tvm.ir.structural_hash(x) == tvm.ir.structural_hash(y)
def run_opt_pass(expr, passes): def run_opt_pass(expr, passes):
passes = passes if isinstance(passes, list) else [passes] passes = passes if isinstance(passes, list) else [passes]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment