Commit 0f7aa30b by Jared Roesch Committed by Tianqi Chen

[RELAY] Add structural hashing for Relay (#1977)

parent fc0149d5
......@@ -136,6 +136,27 @@ tvm::Array<TypeVar> FreeTypeVars(const Expr& expr);
*/
Expr DeadCodeElimination(const Expr& e);
/*! \brief Hash a Relay type.
*
* Implements structural hashing of a Relay type.
*
* \param type the type to hash.
*
* \return the hash value.
*/
size_t StructuralHash(const Type& type);
/*! \brief Hash a Relay expression.
*
* Implements structural hashing of a Relay expression.
*
* \param expr the expression to hash.
*
* \return the hash value.
*/
size_t StructuralHash(const Expr& expr);
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PASS_H_
# pylint: disable=no-else-return,
# pylint: disable=no-else-return
# pylint: disable=unidiomatic-typecheck
"""The set of passes for Relay.
......@@ -7,7 +7,8 @@ scripting them in Python.
"""
from . import _ir_pass
from . import _make
# pylint: disable=invalid-name
from .expr import Expr
from .ty import Type
def infer_type(expr, env=None):
"""Infer the type of expr under the context of env.
......@@ -148,7 +149,6 @@ def alpha_equal(lhs, rhs):
"""
return bool(_make._alpha_equal(lhs, rhs))
def graph_equal(lhs, rhs):
"""Compare two Relay expr for data-flow equivalence.
The difference between this and alpha-equality is that
......@@ -169,3 +169,25 @@ def graph_equal(lhs, rhs):
True iff lhs is data-flow equivalent to rhs.
"""
return bool(_make._graph_equal(lhs, rhs))
def structural_hash(value):
"""Hash a Relay expression structurally.
Parameters
----------
expr: tvm.relay.Expr or tvm.relay.Type
The expression to hash.
Returns
-------
result: int
The hash value
"""
if isinstance(value, Expr):
return int(_ir_pass._expr_hash(value))
elif isinstance(value, Type):
return int(_ir_pass._type_hash(value))
else:
msg = ("found value of type {0} expected" +
"relay.Expr or relay.Type").format(type(value))
raise TypeError(msg)
/*!
* Copyright (c) 2018 by Contributors
* \file src/tvm/relay/ir/hash.cc
* \brief Hash functions for Relay types and expressions.
*/
#include <tvm/ir_pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/relay/pass.h>
#include <tvm/attrs.h>
#include "type_functor.h"
#include "../../lang/attr_functor.h"
namespace tvm {
namespace relay {
// Hash handler for Relay.
class RelayHashHandler:
public AttrsHashHandler,
public TypeFunctor<size_t(const Type&)>,
public ExprFunctor<size_t(const Expr&)> {
public:
explicit RelayHashHandler() {}
/*!
* Compute hash of a node.
* \param ref The node to hash.
* \return the hash value.
*/
size_t Hash(const NodeRef& ref) {
if (!ref.defined()) return ref.hash();
if (ref->derived_from<TypeNode>()) {
return TypeHash(Downcast<Type>(ref));
}
if (ref->derived_from<ExprNode>()) {
return ExprHash(Downcast<Expr>(ref));
}
return AttrHash(ref);
}
/*!
* Compute hash of the attributes.
* \param ref The attributes.
* \return the hash value
*/
size_t AttrHash(const NodeRef& ref) {
if (!ref.defined()) { return ref.hash(); }
return AttrsHashHandler::Hash(ref);
}
/*!
* Compute hash of a Relay type.
* \param ref The type to hash.
* \param rhs The right hand operand.
* \return the hash value.
*/
size_t TypeHash(const Type& type) {
if (!type.defined()) { return type.hash(); }
auto found = hash_map_.find(type);
if (found != hash_map_.end()) {
return found->second;
} else {
auto hash = this->VisitType(type);
hash_map_.insert({type, hash});
return hash;
}
}
/*!
* Compute the hash of an expression.
*
* \note We run graph structural equality checking when comparing two Exprs.
* This means that AlphaEqualHandler can only be used once for each pair.
* The equality checker checks data-flow equvalence of the Expr DAG.
* This function also runs faster as it memomizes equal_map.
*
* \param expr The expression to hash.
* \return the hash value.
*/
size_t ExprHash(const Expr& expr) {
if (!expr.defined()) return expr.hash();
auto found = hash_map_.find(expr);
if (found != hash_map_.end()) {
return found->second;
} else {
auto hash = this->VisitExpr(expr);
hash_map_.insert({expr, hash});
return hash;
}
}
protected:
/*!
* \brief Hash a DataType.
* \param dtype The dtype to hash.
* \return the hash value.
*/
size_t DataTypeHash(const DataType& dtype) {
return ::tvm::AttrsHash()(dtype);
}
using AttrsHashHandler::VisitAttr_;
size_t VisitAttr_(const Variable* var) final {
auto it = hash_map_.find(GetRef<VarExpr>(var));
if (it != hash_map_.end()) {
return it->second;
}
size_t hash = std::hash<std::string>()(var->_type_key);
return Combine(hash, std::hash<std::string>()(var->name_hint));
}
// Type hashing
size_t VisitType_(const TensorTypeNode* tensor_type) final {
size_t hash = std::hash<std::string>()(tensor_type->_type_key);
hash = Combine(hash, DataTypeHash(tensor_type->dtype));
hash = Combine(hash, Hash(tensor_type->shape));
return hash;
}
size_t VisitType_(const IncompleteTypeNode* incomplete) final {
size_t hash = std::hash<std::string>()(incomplete->_type_key);
return Combine(hash, std::hash<int>()(incomplete->kind));
}
size_t VisitType_(const TypeVarNode* tyvar) final {
/*
TypeVar/Var/Variable have two locations where they are hashed:
The declaration site of a function, let, or function type.
The first occurence in the term.
We will only reach this code if the TypeVar itself is unbound, we assign
a free variable index to it, meaning this hashing function implements
structural equality for both open (i.e graph equality) and closed terms
(i.e alpha_equality).
*/
return BindVar(GetRef<TypeVar>(tyvar));
}
size_t VisitType_(const FuncTypeNode* func_type) final {
size_t hash = std::hash<std::string>()(func_type->_type_key);
for (auto type_param : func_type->type_params) {
hash = Combine(hash, BindVar(type_param));
}
for (auto arg : func_type->arg_types) {
hash = Combine(hash, TypeHash(arg));
}
hash = Combine(hash, TypeHash(func_type->ret_type));
for (auto cs : func_type->type_constraints) {
hash = Combine(hash, TypeHash(cs));
}
return hash;
}
size_t VisitType_(const TypeRelationNode* type_rel) final {
size_t hash = std::hash<std::string>()(type_rel->_type_key);
hash = Combine(hash, std::hash<std::string>()(type_rel->func->name));
hash = Combine(hash, AttrHash(type_rel->attrs));
for (auto arg : type_rel->args) {
hash = Combine(hash, TypeHash(arg));
}
return hash;
}
size_t VisitType_(const TupleTypeNode* tuple_type) final {
size_t hash = std::hash<std::string>()(tuple_type->_type_key);
for (size_t i = 0; i < tuple_type->fields.size(); i++) {
hash = Combine(hash, TypeHash(tuple_type->fields[i]));
}
return hash;
}
// Expr hashing.
size_t NDArrayHash(const runtime::NDArray& array) {
size_t hash = std::hash<uint8_t>()(array->dtype.code);
hash = Combine(hash, std::hash<uint8_t>()(array->dtype.bits));
hash = Combine(hash, std::hash<uint16_t>()(array->dtype.lanes));
CHECK_EQ(array->ctx.device_type, kDLCPU) << "can only compare CPU tensor";
size_t data_size = runtime::GetDataSize(*array.operator->());
uint8_t * data = reinterpret_cast<uint8_t*>(array->data);
for (size_t i = 0; i < data_size; i++) {
hash = Combine(hash, std::hash<uint8_t>()(data[i]));
}
return hash;
}
size_t BindVar(const NodeRef& var) {
size_t hash = std::hash<int>()(var_counter++);
CHECK_EQ(hash_map_.count(var), 0);
hash_map_[var] = hash;
const auto* ty_param = var.as<TypeVarNode>();
if (ty_param && ty_param->kind == TypeVarNode::Kind::kShapeVar) {
hash_map_[ty_param->var] = hash;
}
return hash;
}
size_t VisitExpr_(const VarNode* var) final {
size_t name_hash = std::hash<std::string>()(var->name_hint);
return Combine(name_hash, TypeHash(var->type_annotation));
}
size_t VisitExpr_(const GlobalVarNode* global) final {
return std::hash<std::string>()(global->name_hint);
}
size_t VisitExpr_(const TupleNode* tuple) final {
size_t hash = std::hash<std::string>()(tuple->_type_key);
for (size_t i = 0; i < tuple->fields.size(); i++) {
hash = Combine(hash, ExprHash(tuple->fields[i]));
}
return hash;
}
size_t VisitExpr_(const FunctionNode* func) final {
size_t hash = std::hash<std::string>()(func->_type_key);
for (auto type_param : func->type_params) {
hash = Combine(hash, BindVar(type_param));
}
for (auto param : func->params) {
hash = Combine(hash, BindVar(param));
}
hash = Combine(hash, TypeHash(func->ret_type));
hash = Combine(hash, ExprHash(func->body));
return hash;
}
size_t VisitExpr_(const CallNode* call) final {
size_t hash = std::hash<std::string>()(call->_type_key);
hash = Combine(hash, ExprHash(call->op));
for (auto arg : call->args) {
hash = Combine(hash, ExprHash(arg));
}
hash = Combine(hash, AttrHash(call->attrs));
return hash;
}
size_t VisitExpr_(const LetNode* let) final {
size_t hash = std::hash<std::string>()(let->_type_key);
hash = Combine(hash, BindVar(let->var));
hash = Combine(hash, ExprHash(let->value));
hash = Combine(hash, ExprHash(let->body));
return hash;
}
size_t VisitExpr_(const IfNode* ite) final {
size_t hash = std::hash<std::string>()(ite->_type_key);
hash = Combine(hash, ExprHash(ite->cond));
hash = Combine(hash, ExprHash(ite->true_branch));
hash = Combine(hash, ExprHash(ite->false_branch));
return hash;
}
size_t VisitExpr_(const OpNode* op) final {
return GetRef<Op>(op).hash();
}
size_t VisitExpr_(const ConstantNode* rconst) final {
return NDArrayHash(rconst->data);
}
size_t VisitExpr_(const TupleGetItemNode* get_item) final {
size_t hash = std::hash<std::string>()(get_item->_type_key);
hash = Combine(hash, ExprHash(get_item->tuple));
hash = Combine(hash, std::hash<int>()(get_item->index));
return hash;
}
private:
// renaming of NodeRef to indicate two nodes equals to each other
std::unordered_map<NodeRef, size_t, NodeHash, NodeEqual> hash_map_;
int var_counter = 0;
};
size_t StructuralHash(const Type& type) {
return RelayHashHandler().TypeHash(type);
}
size_t StructuralHash(const Expr& expr) {
return RelayHashHandler().ExprHash(expr);
}
TVM_REGISTER_API("relay._ir_pass._expr_hash")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = static_cast<int64_t>(RelayHashHandler().Hash(args[0]));
});
TVM_REGISTER_API("relay._ir_pass._type_hash")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = static_cast<int64_t>(RelayHashHandler().TypeHash(args[0]));
});
} // namespace relay
} // namespace tvm
import tvm
import numpy as np
from tvm import relay
from tvm.relay.ir_pass import alpha_equal
from tvm.relay import ir_pass
def alpha_equal(x, y):
"""
Wrapper around alpha equality which ensures that
the hash function respects equality.
"""
return ir_pass.alpha_equal(x, y) and ir_pass.structural_hash(x) == ir_pass.structural_hash(y)
def test_tensor_type_alpha_equal():
t1 = relay.TensorType((3, 4), "float32")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment