Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
a2edd01b
Unverified
Commit
a2edd01b
authored
Mar 29, 2020
by
Zhi
Committed by
GitHub
Mar 29, 2020
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
relay::StructuralHash to tvm::StructuralHash (#5166)
parent
919ae889
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
11 additions
and
490 deletions
+11
-490
include/tvm/relay/analysis.h
+0
-22
python/tvm/relay/analysis/analysis.py
+1
-25
python/tvm/relay/frontend/tensorflow.py
+1
-1
python/tvm/relay/testing/py_converter.py
+1
-1
src/relay/analysis/extract_fused_functions.cc
+2
-1
src/relay/backend/compile_engine.h
+2
-1
src/relay/backend/vm/lambda_lift.cc
+2
-1
src/relay/ir/hash.cc
+0
-437
tests/python/relay/test_pass_qnn_legalize.py
+2
-1
No files found.
include/tvm/relay/analysis.h
View file @
a2edd01b
...
@@ -225,28 +225,6 @@ TVM_DLL Map<Expr, Integer> CollectDeviceAnnotationOps(const Expr& expr);
...
@@ -225,28 +225,6 @@ TVM_DLL Map<Expr, Integer> CollectDeviceAnnotationOps(const Expr& expr);
*/
*/
TVM_DLL
Array
<
Pattern
>
UnmatchedCases
(
const
Match
&
match
,
const
IRModule
&
mod
);
TVM_DLL
Array
<
Pattern
>
UnmatchedCases
(
const
Match
&
match
,
const
IRModule
&
mod
);
/*! \brief A hashing structure in the style of std::hash. */
struct
StructuralHash
{
/*! \brief Hash a Relay type.
*
* Implements structural hashing of a Relay type.
*
* \param type the type to hash.
*
* \return the hash value.
*/
size_t
operator
()(
const
Type
&
type
)
const
;
/*! \brief Hash a Relay expression.
*
* Implements structural hashing of a Relay expression.
*
* \param expr the expression to hash.
*
* \return the hash value.
*/
size_t
operator
()(
const
Expr
&
expr
)
const
;
};
}
// namespace relay
}
// namespace relay
}
// namespace tvm
}
// namespace tvm
...
...
python/tvm/relay/analysis/analysis.py
View file @
a2edd01b
...
@@ -20,11 +20,10 @@
...
@@ -20,11 +20,10 @@
This file contains the set of passes for Relay, which exposes an interface for
This file contains the set of passes for Relay, which exposes an interface for
configuring the passes and scripting them in Python.
configuring the passes and scripting them in Python.
"""
"""
from
tvm.ir
import
RelayExpr
,
IRModule
from
tvm.ir
import
IRModule
from
.
import
_ffi_api
from
.
import
_ffi_api
from
.feature
import
Feature
from
.feature
import
Feature
from
..ty
import
Type
def
post_order_visit
(
expr
,
fvisit
):
def
post_order_visit
(
expr
,
fvisit
):
...
@@ -314,29 +313,6 @@ def detect_feature(a, b=None):
...
@@ -314,29 +313,6 @@ def detect_feature(a, b=None):
return
{
Feature
(
int
(
x
))
for
x
in
_ffi_api
.
detect_feature
(
a
,
b
)}
return
{
Feature
(
int
(
x
))
for
x
in
_ffi_api
.
detect_feature
(
a
,
b
)}
def
structural_hash
(
value
):
"""Hash a Relay expression structurally.
Parameters
----------
expr : Union[tvm.relay.Expr, tvm.relay.Type]
The expression to hash.
Returns
-------
result : int
The hash value
"""
if
isinstance
(
value
,
RelayExpr
):
return
int
(
_ffi_api
.
_expr_hash
(
value
))
elif
isinstance
(
value
,
Type
):
return
int
(
_ffi_api
.
_type_hash
(
value
))
else
:
msg
=
(
"found value of type {0} expected"
+
"relay.Expr or relay.Type"
)
.
format
(
type
(
value
))
raise
TypeError
(
msg
)
def
extract_fused_functions
(
mod
):
def
extract_fused_functions
(
mod
):
"""Pass to extract IRModule of only fused primitive functions.
"""Pass to extract IRModule of only fused primitive functions.
...
...
python/tvm/relay/frontend/tensorflow.py
View file @
a2edd01b
...
@@ -27,7 +27,7 @@ import tvm
...
@@ -27,7 +27,7 @@ import tvm
from
tvm.ir
import
IRModule
from
tvm.ir
import
IRModule
from
tvm.relay.prelude
import
Prelude
from
tvm.relay.prelude
import
Prelude
from
tvm.
relay.analysis
import
structural_hash
as
s_hash
from
tvm.
ir
import
structural_hash
as
s_hash
from
..
import
analysis
from
..
import
analysis
from
..
import
expr
as
_expr
from
..
import
expr
as
_expr
...
...
python/tvm/relay/testing/py_converter.py
View file @
a2edd01b
...
@@ -238,7 +238,7 @@ class PythonConverter(ExprFunctor):
...
@@ -238,7 +238,7 @@ class PythonConverter(ExprFunctor):
# compile the function and register globally
# compile the function and register globally
cc_key
=
compile_engine
.
CCacheKey
(
op
,
self
.
tgt
)
cc_key
=
compile_engine
.
CCacheKey
(
op
,
self
.
tgt
)
func_hash
=
relay
.
analysis
.
structural_hash
(
op
)
func_hash
=
tvm
.
ir
.
structural_hash
(
op
)
op_name
=
'_lowered_op_{}'
.
format
(
func_hash
)
op_name
=
'_lowered_op_{}'
.
format
(
func_hash
)
if
not
tvm
.
get_global_func
(
op_name
,
allow_missing
=
True
):
if
not
tvm
.
get_global_func
(
op_name
,
allow_missing
=
True
):
jitted
=
self
.
engine
.
jit
(
cc_key
,
self
.
tgt
)
jitted
=
self
.
engine
.
jit
(
cc_key
,
self
.
tgt
)
...
...
src/relay/analysis/extract_fused_functions.cc
View file @
a2edd01b
...
@@ -21,6 +21,7 @@
...
@@ -21,6 +21,7 @@
* \file extract_fused_functions.cc
* \file extract_fused_functions.cc
* \brief Apply fusion and extract fused primitive functions from an IRModule
* \brief Apply fusion and extract fused primitive functions from an IRModule
*/
*/
#include <tvm/node/structural_hash.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/expr_functor.h>
...
@@ -55,7 +56,7 @@ class FusedFunctionExtractorWrapper : private ExprVisitor {
...
@@ -55,7 +56,7 @@ class FusedFunctionExtractorWrapper : private ExprVisitor {
if
(
n
->
HasNonzeroAttr
(
attr
::
kPrimitive
))
{
if
(
n
->
HasNonzeroAttr
(
attr
::
kPrimitive
))
{
// Add function to functions, keyed by function hash string
// Add function to functions, keyed by function hash string
Function
func
=
Function
(
n
->
params
,
n
->
body
,
n
->
ret_type
,
n
->
type_params
,
n
->
attrs
);
Function
func
=
Function
(
n
->
params
,
n
->
body
,
n
->
ret_type
,
n
->
type_params
,
n
->
attrs
);
size_t
hash_
=
StructuralHash
()(
func
);
size_t
hash_
=
tvm
::
StructuralHash
()(
func
);
this
->
functions
.
Set
(
std
::
to_string
(
hash_
),
func
);
this
->
functions
.
Set
(
std
::
to_string
(
hash_
),
func
);
}
}
...
...
src/relay/backend/compile_engine.h
View file @
a2edd01b
...
@@ -26,6 +26,7 @@
...
@@ -26,6 +26,7 @@
#define TVM_RELAY_BACKEND_COMPILE_ENGINE_H_
#define TVM_RELAY_BACKEND_COMPILE_ENGINE_H_
#include <tvm/node/structural_equal.h>
#include <tvm/node/structural_equal.h>
#include <tvm/node/structural_hash.h>
#include <tvm/tir/lowered_func.h>
#include <tvm/tir/lowered_func.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/module.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/analysis.h>
...
@@ -258,7 +259,7 @@ bool IsDynamic(const Type& ty);
...
@@ -258,7 +259,7 @@ bool IsDynamic(const Type& ty);
inline
size_t
CCacheKeyNode
::
Hash
()
const
{
inline
size_t
CCacheKeyNode
::
Hash
()
const
{
if
(
hash_
!=
0
)
return
hash_
;
if
(
hash_
!=
0
)
return
hash_
;
// do structral hash, avoid 0.
// do structral hash, avoid 0.
hash_
=
StructuralHash
()(
this
->
source_func
);
hash_
=
tvm
::
StructuralHash
()(
this
->
source_func
);
hash_
=
dmlc
::
HashCombine
(
hash_
=
dmlc
::
HashCombine
(
hash_
,
std
::
hash
<
std
::
string
>
()(
target
->
str
()));
hash_
,
std
::
hash
<
std
::
string
>
()(
target
->
str
()));
if
(
hash_
==
0
)
hash_
=
1
;
if
(
hash_
==
0
)
hash_
=
1
;
...
...
src/relay/backend/vm/lambda_lift.cc
View file @
a2edd01b
...
@@ -23,6 +23,7 @@
...
@@ -23,6 +23,7 @@
*/
*/
#include <tvm/node/structural_equal.h>
#include <tvm/node/structural_equal.h>
#include <tvm/node/structural_hash.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/support/logging.h>
#include <tvm/support/logging.h>
...
@@ -39,7 +40,7 @@ namespace relay {
...
@@ -39,7 +40,7 @@ namespace relay {
namespace
vm
{
namespace
vm
{
inline
std
::
string
GenerateName
(
const
Function
&
func
)
{
inline
std
::
string
GenerateName
(
const
Function
&
func
)
{
size_t
hash
=
StructuralHash
()(
func
);
size_t
hash
=
tvm
::
StructuralHash
()(
func
);
return
std
::
string
(
"lifted_name"
)
+
std
::
to_string
(
hash
);
return
std
::
string
(
"lifted_name"
)
+
std
::
to_string
(
hash
);
}
}
...
...
src/relay/ir/hash.cc
deleted
100644 → 0
View file @
919ae889
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/relay/ir/hash.cc
* \brief Hash functions for Relay types and expressions.
*/
#include <tvm/ir/type_functor.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/relay/analysis.h>
#include <tvm/ir/attrs.h>
#include "../../ir/attr_functor.h"
namespace
tvm
{
namespace
relay
{
// Hash handler for Relay.
class
RelayHashHandler
:
public
AttrsHashHandler
,
public
TypeFunctor
<
size_t
(
const
Type
&
)
>
,
public
ExprFunctor
<
size_t
(
const
Expr
&
)
>
,
public
PatternFunctor
<
size_t
(
const
Pattern
&
)
>
{
public
:
explicit
RelayHashHandler
()
{}
/*!
* Compute hash of a node.
* \param ref The node to hash.
* \return the hash value.
*/
size_t
Hash
(
const
ObjectRef
&
ref
)
{
if
(
!
ref
.
defined
())
return
ObjectHash
()(
ref
);
if
(
ref
->
IsInstance
<
TypeNode
>
())
{
return
TypeHash
(
Downcast
<
Type
>
(
ref
));
}
if
(
ref
->
IsInstance
<
ExprNode
>
())
{
return
ExprHash
(
Downcast
<
Expr
>
(
ref
));
}
return
AttrHash
(
ref
);
}
/*!
* Compute hash of the attributes.
* \param ref The attributes.
* \return the hash value
*/
size_t
AttrHash
(
const
ObjectRef
&
ref
)
{
if
(
!
ref
.
defined
())
{
return
ObjectHash
()(
ref
);
}
return
AttrsHashHandler
::
Hash
(
ref
);
}
/*!
* Compute hash of a Relay type.
* \param ref The type to hash.
* \param rhs The right hand operand.
* \return the hash value.
*/
size_t
TypeHash
(
const
Type
&
type
)
{
if
(
!
type
.
defined
())
{
return
ObjectHash
()(
type
);
}
auto
found
=
hash_map_
.
find
(
type
);
if
(
found
!=
hash_map_
.
end
())
{
return
found
->
second
;
}
else
{
auto
hash
=
this
->
VisitType
(
type
);
hash_map_
.
insert
({
type
,
hash
});
return
hash
;
}
}
/*!
* Compute the hash of an expression.
*
* \note We run graph structural equality checking when comparing two Exprs.
* This means that AlphaEqualHandler can only be used once for each pair.
* The equality checker checks data-flow equvalence of the Expr DAG.
* This function also runs faster as it memomizes equal_map.
*
* \param expr The expression to hash.
* \return the hash value.
*/
size_t
ExprHash
(
const
Expr
&
expr
)
{
if
(
!
expr
.
defined
())
{
return
ObjectHash
()(
expr
);
}
auto
found
=
hash_map_
.
find
(
expr
);
if
(
found
!=
hash_map_
.
end
())
{
return
found
->
second
;
}
else
{
auto
hash
=
this
->
VisitExpr
(
expr
);
hash_map_
.
insert
({
expr
,
hash
});
return
hash
;
}
}
protected
:
/*!
* \brief Hash a DataType.
* \param dtype The dtype to hash.
* \return the hash value.
*/
size_t
DataTypeHash
(
const
DataType
&
dtype
)
{
return
::
tvm
::
AttrsHash
()(
dtype
);
}
using
AttrsHashHandler
::
VisitAttr_
;
size_t
VisitAttr_
(
const
tvm
::
tir
::
VarNode
*
var
)
final
{
size_t
hash
=
std
::
hash
<
std
::
string
>
()(
VarNode
::
_type_key
);
auto
it
=
hash_map_
.
find
(
GetRef
<
tvm
::
tir
::
Var
>
(
var
));
if
(
it
!=
hash_map_
.
end
())
{
return
it
->
second
;
}
return
Combine
(
hash
,
std
::
hash
<
std
::
string
>
()(
var
->
name_hint
));
}
// Type hashing
size_t
VisitType_
(
const
TensorTypeNode
*
tensor_type
)
final
{
size_t
hash
=
std
::
hash
<
std
::
string
>
()(
TensorTypeNode
::
_type_key
);
hash
=
Combine
(
hash
,
DataTypeHash
(
tensor_type
->
dtype
));
hash
=
Combine
(
hash
,
Hash
(
tensor_type
->
shape
));
return
hash
;
}
size_t
VisitType_
(
const
IncompleteTypeNode
*
incomplete
)
final
{
size_t
hash
=
std
::
hash
<
std
::
string
>
()(
IncompleteTypeNode
::
_type_key
);
return
Combine
(
hash
,
std
::
hash
<
int
>
()(
incomplete
->
kind
));
}
size_t
VisitType_
(
const
TypeVarNode
*
tyvar
)
final
{
/*
TypeVar/Var/Variable have two locations where they are hashed:
The declaration site of a function, let, or function type.
The first occurence in the term.
We will only reach this code if the TypeVar itself is unbound, we assign
a free variable index to it, meaning this hashing function implements
structural equality for both open (i.e graph equality) and closed terms
(i.e alpha_equality).
*/
return
BindVar
(
GetRef
<
TypeVar
>
(
tyvar
));
}
size_t
VisitType_
(
const
FuncTypeNode
*
func_type
)
final
{
size_t
hash
=
std
::
hash
<
std
::
string
>
()(
FuncTypeNode
::
_type_key
);
for
(
auto
type_param
:
func_type
->
type_params
)
{
hash
=
Combine
(
hash
,
BindVar
(
type_param
));
}
for
(
auto
arg
:
func_type
->
arg_types
)
{
hash
=
Combine
(
hash
,
TypeHash
(
arg
));
}
hash
=
Combine
(
hash
,
TypeHash
(
func_type
->
ret_type
));
for
(
auto
cs
:
func_type
->
type_constraints
)
{
hash
=
Combine
(
hash
,
TypeHash
(
cs
));
}
return
hash
;
}
size_t
VisitType_
(
const
TypeRelationNode
*
type_rel
)
final
{
size_t
hash
=
std
::
hash
<
std
::
string
>
()(
TypeRelationNode
::
_type_key
);
hash
=
Combine
(
hash
,
std
::
hash
<
std
::
string
>
()(
type_rel
->
func
->
name
));
hash
=
Combine
(
hash
,
AttrHash
(
type_rel
->
attrs
));
for
(
auto
arg
:
type_rel
->
args
)
{
hash
=
Combine
(
hash
,
TypeHash
(
arg
));
}
return
hash
;
}
size_t
VisitType_
(
const
TupleTypeNode
*
tuple_type
)
final
{
size_t
hash
=
std
::
hash
<
std
::
string
>
()(
TupleTypeNode
::
_type_key
);
for
(
size_t
i
=
0
;
i
<
tuple_type
->
fields
.
size
();
i
++
)
{
hash
=
Combine
(
hash
,
TypeHash
(
tuple_type
->
fields
[
i
]));
}
return
hash
;
}
size_t
VisitType_
(
const
RelayRefTypeNode
*
rtn
)
final
{
size_t
hash
=
std
::
hash
<
std
::
string
>
()(
RelayRefTypeNode
::
_type_key
);
hash
=
Combine
(
hash
,
TypeHash
(
rtn
->
value
));
return
hash
;
}
// Expr hashing.
size_t
NDArrayHash
(
const
runtime
::
NDArray
&
array
)
{
size_t
hash
=
std
::
hash
<
uint8_t
>
()(
array
->
dtype
.
code
);
hash
=
Combine
(
hash
,
std
::
hash
<
uint8_t
>
()(
array
->
dtype
.
bits
));
hash
=
Combine
(
hash
,
std
::
hash
<
uint16_t
>
()(
array
->
dtype
.
lanes
));
CHECK_EQ
(
array
->
ctx
.
device_type
,
kDLCPU
)
<<
"can only compare CPU tensor"
;
size_t
data_size
=
runtime
::
GetDataSize
(
*
array
.
operator
->
());
uint8_t
*
data
=
reinterpret_cast
<
uint8_t
*>
(
array
->
data
);
for
(
size_t
i
=
0
;
i
<
data_size
;
i
++
)
{
hash
=
Combine
(
hash
,
std
::
hash
<
uint8_t
>
()(
data
[
i
]));
}
return
hash
;
}
size_t
BindVar
(
const
ObjectRef
&
var
)
{
size_t
hash
=
std
::
hash
<
int
>
()(
var_counter
++
);
CHECK_EQ
(
hash_map_
.
count
(
var
),
0
);
if
(
auto
var_node
=
var
.
as
<
VarNode
>
())
{
hash
=
Combine
(
hash
,
TypeHash
(
var_node
->
type_annotation
));
}
hash_map_
[
var
]
=
hash
;
return
hash
;
}
size_t
VisitExpr_
(
const
VarNode
*
var
)
final
{
// hash free variable
size_t
name_hash
=
std
::
hash
<
const
Object
*>
()(
var
->
vid
.
get
());
return
Combine
(
name_hash
,
TypeHash
(
var
->
type_annotation
));
}
size_t
VisitExpr_
(
const
GlobalVarNode
*
global
)
final
{
return
std
::
hash
<
std
::
string
>
()(
global
->
name_hint
);
}
size_t
VisitExpr_
(
const
TupleNode
*
tuple
)
final
{
size_t
hash
=
std
::
hash
<
std
::
string
>
()(
TupleNode
::
_type_key
);
for
(
size_t
i
=
0
;
i
<
tuple
->
fields
.
size
();
i
++
)
{
hash
=
Combine
(
hash
,
ExprHash
(
tuple
->
fields
[
i
]));
}
return
hash
;
}
size_t
VisitExpr_
(
const
FunctionNode
*
func
)
final
{
size_t
hash
=
std
::
hash
<
std
::
string
>
()(
FunctionNode
::
_type_key
);
for
(
auto
type_param
:
func
->
type_params
)
{
hash
=
Combine
(
hash
,
BindVar
(
type_param
));
}
for
(
auto
param
:
func
->
params
)
{
hash
=
Combine
(
hash
,
BindVar
(
param
));
}
hash
=
Combine
(
hash
,
TypeHash
(
func
->
ret_type
));
hash
=
Combine
(
hash
,
ExprHash
(
func
->
body
));
hash
=
Combine
(
hash
,
AttrHash
(
func
->
attrs
));
return
hash
;
}
size_t
VisitExpr_
(
const
CallNode
*
call
)
final
{
size_t
hash
=
std
::
hash
<
std
::
string
>
()(
CallNode
::
_type_key
);
hash
=
Combine
(
hash
,
ExprHash
(
call
->
op
));
for
(
auto
arg
:
call
->
args
)
{
hash
=
Combine
(
hash
,
ExprHash
(
arg
));
}
for
(
auto
t
:
call
->
type_args
)
{
CHECK
(
t
.
defined
());
hash
=
Combine
(
hash
,
TypeHash
(
t
));
}
hash
=
Combine
(
hash
,
AttrHash
(
call
->
attrs
));
return
hash
;
}
size_t
VisitExpr_
(
const
LetNode
*
let
)
final
{
size_t
hash
=
std
::
hash
<
std
::
string
>
()(
LetNode
::
_type_key
);
hash
=
Combine
(
hash
,
BindVar
(
let
->
var
));
hash
=
Combine
(
hash
,
ExprHash
(
let
->
value
));
hash
=
Combine
(
hash
,
ExprHash
(
let
->
body
));
return
hash
;
}
size_t
VisitExpr_
(
const
IfNode
*
ite
)
final
{
size_t
key
=
std
::
hash
<
std
::
string
>
()(
IfNode
::
_type_key
);
size_t
hash
=
key
;
hash
=
Combine
(
hash
,
ExprHash
(
ite
->
cond
));
hash
=
Combine
(
hash
,
ExprHash
(
ite
->
true_branch
));
hash
=
Combine
(
hash
,
ExprHash
(
ite
->
false_branch
));
return
hash
;
}
size_t
VisitExpr_
(
const
OpNode
*
op
)
final
{
return
ObjectHash
()(
GetRef
<
Op
>
(
op
));
}
size_t
VisitExpr_
(
const
ConstantNode
*
rconst
)
final
{
return
NDArrayHash
(
rconst
->
data
);
}
size_t
VisitExpr_
(
const
TupleGetItemNode
*
get_item
)
final
{
size_t
hash
=
std
::
hash
<
std
::
string
>
()(
TupleGetItemNode
::
_type_key
);
hash
=
Combine
(
hash
,
ExprHash
(
get_item
->
tuple
));
hash
=
Combine
(
hash
,
std
::
hash
<
int
>
()(
get_item
->
index
));
return
hash
;
}
size_t
VisitExpr_
(
const
RefCreateNode
*
rn
)
final
{
size_t
hash
=
std
::
hash
<
std
::
string
>
()(
RefCreateNode
::
_type_key
);
hash
=
Combine
(
hash
,
ExprHash
(
rn
->
value
));
return
hash
;
}
size_t
VisitExpr_
(
const
RefReadNode
*
rn
)
final
{
size_t
hash
=
std
::
hash
<
std
::
string
>
()(
RefReadNode
::
_type_key
);
hash
=
Combine
(
hash
,
ExprHash
(
rn
->
ref
));
return
hash
;
}
size_t
VisitExpr_
(
const
RefWriteNode
*
rn
)
final
{
size_t
hash
=
std
::
hash
<
std
::
string
>
()(
RefWriteNode
::
_type_key
);
hash
=
Combine
(
hash
,
ExprHash
(
rn
->
ref
));
hash
=
Combine
(
hash
,
ExprHash
(
rn
->
value
));
return
hash
;
}
size_t
VisitExpr_
(
const
MatchNode
*
mn
)
final
{
size_t
hash
=
std
::
hash
<
std
::
string
>
()(
MatchNode
::
_type_key
);
hash
=
Combine
(
hash
,
ExprHash
(
mn
->
data
));
for
(
const
auto
&
c
:
mn
->
clauses
)
{
hash
=
Combine
(
hash
,
PatternHash
(
c
->
lhs
));
hash
=
Combine
(
hash
,
ExprHash
(
c
->
rhs
));
}
hash
=
Combine
(
hash
,
std
::
hash
<
bool
>
()(
mn
->
complete
));
return
hash
;
}
size_t
VisitExpr_
(
const
ConstructorNode
*
cn
)
final
{
size_t
hash
=
std
::
hash
<
std
::
string
>
()(
ConstructorNode
::
_type_key
);
hash
=
Combine
(
hash
,
std
::
hash
<
std
::
string
>
()(
cn
->
name_hint
));
return
hash
;
}
size_t
VisitType_
(
const
TypeCallNode
*
tcn
)
final
{
size_t
hash
=
std
::
hash
<
std
::
string
>
()(
TypeCallNode
::
_type_key
);
hash
=
Combine
(
hash
,
TypeHash
(
tcn
->
func
));
for
(
const
auto
&
t
:
tcn
->
args
)
{
hash
=
Combine
(
hash
,
TypeHash
(
t
));
}
return
hash
;
}
size_t
VisitType_
(
const
TypeDataNode
*
tdn
)
final
{
size_t
hash
=
std
::
hash
<
std
::
string
>
()(
TypeDataNode
::
_type_key
);
hash
=
Combine
(
hash
,
TypeHash
(
tdn
->
header
));
for
(
const
auto
&
tv
:
tdn
->
type_vars
)
{
hash
=
Combine
(
hash
,
TypeHash
(
tv
));
}
for
(
const
auto
&
cn
:
tdn
->
constructors
)
{
hash
=
Combine
(
hash
,
ExprHash
(
cn
));
}
return
hash
;
}
size_t
VisitType_
(
const
GlobalTypeVarNode
*
tvn
)
final
{
return
BindVar
(
GetRef
<
GlobalTypeVar
>
(
tvn
));
}
size_t
PatternHash
(
const
Pattern
&
p
)
{
return
VisitPattern
(
p
);
}
size_t
VisitPattern_
(
const
PatternConstructorNode
*
pcn
)
final
{
size_t
hash
=
std
::
hash
<
std
::
string
>
()(
PatternConstructorNode
::
_type_key
);
hash
=
Combine
(
hash
,
ExprHash
(
pcn
->
constructor
));
for
(
const
auto
&
p
:
pcn
->
patterns
)
{
hash
=
Combine
(
hash
,
PatternHash
(
p
));
}
return
hash
;
}
size_t
VisitPattern_
(
const
PatternTupleNode
*
ptn
)
final
{
size_t
hash
=
std
::
hash
<
std
::
string
>
()(
PatternTupleNode
::
_type_key
);
for
(
const
auto
&
p
:
ptn
->
patterns
)
{
hash
=
Combine
(
hash
,
PatternHash
(
p
));
}
return
hash
;
}
size_t
VisitPattern_
(
const
PatternVarNode
*
pvn
)
final
{
size_t
hash
=
std
::
hash
<
std
::
string
>
()(
PatternVarNode
::
_type_key
);
hash
=
Combine
(
hash
,
BindVar
(
pvn
->
var
));
return
hash
;
}
size_t
VisitPattern_
(
const
PatternWildcardNode
*
pwn
)
final
{
size_t
hash
=
std
::
hash
<
std
::
string
>
()(
PatternWildcardNode
::
_type_key
);
return
hash
;
}
private
:
// renaming of NodeRef to indicate two nodes equals to each other
std
::
unordered_map
<
ObjectRef
,
size_t
,
ObjectHash
,
ObjectEqual
>
hash_map_
;
int
var_counter
=
0
;
};
size_t
StructuralHash
::
operator
()(
const
Type
&
type
)
const
{
return
RelayHashHandler
().
TypeHash
(
type
);
}
size_t
StructuralHash
::
operator
()(
const
Expr
&
expr
)
const
{
return
RelayHashHandler
().
ExprHash
(
expr
);
}
TVM_REGISTER_GLOBAL
(
"relay.analysis._expr_hash"
)
.
set_body_typed
([](
ObjectRef
ref
)
{
return
static_cast
<
int64_t
>
(
RelayHashHandler
().
Hash
(
ref
));
});
TVM_REGISTER_GLOBAL
(
"relay.analysis._type_hash"
)
.
set_body_typed
([](
Type
type
)
{
return
static_cast
<
int64_t
>
(
RelayHashHandler
().
TypeHash
(
type
));
});
}
// namespace relay
}
// namespace tvm
tests/python/relay/test_pass_qnn_legalize.py
View file @
a2edd01b
...
@@ -31,7 +31,8 @@ def alpha_equal(x, y):
...
@@ -31,7 +31,8 @@ def alpha_equal(x, y):
"""
"""
x
=
x
[
'main'
]
x
=
x
[
'main'
]
y
=
y
[
'main'
]
y
=
y
[
'main'
]
return
tvm
.
ir
.
structural_equal
(
x
,
y
)
and
analysis
.
structural_hash
(
x
)
==
analysis
.
structural_hash
(
y
)
return
tvm
.
ir
.
structural_equal
(
x
,
y
)
and
\
tvm
.
ir
.
structural_hash
(
x
)
==
tvm
.
ir
.
structural_hash
(
y
)
def
run_opt_pass
(
expr
,
passes
):
def
run_opt_pass
(
expr
,
passes
):
passes
=
passes
if
isinstance
(
passes
,
list
)
else
[
passes
]
passes
=
passes
if
isinstance
(
passes
,
list
)
else
[
passes
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment