Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
96b2c082
Commit
96b2c082
authored
May 22, 2018
by
Pariksheet Pinjari
Committed by
Tianqi Chen
May 21, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Added equality check and upgraded concatenate op (#1172)
parent
05e806e0
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
104 additions
and
7 deletions
+104
-7
topi/include/topi/broadcast.h
+2
-1
topi/include/topi/detail/broadcast.h
+5
-4
topi/include/topi/detail/constant_utils.h
+18
-0
topi/include/topi/nn.h
+2
-2
topi/include/topi/transform.h
+2
-0
topi/python/topi/transform.py
+1
-0
topi/tests/python_cpp/test_topi_transform.py
+74
-0
No files found.
topi/include/topi/broadcast.h
View file @
96b2c082
...
...
@@ -9,6 +9,7 @@
#include <string>
#include "topi/detail/broadcast.h"
#include "topi/detail/constant_utils.h"
#include "topi/tags.h"
namespace
topi
{
...
...
@@ -34,7 +35,7 @@ inline tvm::Tensor broadcast_to(const tvm::Tensor& t,
auto
bh
=
detail
::
BroadcastShape
(
output_shape
,
t
->
shape
);
CHECK_EQ
(
output_shape
.
size
(),
bh
.
common_shape
.
size
());
for
(
size_t
i
=
0
;
i
<
output_shape
.
size
();
++
i
)
{
CHECK
(
t
vm
::
ir
::
Equal
(
output_shape
[
i
],
bh
.
common_shape
[
i
]));
CHECK
(
t
opi
::
detail
::
EqualCheck
(
output_shape
[
i
],
bh
.
common_shape
[
i
]));
}
auto
l
=
[
&
](
tvm
::
Array
<
tvm
::
Var
>
ovars
)
{
return
t
(
detail
::
InputIndexFromBroadcast
(
ovars
,
t
,
bh
.
vars2
,
bh
.
all_vars
));
...
...
topi/include/topi/detail/broadcast.h
View file @
96b2c082
...
...
@@ -12,6 +12,7 @@
#include "tvm/ir_pass.h"
#include "tvm/tvm.h"
#include "topi/detail/constant_utils.h"
namespace
topi
{
namespace
detail
{
...
...
@@ -32,15 +33,15 @@ inline BroadcastHelper BroadcastShape(const tvm::Array<tvm::Expr>& shape1,
int
i
;
for
(
i
=
1
;
i
<=
std
::
min
(
s1_size
,
s2_size
);
++
i
)
{
bh
.
all_vars
.
push_front
(
tvm
::
Var
());
if
(
t
vm
::
ir
::
Equal
(
shape1
[
s1_size
-
i
],
shape2
[
s2_size
-
i
]))
{
if
(
t
opi
::
detail
::
EqualCheck
(
shape1
[
s1_size
-
i
],
shape2
[
s2_size
-
i
]))
{
bh
.
common_shape
.
push_front
(
shape1
[
s1_size
-
i
]);
bh
.
vars1
.
push_front
(
bh
.
all_vars
[
0
]);
bh
.
vars2
.
push_front
(
bh
.
all_vars
[
0
]);
}
else
if
(
t
vm
::
ir
::
Equal
(
one
,
shape1
[
s1_size
-
i
]))
{
CHECK
(
!
t
vm
::
ir
::
Equal
(
one
,
shape2
[
s2_size
-
i
]));
}
else
if
(
t
opi
::
detail
::
EqualCheck
(
one
,
shape1
[
s1_size
-
i
]))
{
CHECK
(
!
t
opi
::
detail
::
EqualCheck
(
one
,
shape2
[
s2_size
-
i
]));
bh
.
common_shape
.
push_front
(
shape2
[
s2_size
-
i
]);
bh
.
vars2
.
push_front
(
bh
.
all_vars
[
0
]);
}
else
if
(
t
vm
::
ir
::
Equal
(
one
,
shape2
[
s2_size
-
i
]))
{
}
else
if
(
t
opi
::
detail
::
EqualCheck
(
one
,
shape2
[
s2_size
-
i
]))
{
bh
.
common_shape
.
push_front
(
shape1
[
s1_size
-
i
]);
bh
.
vars1
.
push_front
(
bh
.
all_vars
[
0
]);
}
else
{
...
...
topi/include/topi/detail/constant_utils.h
View file @
96b2c082
...
...
@@ -65,6 +65,24 @@ inline std::vector<int> GetConstIntValues(Array<Expr> exprs, const std::string&
return
result
;
}
/*!
* \brief Check weather the two expressions are equal or not, if not simplify the expressions and check again
* \note This is stronger equality check than tvm::ir::Equal
*
* \param lhs First expreesion
* \param rhs Second expreesion
*
* \return result True if both expressions are equal, else false
*/
inline
bool
EqualCheck
(
Expr
lhs
,
Expr
rhs
)
{
bool
result
=
tvm
::
ir
::
Equal
(
lhs
,
rhs
);
if
(
!
result
)
{
Expr
zero
(
0
);
result
=
tvm
::
ir
::
Equal
(
tvm
::
ir
::
CanonicalSimplify
(
lhs
-
rhs
),
zero
);
}
return
result
;
}
}
// namespace detail
}
// namespace topi
#endif // TOPI_DETAIL_CONSTANT_UTILS_H_
topi/include/topi/nn.h
View file @
96b2c082
...
...
@@ -186,13 +186,13 @@ inline tvm::Tensor pad(const tvm::Tensor& t,
indices
.
push_back
(
ovars
[
i
]);
continue
;
}
if
(
!
t
vm
::
ir
::
Equal
(
pad_before
[
i
],
0
))
{
if
(
!
t
opi
::
detail
::
EqualCheck
(
pad_before
[
i
],
0
))
{
sel
.
push_back
(
ovars
[
i
]
>=
pad_before
[
i
]);
indices
.
push_back
(
ovars
[
i
]
-
pad_before
[
i
]);
}
else
{
indices
.
push_back
(
ovars
[
i
]);
}
if
(
!
t
vm
::
ir
::
Equal
(
pad_after
[
i
],
0
))
{
if
(
!
t
opi
::
detail
::
EqualCheck
(
pad_after
[
i
],
0
))
{
sel
.
push_back
(
tvm
::
ir
::
Simplify
(
ovars
[
i
]
<
pad_before
[
i
]
+
t
->
shape
[
i
]));
}
}
...
...
topi/include/topi/transform.h
View file @
96b2c082
...
...
@@ -15,6 +15,7 @@
#include "topi/detail/ravel_unravel.h"
#include "topi/detail/constant_utils.h"
#include "tvm/tvm.h"
#include "tvm/ir_pass.h"
namespace
topi
{
using
namespace
tvm
;
...
...
@@ -260,6 +261,7 @@ inline Tensor concatenate(const Array<Tensor>& inputs,
for
(
size_t
i
=
1
;
i
<
axis_sizes
.
size
();
++
i
)
{
join_size
+=
axis_sizes
[
i
];
}
join_size
=
tvm
::
ir
::
Simplify
(
join_size
);
Array
<
Expr
>
out_shape
;
for
(
size_t
i
=
0
;
i
<
inputs
[
0
]
->
shape
.
size
();
++
i
)
{
out_shape
.
push_back
(
i
==
static_cast
<
size_t
>
(
axis
)
?
join_size
:
inputs
[
0
]
->
shape
[
i
]);
...
...
topi/python/topi/transform.py
View file @
96b2c082
...
...
@@ -226,6 +226,7 @@ def concatenate(a_tuple, axis=0):
axis_sizes
=
[
a_tuple
[
i
]
.
shape
[
axis
]
for
i
in
range
(
len
(
a_tuple
))]
out_shape
=
[
a_tuple
[
0
]
.
shape
[
i
]
for
i
in
range
(
0
,
axis
)]
+
[
sum
(
axis_sizes
)]
\
+
[
a_tuple
[
0
]
.
shape
[
i
]
for
i
in
range
(
axis
+
1
,
len
(
a_tuple
[
0
]
.
shape
))]
out_shape
[
axis
]
=
tvm
.
ir_pass
.
Simplify
(
out_shape
[
axis
])
def
_compute
(
*
indices
):
ret
=
a_tuple
[
0
](
*
indices
)
...
...
topi/tests/python_cpp/test_topi_transform.py
View file @
96b2c082
...
...
@@ -206,6 +206,70 @@ def verify_take(src_shape, indices_src, axis=None):
for
device
in
[
"llvm"
,
"opencl"
]:
check_device
(
device
)
def
verify_concatenate_split
(
shapes
,
axis
,
indices_or_sections
):
tensor_l_concatenate
=
[]
for
i
,
shape
in
enumerate
(
shapes
):
tensor_l_concatenate
.
append
(
tvm
.
placeholder
(
shape
,
name
=
"A"
+
str
(
i
)))
out_tensor
=
topi
.
cpp
.
concatenate
(
tensor_l_concatenate
,
axis
)
tensor_l
=
topi
.
cpp
.
split
(
out_tensor
,
indices_or_sections
,
axis
)
tensor_l
=
list
(
tensor_l
)
def
check_device
(
device
):
if
not
tvm
.
module
.
enabled
(
device
):
print
(
"Skip because
%
s is not enabled"
%
device
)
return
print
(
"Running on target:
%
s"
%
device
)
target
=
topi
.
cpp
.
TEST_create_target
(
device
)
if
device
==
"llvm"
:
s
=
topi
.
cpp
.
generic
.
schedule_injective
(
target
,
tensor_l
)
else
:
s
=
topi
.
cpp
.
cuda
.
schedule_injective
(
target
,
tensor_l
)
ctx
=
tvm
.
context
(
device
,
0
)
foo
=
tvm
.
build
(
s
,
tensor_l_concatenate
+
tensor_l
,
device
,
name
=
"concatenate_split"
)
data_npys
=
[
np
.
random
.
normal
(
size
=
shape
)
.
astype
(
tensor_l_concatenate
[
0
]
.
dtype
)
for
shape
in
shapes
]
out_npy_conc
=
np
.
concatenate
(
data_npys
,
axis
=
axis
)
out_npys_split
=
np
.
split
(
out_npy_conc
,
indices_or_sections
,
axis
=
axis
)
data_nds
=
[
tvm
.
nd
.
array
(
data_npy
,
ctx
)
for
data_npy
in
data_npys
]
out_nds
=
[
tvm
.
nd
.
empty
(
out_npy
.
shape
,
ctx
=
ctx
,
dtype
=
tensor_l
[
0
]
.
dtype
)
for
out_npy
in
out_npys_split
]
foo
(
*
(
data_nds
+
out_nds
))
for
out_nd
,
out_npy
in
zip
(
out_nds
,
out_npys_split
):
np
.
testing
.
assert_allclose
(
out_nd
.
asnumpy
(),
out_npy
)
for
device
in
[
"llvm"
,
"cuda"
,
"opencl"
,
"metal"
,
"rocm"
]:
check_device
(
device
)
def
verify_concatenate_broadcast
(
shapes
,
axis
,
rhs_shape
):
B
=
tvm
.
placeholder
(
shape
=
rhs_shape
,
name
=
"B"
)
tensor_l
=
[]
for
i
,
shape
in
enumerate
(
shapes
):
tensor_l
.
append
(
tvm
.
placeholder
(
shape
,
name
=
"A"
+
str
(
i
)))
out_tensor
=
topi
.
cpp
.
concatenate
(
tensor_l
,
axis
)
C
=
topi
.
cpp
.
broadcast_add
(
out_tensor
,
B
)
def
check_device
(
device
):
ctx
=
tvm
.
context
(
device
,
0
)
if
not
ctx
.
exist
:
print
(
"Skip because
%
s is not enabled"
%
device
)
return
print
(
"Running on target:
%
s"
%
device
)
target
=
topi
.
cpp
.
TEST_create_target
(
device
)
if
device
==
"llvm"
:
s
=
topi
.
cpp
.
generic
.
schedule_injective
(
target
,
[
C
])
else
:
s
=
topi
.
cpp
.
cuda
.
schedule_injective
(
target
,
[
C
])
ctx
=
tvm
.
context
(
device
,
0
)
foo
=
tvm
.
build
(
s
,
tensor_l
+
[
B
,
C
],
device
,
name
=
"broadcast_binary_add"
)
data_npys
=
[
np
.
random
.
normal
(
size
=
shape
)
.
astype
(
tensor_l
[
0
]
.
dtype
)
for
shape
in
shapes
]
lhs_npy
=
np
.
concatenate
(
data_npys
,
axis
=
axis
)
rhs_npy
=
np
.
random
.
uniform
(
size
=
rhs_shape
)
.
astype
(
B
.
dtype
)
out_npy
=
lhs_npy
+
rhs_npy
data_nds
=
[
tvm
.
nd
.
array
(
data_npy
,
ctx
)
for
data_npy
in
data_npys
]
rhs_nd
=
tvm
.
nd
.
array
(
rhs_npy
,
ctx
)
out_nd
=
tvm
.
nd
.
array
(
np
.
empty
(
out_npy
.
shape
)
.
astype
(
B
.
dtype
),
ctx
)
for
_
in
range
(
1
):
foo
(
*
(
data_nds
+
[
rhs_nd
]
+
[
out_nd
]))
np
.
testing
.
assert_allclose
(
out_nd
.
asnumpy
(),
out_npy
,
rtol
=
1E-4
,
atol
=
1E-4
)
for
device
in
[
"llvm"
,
"cuda"
,
"opencl"
,
"metal"
,
"rocm"
]:
check_device
(
device
)
def
test_expand_dims
():
verify_expand_dims
((
3
,
10
),
(
3
,
10
,
1
,
1
),
2
,
2
)
...
...
@@ -258,6 +322,14 @@ def test_take():
verify_take
((
2
,
2
),
[[[
1
,
0
],[
0
,
1
]]],
1
)
verify_take
((
4
,
3
,
5
,
6
),
[[
2
,
1
,
0
,
0
]],
-
2
)
def
test_regression_1
():
verify_concatenate_split
([(
2
,
3
,
4
),
(
2
,
2
,
4
),
(
2
,
5
,
4
)],
1
,
[
3
,
7
])
verify_concatenate_split
([(
3
,
4
),
(
2
,
4
),
(
3
,
4
)],
0
,
[
1
,
2
,
3
,
4
])
def
test_regression_2
():
verify_concatenate_broadcast
([(
5
,
1
,
3
),
(
5
,
1
,
3
)],
1
,
[
2
,
1
])
verify_concatenate_broadcast
([(
5
,
1
,
2
),
(
5
,
1
,
3
)],
2
,
[
1
,
5
])
if
__name__
==
"__main__"
:
test_concatenate
()
test_tranpose
()
...
...
@@ -266,3 +338,5 @@ if __name__ == "__main__":
test_squeeze
()
test_split
()
test_take
()
test_regression_1
()
test_regression_2
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment