Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
0f7aa30b
Commit
0f7aa30b
authored
Oct 25, 2018
by
Jared Roesch
Committed by
Tianqi Chen
Oct 25, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RELAY] Add structural hashing for Relay (#1977)
parent
fc0149d5
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
362 additions
and
4 deletions
+362
-4
include/tvm/relay/pass.h
+21
-0
python/tvm/relay/ir_pass.py
+25
-3
src/relay/ir/hash.cc
+308
-0
tests/python/relay/test_pass_alpha_equal.py
+8
-1
No files found.
include/tvm/relay/pass.h
View file @
0f7aa30b
...
@@ -136,6 +136,27 @@ tvm::Array<TypeVar> FreeTypeVars(const Expr& expr);
...
@@ -136,6 +136,27 @@ tvm::Array<TypeVar> FreeTypeVars(const Expr& expr);
*/
*/
Expr
DeadCodeElimination
(
const
Expr
&
e
);
Expr
DeadCodeElimination
(
const
Expr
&
e
);
/*! \brief Hash a Relay type.
*
* Implements structural hashing of a Relay type.
*
* \param type the type to hash.
*
* \return the hash value.
*/
size_t
StructuralHash
(
const
Type
&
type
);
/*! \brief Hash a Relay expression.
*
* Implements structural hashing of a Relay expression.
*
* \param expr the expression to hash.
*
* \return the hash value.
*/
size_t
StructuralHash
(
const
Expr
&
expr
);
}
// namespace relay
}
// namespace relay
}
// namespace tvm
}
// namespace tvm
#endif // TVM_RELAY_PASS_H_
#endif // TVM_RELAY_PASS_H_
python/tvm/relay/ir_pass.py
View file @
0f7aa30b
# pylint: disable=no-else-return
,
# pylint: disable=no-else-return
# pylint: disable=unidiomatic-typecheck
# pylint: disable=unidiomatic-typecheck
"""The set of passes for Relay.
"""The set of passes for Relay.
...
@@ -7,7 +7,8 @@ scripting them in Python.
...
@@ -7,7 +7,8 @@ scripting them in Python.
"""
"""
from
.
import
_ir_pass
from
.
import
_ir_pass
from
.
import
_make
from
.
import
_make
# pylint: disable=invalid-name
from
.expr
import
Expr
from
.ty
import
Type
def
infer_type
(
expr
,
env
=
None
):
def
infer_type
(
expr
,
env
=
None
):
"""Infer the type of expr under the context of env.
"""Infer the type of expr under the context of env.
...
@@ -148,7 +149,6 @@ def alpha_equal(lhs, rhs):
...
@@ -148,7 +149,6 @@ def alpha_equal(lhs, rhs):
"""
"""
return
bool
(
_make
.
_alpha_equal
(
lhs
,
rhs
))
return
bool
(
_make
.
_alpha_equal
(
lhs
,
rhs
))
def
graph_equal
(
lhs
,
rhs
):
def
graph_equal
(
lhs
,
rhs
):
"""Compare two Relay expr for data-flow equivalence.
"""Compare two Relay expr for data-flow equivalence.
The difference between this and alpha-equality is that
The difference between this and alpha-equality is that
...
@@ -169,3 +169,25 @@ def graph_equal(lhs, rhs):
...
@@ -169,3 +169,25 @@ def graph_equal(lhs, rhs):
True iff lhs is data-flow equivalent to rhs.
True iff lhs is data-flow equivalent to rhs.
"""
"""
return
bool
(
_make
.
_graph_equal
(
lhs
,
rhs
))
return
bool
(
_make
.
_graph_equal
(
lhs
,
rhs
))
def
structural_hash
(
value
):
"""Hash a Relay expression structurally.
Parameters
----------
expr: tvm.relay.Expr or tvm.relay.Type
The expression to hash.
Returns
-------
result: int
The hash value
"""
if
isinstance
(
value
,
Expr
):
return
int
(
_ir_pass
.
_expr_hash
(
value
))
elif
isinstance
(
value
,
Type
):
return
int
(
_ir_pass
.
_type_hash
(
value
))
else
:
msg
=
(
"found value of type {0} expected"
+
"relay.Expr or relay.Type"
)
.
format
(
type
(
value
))
raise
TypeError
(
msg
)
src/relay/ir/hash.cc
0 → 100644
View file @
0f7aa30b
/*!
* Copyright (c) 2018 by Contributors
* \file src/tvm/relay/ir/hash.cc
* \brief Hash functions for Relay types and expressions.
*/
#include <tvm/ir_pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/relay/pass.h>
#include <tvm/attrs.h>
#include "type_functor.h"
#include "../../lang/attr_functor.h"
namespace
tvm
{
namespace
relay
{
// Hash handler for Relay.
class
RelayHashHandler
:
public
AttrsHashHandler
,
public
TypeFunctor
<
size_t
(
const
Type
&
)
>
,
public
ExprFunctor
<
size_t
(
const
Expr
&
)
>
{
public
:
explicit
RelayHashHandler
()
{}
/*!
* Compute hash of a node.
* \param ref The node to hash.
* \return the hash value.
*/
size_t
Hash
(
const
NodeRef
&
ref
)
{
if
(
!
ref
.
defined
())
return
ref
.
hash
();
if
(
ref
->
derived_from
<
TypeNode
>
())
{
return
TypeHash
(
Downcast
<
Type
>
(
ref
));
}
if
(
ref
->
derived_from
<
ExprNode
>
())
{
return
ExprHash
(
Downcast
<
Expr
>
(
ref
));
}
return
AttrHash
(
ref
);
}
/*!
* Compute hash of the attributes.
* \param ref The attributes.
* \return the hash value
*/
size_t
AttrHash
(
const
NodeRef
&
ref
)
{
if
(
!
ref
.
defined
())
{
return
ref
.
hash
();
}
return
AttrsHashHandler
::
Hash
(
ref
);
}
/*!
* Compute hash of a Relay type.
* \param ref The type to hash.
* \param rhs The right hand operand.
* \return the hash value.
*/
size_t
TypeHash
(
const
Type
&
type
)
{
if
(
!
type
.
defined
())
{
return
type
.
hash
();
}
auto
found
=
hash_map_
.
find
(
type
);
if
(
found
!=
hash_map_
.
end
())
{
return
found
->
second
;
}
else
{
auto
hash
=
this
->
VisitType
(
type
);
hash_map_
.
insert
({
type
,
hash
});
return
hash
;
}
}
/*!
* Compute the hash of an expression.
*
* \note We run graph structural equality checking when comparing two Exprs.
* This means that AlphaEqualHandler can only be used once for each pair.
* The equality checker checks data-flow equvalence of the Expr DAG.
* This function also runs faster as it memomizes equal_map.
*
* \param expr The expression to hash.
* \return the hash value.
*/
size_t
ExprHash
(
const
Expr
&
expr
)
{
if
(
!
expr
.
defined
())
return
expr
.
hash
();
auto
found
=
hash_map_
.
find
(
expr
);
if
(
found
!=
hash_map_
.
end
())
{
return
found
->
second
;
}
else
{
auto
hash
=
this
->
VisitExpr
(
expr
);
hash_map_
.
insert
({
expr
,
hash
});
return
hash
;
}
}
protected
:
/*!
* \brief Hash a DataType.
* \param dtype The dtype to hash.
* \return the hash value.
*/
size_t
DataTypeHash
(
const
DataType
&
dtype
)
{
return
::
tvm
::
AttrsHash
()(
dtype
);
}
using
AttrsHashHandler
::
VisitAttr_
;
size_t
VisitAttr_
(
const
Variable
*
var
)
final
{
auto
it
=
hash_map_
.
find
(
GetRef
<
VarExpr
>
(
var
));
if
(
it
!=
hash_map_
.
end
())
{
return
it
->
second
;
}
size_t
hash
=
std
::
hash
<
std
::
string
>
()(
var
->
_type_key
);
return
Combine
(
hash
,
std
::
hash
<
std
::
string
>
()(
var
->
name_hint
));
}
// Type hashing
size_t
VisitType_
(
const
TensorTypeNode
*
tensor_type
)
final
{
size_t
hash
=
std
::
hash
<
std
::
string
>
()(
tensor_type
->
_type_key
);
hash
=
Combine
(
hash
,
DataTypeHash
(
tensor_type
->
dtype
));
hash
=
Combine
(
hash
,
Hash
(
tensor_type
->
shape
));
return
hash
;
}
size_t
VisitType_
(
const
IncompleteTypeNode
*
incomplete
)
final
{
size_t
hash
=
std
::
hash
<
std
::
string
>
()(
incomplete
->
_type_key
);
return
Combine
(
hash
,
std
::
hash
<
int
>
()(
incomplete
->
kind
));
}
size_t
VisitType_
(
const
TypeVarNode
*
tyvar
)
final
{
/*
TypeVar/Var/Variable have two locations where they are hashed:
The declaration site of a function, let, or function type.
The first occurence in the term.
We will only reach this code if the TypeVar itself is unbound, we assign
a free variable index to it, meaning this hashing function implements
structural equality for both open (i.e graph equality) and closed terms
(i.e alpha_equality).
*/
return
BindVar
(
GetRef
<
TypeVar
>
(
tyvar
));
}
size_t
VisitType_
(
const
FuncTypeNode
*
func_type
)
final
{
size_t
hash
=
std
::
hash
<
std
::
string
>
()(
func_type
->
_type_key
);
for
(
auto
type_param
:
func_type
->
type_params
)
{
hash
=
Combine
(
hash
,
BindVar
(
type_param
));
}
for
(
auto
arg
:
func_type
->
arg_types
)
{
hash
=
Combine
(
hash
,
TypeHash
(
arg
));
}
hash
=
Combine
(
hash
,
TypeHash
(
func_type
->
ret_type
));
for
(
auto
cs
:
func_type
->
type_constraints
)
{
hash
=
Combine
(
hash
,
TypeHash
(
cs
));
}
return
hash
;
}
size_t
VisitType_
(
const
TypeRelationNode
*
type_rel
)
final
{
size_t
hash
=
std
::
hash
<
std
::
string
>
()(
type_rel
->
_type_key
);
hash
=
Combine
(
hash
,
std
::
hash
<
std
::
string
>
()(
type_rel
->
func
->
name
));
hash
=
Combine
(
hash
,
AttrHash
(
type_rel
->
attrs
));
for
(
auto
arg
:
type_rel
->
args
)
{
hash
=
Combine
(
hash
,
TypeHash
(
arg
));
}
return
hash
;
}
size_t
VisitType_
(
const
TupleTypeNode
*
tuple_type
)
final
{
size_t
hash
=
std
::
hash
<
std
::
string
>
()(
tuple_type
->
_type_key
);
for
(
size_t
i
=
0
;
i
<
tuple_type
->
fields
.
size
();
i
++
)
{
hash
=
Combine
(
hash
,
TypeHash
(
tuple_type
->
fields
[
i
]));
}
return
hash
;
}
// Expr hashing.
size_t
NDArrayHash
(
const
runtime
::
NDArray
&
array
)
{
size_t
hash
=
std
::
hash
<
uint8_t
>
()(
array
->
dtype
.
code
);
hash
=
Combine
(
hash
,
std
::
hash
<
uint8_t
>
()(
array
->
dtype
.
bits
));
hash
=
Combine
(
hash
,
std
::
hash
<
uint16_t
>
()(
array
->
dtype
.
lanes
));
CHECK_EQ
(
array
->
ctx
.
device_type
,
kDLCPU
)
<<
"can only compare CPU tensor"
;
size_t
data_size
=
runtime
::
GetDataSize
(
*
array
.
operator
->
());
uint8_t
*
data
=
reinterpret_cast
<
uint8_t
*>
(
array
->
data
);
for
(
size_t
i
=
0
;
i
<
data_size
;
i
++
)
{
hash
=
Combine
(
hash
,
std
::
hash
<
uint8_t
>
()(
data
[
i
]));
}
return
hash
;
}
size_t
BindVar
(
const
NodeRef
&
var
)
{
size_t
hash
=
std
::
hash
<
int
>
()(
var_counter
++
);
CHECK_EQ
(
hash_map_
.
count
(
var
),
0
);
hash_map_
[
var
]
=
hash
;
const
auto
*
ty_param
=
var
.
as
<
TypeVarNode
>
();
if
(
ty_param
&&
ty_param
->
kind
==
TypeVarNode
::
Kind
::
kShapeVar
)
{
hash_map_
[
ty_param
->
var
]
=
hash
;
}
return
hash
;
}
size_t
VisitExpr_
(
const
VarNode
*
var
)
final
{
size_t
name_hash
=
std
::
hash
<
std
::
string
>
()(
var
->
name_hint
);
return
Combine
(
name_hash
,
TypeHash
(
var
->
type_annotation
));
}
size_t
VisitExpr_
(
const
GlobalVarNode
*
global
)
final
{
return
std
::
hash
<
std
::
string
>
()(
global
->
name_hint
);
}
size_t
VisitExpr_
(
const
TupleNode
*
tuple
)
final
{
size_t
hash
=
std
::
hash
<
std
::
string
>
()(
tuple
->
_type_key
);
for
(
size_t
i
=
0
;
i
<
tuple
->
fields
.
size
();
i
++
)
{
hash
=
Combine
(
hash
,
ExprHash
(
tuple
->
fields
[
i
]));
}
return
hash
;
}
size_t
VisitExpr_
(
const
FunctionNode
*
func
)
final
{
size_t
hash
=
std
::
hash
<
std
::
string
>
()(
func
->
_type_key
);
for
(
auto
type_param
:
func
->
type_params
)
{
hash
=
Combine
(
hash
,
BindVar
(
type_param
));
}
for
(
auto
param
:
func
->
params
)
{
hash
=
Combine
(
hash
,
BindVar
(
param
));
}
hash
=
Combine
(
hash
,
TypeHash
(
func
->
ret_type
));
hash
=
Combine
(
hash
,
ExprHash
(
func
->
body
));
return
hash
;
}
size_t
VisitExpr_
(
const
CallNode
*
call
)
final
{
size_t
hash
=
std
::
hash
<
std
::
string
>
()(
call
->
_type_key
);
hash
=
Combine
(
hash
,
ExprHash
(
call
->
op
));
for
(
auto
arg
:
call
->
args
)
{
hash
=
Combine
(
hash
,
ExprHash
(
arg
));
}
hash
=
Combine
(
hash
,
AttrHash
(
call
->
attrs
));
return
hash
;
}
size_t
VisitExpr_
(
const
LetNode
*
let
)
final
{
size_t
hash
=
std
::
hash
<
std
::
string
>
()(
let
->
_type_key
);
hash
=
Combine
(
hash
,
BindVar
(
let
->
var
));
hash
=
Combine
(
hash
,
ExprHash
(
let
->
value
));
hash
=
Combine
(
hash
,
ExprHash
(
let
->
body
));
return
hash
;
}
size_t
VisitExpr_
(
const
IfNode
*
ite
)
final
{
size_t
hash
=
std
::
hash
<
std
::
string
>
()(
ite
->
_type_key
);
hash
=
Combine
(
hash
,
ExprHash
(
ite
->
cond
));
hash
=
Combine
(
hash
,
ExprHash
(
ite
->
true_branch
));
hash
=
Combine
(
hash
,
ExprHash
(
ite
->
false_branch
));
return
hash
;
}
size_t
VisitExpr_
(
const
OpNode
*
op
)
final
{
return
GetRef
<
Op
>
(
op
).
hash
();
}
size_t
VisitExpr_
(
const
ConstantNode
*
rconst
)
final
{
return
NDArrayHash
(
rconst
->
data
);
}
size_t
VisitExpr_
(
const
TupleGetItemNode
*
get_item
)
final
{
size_t
hash
=
std
::
hash
<
std
::
string
>
()(
get_item
->
_type_key
);
hash
=
Combine
(
hash
,
ExprHash
(
get_item
->
tuple
));
hash
=
Combine
(
hash
,
std
::
hash
<
int
>
()(
get_item
->
index
));
return
hash
;
}
private
:
// renaming of NodeRef to indicate two nodes equals to each other
std
::
unordered_map
<
NodeRef
,
size_t
,
NodeHash
,
NodeEqual
>
hash_map_
;
int
var_counter
=
0
;
};
size_t
StructuralHash
(
const
Type
&
type
)
{
return
RelayHashHandler
().
TypeHash
(
type
);
}
size_t
StructuralHash
(
const
Expr
&
expr
)
{
return
RelayHashHandler
().
ExprHash
(
expr
);
}
TVM_REGISTER_API
(
"relay._ir_pass._expr_hash"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
static_cast
<
int64_t
>
(
RelayHashHandler
().
Hash
(
args
[
0
]));
});
TVM_REGISTER_API
(
"relay._ir_pass._type_hash"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
static_cast
<
int64_t
>
(
RelayHashHandler
().
TypeHash
(
args
[
0
]));
});
}
// namespace relay
}
// namespace tvm
tests/python/relay/test_pass_alpha_equal.py
View file @
0f7aa30b
import
tvm
import
tvm
import
numpy
as
np
import
numpy
as
np
from
tvm
import
relay
from
tvm
import
relay
from
tvm.relay.ir_pass
import
alpha_equal
from
tvm.relay
import
ir_pass
def
alpha_equal
(
x
,
y
):
"""
Wrapper around alpha equality which ensures that
the hash function respects equality.
"""
return
ir_pass
.
alpha_equal
(
x
,
y
)
and
ir_pass
.
structural_hash
(
x
)
==
ir_pass
.
structural_hash
(
y
)
def
test_tensor_type_alpha_equal
():
def
test_tensor_type_alpha_equal
():
t1
=
relay
.
TensorType
((
3
,
4
),
"float32"
)
t1
=
relay
.
TensorType
((
3
,
4
),
"float32"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment