Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
6536b356
Unverified
Commit
6536b356
authored
Mar 29, 2020
by
Zhi
Committed by
GitHub
Mar 29, 2020
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
remove AttrsEqual and AttrsHash related code (#5169)
parent
a2edd01b
Show whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
29 additions
and
573 deletions
+29
-573
include/tvm/ir/attrs.h
+3
-168
src/ir/attr_functor.h
+0
-89
src/ir/attrs.cc
+0
-278
src/node/structural_equal.cc
+1
-0
src/relay/transforms/combine_parallel_conv2d.cc
+2
-2
src/relay/transforms/combine_parallel_dense.cc
+1
-1
src/relay/transforms/combine_parallel_op.cc
+2
-1
src/relay/transforms/combine_parallel_op_batch.cc
+2
-2
src/relay/transforms/eliminate_common_subexpr.cc
+1
-1
src/relay/transforms/fold_scale_axis.cc
+2
-2
src/relay/transforms/fuse_ops.cc
+1
-1
src/relay/transforms/pattern_util.h
+1
-1
src/tir/pass/ffi_api.cc
+0
-12
tests/python/relay/test_ir_nodes.py
+0
-1
tests/python/unittest/test_ir_attrs.py
+4
-5
tests/python/unittest/test_tir_pass_attrs_hash_equal.py
+9
-9
No files found.
include/tvm/ir/attrs.h
View file @
6536b356
...
...
@@ -46,6 +46,8 @@
#include <dmlc/common.h>
#include <tvm/ir/expr.h>
#include <tvm/node/structural_equal.h>
#include <tvm/node/structural_hash.h>
#include <tvm/runtime/packed_func.h>
#include <unordered_map>
...
...
@@ -131,95 +133,6 @@ class AttrFieldInfo : public ObjectRef {
TVM_DEFINE_OBJECT_REF_METHODS
(
AttrFieldInfo
,
ObjectRef
,
AttrFieldInfoNode
);
};
class
AttrsHashHandler
;
class
AttrsEqualHandler
;
/*!
* \brief Content-aware Equality comparator for attrs.
*
* This comparator will recursively deep compare the following Attributes.
*
* - IntImm, UIntImm, FloatImm, StringImm
* - Any subclass of BaseAttrsNode
* - Array of Attributes.
* - Map from string to Attributes.
*/
class
AttrsEqual
{
public
:
bool
operator
()(
const
double
&
lhs
,
const
double
&
rhs
)
const
{
// fuzzy float pt comparison
constexpr
double
atol
=
1e-9
;
if
(
lhs
==
rhs
)
return
true
;
double
diff
=
lhs
-
rhs
;
return
diff
>
-
atol
&&
diff
<
atol
;
}
bool
operator
()(
const
int64_t
&
lhs
,
const
int64_t
&
rhs
)
const
{
return
lhs
==
rhs
;
}
bool
operator
()(
const
uint64_t
&
lhs
,
const
uint64_t
&
rhs
)
const
{
return
lhs
==
rhs
;
}
bool
operator
()(
const
int
&
lhs
,
const
int
&
rhs
)
const
{
return
lhs
==
rhs
;
}
bool
operator
()(
const
bool
&
lhs
,
const
bool
&
rhs
)
const
{
return
lhs
==
rhs
;
}
bool
operator
()(
const
std
::
string
&
lhs
,
const
std
::
string
&
rhs
)
const
{
return
lhs
==
rhs
;
}
bool
operator
()(
const
DataType
&
lhs
,
const
DataType
&
rhs
)
const
{
return
lhs
==
rhs
;
}
// node comparator
TVM_DLL
bool
operator
()(
const
ObjectRef
&
lhs
,
const
ObjectRef
&
rhs
)
const
;
protected
:
friend
class
AttrsEqualHandler
;
/*! \brief internal handle. */
AttrsEqualHandler
*
handler_
{
nullptr
};
};
/*!
* \brief Content-aware hash function.
*
* This hash functor will recursively hash the content of the Attributes.
* It is guaranteed that if AttrsEqual(a, b) == true, then AttrsHash(a) == AttrsHash(b);
*/
class
AttrsHash
{
public
:
size_t
operator
()(
const
double
&
value
)
const
{
return
std
::
hash
<
double
>
()(
value
);
}
size_t
operator
()(
const
int64_t
&
value
)
const
{
return
std
::
hash
<
int64_t
>
()(
value
);
}
size_t
operator
()(
const
uint64_t
&
value
)
const
{
return
std
::
hash
<
uint64_t
>
()(
value
);
}
size_t
operator
()(
const
int
&
value
)
const
{
return
std
::
hash
<
int
>
()(
value
);
}
size_t
operator
()(
const
bool
&
value
)
const
{
return
std
::
hash
<
bool
>
()(
value
);
}
size_t
operator
()(
const
std
::
string
&
value
)
const
{
return
std
::
hash
<
std
::
string
>
()(
value
);
}
size_t
operator
()(
const
DataType
&
value
)
const
{
return
std
::
hash
<
int
>
()(
static_cast
<
int
>
(
value
.
code
())
|
(
static_cast
<
int
>
(
value
.
bits
())
<<
8
)
|
(
static_cast
<
int
>
(
value
.
lanes
())
<<
16
));
}
TVM_DLL
size_t
operator
()(
const
ObjectRef
&
value
)
const
;
private
:
friend
class
AttrsHashHandler
;
/*! \brief internal handle. */
AttrsHashHandler
*
handler_
{
nullptr
};
};
/*!
* \brief Base class of all attribute class
* \note Do not subclass AttrBaseNode directly,
...
...
@@ -266,20 +179,6 @@ class BaseAttrsNode : public Object {
* \note This function throws when the required field is not present.
*/
TVM_DLL
virtual
void
InitByPackedArgs
(
const
TVMArgs
&
kwargs
,
bool
allow_unknown
=
false
)
=
0
;
/*!
* \brief Whether this attribute's content equals to another node.
* \param other The pointer to another node.
* \param equal The equal comparator
* \return The comparison result.
*/
TVM_DLL
virtual
bool
ContentEqual
(
const
Object
*
other
,
AttrsEqual
equal
)
const
=
0
;
/*!
* \brief Content aware hash.
* \param hasher The hasher to run the hash.
* \return the hash result.
*/
TVM_DLL
virtual
size_t
ContentHash
(
AttrsHash
hasher
)
const
=
0
;
static
constexpr
const
bool
_type_has_method_sequal_reduce
=
true
;
static
constexpr
const
bool
_type_has_method_shash_reduce
=
true
;
...
...
@@ -320,8 +219,6 @@ class DictAttrsNode : public BaseAttrsNode {
void
VisitNonDefaultAttrs
(
AttrVisitor
*
v
)
final
;
void
InitByPackedArgs
(
const
runtime
::
TVMArgs
&
args
,
bool
allow_unknown
)
final
;
Array
<
AttrFieldInfo
>
ListFieldInfo
()
const
final
;
bool
ContentEqual
(
const
Object
*
other
,
AttrsEqual
equal
)
const
final
;
size_t
ContentHash
(
AttrsHash
hasher
)
const
final
;
// type info
static
constexpr
const
char
*
_type_key
=
"DictAttrs"
;
TVM_DECLARE_FINAL_OBJECT_INFO
(
DictAttrsNode
,
BaseAttrsNode
);
...
...
@@ -386,34 +283,6 @@ class AttrNormalVisitor {
AttrVisitor
*
visitor_
;
};
// Wrapper for normal visitor.
class
AttrsEqualVisitor
{
public
:
bool
result_
{
true
};
// constructor
AttrsEqualVisitor
(
const
Object
*
lhs
,
const
Object
*
rhs
,
const
AttrsEqual
&
equal
)
:
lhs_
(
lhs
),
rhs_
(
rhs
),
equal_
(
equal
)
{
}
template
<
typename
T
>
AttrNopEntry
operator
()(
const
char
*
key
,
T
*
lhs_value
)
{
if
(
!
result_
)
return
AttrNopEntry
();
const
T
*
rhs_value
=
reinterpret_cast
<
const
T
*>
(
reinterpret_cast
<
const
char
*>
(
rhs_
)
+
(
reinterpret_cast
<
const
char
*>
(
lhs_value
)
-
reinterpret_cast
<
const
char
*>
(
lhs_
)));
if
(
!
equal_
(
*
lhs_value
,
*
rhs_value
))
{
result_
=
false
;
}
return
AttrNopEntry
();
}
private
:
const
Object
*
lhs_
;
const
Object
*
rhs_
;
const
AttrsEqual
&
equal_
;
};
class
AttrsSEqualVisitor
{
public
:
bool
result_
{
true
};
...
...
@@ -441,23 +310,6 @@ class AttrsSEqualVisitor {
const
SEqualReducer
&
equal_
;
};
class
AttrsHashVisitor
{
public
:
explicit
AttrsHashVisitor
(
const
AttrsHash
&
hasher
)
:
hasher_
(
hasher
)
{}
size_t
result_
{
0
};
template
<
typename
T
>
AttrNopEntry
operator
()(
const
char
*
key
,
T
*
value
)
{
result_
=
dmlc
::
HashCombine
(
result_
,
hasher_
(
*
value
));
return
AttrNopEntry
();
}
private
:
const
AttrsHash
&
hasher_
;
};
class
AttrsSHashVisitor
{
public
:
explicit
AttrsSHashVisitor
(
const
SHashReducer
&
hash_reducer
)
...
...
@@ -760,7 +612,7 @@ struct AttrTriggerNonDefaultEntry {
return
*
this
;
}
TSelf
&
set_default
(
const
T
&
value
)
{
if
(
Attrs
Equal
()(
value
,
*
data_
))
{
if
(
tvm
::
Structural
Equal
()(
value
,
*
data_
))
{
trigger_
=
false
;
}
return
*
this
;
...
...
@@ -890,23 +742,6 @@ class AttrsNode : public BaseAttrsNode {
return
visitor
.
fields_
;
}
bool
ContentEqual
(
const
Object
*
other
,
AttrsEqual
equal
)
const
final
{
DerivedType
*
pself
=
self
();
if
(
pself
==
other
)
return
true
;
if
(
other
==
nullptr
)
return
false
;
if
(
pself
->
type_index
()
!=
other
->
type_index
())
return
false
;
::
tvm
::
detail
::
AttrsEqualVisitor
visitor
(
pself
,
other
,
equal
);
self
()
->
__VisitAttrs__
(
visitor
);
return
visitor
.
result_
;
}
size_t
ContentHash
(
AttrsHash
hasher
)
const
final
{
::
tvm
::
detail
::
AttrsHashVisitor
visitor
(
hasher
);
visitor
.
result_
=
this
->
GetTypeKeyHash
();
self
()
->
__VisitAttrs__
(
visitor
);
return
visitor
.
result_
;
}
private
:
DerivedType
*
self
()
const
{
return
const_cast
<
DerivedType
*>
(
...
...
src/ir/attr_functor.h
View file @
6536b356
...
...
@@ -147,94 +147,5 @@ class AttrFunctor<R(const ObjectRef& n, Args...)> {
}
};
class
AttrsEqualHandler
:
protected
AttrFunctor
<
bool
(
const
ObjectRef
&
,
const
ObjectRef
&
)
>
{
public
:
/*!
* \brief Check if lhs equals rhs
* \param lhs The left operand.
* \param rhs The right operand.
*/
bool
Equal
(
const
ObjectRef
&
lhs
,
const
ObjectRef
&
rhs
);
protected
:
bool
VisitAttrDefault_
(
const
Object
*
lhs
,
const
ObjectRef
&
other
)
final
;
bool
VisitAttr_
(
const
ArrayNode
*
lhs
,
const
ObjectRef
&
other
)
final
;
bool
VisitAttr_
(
const
StrMapNode
*
lhs
,
const
ObjectRef
&
other
)
final
;
bool
VisitAttr_
(
const
tir
::
IntImmNode
*
lhs
,
const
ObjectRef
&
other
)
final
;
bool
VisitAttr_
(
const
tir
::
FloatImmNode
*
lhs
,
const
ObjectRef
&
other
)
final
;
bool
VisitAttr_
(
const
tir
::
StringImmNode
*
lhs
,
const
ObjectRef
&
other
)
final
;
bool
VisitAttr_
(
const
tir
::
AddNode
*
lhs
,
const
ObjectRef
&
other
)
final
;
bool
VisitAttr_
(
const
tir
::
SubNode
*
lhs
,
const
ObjectRef
&
other
)
final
;
bool
VisitAttr_
(
const
tir
::
MulNode
*
lhs
,
const
ObjectRef
&
other
)
final
;
bool
VisitAttr_
(
const
tir
::
DivNode
*
lhs
,
const
ObjectRef
&
other
)
final
;
bool
VisitAttr_
(
const
tir
::
ModNode
*
lhs
,
const
ObjectRef
&
other
)
final
;
bool
VisitAttr_
(
const
tir
::
FloorDivNode
*
lhs
,
const
ObjectRef
&
other
)
final
;
bool
VisitAttr_
(
const
tir
::
FloorModNode
*
lhs
,
const
ObjectRef
&
other
)
final
;
bool
VisitAttr_
(
const
tir
::
MinNode
*
lhs
,
const
ObjectRef
&
other
)
final
;
bool
VisitAttr_
(
const
tir
::
MaxNode
*
lhs
,
const
ObjectRef
&
other
)
final
;
bool
VisitAttr_
(
const
tir
::
GENode
*
lhs
,
const
ObjectRef
&
other
)
final
;
bool
VisitAttr_
(
const
tir
::
GTNode
*
lhs
,
const
ObjectRef
&
other
)
final
;
bool
VisitAttr_
(
const
tir
::
LTNode
*
lhs
,
const
ObjectRef
&
other
)
final
;
bool
VisitAttr_
(
const
tir
::
LENode
*
lhs
,
const
ObjectRef
&
other
)
final
;
bool
VisitAttr_
(
const
tir
::
EQNode
*
lhs
,
const
ObjectRef
&
other
)
final
;
bool
VisitAttr_
(
const
tir
::
NENode
*
lhs
,
const
ObjectRef
&
other
)
final
;
bool
VisitAttr_
(
const
tir
::
AndNode
*
lhs
,
const
ObjectRef
&
other
)
final
;
bool
VisitAttr_
(
const
tir
::
OrNode
*
lhs
,
const
ObjectRef
&
other
)
final
;
bool
VisitAttr_
(
const
tir
::
NotNode
*
lhs
,
const
ObjectRef
&
other
)
final
;
bool
VisitAttr_
(
const
tir
::
CastNode
*
lhs
,
const
ObjectRef
&
other
)
final
;
bool
VisitAttr_
(
const
tir
::
CallNode
*
lhs
,
const
ObjectRef
&
other
)
final
;
bool
VisitAttr_
(
const
tir
::
SelectNode
*
lhs
,
const
ObjectRef
&
other
)
final
;
};
class
AttrsHashHandler
:
protected
AttrFunctor
<
size_t
(
const
ObjectRef
&
)
>
{
public
:
/*!
* \brief Get hash value of node
* \param node The node to be hashed.
*/
size_t
Hash
(
const
ObjectRef
&
node
)
{
if
(
!
node
.
defined
())
return
0
;
return
this
->
VisitAttr
(
node
);
}
protected
:
size_t
VisitAttrDefault_
(
const
Object
*
lhs
)
final
;
size_t
VisitAttr_
(
const
tir
::
IntImmNode
*
lhs
)
final
;
size_t
VisitAttr_
(
const
tir
::
FloatImmNode
*
lhs
)
final
;
size_t
VisitAttr_
(
const
tir
::
StringImmNode
*
lhs
)
final
;
size_t
VisitAttr_
(
const
ArrayNode
*
lhs
)
final
;
size_t
VisitAttr_
(
const
StrMapNode
*
lhs
)
final
;
size_t
VisitAttr_
(
const
tir
::
AddNode
*
op
)
final
;
size_t
VisitAttr_
(
const
tir
::
SubNode
*
op
)
final
;
size_t
VisitAttr_
(
const
tir
::
MulNode
*
op
)
final
;
size_t
VisitAttr_
(
const
tir
::
DivNode
*
op
)
final
;
size_t
VisitAttr_
(
const
tir
::
ModNode
*
op
)
final
;
size_t
VisitAttr_
(
const
tir
::
FloorDivNode
*
op
)
final
;
size_t
VisitAttr_
(
const
tir
::
FloorModNode
*
op
)
final
;
size_t
VisitAttr_
(
const
tir
::
MinNode
*
op
)
final
;
size_t
VisitAttr_
(
const
tir
::
MaxNode
*
op
)
final
;
size_t
VisitAttr_
(
const
tir
::
GENode
*
op
)
final
;
size_t
VisitAttr_
(
const
tir
::
GTNode
*
op
)
final
;
size_t
VisitAttr_
(
const
tir
::
LENode
*
op
)
final
;
size_t
VisitAttr_
(
const
tir
::
LTNode
*
op
)
final
;
size_t
VisitAttr_
(
const
tir
::
EQNode
*
op
)
final
;
size_t
VisitAttr_
(
const
tir
::
NENode
*
op
)
final
;
size_t
VisitAttr_
(
const
tir
::
AndNode
*
op
)
final
;
size_t
VisitAttr_
(
const
tir
::
OrNode
*
op
)
final
;
size_t
VisitAttr_
(
const
tir
::
NotNode
*
op
)
final
;
size_t
VisitAttr_
(
const
tir
::
CastNode
*
op
)
final
;
size_t
VisitAttr_
(
const
tir
::
CallNode
*
op
)
final
;
size_t
VisitAttr_
(
const
tir
::
SelectNode
*
op
)
final
;
/*!
* \brief alias of dmlc::HashCombine
* \param lhs The first hash value.
* \param rhs The second hash value.
*/
static
size_t
Combine
(
size_t
lhs
,
size_t
rhs
)
{
return
dmlc
::
HashCombine
(
lhs
,
rhs
);
}
};
}
// namespace tvm
#endif // TVM_IR_ATTR_FUNCTOR_H_
src/ir/attrs.cc
View file @
6536b356
...
...
@@ -74,287 +74,9 @@ TVM_REGISTER_GLOBAL("ir.DictAttrsGetDict")
return
attrs
->
dict
;
});
using
namespace
tir
;
// Equal handler.
bool
AttrsEqualHandler
::
Equal
(
const
ObjectRef
&
lhs
,
const
ObjectRef
&
rhs
)
{
if
(
lhs
.
same_as
(
rhs
))
return
true
;
if
(
!
lhs
.
defined
()
&&
rhs
.
defined
())
return
false
;
if
(
!
rhs
.
defined
()
&&
lhs
.
defined
())
return
false
;
return
this
->
VisitAttr
(
lhs
,
rhs
);
}
bool
AttrsEqualHandler
::
VisitAttrDefault_
(
const
Object
*
lhs
,
const
ObjectRef
&
other
)
{
if
(
lhs
->
IsInstance
<
BaseAttrsNode
>
())
{
AttrsEqual
equal
;
equal
.
handler_
=
this
;
return
static_cast
<
const
BaseAttrsNode
*>
(
lhs
)
->
ContentEqual
(
other
.
get
(),
equal
);
}
return
lhs
==
other
.
get
();
}
bool
AttrsEqualHandler
::
VisitAttr_
(
const
IntImmNode
*
lhs
,
const
ObjectRef
&
other
)
{
if
(
const
auto
*
rhs
=
other
.
as
<
IntImmNode
>
())
{
return
lhs
->
value
==
rhs
->
value
;
}
else
{
return
false
;
}
}
bool
AttrsEqualHandler
::
VisitAttr_
(
const
FloatImmNode
*
lhs
,
const
ObjectRef
&
other
)
{
if
(
const
auto
*
rhs
=
other
.
as
<
FloatImmNode
>
())
{
return
lhs
->
value
==
rhs
->
value
;
}
else
{
return
false
;
}
}
bool
AttrsEqualHandler
::
VisitAttr_
(
const
StringImmNode
*
lhs
,
const
ObjectRef
&
other
)
{
if
(
const
auto
*
rhs
=
other
.
as
<
StringImmNode
>
())
{
return
lhs
->
value
==
rhs
->
value
;
}
else
{
return
false
;
}
}
bool
AttrsEqualHandler
::
VisitAttr_
(
const
ArrayNode
*
lhs
,
const
ObjectRef
&
other
)
{
if
(
const
auto
*
rhs
=
other
.
as
<
ArrayNode
>
())
{
if
(
rhs
->
data
.
size
()
!=
lhs
->
data
.
size
())
return
false
;
for
(
size_t
i
=
0
;
i
<
lhs
->
data
.
size
();
++
i
)
{
if
(
!
Equal
(
lhs
->
data
[
i
],
rhs
->
data
[
i
]))
return
false
;
}
return
true
;
}
else
{
return
false
;
}
}
bool
AttrsEqualHandler
::
VisitAttr_
(
const
StrMapNode
*
lhs
,
const
ObjectRef
&
other
)
{
if
(
const
auto
*
rhs
=
other
.
as
<
StrMapNode
>
())
{
if
(
rhs
->
data
.
size
()
!=
lhs
->
data
.
size
())
return
false
;
for
(
const
auto
&
kv
:
lhs
->
data
)
{
auto
it
=
rhs
->
data
.
find
(
kv
.
first
);
if
(
it
==
rhs
->
data
.
end
())
return
false
;
if
(
!
Equal
(
kv
.
second
,
it
->
second
))
return
false
;
}
return
true
;
}
else
{
return
false
;
}
}
#define TVM_DEFINE_ATTRS_BINOP_EQUAL(NodeName) \
bool AttrsEqualHandler::VisitAttr_(const NodeName* lhs, const ObjectRef& other) { \
if (const auto* rhs = other.as<NodeName>()) { \
if (!Equal(lhs->a, rhs->a)) return false; \
if (!Equal(lhs->b, rhs->b)) return false; \
return true; \
} else { \
return false; \
} \
} \
TVM_DEFINE_ATTRS_BINOP_EQUAL
(
AddNode
);
TVM_DEFINE_ATTRS_BINOP_EQUAL
(
SubNode
);
TVM_DEFINE_ATTRS_BINOP_EQUAL
(
MulNode
);
TVM_DEFINE_ATTRS_BINOP_EQUAL
(
DivNode
);
TVM_DEFINE_ATTRS_BINOP_EQUAL
(
ModNode
);
TVM_DEFINE_ATTRS_BINOP_EQUAL
(
FloorDivNode
);
TVM_DEFINE_ATTRS_BINOP_EQUAL
(
FloorModNode
);
TVM_DEFINE_ATTRS_BINOP_EQUAL
(
MaxNode
);
TVM_DEFINE_ATTRS_BINOP_EQUAL
(
MinNode
);
TVM_DEFINE_ATTRS_BINOP_EQUAL
(
GENode
);
TVM_DEFINE_ATTRS_BINOP_EQUAL
(
GTNode
);
TVM_DEFINE_ATTRS_BINOP_EQUAL
(
LENode
);
TVM_DEFINE_ATTRS_BINOP_EQUAL
(
LTNode
);
TVM_DEFINE_ATTRS_BINOP_EQUAL
(
EQNode
);
TVM_DEFINE_ATTRS_BINOP_EQUAL
(
NENode
);
TVM_DEFINE_ATTRS_BINOP_EQUAL
(
AndNode
);
TVM_DEFINE_ATTRS_BINOP_EQUAL
(
OrNode
);
bool
AttrsEqualHandler
::
VisitAttr_
(
const
NotNode
*
lhs
,
const
ObjectRef
&
other
)
{
if
(
const
auto
*
rhs
=
other
.
as
<
NotNode
>
())
{
return
Equal
(
lhs
->
a
,
rhs
->
a
);
}
else
{
return
false
;
}
}
bool
AttrsEqualHandler
::
VisitAttr_
(
const
CastNode
*
lhs
,
const
ObjectRef
&
other
)
{
if
(
const
auto
*
rhs
=
other
.
as
<
CastNode
>
())
{
if
(
lhs
->
dtype
!=
rhs
->
dtype
)
return
false
;
return
Equal
(
lhs
->
value
,
rhs
->
value
);
}
else
{
return
false
;
}
}
bool
AttrsEqualHandler
::
VisitAttr_
(
const
CallNode
*
lhs
,
const
ObjectRef
&
other
)
{
if
(
const
auto
*
rhs
=
other
.
as
<
CallNode
>
())
{
return
lhs
->
name
==
rhs
->
name
&&
lhs
->
dtype
==
rhs
->
dtype
&&
lhs
->
call_type
==
rhs
->
call_type
&&
Equal
(
lhs
->
args
,
rhs
->
args
);
}
else
{
return
false
;
}
}
bool
AttrsEqualHandler
::
VisitAttr_
(
const
SelectNode
*
lhs
,
const
ObjectRef
&
other
)
{
if
(
const
auto
*
rhs
=
other
.
as
<
SelectNode
>
())
{
return
Equal
(
lhs
->
condition
,
rhs
->
condition
)
&&
Equal
(
lhs
->
true_value
,
rhs
->
true_value
)
&&
Equal
(
lhs
->
false_value
,
rhs
->
false_value
);
}
else
{
return
false
;
}
}
// Hash Handler.
size_t
AttrsHashHandler
::
VisitAttrDefault_
(
const
Object
*
value
)
{
if
(
value
->
IsInstance
<
BaseAttrsNode
>
())
{
AttrsHash
hasher
;
hasher
.
handler_
=
this
;
return
static_cast
<
const
BaseAttrsNode
*>
(
value
)
->
ContentHash
(
hasher
);
}
else
{
return
ObjectHash
()(
GetRef
<
ObjectRef
>
(
value
));
}
}
size_t
AttrsHashHandler
::
VisitAttr_
(
const
IntImmNode
*
op
)
{
return
std
::
hash
<
int64_t
>
()(
op
->
value
);
}
size_t
AttrsHashHandler
::
VisitAttr_
(
const
FloatImmNode
*
op
)
{
return
std
::
hash
<
double
>
()(
op
->
value
);
}
size_t
AttrsHashHandler
::
VisitAttr_
(
const
StringImmNode
*
op
)
{
return
std
::
hash
<
std
::
string
>
()(
op
->
value
);
}
size_t
AttrsHashHandler
::
VisitAttr_
(
const
ArrayNode
*
op
)
{
size_t
result
=
op
->
data
.
size
();
for
(
size_t
i
=
0
;
i
<
op
->
data
.
size
();
++
i
)
{
result
=
Combine
(
result
,
this
->
Hash
(
op
->
data
[
i
]));
}
return
result
;
}
size_t
AttrsHashHandler
::
VisitAttr_
(
const
StrMapNode
*
lhs
)
{
using
Entry
=
std
::
pair
<
std
::
string
,
ObjectRef
>
;
std
::
vector
<
Entry
>
data
(
lhs
->
data
.
begin
(),
lhs
->
data
.
end
());
std
::
sort
(
data
.
begin
(),
data
.
end
(),
[](
const
Entry
&
a
,
const
Entry
&
b
)
{
return
a
.
first
<
b
.
first
;
});
size_t
result
=
0
;
for
(
const
Entry
&
kv
:
data
)
{
result
=
Combine
(
result
,
std
::
hash
<
std
::
string
>
()(
kv
.
first
));
result
=
Combine
(
result
,
this
->
Hash
(
kv
.
second
));
}
return
result
;
}
#define TVM_DEFINE_ATTRS_BINOP_HASH(NodeName) \
size_t AttrsHashHandler::VisitAttr_(const NodeName* op) { \
static size_t key = std::hash<std::string>()(NodeName::_type_key); \
return Combine(key, Combine(Hash(op->a), Hash(op->b))); \
} \
TVM_DEFINE_ATTRS_BINOP_HASH
(
AddNode
);
TVM_DEFINE_ATTRS_BINOP_HASH
(
SubNode
);
TVM_DEFINE_ATTRS_BINOP_HASH
(
MulNode
);
TVM_DEFINE_ATTRS_BINOP_HASH
(
DivNode
);
TVM_DEFINE_ATTRS_BINOP_HASH
(
ModNode
);
TVM_DEFINE_ATTRS_BINOP_HASH
(
FloorDivNode
);
TVM_DEFINE_ATTRS_BINOP_HASH
(
FloorModNode
);
TVM_DEFINE_ATTRS_BINOP_HASH
(
MaxNode
);
TVM_DEFINE_ATTRS_BINOP_HASH
(
MinNode
);
TVM_DEFINE_ATTRS_BINOP_HASH
(
GENode
);
TVM_DEFINE_ATTRS_BINOP_HASH
(
GTNode
);
TVM_DEFINE_ATTRS_BINOP_HASH
(
LENode
);
TVM_DEFINE_ATTRS_BINOP_HASH
(
LTNode
);
TVM_DEFINE_ATTRS_BINOP_HASH
(
EQNode
);
TVM_DEFINE_ATTRS_BINOP_HASH
(
NENode
);
TVM_DEFINE_ATTRS_BINOP_HASH
(
AndNode
);
TVM_DEFINE_ATTRS_BINOP_HASH
(
OrNode
);
size_t
AttrsHashHandler
::
VisitAttr_
(
const
NotNode
*
op
)
{
static
size_t
key
=
std
::
hash
<
std
::
string
>
()(
NotNode
::
_type_key
);
return
Combine
(
key
,
Hash
(
op
->
a
));
}
size_t
AttrsHashHandler
::
VisitAttr_
(
const
CastNode
*
op
)
{
static
size_t
key
=
std
::
hash
<
std
::
string
>
()(
CastNode
::
_type_key
);
AttrsHash
hasher
;
size_t
res
=
key
;
res
=
Combine
(
res
,
hasher
(
op
->
dtype
));
res
=
Combine
(
res
,
Hash
(
op
->
value
));
return
res
;
}
size_t
AttrsHashHandler
::
VisitAttr_
(
const
CallNode
*
op
)
{
static
size_t
key
=
std
::
hash
<
std
::
string
>
()(
CallNode
::
_type_key
);
AttrsHash
hasher
;
size_t
res
=
key
;
res
=
Combine
(
res
,
hasher
(
op
->
name
));
res
=
Combine
(
res
,
hasher
(
op
->
dtype
));
res
=
Combine
(
res
,
Hash
(
op
->
args
));
return
res
;
}
size_t
AttrsHashHandler
::
VisitAttr_
(
const
SelectNode
*
op
)
{
static
size_t
key
=
std
::
hash
<
std
::
string
>
()(
SelectNode
::
_type_key
);
size_t
res
=
key
;
res
=
Combine
(
res
,
Hash
(
op
->
condition
));
res
=
Combine
(
res
,
Hash
(
op
->
true_value
));
res
=
Combine
(
res
,
Hash
(
op
->
false_value
));
return
res
;
}
// Default case
bool
AttrsEqual
::
operator
()(
const
ObjectRef
&
lhs
,
const
ObjectRef
&
rhs
)
const
{
if
(
lhs
.
same_as
(
rhs
))
return
true
;
if
(
handler_
==
nullptr
)
{
return
AttrsEqualHandler
().
Equal
(
lhs
,
rhs
);
}
else
{
return
handler_
->
Equal
(
lhs
,
rhs
);
}
}
size_t
AttrsHash
::
operator
()(
const
ObjectRef
&
node
)
const
{
if
(
!
node
.
defined
())
return
0
;
if
(
handler_
==
nullptr
)
{
return
AttrsHashHandler
().
Hash
(
node
);
}
else
{
return
handler_
->
Hash
(
node
);
}
}
size_t
DictAttrsNode
::
ContentHash
(
AttrsHash
hasher
)
const
{
return
hasher
(
this
->
dict
);
}
bool
DictAttrsNode
::
ContentEqual
(
const
Object
*
other
,
AttrsEqual
equal
)
const
{
if
(
this
==
other
)
return
true
;
if
(
other
==
nullptr
)
return
false
;
if
(
this
->
type_index
()
!=
other
->
type_index
())
return
false
;
return
equal
(
this
->
dict
,
static_cast
<
const
DictAttrsNode
*>
(
other
)
->
dict
);
}
TVM_REGISTER_GLOBAL
(
"ir.AttrsListFieldInfo"
)
.
set_body_typed
([](
Attrs
attrs
)
{
return
attrs
->
ListFieldInfo
();
});
TVM_REGISTER_GLOBAL
(
"ir.AttrsEqual"
)
.
set_body_typed
([](
ObjectRef
lhs
,
ObjectRef
rhs
)
{
return
AttrsEqual
()(
lhs
,
rhs
);
});
}
// namespace tvm
src/node/structural_equal.cc
View file @
6536b356
...
...
@@ -103,6 +103,7 @@ class RemapVarSEqualHandler :
// Function that implements actual equality check.
bool
Equal
(
const
ObjectRef
&
lhs
,
const
ObjectRef
&
rhs
,
bool
map_free_vars
)
{
if
(
!
lhs
.
defined
()
&&
!
rhs
.
defined
())
return
true
;
task_stack_
.
clear
();
pending_tasks_
.
clear
();
equal_map_lhs_
.
clear
();
...
...
src/relay/transforms/combine_parallel_conv2d.cc
View file @
6536b356
...
...
@@ -59,7 +59,7 @@ class ParallelConv2DCombiner : public ParallelOpCombiner {
}
bool
CanOpsBeCombined
(
const
CallNode
*
a
,
const
CallNode
*
b
)
{
Attrs
Equal
eq
;
Structural
Equal
eq
;
const
Layout
kOIHW
(
"OIHW"
);
const
auto
*
attrs_a
=
a
->
attrs
.
as
<
Conv2DAttrs
>
();
const
auto
*
attrs_b
=
b
->
attrs
.
as
<
Conv2DAttrs
>
();
...
...
@@ -112,7 +112,7 @@ class ParallelConv2DCombiner : public ParallelOpCombiner {
}
bool
IsArgCompatible
(
const
CallNode
*
a
,
const
CallNode
*
b
,
size_t
index
)
{
Attrs
Equal
eq
;
Structural
Equal
eq
;
auto
ta
=
a
->
args
[
index
]
->
type_as
<
TensorTypeNode
>
();
auto
tb
=
b
->
args
[
index
]
->
type_as
<
TensorTypeNode
>
();
auto
toutput_a
=
a
->
type_as
<
TensorTypeNode
>
();
...
...
src/relay/transforms/combine_parallel_dense.cc
View file @
6536b356
...
...
@@ -54,7 +54,7 @@ class ParallelDenseCombiner : public ParallelOpBatchCombiner {
protected
:
virtual
bool
CanOpsBeCombined
(
const
CallNode
*
a
,
const
CallNode
*
b
)
{
Attrs
Equal
eq
;
Structural
Equal
eq
;
const
auto
*
attrs_a
=
a
->
attrs
.
as
<
DenseAttrs
>
();
const
auto
*
attrs_b
=
b
->
attrs
.
as
<
DenseAttrs
>
();
CHECK
(
attrs_a
);
...
...
src/relay/transforms/combine_parallel_op.cc
View file @
6536b356
...
...
@@ -23,6 +23,7 @@
* \brief Abstract class to combine parallel ops and their successive element-wise ops.
*/
#include <tvm/node/structural_hash.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/attrs/nn.h>
...
...
@@ -155,7 +156,7 @@ void ParallelOpCombiner::CombineBranches(const Group& branches) {
bool
ParallelOpCombiner
::
CheckLevel
(
const
Group
&
branches
,
size_t
depth
,
size_t
parent_index
)
{
const
CallNode
*
call
=
branches
[
0
][
depth
];
Attrs
Equal
attrs_equal
;
tvm
::
Structural
Equal
attrs_equal
;
// check if all branches in current depth can be combined
for
(
auto
it
=
branches
.
begin
()
+
1
;
it
!=
branches
.
end
();
it
++
)
{
const
Branch
&
branch
=
*
it
;
...
...
src/relay/transforms/combine_parallel_op_batch.cc
View file @
6536b356
...
...
@@ -76,7 +76,7 @@ bool ParallelOpBatchCombiner::CanOpsBeCombined(const CallNode* a, const CallNode
return
false
;
}
Attrs
Equal
eq
;
Structural
Equal
eq
;
for
(
size_t
i
=
0
;
i
<
a
->
args
.
size
();
i
++
)
{
auto
ta
=
a
->
args
[
i
]
->
type_as
<
TensorTypeNode
>
();
auto
tb
=
b
->
args
[
i
]
->
type_as
<
TensorTypeNode
>
();
...
...
@@ -112,7 +112,7 @@ Call ParallelOpBatchCombiner::MakeCombinedOp(const Group& branches) {
}
bool
ParallelOpBatchCombiner
::
IsArgCompatible
(
const
CallNode
*
a
,
const
CallNode
*
b
,
size_t
index
)
{
Attrs
Equal
eq
;
Structural
Equal
eq
;
auto
ta
=
a
->
args
[
index
]
->
type_as
<
TensorTypeNode
>
();
auto
tb
=
b
->
args
[
index
]
->
type_as
<
TensorTypeNode
>
();
...
...
src/relay/transforms/eliminate_common_subexpr.cc
View file @
6536b356
...
...
@@ -45,7 +45,7 @@ class CommonSubexprEliminator : public ExprMutator {
const
CallNode
*
new_call
=
new_expr
.
as
<
CallNode
>
();
CHECK
(
new_call
);
const
OpNode
*
op
=
new_call
->
op
.
as
<
OpNode
>
();
Attrs
Equal
attrs_equal
;
Structural
Equal
attrs_equal
;
if
(
new_call
->
args
.
size
()
==
0
||
op
==
nullptr
||
op_stateful
.
get
(
GetRef
<
Op
>
(
op
),
false
))
{
return
new_expr
;
...
...
src/relay/transforms/fold_scale_axis.cc
View file @
6536b356
...
...
@@ -765,7 +765,7 @@ RELAY_REGISTER_OP("nn.leaky_relu")
Message
AddSubBackwardPrep
(
const
Call
&
call
,
const
Array
<
Message
>&
in_messages
)
{
const
auto
*
tlhs
=
call
->
args
[
0
]
->
type_as
<
TensorTypeNode
>
();
const
auto
*
trhs
=
call
->
args
[
1
]
->
type_as
<
TensorTypeNode
>
();
Attrs
Equal
equal
;
Structural
Equal
equal
;
if
(
in_messages
[
0
].
defined
()
&&
MatchBroadcastToLeftAxes
(
tlhs
,
trhs
,
in_messages
[
0
]
->
axes
))
{
return
in_messages
[
0
];
...
...
@@ -795,7 +795,7 @@ Expr AddSubBackwardTransform(const Call& call,
}
Message
lhs_message
=
transformer
->
GetMessage
(
call
->
args
[
0
]);
Message
rhs_message
=
transformer
->
GetMessage
(
call
->
args
[
1
]);
Attrs
Equal
equal
;
Structural
Equal
equal
;
if
(
lhs_message
.
defined
()
&&
rhs_message
.
defined
())
{
CHECK
(
equal
(
lhs_message
->
axes
,
rhs_message
->
axes
));
...
...
src/relay/transforms/fuse_ops.cc
View file @
6536b356
...
...
@@ -162,7 +162,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
// The output.
IndexedForwardGraph
graph_
;
// attribute equal comparator
Attrs
Equal
attr_equal_
;
Structural
Equal
attr_equal_
;
// Update the message stored at the node.
void
Update
(
const
Expr
&
node
,
IndexedForwardGraph
::
Node
*
parent
,
...
...
src/relay/transforms/pattern_util.h
View file @
6536b356
...
...
@@ -104,7 +104,7 @@ inline bool MatchBroadcastToLeftAxes(const TensorTypeNode* tlhs,
const
Array
<
Integer
>&
lhs_axes
,
Expr
*
rhs_value
=
nullptr
)
{
if
(
tlhs
->
shape
.
size
()
<
trhs
->
shape
.
size
())
return
false
;
Attrs
Equal
equal
;
Structural
Equal
equal
;
size_t
base
=
tlhs
->
shape
.
size
()
-
trhs
->
shape
.
size
();
size_t
j
=
0
;
...
...
src/tir/pass/ffi_api.cc
View file @
6536b356
...
...
@@ -101,18 +101,6 @@ TVM_REGISTER_GLOBAL("ir_pass.RewriteForTensorCore")
return
RewriteForTensorCore
(
stmt
,
schedule
,
extern_buffer
);
});
TVM_REGISTER_GLOBAL
(
"ir_pass.AttrsEqual"
)
.
set_body_typed
(
[](
const
ObjectRef
&
lhs
,
const
ObjectRef
&
rhs
)
{
return
AttrsEqual
()(
lhs
,
rhs
);
});
TVM_REGISTER_GLOBAL
(
"ir_pass.AttrsHash"
)
.
set_body_typed
([](
const
ObjectRef
&
node
)
->
int64_t
{
return
AttrsHash
()(
node
);
});
TVM_REGISTER_GLOBAL
(
"ir_pass.ExprUseVar"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
ExprUseVar
(
args
[
0
].
operator
PrimExpr
(),
args
[
1
].
operator
Var
());
...
...
tests/python/relay/test_ir_nodes.py
View file @
6536b356
...
...
@@ -106,7 +106,6 @@ def test_function():
check_json_roundtrip
(
fn
)
@pytest.mark.skip
(
reason
=
"AttrsEqualHandler doesn't handle Map so far."
)
def
test_function_attrs
():
param_names
=
[
'a'
,
'b'
,
'c'
,
'd'
]
params
=
tvm
.
runtime
.
convert
([
relay
.
var
(
n
,
shape
=
(
5
,
2
))
for
n
in
param_names
])
...
...
tests/python/unittest/test_ir_attrs.py
View file @
6536b356
...
...
@@ -51,14 +51,13 @@ def test_dict_attrs():
def
test_attrs_equal
():
attr_equal
=
tvm
.
ir
.
_ffi_api
.
AttrsEqual
dattr0
=
tvm
.
ir
.
make_node
(
"DictAttrs"
,
x
=
1
,
y
=
[
10
,
20
])
dattr1
=
tvm
.
ir
.
make_node
(
"DictAttrs"
,
y
=
[
10
,
20
],
x
=
1
)
dattr2
=
tvm
.
ir
.
make_node
(
"DictAttrs"
,
x
=
1
,
y
=
None
)
assert
attr
_equal
(
dattr0
,
dattr1
)
assert
not
attr
_equal
(
dattr0
,
dattr2
)
assert
not
attr
_equal
({
"x"
:
1
},
tvm
.
runtime
.
convert
(
1
))
assert
not
attr
_equal
([
1
,
2
],
tvm
.
runtime
.
convert
(
1
))
assert
tvm
.
ir
.
structural
_equal
(
dattr0
,
dattr1
)
assert
not
tvm
.
ir
.
structural
_equal
(
dattr0
,
dattr2
)
assert
not
tvm
.
ir
.
structural
_equal
({
"x"
:
1
},
tvm
.
runtime
.
convert
(
1
))
assert
not
tvm
.
ir
.
structural
_equal
([
1
,
2
],
tvm
.
runtime
.
convert
(
1
))
...
...
tests/python/unittest/test_tir_pass_attrs_hash_equal.py
View file @
6536b356
...
...
@@ -21,28 +21,28 @@ def test_attrs_equal():
x
=
tvm
.
ir
.
make_node
(
"attrs.TestAttrs"
,
name
=
"xx"
,
padding
=
(
3
,
4
))
y
=
tvm
.
ir
.
make_node
(
"attrs.TestAttrs"
,
name
=
"xx"
,
padding
=
(
3
,
4
))
z
=
tvm
.
ir
.
make_node
(
"attrs.TestAttrs"
,
name
=
"xx"
,
padding
=
(
3
,
4
,
1
))
assert
tvm
.
tir
.
ir_pass
.
AttrsE
qual
(
x
,
y
)
assert
not
tvm
.
tir
.
ir_pass
.
AttrsE
qual
(
x
,
z
)
assert
tvm
.
ir
.
structural_e
qual
(
x
,
y
)
assert
not
tvm
.
ir
.
structural_e
qual
(
x
,
z
)
dattr
=
tvm
.
ir
.
make_node
(
"DictAttrs"
,
x
=
1
,
y
=
10
,
name
=
"xyz"
,
padding
=
(
0
,
0
))
assert
not
tvm
.
tir
.
ir_pass
.
AttrsE
qual
(
dattr
,
x
)
assert
not
tvm
.
ir
.
structural_e
qual
(
dattr
,
x
)
dattr2
=
tvm
.
ir
.
make_node
(
"DictAttrs"
,
x
=
1
,
y
=
10
,
name
=
"xyz"
,
padding
=
(
0
,
0
))
assert
tvm
.
tir
.
ir_pass
.
AttrsE
qual
(
dattr
,
dattr2
)
assert
tvm
.
ir
.
structural_e
qual
(
dattr
,
dattr2
)
assert
tvm
.
tir
.
ir_pass
.
AttrsE
qual
({
"x"
:
x
},
{
"x"
:
y
})
assert
tvm
.
ir
.
structural_e
qual
({
"x"
:
x
},
{
"x"
:
y
})
# array related checks
assert
tvm
.
tir
.
ir_pass
.
AttrsE
qual
({
"x"
:
[
x
,
x
]},
{
"x"
:
[
y
,
x
]})
assert
not
tvm
.
tir
.
ir_pass
.
AttrsE
qual
({
"x"
:
[
x
,
1
]},
{
"x"
:
[
y
,
2
]})
assert
tvm
.
ir
.
structural_e
qual
({
"x"
:
[
x
,
x
]},
{
"x"
:
[
y
,
x
]})
assert
not
tvm
.
ir
.
structural_e
qual
({
"x"
:
[
x
,
1
]},
{
"x"
:
[
y
,
2
]})
n
=
te
.
var
(
"n"
)
assert
tvm
.
tir
.
ir_pass
.
AttrsE
qual
({
"x"
:
n
+
1
},
{
"x"
:
n
+
1
})
assert
tvm
.
ir
.
structural_e
qual
({
"x"
:
n
+
1
},
{
"x"
:
n
+
1
})
def
test_attrs_hash
():
fhash
=
tvm
.
tir
.
ir_pass
.
AttrsH
ash
fhash
=
tvm
.
ir
.
structural_h
ash
x
=
tvm
.
ir
.
make_node
(
"attrs.TestAttrs"
,
name
=
"xx"
,
padding
=
(
3
,
4
))
y
=
tvm
.
ir
.
make_node
(
"attrs.TestAttrs"
,
name
=
"xx"
,
padding
=
(
3
,
4
))
assert
fhash
({
"x"
:
x
})
==
fhash
({
"x"
:
y
})
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment