Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
ee3c1b09
Commit
ee3c1b09
authored
Jul 12, 2018
by
Yao Wang
Committed by
Tianqi Chen
Jul 12, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[TOPI]Add where operator (#1416)
parent
6ea74d41
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
223 additions
and
0 deletions
+223
-0
nnvm/python/nnvm/top/transform.py
+4
-0
nnvm/src/top/tensor/transform.cc
+93
-0
nnvm/tests/python/compiler/test_top_level4.py
+31
-0
topi/include/topi/transform.h
+48
-0
topi/src/topi.cc
+5
-0
topi/tests/python_cpp/test_topi_transform.py
+42
-0
No files found.
nnvm/python/nnvm/top/transform.py
View file @
ee3c1b09
...
@@ -72,3 +72,7 @@ reg.register_schedule("strided_slice", _fschedule_injective)
...
@@ -72,3 +72,7 @@ reg.register_schedule("strided_slice", _fschedule_injective)
# slice_like
# slice_like
reg
.
register_pattern
(
"slice_like"
,
OpPattern
.
INJECTIVE
)
reg
.
register_pattern
(
"slice_like"
,
OpPattern
.
INJECTIVE
)
reg
.
register_schedule
(
"slice_like"
,
_fschedule_injective
)
reg
.
register_schedule
(
"slice_like"
,
_fschedule_injective
)
# where
reg
.
register_pattern
(
"where"
,
OpPattern
.
INJECTIVE
)
reg
.
register_schedule
(
"where"
,
_fschedule_injective
)
nnvm/src/top/tensor/transform.cc
View file @
ee3c1b09
...
@@ -1221,5 +1221,98 @@ NNVM_REGISTER_OP(slice_like)
...
@@ -1221,5 +1221,98 @@ NNVM_REGISTER_OP(slice_like)
})
})
.
set_support_level
(
4
);
.
set_support_level
(
4
);
// where
inline
bool
WhereShape
(
const
nnvm
::
NodeAttrs
&
attrs
,
std
::
vector
<
TShape
>*
in_attrs
,
std
::
vector
<
TShape
>*
out_attrs
)
{
CHECK_EQ
(
in_attrs
->
size
(),
3U
);
CHECK_EQ
(
out_attrs
->
size
(),
1U
);
const
TShape
&
cond_shape
=
in_attrs
->
at
(
0
);
const
TShape
&
x_shape
=
in_attrs
->
at
(
1
);
const
TShape
&
y_shape
=
in_attrs
->
at
(
2
);
CHECK_EQ
(
x_shape
,
y_shape
)
<<
"x and y must have the same shape: "
<<
x_shape
<<
" vs "
<<
y_shape
;
if
(
cond_shape
!=
x_shape
)
{
CHECK_EQ
(
cond_shape
.
ndim
(),
1
)
<<
"Shape of condition "
<<
cond_shape
<<
" must be either equal to x or has dimension of 1."
;
}
NNVM_ASSIGN_OUTPUT_SHAPE
(
attrs
,
*
out_attrs
,
0
,
x_shape
);
return
true
;
}
inline
bool
WhereInferType
(
const
NodeAttrs
&
attrs
,
std
::
vector
<
int
>
*
in_attrs
,
std
::
vector
<
int
>
*
out_attrs
)
{
DTYPE_ASSIGN
(
out_attrs
->
at
(
0
),
in_attrs
->
at
(
1
));
return
true
;
}
inline
bool
WhereCorrectLayout
(
const
NodeAttrs
&
attrs
,
std
::
vector
<
Layout
>
*
ilayouts
,
const
std
::
vector
<
Layout
>
*
last_ilayouts
,
std
::
vector
<
Layout
>
*
olayouts
)
{
CHECK_EQ
(
ilayouts
->
size
(),
last_ilayouts
->
size
());
CHECK_EQ
(
olayouts
->
size
(),
1U
);
for
(
size_t
i
=
0
;
i
<
ilayouts
->
size
();
++
i
)
{
const
Layout
&
input
=
last_ilayouts
->
at
(
i
).
defined
()
?
last_ilayouts
->
at
(
i
)
:
ilayouts
->
at
(
i
);
NNVM_ASSIGN_LAYOUT
(
*
ilayouts
,
i
,
input
);
}
return
true
;
}
NNVM_REGISTER_OP
(
where
)
.
describe
(
R"code(
Return the elements, either from x or y, depending on the condition.
Given three ndarrays, condition, x, and y, return an ndarray with the elements
from x or y, depending on the elements from condition are true or false.
x and y must have the same shape. If condition has the same shape as x,
each element in the output array is from x if the corresponding element
in the condition is true, and from y if false.
If condition does not have the same shape as x, it must be a 1D array whose
size is the same as x’s first dimension size. Each row of the output array
is from x’s row if the corresponding element from condition is true, and
from y’s row if false.
Note that all non-zero values are interpreted as True in condition.
Examples::
x = [[1, 2], [3, 4]]
y = [[5, 6], [7, 8]]
cond = [[0, 1], [-1, 0]]
where(cond, x, y) = [[5, 2], [3, 8]]
cond = [1, 0]
where(cond, x, y) = [[1, 2], [7, 8]]
)code"
NNVM_ADD_FILELINE
)
.
add_argument
(
"condition"
,
"Tensor"
,
"Condition array"
)
.
add_argument
(
"x"
,
"Tensor"
,
"First array to be selected"
)
.
add_argument
(
"y"
,
"Tensor"
,
"Second array to be selected"
)
.
set_num_inputs
(
3
)
.
set_num_outputs
(
1
)
.
set_attr
<
FInferShape
>
(
"FInferShape"
,
WhereShape
)
.
set_attr
<
FInferType
>
(
"FInferType"
,
WhereInferType
)
.
set_attr
<
FCorrectLayout
>
(
"FCorrectLayout"
,
WhereCorrectLayout
)
.
set_attr
<
FTVMCompute
>
(
"FTVMCompute"
,
[](
const
NodeAttrs
&
attrs
,
const
Array
<
Tensor
>&
inputs
,
const
Array
<
Tensor
>&
out_info
)
{
return
Array
<
Tensor
>
{
topi
::
where
(
inputs
[
0
],
inputs
[
1
],
inputs
[
2
])
};
})
.
set_attr
<
FListInputNames
>
(
"FListInputNames"
,
[](
const
NodeAttrs
&
attrs
)
{
return
std
::
vector
<
std
::
string
>
{
"condition"
,
"x"
,
"y"
};
})
.
set_support_level
(
4
);
}
// namespace top
}
// namespace top
}
// namespace nnvm
}
// namespace nnvm
nnvm/tests/python/compiler/test_top_level4.py
View file @
ee3c1b09
...
@@ -645,6 +645,36 @@ def test_slice_like():
...
@@ -645,6 +645,36 @@ def test_slice_like():
axis
=
(
2
,
3
)
axis
=
(
2
,
3
)
verify_slice_like
(
np_data
,
np_shape_like
,
axis
)
verify_slice_like
(
np_data
,
np_shape_like
,
axis
)
def
verify_where
(
condition
,
x
,
y
):
dtype
=
"float32"
if
len
(
condition
.
shape
)
==
1
:
np_out
=
np
.
array
([
xv
if
c
else
yv
for
(
c
,
xv
,
yv
)
in
zip
(
condition
,
x
,
y
)])
else
:
np_out
=
np
.
where
(
condition
,
x
,
y
)
cond_var
=
sym
.
Variable
(
"condition"
)
x_var
=
sym
.
Variable
(
"x"
)
y_var
=
sym
.
Variable
(
"y"
)
net
=
sym
.
where
(
cond_var
,
x_var
,
y_var
)
for
target
,
ctx
in
ctx_list
():
graph
,
lib
,
_
=
nnvm
.
compiler
.
build
(
net
,
target
,
{
"condition"
:
condition
.
shape
,
"x"
:
x
.
shape
,
"y"
:
y
.
shape
})
m
=
graph_runtime
.
create
(
graph
,
lib
,
ctx
)
m
.
set_input
(
**
{
"condition"
:
condition
,
"x"
:
x
,
"y"
:
y
})
m
.
run
()
out
=
m
.
get_output
(
0
,
tvm
.
nd
.
empty
(
x
.
shape
,
dtype
))
np
.
testing
.
assert_allclose
(
out
.
asnumpy
(),
np_out
,
atol
=
1e-5
,
rtol
=
1e-5
)
def
test_where
():
shape
=
(
13
,
8
,
224
,
224
,
6
)
condition
=
np
.
random
.
uniform
(
low
=-
1
,
high
=
1
,
size
=
shape
)
.
astype
(
"float32"
)
x
=
np
.
random
.
uniform
(
size
=
shape
)
.
astype
(
"float32"
)
y
=
np
.
random
.
uniform
(
size
=
shape
)
.
astype
(
"float32"
)
verify_where
(
condition
,
x
,
y
)
condition
=
np
.
random
.
uniform
(
low
=-
1
,
high
=
1
,
size
=
(
shape
[
0
],))
.
astype
(
"float32"
)
x
=
np
.
random
.
uniform
(
size
=
shape
)
.
astype
(
"float32"
)
y
=
np
.
random
.
uniform
(
size
=
shape
)
.
astype
(
"float32"
)
verify_where
(
condition
,
x
,
y
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_reshape
()
test_reshape
()
...
@@ -665,4 +695,5 @@ if __name__ == "__main__":
...
@@ -665,4 +695,5 @@ if __name__ == "__main__":
test_multibox_transform_loc
()
test_multibox_transform_loc
()
test_nms
()
test_nms
()
test_slice_like
()
test_slice_like
()
test_where
()
print
(
nnvm
.
compiler
.
engine
.
dump
())
print
(
nnvm
.
compiler
.
engine
.
dump
())
topi/include/topi/transform.h
View file @
ee3c1b09
...
@@ -575,5 +575,53 @@ inline Tensor take(const Tensor& a,
...
@@ -575,5 +575,53 @@ inline Tensor take(const Tensor& a,
},
name
,
tag
);
},
name
,
tag
);
}
}
/*!
* \brief Return the elements, either from x or y, depending on the condition.
*
* \param condition The condition array.
* \param x First array to be selected.
* \param y Second array to be selected.
* \param name The name of the operation.
* \param tag The tag to mark the operation.
*
* \return A Tensor selected from x or y depending on condition.
*/
inline
Tensor
where
(
const
Tensor
&
condition
,
const
Tensor
&
x
,
const
Tensor
&
y
,
std
::
string
name
=
"tensor"
,
std
::
string
tag
=
kInjective
)
{
CHECK_EQ
(
x
->
shape
.
size
(),
y
->
shape
.
size
())
<<
"x and y must have the same shape.Got different number of dimension: "
<<
x
->
shape
.
size
()
<<
" vs "
<<
y
->
shape
.
size
();
CHECK_EQ
(
x
->
dtype
,
y
->
dtype
)
<<
"x and y must have the same dtype: "
<<
x
->
dtype
<<
" vs "
<<
y
->
dtype
;
Array
<
Expr
>
oshape
=
x
->
shape
;
Tensor
out
;
if
(
condition
->
shape
.
size
()
!=
1
)
{
CHECK_EQ
(
condition
->
shape
.
size
(),
x
->
shape
.
size
())
<<
"condition array must be either have the same shape as x or to be a "
"1-D array.Got different number of dimension: "
<<
condition
->
shape
.
size
()
<<
" vs "
<<
x
->
shape
.
size
();
out
=
compute
(
oshape
,
[
&
](
const
Array
<
Var
>&
indices
)
{
return
tvm
::
select
(
condition
(
indices
)
!=
0
,
x
(
indices
),
y
(
indices
));
},
name
,
tag
);
}
else
{
CHECK_EQ
(
topi
::
GetConstInt
(
condition
->
shape
[
0
]),
topi
::
GetConstInt
(
x
->
shape
[
0
]))
<<
"If condition is 1-D, the first dimension must be the same as x: "
<<
condition
->
shape
[
0
]
<<
" vs "
<<
x
->
shape
[
0
];
out
=
compute
(
oshape
,
[
&
](
const
Array
<
Var
>&
indices
)
{
Array
<
Expr
>
condition_idx
{
indices
[
0
]};
return
tvm
::
select
(
condition
(
condition_idx
)
!=
0
,
x
(
indices
),
y
(
indices
));
},
name
,
tag
);
}
return
out
;
}
}
// namespace topi
}
// namespace topi
#endif // TOPI_TRANSFORM_H_
#endif // TOPI_TRANSFORM_H_
topi/src/topi.cc
View file @
ee3c1b09
...
@@ -280,6 +280,11 @@ TVM_REGISTER_GLOBAL("topi.take")
...
@@ -280,6 +280,11 @@ TVM_REGISTER_GLOBAL("topi.take")
}
}
});
});
TVM_REGISTER_GLOBAL
(
"topi.where"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
*
rv
=
where
(
args
[
0
],
args
[
1
],
args
[
2
]);
});
TVM_REGISTER_GLOBAL
(
"topi.strided_slice"
)
TVM_REGISTER_GLOBAL
(
"topi.strided_slice"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
*
rv
=
strided_slice
(
args
[
0
],
args
[
1
],
args
[
2
],
args
[
3
]);
*
rv
=
strided_slice
(
args
[
0
],
args
[
1
],
args
[
2
],
args
[
3
]);
...
...
topi/tests/python_cpp/test_topi_transform.py
View file @
ee3c1b09
...
@@ -206,6 +206,35 @@ def verify_take(src_shape, indices_src, axis=None):
...
@@ -206,6 +206,35 @@ def verify_take(src_shape, indices_src, axis=None):
for
device
in
[
"llvm"
,
"opencl"
]:
for
device
in
[
"llvm"
,
"opencl"
]:
check_device
(
device
)
check_device
(
device
)
def
verify_where
(
condition
,
x
,
y
):
dtype
=
"float32"
if
len
(
condition
.
shape
)
==
1
:
np_out
=
np
.
array
([
xv
if
c
else
yv
for
(
c
,
xv
,
yv
)
in
zip
(
condition
,
x
,
y
)])
else
:
np_out
=
np
.
where
(
condition
,
x
,
y
)
A
=
tvm
.
placeholder
(
shape
=
condition
.
shape
,
dtype
=
dtype
,
name
=
"condition"
)
B
=
tvm
.
placeholder
(
shape
=
x
.
shape
,
dtype
=
dtype
,
name
=
"x"
)
C
=
tvm
.
placeholder
(
shape
=
y
.
shape
,
dtype
=
dtype
,
name
=
"y"
)
out_tensor
=
topi
.
cpp
.
where
(
A
,
B
,
C
)
def
check_device
(
device
):
ctx
=
tvm
.
context
(
device
,
0
)
if
not
ctx
.
exist
:
print
(
"Skip because
%
s is not enabled"
%
device
)
return
print
(
"Running on target:
%
s"
%
device
)
with
tvm
.
target
.
create
(
device
):
s
=
topi
.
generic
.
schedule_injective
(
out_tensor
)
foo
=
tvm
.
build
(
s
,
[
A
,
B
,
C
,
out_tensor
],
device
,
name
=
"where"
)
tvm_out
=
tvm
.
nd
.
empty
(
x
.
shape
,
ctx
=
ctx
,
dtype
=
dtype
)
foo
(
tvm
.
nd
.
array
(
condition
,
ctx
),
tvm
.
nd
.
array
(
x
,
ctx
),
tvm
.
nd
.
array
(
y
,
ctx
),
tvm_out
)
np
.
testing
.
assert_allclose
(
tvm_out
.
asnumpy
(),
np_out
)
for
device
in
[
"llvm"
,
"nvptx"
,
"cuda"
,
"opencl"
,
"metal"
,
"rocm"
]:
check_device
(
device
)
def
verify_concatenate_split
(
shapes
,
axis
,
indices_or_sections
):
def
verify_concatenate_split
(
shapes
,
axis
,
indices_or_sections
):
tensor_l_concatenate
=
[]
tensor_l_concatenate
=
[]
for
i
,
shape
in
enumerate
(
shapes
):
for
i
,
shape
in
enumerate
(
shapes
):
...
@@ -324,6 +353,18 @@ def test_take():
...
@@ -324,6 +353,18 @@ def test_take():
verify_take
((
2
,
2
),
[[[
1
,
0
],[
0
,
1
]]],
1
)
verify_take
((
2
,
2
),
[[[
1
,
0
],[
0
,
1
]]],
1
)
verify_take
((
4
,
3
,
5
,
6
),
[[
2
,
1
,
0
,
0
]],
-
2
)
verify_take
((
4
,
3
,
5
,
6
),
[[
2
,
1
,
0
,
0
]],
-
2
)
def
test_where
():
shape
=
(
10
,
3
,
7
,
13
)
condition
=
np
.
random
.
uniform
(
low
=-
1
,
high
=
1
,
size
=
shape
)
.
astype
(
"float32"
)
x
=
np
.
random
.
uniform
(
size
=
shape
)
.
astype
(
"float32"
)
y
=
np
.
random
.
uniform
(
size
=
shape
)
.
astype
(
"float32"
)
verify_where
(
condition
,
x
,
y
)
condition
=
np
.
random
.
uniform
(
low
=-
1
,
high
=
1
,
size
=
(
shape
[
0
],))
.
astype
(
"float32"
)
x
=
np
.
random
.
uniform
(
size
=
shape
)
.
astype
(
"float32"
)
y
=
np
.
random
.
uniform
(
size
=
shape
)
.
astype
(
"float32"
)
verify_where
(
condition
,
x
,
y
)
def
test_regression_1
():
def
test_regression_1
():
verify_concatenate_split
([(
2
,
3
,
4
),
(
2
,
2
,
4
),
(
2
,
5
,
4
)],
1
,
[
3
,
7
])
verify_concatenate_split
([(
2
,
3
,
4
),
(
2
,
2
,
4
),
(
2
,
5
,
4
)],
1
,
[
3
,
7
])
verify_concatenate_split
([(
3
,
4
),
(
2
,
4
),
(
3
,
4
)],
0
,
[
1
,
2
,
3
,
4
])
verify_concatenate_split
([(
3
,
4
),
(
2
,
4
),
(
3
,
4
)],
0
,
[
1
,
2
,
3
,
4
])
...
@@ -340,5 +381,6 @@ if __name__ == "__main__":
...
@@ -340,5 +381,6 @@ if __name__ == "__main__":
test_squeeze
()
test_squeeze
()
test_split
()
test_split
()
test_take
()
test_take
()
test_where
()
test_regression_1
()
test_regression_1
()
test_regression_2
()
test_regression_2
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment