Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
0858c5ad
Commit
0858c5ad
authored
Jul 26, 2019
by
Lianmin Zheng
Committed by
Tianqi Chen
Jul 25, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[IR] Make iterators compatible with constructors of STL containers (#3624)
parent
97e333ca
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
44 additions
and
27 deletions
+44
-27
include/tvm/node/container.h
+16
-6
src/relay/op/op_common.h
+0
-10
src/relay/op/tensor/reduce.cc
+2
-2
src/relay/op/tensor/transform.cc
+8
-8
src/relay/op/vision/yolo.cc
+1
-1
tests/cpp/container_test.cc
+17
-0
No files found.
include/tvm/node/container.h
View file @
0858c5ad
...
@@ -110,18 +110,28 @@ template<typename Converter,
...
@@ -110,18 +110,28 @@ template<typename Converter,
typename
TIter
>
typename
TIter
>
class
IterAdapter
{
class
IterAdapter
{
public
:
public
:
using
difference_type
=
typename
std
::
iterator_traits
<
TIter
>::
difference_type
;
using
value_type
=
typename
std
::
iterator_traits
<
TIter
>::
value_type
;
using
pointer
=
typename
std
::
iterator_traits
<
TIter
>::
pointer
;
using
reference
=
typename
std
::
iterator_traits
<
TIter
>::
reference
;
using
iterator_category
=
typename
std
::
iterator_traits
<
TIter
>::
iterator_category
;
explicit
IterAdapter
(
TIter
iter
)
:
iter_
(
iter
)
{}
explicit
IterAdapter
(
TIter
iter
)
:
iter_
(
iter
)
{}
inline
IterAdapter
&
operator
++
()
{
// NOLINT(*)
inline
IterAdapter
&
operator
++
()
{
++
iter_
;
return
*
this
;
}
inline
IterAdapter
&
operator
++
(
int
)
{
// NOLINT(*)
++
iter_
;
++
iter_
;
return
*
this
;
return
*
this
;
}
}
inline
IterAdapter
operator
+
(
int
offset
)
const
{
// NOLINT(*)
inline
IterAdapter
operator
+
(
difference_type
offset
)
const
{
return
IterAdapter
(
iter_
+
offset
);
return
IterAdapter
(
iter_
+
offset
);
}
}
template
<
typename
T
=
IterAdapter
>
typename
std
::
enable_if
<
std
::
is_same
<
iterator_category
,
std
::
random_access_iterator_tag
>::
value
,
typename
T
::
difference_type
>::
type
inline
operator
-
(
const
IterAdapter
&
rhs
)
const
{
return
iter_
-
rhs
.
iter_
;
}
inline
bool
operator
==
(
IterAdapter
other
)
const
{
inline
bool
operator
==
(
IterAdapter
other
)
const
{
return
iter_
==
other
.
iter_
;
return
iter_
==
other
.
iter_
;
}
}
...
...
src/relay/op/op_common.h
View file @
0858c5ad
...
@@ -35,16 +35,6 @@
...
@@ -35,16 +35,6 @@
namespace
tvm
{
namespace
tvm
{
namespace
relay
{
namespace
relay
{
template
<
typename
T
>
inline
std
::
vector
<
T
>
AsVector
(
const
Array
<
T
>
&
array
)
{
std
::
vector
<
T
>
result
;
result
.
reserve
(
array
.
size
());
for
(
const
T
&
ele
:
array
)
{
result
.
push_back
(
ele
);
}
return
result
;
}
/*! Quick helper macro
/*! Quick helper macro
* - Expose a positional make function to construct the node.
* - Expose a positional make function to construct the node.
* - Register op to the registry.
* - Register op to the registry.
...
...
src/relay/op/tensor/reduce.cc
View file @
0858c5ad
...
@@ -229,7 +229,7 @@ bool ArgReduceRel(const Array<Type>& types,
...
@@ -229,7 +229,7 @@ bool ArgReduceRel(const Array<Type>& types,
const
auto
*
data
=
types
[
0
].
as
<
TensorTypeNode
>
();
const
auto
*
data
=
types
[
0
].
as
<
TensorTypeNode
>
();
if
(
data
==
nullptr
)
return
false
;
if
(
data
==
nullptr
)
return
false
;
CHECK
(
static_cast
<
int
>
(
data
->
shape
.
size
())
!=
0
);
CHECK
(
static_cast
<
int
>
(
data
->
shape
.
size
())
!=
0
);
std
::
vector
<
IndexExpr
>
&&
in_shape
=
AsVector
(
data
->
shape
);
std
::
vector
<
IndexExpr
>
in_shape
(
data
->
shape
.
begin
(),
data
->
shape
.
end
()
);
const
ReduceAttrs
*
param
=
attrs
.
as
<
ReduceAttrs
>
();
const
ReduceAttrs
*
param
=
attrs
.
as
<
ReduceAttrs
>
();
CHECK
(
param
!=
nullptr
);
CHECK
(
param
!=
nullptr
);
...
@@ -254,7 +254,7 @@ bool ReduceRel(const Array<Type>& types,
...
@@ -254,7 +254,7 @@ bool ReduceRel(const Array<Type>& types,
CHECK_EQ
(
types
.
size
(),
2
);
CHECK_EQ
(
types
.
size
(),
2
);
const
auto
*
data
=
types
[
0
].
as
<
TensorTypeNode
>
();
const
auto
*
data
=
types
[
0
].
as
<
TensorTypeNode
>
();
if
(
data
==
nullptr
)
return
false
;
if
(
data
==
nullptr
)
return
false
;
std
::
vector
<
IndexExpr
>
&&
in_shape
=
AsVector
(
data
->
shape
);
std
::
vector
<
IndexExpr
>
in_shape
(
data
->
shape
.
begin
(),
data
->
shape
.
end
()
);
const
ReduceAttrs
*
param
=
attrs
.
as
<
ReduceAttrs
>
();
const
ReduceAttrs
*
param
=
attrs
.
as
<
ReduceAttrs
>
();
CHECK
(
param
!=
nullptr
);
CHECK
(
param
!=
nullptr
);
...
...
src/relay/op/tensor/transform.cc
View file @
0858c5ad
...
@@ -265,7 +265,7 @@ bool ConcatenateRel(const Array<Type>& types,
...
@@ -265,7 +265,7 @@ bool ConcatenateRel(const Array<Type>& types,
}
}
axis
=
axis
<
0
?
ndim
+
axis
:
axis
;
axis
=
axis
<
0
?
ndim
+
axis
:
axis
;
// Calculate shape
// Calculate shape
std
::
vector
<
IndexExpr
>
&&
oshape
=
AsVector
(
first
->
shape
);
std
::
vector
<
IndexExpr
>
oshape
(
first
->
shape
.
begin
(),
first
->
shape
.
end
()
);
IndexExpr
&
concat_dim
=
oshape
[
axis
];
IndexExpr
&
concat_dim
=
oshape
[
axis
];
bool
has_any
=
false
;
bool
has_any
=
false
;
if
(
concat_dim
.
as
<
Any
>
())
{
if
(
concat_dim
.
as
<
Any
>
())
{
...
@@ -834,7 +834,7 @@ bool TakeRel(const Array<Type>& types,
...
@@ -834,7 +834,7 @@ bool TakeRel(const Array<Type>& types,
CHECK
(
param
!=
nullptr
);
CHECK
(
param
!=
nullptr
);
if
(
!
param
->
axis
.
defined
())
{
if
(
!
param
->
axis
.
defined
())
{
std
::
vector
<
IndexExpr
>
&&
oshape
=
AsVector
(
indices
->
shape
);
std
::
vector
<
IndexExpr
>
oshape
(
indices
->
shape
.
begin
(),
indices
->
shape
.
end
()
);
reporter
->
Assign
(
types
[
2
],
TensorTypeNode
::
make
(
oshape
,
data
->
dtype
));
reporter
->
Assign
(
types
[
2
],
TensorTypeNode
::
make
(
oshape
,
data
->
dtype
));
return
true
;
return
true
;
}
}
...
@@ -1990,7 +1990,7 @@ bool SplitRel(const Array<Type>& types,
...
@@ -1990,7 +1990,7 @@ bool SplitRel(const Array<Type>& types,
<<
"indices_or_sections need to be able to divide input.shape[axis]"
;
<<
"indices_or_sections need to be able to divide input.shape[axis]"
;
std
::
vector
<
Type
>
fields
;
std
::
vector
<
Type
>
fields
;
for
(
int
i
=
0
;
i
<
sections
->
value
;
++
i
)
{
for
(
int
i
=
0
;
i
<
sections
->
value
;
++
i
)
{
std
::
vector
<
IndexExpr
>
&&
oshape
=
AsVector
(
data
->
shape
);
std
::
vector
<
IndexExpr
>
oshape
(
data
->
shape
.
begin
(),
data
->
shape
.
end
()
);
oshape
[
axis
]
/=
int32_t
(
sections
->
value
);
oshape
[
axis
]
/=
int32_t
(
sections
->
value
);
auto
vec_type
=
TensorTypeNode
::
make
(
oshape
,
data
->
dtype
);
auto
vec_type
=
TensorTypeNode
::
make
(
oshape
,
data
->
dtype
);
fields
.
push_back
(
vec_type
);
fields
.
push_back
(
vec_type
);
...
@@ -2003,7 +2003,7 @@ bool SplitRel(const Array<Type>& types,
...
@@ -2003,7 +2003,7 @@ bool SplitRel(const Array<Type>& types,
for
(
unsigned
int
i
=
0
;
i
<
indices
.
size
();
++
i
)
{
for
(
unsigned
int
i
=
0
;
i
<
indices
.
size
();
++
i
)
{
CHECK
(
reporter
->
Assert
(
IndexExpr
(
indices
[
i
])
>
begin
))
CHECK
(
reporter
->
Assert
(
IndexExpr
(
indices
[
i
])
>
begin
))
<<
"indices_or_sections need to be a sorted ascending list"
;
<<
"indices_or_sections need to be a sorted ascending list"
;
std
::
vector
<
IndexExpr
>
&&
oshape
=
AsVector
(
data
->
shape
);
std
::
vector
<
IndexExpr
>
oshape
(
data
->
shape
.
begin
(),
data
->
shape
.
end
()
);
oshape
[
axis
]
=
IndexExpr
(
indices
[
i
])
-
begin
;
oshape
[
axis
]
=
IndexExpr
(
indices
[
i
])
-
begin
;
begin
=
IndexExpr
(
indices
[
i
]);
begin
=
IndexExpr
(
indices
[
i
]);
auto
vec_type
=
TensorTypeNode
::
make
(
oshape
,
data
->
dtype
);
auto
vec_type
=
TensorTypeNode
::
make
(
oshape
,
data
->
dtype
);
...
@@ -2011,7 +2011,7 @@ bool SplitRel(const Array<Type>& types,
...
@@ -2011,7 +2011,7 @@ bool SplitRel(const Array<Type>& types,
}
}
CHECK
(
reporter
->
Assert
(
begin
<
data
->
shape
[
axis
]))
CHECK
(
reporter
->
Assert
(
begin
<
data
->
shape
[
axis
]))
<<
"The sum of sections must match the input.shape[axis]"
;
<<
"The sum of sections must match the input.shape[axis]"
;
std
::
vector
<
IndexExpr
>
&&
oshape
=
AsVector
(
data
->
shape
);
std
::
vector
<
IndexExpr
>
oshape
(
data
->
shape
.
begin
(),
data
->
shape
.
end
()
);
oshape
[
axis
]
=
data
->
shape
[
axis
]
-
begin
;
oshape
[
axis
]
=
data
->
shape
[
axis
]
-
begin
;
auto
vec_type
=
TensorTypeNode
::
make
(
oshape
,
data
->
dtype
);
auto
vec_type
=
TensorTypeNode
::
make
(
oshape
,
data
->
dtype
);
fields
.
push_back
(
vec_type
);
fields
.
push_back
(
vec_type
);
...
@@ -2105,9 +2105,9 @@ bool SliceLikeRel(const Array<Type>& types,
...
@@ -2105,9 +2105,9 @@ bool SliceLikeRel(const Array<Type>& types,
const
auto
param
=
attrs
.
as
<
SliceLikeAttrs
>
();
const
auto
param
=
attrs
.
as
<
SliceLikeAttrs
>
();
CHECK
(
param
!=
nullptr
);
CHECK
(
param
!=
nullptr
);
const
Array
<
IndexExpr
>
dshape
=
data
->
shape
;
const
Array
<
IndexExpr
>
&
dshape
=
data
->
shape
;
const
Array
<
IndexExpr
>
target_shape
=
target
->
shape
;
const
Array
<
IndexExpr
>
&
target_shape
=
target
->
shape
;
std
::
vector
<
IndexExpr
>
&&
oshape
=
AsVector
(
dshape
);
std
::
vector
<
IndexExpr
>
oshape
(
dshape
.
begin
(),
dshape
.
end
()
);
if
(
!
param
->
axes
.
defined
())
{
if
(
!
param
->
axes
.
defined
())
{
for
(
size_t
i
=
0
;
i
<
dshape
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
dshape
.
size
();
++
i
)
{
...
...
src/relay/op/vision/yolo.cc
View file @
0858c5ad
...
@@ -53,7 +53,7 @@ bool YoloReorgRel(const Array<Type>& types,
...
@@ -53,7 +53,7 @@ bool YoloReorgRel(const Array<Type>& types,
CHECK
(
param
!=
nullptr
);
CHECK
(
param
!=
nullptr
);
CHECK
(
data
->
shape
.
size
()
==
4
)
<<
"Yolo reorg supports only 4 dimension."
;
CHECK
(
data
->
shape
.
size
()
==
4
)
<<
"Yolo reorg supports only 4 dimension."
;
std
::
vector
<
IndexExpr
>
&&
oshape
=
AsVector
(
data
->
shape
);
std
::
vector
<
IndexExpr
>
oshape
(
data
->
shape
.
begin
(),
data
->
shape
.
end
()
);
oshape
[
1
]
=
oshape
[
1
]
*
param
->
stride
*
param
->
stride
;
oshape
[
1
]
=
oshape
[
1
]
*
param
->
stride
*
param
->
stride
;
oshape
[
2
]
=
oshape
[
2
]
/
param
->
stride
;
oshape
[
2
]
=
oshape
[
2
]
/
param
->
stride
;
oshape
[
3
]
=
oshape
[
3
]
/
param
->
stride
;
oshape
[
3
]
=
oshape
[
3
]
/
param
->
stride
;
...
...
tests/cpp/container_test.cc
View file @
0858c5ad
...
@@ -17,6 +17,8 @@
...
@@ -17,6 +17,8 @@
* under the License.
* under the License.
*/
*/
#include <vector>
#include <unordered_map>
#include <dmlc/logging.h>
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include <tvm/packed_func_ext.h>
#include <tvm/packed_func_ext.h>
...
@@ -42,6 +44,13 @@ TEST(Array, Mutate) {
...
@@ -42,6 +44,13 @@ TEST(Array, Mutate) {
CHECK
(
list2
[
1
].
same_as
(
z
));
CHECK
(
list2
[
1
].
same_as
(
z
));
}
}
TEST
(
Array
,
Iterator
)
{
using
namespace
tvm
;
Array
<
Expr
>
array
{
1
,
2
,
3
};
std
::
vector
<
Expr
>
vector
(
array
.
begin
(),
array
.
end
());
CHECK
(
vector
[
1
].
as
<
IntImm
>
()
->
value
==
2
);
}
TEST
(
Map
,
Expr
)
{
TEST
(
Map
,
Expr
)
{
using
namespace
tvm
;
using
namespace
tvm
;
Var
x
(
"x"
);
Var
x
(
"x"
);
...
@@ -86,6 +95,14 @@ TEST(Map, Mutate) {
...
@@ -86,6 +95,14 @@ TEST(Map, Mutate) {
LOG
(
INFO
)
<<
dict
;
LOG
(
INFO
)
<<
dict
;
}
}
TEST
(
Map
,
Iterator
)
{
using
namespace
tvm
;
Expr
a
=
1
,
b
=
2
;
Map
<
Expr
,
Expr
>
map1
{{
a
,
b
}};
std
::
unordered_map
<
Expr
,
Expr
,
NodeHash
,
NodeEqual
>
map2
(
map1
.
begin
(),
map1
.
end
());
CHECK
(
map2
[
a
].
as
<
IntImm
>
()
->
value
==
2
);
}
int
main
(
int
argc
,
char
**
argv
)
{
int
main
(
int
argc
,
char
**
argv
)
{
testing
::
InitGoogleTest
(
&
argc
,
argv
);
testing
::
InitGoogleTest
(
&
argc
,
argv
);
testing
::
FLAGS_gtest_death_test_style
=
"threadsafe"
;
testing
::
FLAGS_gtest_death_test_style
=
"threadsafe"
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment