Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
b7a8af8d
Unverified
Commit
b7a8af8d
authored
Oct 15, 2018
by
Tianqi Chen
Committed by
GitHub
Oct 15, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[LANG][ATTRS] Enable deep equality comparison and hash of Attrs (#1903)
parent
b64f3f1c
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
462 additions
and
35 deletions
+462
-35
include/tvm/attrs.h
+153
-3
src/api/api_pass.cc
+9
-0
src/lang/attr_functor.h
+76
-0
src/lang/attrs.cc
+155
-0
tests/python/relay/test_op_level4.py
+0
-32
tests/python/relay/test_op_level5.py
+36
-0
tests/python/unittest/test_pass_attrs_hash_equal.py
+33
-0
No files found.
include/tvm/attrs.h
View file @
b7a8af8d
...
@@ -27,8 +27,10 @@
...
@@ -27,8 +27,10 @@
#ifndef TVM_ATTRS_H_
#ifndef TVM_ATTRS_H_
#define TVM_ATTRS_H_
#define TVM_ATTRS_H_
#include <dmlc/common.h>
#include <unordered_map>
#include <unordered_map>
#include <vector>
#include <vector>
#include <functional>
#include <type_traits>
#include <type_traits>
#include <string>
#include <string>
#include "ir.h"
#include "ir.h"
...
@@ -129,8 +131,8 @@ class BaseAttrsNode : public Node {
...
@@ -129,8 +131,8 @@ class BaseAttrsNode : public Node {
*/
*/
inline
void
PrintDocString
(
std
::
ostream
&
os
)
const
;
// NOLINT(*)
inline
void
PrintDocString
(
std
::
ostream
&
os
)
const
;
// NOLINT(*)
/*!
/*!
* \brief Get the field information
about the
* \brief Get the field information
* \
note This function throws when the required a field is not present
.
* \
return The fields in the Attrs
.
*/
*/
TVM_DLL
virtual
Array
<
AttrFieldInfo
>
ListFieldInfo
()
const
=
0
;
TVM_DLL
virtual
Array
<
AttrFieldInfo
>
ListFieldInfo
()
const
=
0
;
/*!
/*!
...
@@ -138,9 +140,20 @@ class BaseAttrsNode : public Node {
...
@@ -138,9 +140,20 @@ class BaseAttrsNode : public Node {
* \param kwargs The key value pairs for initialization.
* \param kwargs The key value pairs for initialization.
* [key0, value0, key1, value1, ..., key_n, value_n]
* [key0, value0, key1, value1, ..., key_n, value_n]
* \param allow_unknown Whether allow additional unknown fields.
* \param allow_unknown Whether allow additional unknown fields.
* \note This function throws when the required
a
field is not present.
* \note This function throws when the required field is not present.
*/
*/
TVM_DLL
virtual
void
InitByPackedArgs
(
const
TVMArgs
&
kwargs
,
bool
allow_unknown
=
false
)
=
0
;
TVM_DLL
virtual
void
InitByPackedArgs
(
const
TVMArgs
&
kwargs
,
bool
allow_unknown
=
false
)
=
0
;
/*!
* \brief Whether this attribute's content equals to another node.
* \param other The pointer to another node.
* \return The comparison result.
*/
TVM_DLL
virtual
bool
ContentEqual
(
const
Node
*
other
)
const
=
0
;
/*!
* \brief Content aware hash.
* \return the hash result.
*/
TVM_DLL
virtual
size_t
ContentHash
()
const
=
0
;
static
constexpr
const
char
*
_type_key
=
"Attrs"
;
static
constexpr
const
char
*
_type_key
=
"Attrs"
;
TVM_DECLARE_BASE_NODE_INFO
(
BaseAttrsNode
,
Node
);
TVM_DECLARE_BASE_NODE_INFO
(
BaseAttrsNode
,
Node
);
...
@@ -188,11 +201,93 @@ class DictAttrsNode : public BaseAttrsNode {
...
@@ -188,11 +201,93 @@ class DictAttrsNode : public BaseAttrsNode {
void
VisitAttrs
(
AttrVisitor
*
v
)
final
;
void
VisitAttrs
(
AttrVisitor
*
v
)
final
;
void
InitByPackedArgs
(
const
runtime
::
TVMArgs
&
args
,
bool
allow_unknown
)
final
;
void
InitByPackedArgs
(
const
runtime
::
TVMArgs
&
args
,
bool
allow_unknown
)
final
;
Array
<
AttrFieldInfo
>
ListFieldInfo
()
const
final
;
Array
<
AttrFieldInfo
>
ListFieldInfo
()
const
final
;
bool
ContentEqual
(
const
Node
*
other
)
const
final
;
size_t
ContentHash
()
const
final
;
// type info
// type info
static
constexpr
const
char
*
_type_key
=
"DictAttrs"
;
static
constexpr
const
char
*
_type_key
=
"DictAttrs"
;
TVM_DECLARE_NODE_TYPE_INFO
(
DictAttrsNode
,
BaseAttrsNode
);
TVM_DECLARE_NODE_TYPE_INFO
(
DictAttrsNode
,
BaseAttrsNode
);
};
};
/*!
* \brief Content-aware Equality comparator for attrs.
*
* This comparator will recursively deep compare the following Attributes.
*
* - IntImm, UIntImm, FloatImm, StringImm
* - Any subclass of BaseAttrsNode
* - Array of Attributes.
* - Map from string to Attributes.
*/
class
AttrsEqual
{
public
:
bool
operator
()(
const
double
&
lhs
,
const
double
&
rhs
)
const
{
return
lhs
==
rhs
;
}
bool
operator
()(
const
int64_t
&
lhs
,
const
int64_t
&
rhs
)
const
{
return
lhs
==
rhs
;
}
bool
operator
()(
const
uint64_t
&
lhs
,
const
uint64_t
&
rhs
)
const
{
return
lhs
==
rhs
;
}
bool
operator
()(
const
int
&
lhs
,
const
int
&
rhs
)
const
{
return
lhs
==
rhs
;
}
bool
operator
()(
const
bool
&
lhs
,
const
bool
&
rhs
)
const
{
return
lhs
==
rhs
;
}
bool
operator
()(
const
std
::
string
&
lhs
,
const
std
::
string
&
rhs
)
const
{
return
lhs
==
rhs
;
}
bool
operator
()(
const
Type
&
lhs
,
const
Type
&
rhs
)
const
{
return
lhs
==
rhs
;
}
bool
operator
()(
const
NodeRef
&
lhs
,
const
NodeRef
&
rhs
)
const
{
return
AttrsEqual
::
Equal
(
lhs
,
rhs
);
}
// comparator of NodeRef types.
static
TVM_DLL
bool
Equal
(
const
NodeRef
&
lhs
,
const
NodeRef
&
rhs
);
};
/*!
* \brief Content-aware hash function.
*
* This hash functor will recursively hash the content of the Attributes.
* It is guaranteed that if AttrsEqual(a, b) == true, then AttrsHash(a) == AttrsHash(b);
*/
class
AttrsHash
{
public
:
size_t
operator
()(
const
double
&
value
)
const
{
return
std
::
hash
<
double
>
()(
value
);
}
size_t
operator
()(
const
int64_t
&
value
)
const
{
return
std
::
hash
<
int64_t
>
()(
value
);
}
size_t
operator
()(
const
uint64_t
&
value
)
const
{
return
std
::
hash
<
uint64_t
>
()(
value
);
}
size_t
operator
()(
const
int
&
value
)
const
{
return
std
::
hash
<
int
>
()(
value
);
}
size_t
operator
()(
const
bool
&
value
)
const
{
return
std
::
hash
<
bool
>
()(
value
);
}
size_t
operator
()(
const
std
::
string
&
value
)
const
{
return
std
::
hash
<
std
::
string
>
()(
value
);
}
size_t
operator
()(
const
Type
&
value
)
const
{
return
std
::
hash
<
int
>
()(
static_cast
<
int
>
(
value
.
code
())
|
(
static_cast
<
int
>
(
value
.
bits
())
<<
8
)
|
(
static_cast
<
int
>
(
value
.
lanes
())
<<
16
));
}
size_t
operator
()(
const
NodeRef
&
value
)
const
{
return
AttrsHash
::
Hash
(
value
);
}
// hash function of the attribute and attribute fields.
static
TVM_DLL
size_t
Hash
(
const
NodeRef
&
lhs
);
};
// Namespace containing detail implementations
// Namespace containing detail implementations
namespace
detail
{
namespace
detail
{
using
runtime
::
TVMArgValue
;
using
runtime
::
TVMArgValue
;
...
@@ -234,6 +329,44 @@ class AttrNormalVisitor {
...
@@ -234,6 +329,44 @@ class AttrNormalVisitor {
AttrVisitor
*
visitor_
;
AttrVisitor
*
visitor_
;
};
};
// Wrapper for normal visitor.
class
AttrsEqualVisitor
{
public
:
bool
result_
{
true
};
// constructor
AttrsEqualVisitor
(
const
Node
*
lhs
,
const
Node
*
rhs
)
:
lhs_
(
lhs
),
rhs_
(
rhs
)
{
}
template
<
typename
T
>
AttrNopEntry
operator
()(
const
char
*
key
,
T
*
lhs_value
)
{
if
(
!
result_
)
return
AttrNopEntry
();
const
T
*
rhs_value
=
reinterpret_cast
<
const
T
*>
(
reinterpret_cast
<
const
char
*>
(
rhs_
)
+
(
reinterpret_cast
<
const
char
*>
(
lhs_value
)
-
reinterpret_cast
<
const
char
*>
(
lhs_
)));
if
(
!
AttrsEqual
()(
*
lhs_value
,
*
rhs_value
))
{
result_
=
false
;
}
return
AttrNopEntry
();
}
private
:
const
Node
*
lhs_
;
const
Node
*
rhs_
;
};
class
AttrsHashVisitor
{
public
:
size_t
result_
{
0
};
template
<
typename
T
>
AttrNopEntry
operator
()(
const
char
*
key
,
T
*
value
)
{
result_
=
dmlc
::
HashCombine
(
result_
,
AttrsHash
()(
*
value
));
return
AttrNopEntry
();
}
};
// helper entry that does initialization, set default.
// helper entry that does initialization, set default.
template
<
typename
T
>
template
<
typename
T
>
struct
AttrInitEntry
{
struct
AttrInitEntry
{
...
@@ -596,6 +729,23 @@ class AttrsNode : public BaseAttrsNode {
...
@@ -596,6 +729,23 @@ class AttrsNode : public BaseAttrsNode {
return
visitor
.
fields_
;
return
visitor
.
fields_
;
}
}
bool
ContentEqual
(
const
Node
*
other
)
const
final
{
DerivedType
*
pself
=
self
();
if
(
pself
==
other
)
return
true
;
if
(
other
==
nullptr
)
return
false
;
if
(
pself
->
type_index
()
!=
other
->
type_index
())
return
false
;
detail
::
AttrsEqualVisitor
visitor
(
pself
,
other
);
self
()
->
__VisitAttrs__
(
visitor
);
return
visitor
.
result_
;
}
size_t
ContentHash
()
const
final
{
detail
::
AttrsHashVisitor
visitor
;
visitor
.
result_
=
std
::
hash
<
std
::
string
>
()(
this
->
type_key
());
self
()
->
__VisitAttrs__
(
visitor
);
return
visitor
.
result_
;
}
private
:
private
:
DerivedType
*
self
()
const
{
DerivedType
*
self
()
const
{
return
const_cast
<
DerivedType
*>
(
return
const_cast
<
DerivedType
*>
(
...
...
src/api/api_pass.cc
View file @
b7a8af8d
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
*/
*/
#include <tvm/expr.h>
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <tvm/ir.h>
#include <tvm/attrs.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_mutator.h>
...
@@ -65,6 +66,14 @@ TVM_REGISTER_API("ir_pass.Equal")
...
@@ -65,6 +66,14 @@ TVM_REGISTER_API("ir_pass.Equal")
}
}
});
});
TVM_REGISTER_API
(
"ir_pass.AttrsEqual"
)
.
set_body_typed
<
bool
(
const
NodeRef
&
,
const
NodeRef
&
)
>
(
AttrsEqual
::
Equal
);
TVM_REGISTER_API
(
"ir_pass.AttrsHash"
)
.
set_body_typed
<
int64_t
(
const
NodeRef
&
)
>
(
AttrsHash
::
Hash
);
TVM_REGISTER_API
(
"ir_pass.ExprUseVar"
)
TVM_REGISTER_API
(
"ir_pass.ExprUseVar"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
ExprUseVar
(
args
[
0
].
operator
Expr
(),
args
[
1
].
operator
Var
());
*
ret
=
ExprUseVar
(
args
[
0
].
operator
Expr
(),
args
[
1
].
operator
Var
());
...
...
src/lang/attr_functor.h
0 → 100644
View file @
b7a8af8d
/*!
* Copyright (c) 2018 by Contributors
* \file attr_functor.h
* \brief A way to define arbitrary function signature
* with dispatch on common attributes.
*
* Common attributes include:
* - int, float, str constants
* - array of attributes
* - map of attributes
*/
#ifndef TVM_LANG_ATTR_FUNCTOR_H_
#define TVM_LANG_ATTR_FUNCTOR_H_
namespace
tvm
{
template
<
typename
FType
>
class
AttrFunctor
;
#define ATTR_FUNCTOR_DISPATCH(OP) \
vtable.template set_dispatch<OP>( \
[](const NodeRef& n, TSelf* self, Args... args) { \
return self->Visit_(static_cast<const OP*>(n.node_.get()), \
std::forward<Args>(args)...); \
}); \
// A functor for common attribute information.
template
<
typename
R
,
typename
...
Args
>
class
AttrFunctor
<
R
(
const
NodeRef
&
n
,
Args
...)
>
{
private
:
using
TSelf
=
AttrFunctor
<
R
(
const
NodeRef
&
n
,
Args
...)
>
;
using
FType
=
tvm
::
IRFunctor
<
R
(
const
NodeRef
&
n
,
TSelf
*
self
,
Args
...)
>
;
public
:
/*! \brief the result type of this functor */
using
result_type
=
R
;
/*!
* \brief The functor call.
* \param n The expression node.
* \param args Additional arguments.
* \return The result of the call
*/
virtual
R
Visit
(
const
NodeRef
&
n
,
Args
...
args
)
{
static
FType
vtable
=
InitVTable
();
if
(
vtable
.
can_dispatch
(
n
))
{
return
vtable
(
n
,
this
,
std
::
forward
<
Args
>
(
args
)...);
}
else
{
return
VisitDefault_
(
n
,
std
::
forward
<
Args
>
(
args
)...);
}
}
virtual
R
Visit_
(
const
ArrayNode
*
op
,
Args
...
args
)
=
0
;
virtual
R
Visit_
(
const
StrMapNode
*
op
,
Args
...
args
)
=
0
;
virtual
R
Visit_
(
const
ir
::
IntImm
*
op
,
Args
...
args
)
=
0
;
virtual
R
Visit_
(
const
ir
::
UIntImm
*
op
,
Args
...
args
)
=
0
;
virtual
R
Visit_
(
const
ir
::
FloatImm
*
op
,
Args
...
args
)
=
0
;
virtual
R
Visit_
(
const
ir
::
StringImm
*
op
,
Args
...
args
)
=
0
;
virtual
R
VisitDefault_
(
const
NodeRef
&
n
,
Args
...
args
)
=
0
;
private
:
// initialize the vtable.
static
FType
InitVTable
()
{
using
namespace
ir
;
FType
vtable
;
// Set dispatch
ATTR_FUNCTOR_DISPATCH
(
StrMapNode
);
ATTR_FUNCTOR_DISPATCH
(
ArrayNode
);
ATTR_FUNCTOR_DISPATCH
(
IntImm
);
ATTR_FUNCTOR_DISPATCH
(
UIntImm
);
ATTR_FUNCTOR_DISPATCH
(
FloatImm
);
ATTR_FUNCTOR_DISPATCH
(
StringImm
);
return
vtable
;
}
};
}
// namespace tvm
#endif // TVM_LANG_ATTR_FUNCTOR_H_
src/lang/attrs.cc
View file @
b7a8af8d
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
* \file attrs.cc
* \file attrs.cc
*/
*/
#include <tvm/attrs.h>
#include <tvm/attrs.h>
#include "attr_functor.h"
namespace
tvm
{
namespace
tvm
{
...
@@ -44,4 +45,158 @@ TVM_REGISTER_NODE_TYPE(DictAttrsNode);
...
@@ -44,4 +45,158 @@ TVM_REGISTER_NODE_TYPE(DictAttrsNode);
TVM_REGISTER_NODE_TYPE
(
AttrFieldInfoNode
);
TVM_REGISTER_NODE_TYPE
(
AttrFieldInfoNode
);
using
namespace
ir
;
class
AttrsEqualChecker
:
public
AttrFunctor
<
bool
(
const
NodeRef
&
,
const
NodeRef
&
)
>
{
public
:
bool
Check
(
const
NodeRef
&
lhs
,
const
NodeRef
&
rhs
)
{
if
(
!
equal_
)
return
false
;
if
(
lhs
.
same_as
(
rhs
))
return
true
;
if
(
!
lhs
.
defined
()
||
!
rhs
.
defined
())
return
false
;
if
(
!
this
->
Visit
(
lhs
,
rhs
))
{
equal_
=
false
;
}
return
equal_
;
}
bool
VisitDefault_
(
const
NodeRef
&
lhs
,
const
NodeRef
&
other
)
final
{
if
(
lhs
->
derived_from
<
BaseAttrsNode
>
())
{
return
static_cast
<
const
BaseAttrsNode
*>
(
lhs
.
get
())
->
ContentEqual
(
other
.
get
());
}
return
lhs
.
same_as
(
other
);
}
bool
Visit_
(
const
IntImm
*
lhs
,
const
NodeRef
&
other
)
final
{
if
(
const
auto
*
rhs
=
other
.
as
<
IntImm
>
())
{
return
lhs
->
value
==
rhs
->
value
;
}
return
false
;
}
bool
Visit_
(
const
UIntImm
*
lhs
,
const
NodeRef
&
other
)
final
{
if
(
const
auto
*
rhs
=
other
.
as
<
UIntImm
>
())
{
return
lhs
->
value
==
rhs
->
value
;
}
return
false
;
}
bool
Visit_
(
const
FloatImm
*
lhs
,
const
NodeRef
&
other
)
final
{
if
(
const
auto
*
rhs
=
other
.
as
<
FloatImm
>
())
{
return
lhs
->
value
==
rhs
->
value
;
}
return
false
;
}
bool
Visit_
(
const
StringImm
*
lhs
,
const
NodeRef
&
other
)
final
{
if
(
const
auto
*
rhs
=
other
.
as
<
StringImm
>
())
{
return
lhs
->
value
==
rhs
->
value
;
}
return
false
;
}
bool
Visit_
(
const
ArrayNode
*
lhs
,
const
NodeRef
&
other
)
final
{
if
(
const
auto
*
rhs
=
other
.
as
<
ArrayNode
>
())
{
if
(
rhs
->
data
.
size
()
!=
lhs
->
data
.
size
())
return
false
;
for
(
size_t
i
=
0
;
i
<
lhs
->
data
.
size
();
++
i
)
{
if
(
!
Check
(
NodeRef
(
lhs
->
data
[
i
]),
NodeRef
(
rhs
->
data
[
i
])))
return
false
;
}
}
return
true
;
}
bool
Visit_
(
const
StrMapNode
*
lhs
,
const
NodeRef
&
other
)
final
{
if
(
const
auto
*
rhs
=
other
.
as
<
StrMapNode
>
())
{
if
(
rhs
->
data
.
size
()
!=
lhs
->
data
.
size
())
return
false
;
for
(
const
auto
&
kv
:
lhs
->
data
)
{
auto
it
=
rhs
->
data
.
find
(
kv
.
first
);
if
(
it
==
rhs
->
data
.
end
())
return
false
;
if
(
!
Check
(
NodeRef
(
kv
.
second
),
NodeRef
(
it
->
second
)))
return
false
;
}
}
return
true
;
}
private
:
bool
equal_
{
true
};
};
class
AttrContentHasher
:
public
AttrFunctor
<
void
(
const
NodeRef
&
)
>
{
public
:
size_t
result_
{
0
};
void
VisitDefault_
(
const
NodeRef
&
value
)
final
{
if
(
value
->
derived_from
<
BaseAttrsNode
>
())
{
Update
(
static_cast
<
const
BaseAttrsNode
*>
(
value
.
get
())
->
ContentHash
());
}
else
{
Update
(
NodeHash
()(
value
));
}
}
void
Visit_
(
const
IntImm
*
op
)
final
{
Update
(
std
::
hash
<
int64_t
>
()(
op
->
value
));
}
void
Visit_
(
const
UIntImm
*
op
)
final
{
Update
(
std
::
hash
<
uint64_t
>
()(
op
->
value
));
}
void
Visit_
(
const
FloatImm
*
op
)
final
{
Update
(
std
::
hash
<
double
>
()(
op
->
value
));
}
void
Visit_
(
const
StringImm
*
op
)
final
{
Update
(
std
::
hash
<
std
::
string
>
()(
op
->
value
));
}
void
Visit_
(
const
ArrayNode
*
op
)
final
{
Update
(
op
->
data
.
size
());
for
(
size_t
i
=
0
;
i
<
op
->
data
.
size
();
++
i
)
{
this
->
Visit
(
NodeRef
(
op
->
data
[
i
]));
}
}
void
Visit_
(
const
StrMapNode
*
lhs
)
final
{
using
Entry
=
std
::
pair
<
std
::
string
,
NodePtr
<
Node
>
>
;
std
::
vector
<
Entry
>
data
(
lhs
->
data
.
begin
(),
lhs
->
data
.
end
());
std
::
sort
(
data
.
begin
(),
data
.
end
(),
[](
const
Entry
&
a
,
const
Entry
&
b
)
{
return
a
.
first
<
b
.
first
;
});
for
(
const
Entry
&
kv
:
data
)
{
Update
(
std
::
hash
<
std
::
string
>
()(
kv
.
first
));
this
->
Visit
(
NodeRef
(
kv
.
second
));
}
}
void
Update
(
size_t
value
)
{
result_
=
dmlc
::
HashCombine
(
result_
,
value
);
}
};
bool
AttrsEqual
::
Equal
(
const
NodeRef
&
lhs
,
const
NodeRef
&
rhs
)
{
if
(
lhs
.
same_as
(
rhs
))
return
true
;
AttrsEqualChecker
checker
;
return
checker
.
Check
(
lhs
,
rhs
);
}
size_t
AttrsHash
::
Hash
(
const
NodeRef
&
node
)
{
if
(
!
node
.
defined
())
return
0
;
AttrContentHasher
hasher
;
hasher
.
Visit
(
node
);
return
hasher
.
result_
;
}
size_t
DictAttrsNode
::
ContentHash
()
const
{
return
AttrsHash
()(
this
->
dict
);
}
bool
DictAttrsNode
::
ContentEqual
(
const
Node
*
other
)
const
{
if
(
this
==
other
)
return
true
;
if
(
other
==
nullptr
)
return
false
;
if
(
this
->
type_index
()
!=
other
->
type_index
())
return
false
;
return
AttrsEqual
()(
this
->
dict
,
static_cast
<
const
DictAttrsNode
*>
(
other
)
->
dict
);
}
}
// namespace tvm
}
// namespace tvm
tests/python/relay/test_op_level4.py
View file @
b7a8af8d
...
@@ -124,38 +124,6 @@ def test_binary_broadcast():
...
@@ -124,38 +124,6 @@ def test_binary_broadcast():
ftype
=
func
.
checked_type
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
TensorType
((
5
,
10
,
4
),
"int32"
)
assert
ftype
.
ret_type
==
relay
.
TensorType
((
5
,
10
,
4
),
"int32"
)
def
test_multibox_prior
():
sizes
=
(
0.3
,
1.5
,
0.7
)
ratios
=
(
1.3
,
2.4
)
steps
=
(
2.0
,
1.5
)
offsets
=
(
0.2
,
0.3
)
clip
=
True
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
3
,
56
,
56
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
relay
.
vision
.
multibox_prior
(
x
.
var
,
sizes
,
ratios
,
steps
,
offsets
,
clip
))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
(
(
1
,
h
*
w
*
(
len
(
sizes
)
+
len
(
ratios
)
-
1
),
4
),
"float32"
)
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
24
,
32
,
32
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
relay
.
vision
.
multibox_prior
(
x
.
var
))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
(
(
1
,
h
*
w
,
4
),
"float32"
)
def
test_where
():
def
test_where
():
ib
=
relay
.
ir_builder
.
IRBuilder
()
ib
=
relay
.
ir_builder
.
IRBuilder
()
cond
=
ib
.
param
(
"cond"
,
relay
.
TensorType
((
3
,
4
),
"float32"
))
cond
=
ib
.
param
(
"cond"
,
relay
.
TensorType
((
3
,
4
),
"float32"
))
...
...
tests/python/relay/test_op_level5.py
View file @
b7a8af8d
...
@@ -25,5 +25,41 @@ def test_resize_infer_type():
...
@@ -25,5 +25,41 @@ def test_resize_infer_type():
ftype
=
func
.
checked_type
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
((
n
,
c
,
100
,
200
),
"int8"
)
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
((
n
,
c
,
100
,
200
),
"int8"
)
def
test_multibox_prior
():
sizes
=
(
0.3
,
1.5
,
0.7
)
ratios
=
(
1.3
,
2.4
)
steps
=
(
2.0
,
1.5
)
offsets
=
(
0.2
,
0.3
)
clip
=
True
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
3
,
56
,
56
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
relay
.
vision
.
multibox_prior
(
x
,
sizes
,
ratios
,
steps
,
offsets
,
clip
))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
(
(
1
,
h
*
w
*
(
len
(
sizes
)
+
len
(
ratios
)
-
1
),
4
),
"float32"
)
ib
=
relay
.
ir_builder
.
IRBuilder
()
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
24
,
32
,
32
x
=
ib
.
param
(
"x"
,
relay
.
ty
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
with
ib
.
function
(
x
)
as
func
:
ib
.
ret
(
relay
.
vision
.
multibox_prior
(
x
))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
ty
.
TensorType
(
(
1
,
h
*
w
,
4
),
"float32"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_resize_infer_type
()
test_resize_infer_type
()
test_multibox_prior
()
tests/python/unittest/test_pass_attrs_hash_equal.py
0 → 100644
View file @
b7a8af8d
import
tvm
def
test_attrs_equal
():
x
=
tvm
.
make
.
node
(
"attrs.TestAttrs"
,
name
=
"xx"
,
padding
=
(
3
,
4
))
y
=
tvm
.
make
.
node
(
"attrs.TestAttrs"
,
name
=
"xx"
,
padding
=
(
3
,
4
))
z
=
tvm
.
make
.
node
(
"attrs.TestAttrs"
,
name
=
"xx"
,
padding
=
(
3
,
4
,
1
))
assert
tvm
.
ir_pass
.
AttrsEqual
(
x
,
y
)
assert
not
tvm
.
ir_pass
.
AttrsEqual
(
x
,
z
)
dattr
=
tvm
.
make
.
node
(
"DictAttrs"
,
x
=
1
,
y
=
10
,
name
=
"xyz"
,
padding
=
(
0
,
0
))
assert
not
tvm
.
ir_pass
.
AttrsEqual
(
dattr
,
x
)
dattr2
=
tvm
.
make
.
node
(
"DictAttrs"
,
x
=
1
,
y
=
10
,
name
=
"xyz"
,
padding
=
(
0
,
0
))
assert
tvm
.
ir_pass
.
AttrsEqual
(
dattr
,
dattr2
)
assert
tvm
.
ir_pass
.
AttrsEqual
({
"x"
:
x
},
{
"x"
:
y
})
# array related checks
assert
tvm
.
ir_pass
.
AttrsEqual
({
"x"
:
[
x
,
x
]},
{
"x"
:
[
y
,
x
]})
assert
not
tvm
.
ir_pass
.
AttrsEqual
({
"x"
:
[
x
,
1
]},
{
"x"
:
[
y
,
2
]})
def
test_attrs_hash
():
fhash
=
tvm
.
ir_pass
.
AttrsHash
x
=
tvm
.
make
.
node
(
"attrs.TestAttrs"
,
name
=
"xx"
,
padding
=
(
3
,
4
))
y
=
tvm
.
make
.
node
(
"attrs.TestAttrs"
,
name
=
"xx"
,
padding
=
(
3
,
4
))
assert
fhash
({
"x"
:
x
})
==
fhash
({
"x"
:
y
})
assert
fhash
({
"x"
:
x
})
!=
fhash
({
"x"
:
[
y
,
1
]})
assert
fhash
({
"x"
:
[
x
,
1
]})
==
fhash
({
"x"
:
[
y
,
1
]})
assert
fhash
({
"x"
:
[
x
,
2
]})
==
fhash
({
"x"
:
[
y
,
2
]})
if
__name__
==
"__main__"
:
test_attrs_equal
()
test_attrs_hash
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment