Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
493fc040
Commit
493fc040
authored
Oct 11, 2018
by
Zhi
Committed by
Tianqi Chen
Oct 11, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add relay.where (#1869)
parent
65016b65
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
135 additions
and
0 deletions
+135
-0
docs/langref/relay_op.rst
+2
-0
python/tvm/relay/op/transform.py
+39
-0
src/relay/op/tensor/transform.cc
+80
-0
tests/python/relay/test_op_level4.py
+14
-0
No files found.
docs/langref/relay_op.rst
View file @
493fc040
...
...
@@ -98,6 +98,7 @@ This level enables additional math and transform operators.
tvm.relay.maximum
tvm.relay.minimum
tvm.relay.pow
tvm.relay.where
**Level 5: Vision/Image Operators**
...
...
@@ -173,6 +174,7 @@ Level 4 Definitions
.. autofunction:: tvm.relay.maximum
.. autofunction:: tvm.relay.minimum
.. autofunction:: tvm.relay.pow
.. autofunction:: tvm.relay.where
Level 5 Definitions
-------------------
...
...
python/tvm/relay/op/transform.py
View file @
493fc040
...
...
@@ -180,3 +180,42 @@ def full_like(data, fill_value):
The resulting tensor.
"""
return
_make
.
full_like
(
data
,
fill_value
)
def
where
(
condition
,
x
,
y
):
"""Selecting elements from either x or y depending on the value of the
condition.
Parameters
----------
condition : relay.Expr
The condition array. The n-th element in `y` is selected when the n-th
value in the `condition` array is zero. Otherwise, the corresponding
element from `x` will be picked.
x : relay.Expr
The first array to be selected.
y : relay.Expr
The second array to be selected.
Returns
-------
result : relay.Expr
The selected array.
Examples
--------
.. code-block:: python
x = [[1, 2], [3, 4]]
y = [[5, 6], [7, 8]]
condition = [[0, 1], [-1, 0]]
relay.where(conditon, x, y) = [[5, 2], [3, 8]]
condition = [1, 0]
relay.where(conditon, x, y) = [[1, 2], [7, 8]]
Note that the shape of condition, x, and y needs to be the same.
"""
return
_make
.
where
(
condition
,
x
,
y
)
src/relay/op/tensor/transform.cc
View file @
493fc040
...
...
@@ -498,5 +498,85 @@ and type as the input array.
.
set_support_level
(
3
)
.
add_type_rel
(
"FullLike"
,
FullLikeRel
);
// where operator
bool
WhereRel
(
const
Array
<
Type
>&
types
,
int
num_inputs
,
const
Attrs
&
attrs
,
const
TypeReporter
&
reporter
)
{
CHECK_EQ
(
types
.
size
(),
4U
);
const
auto
*
condition
=
types
[
0
].
as
<
TensorTypeNode
>
();
const
auto
*
x
=
types
[
1
].
as
<
TensorTypeNode
>
();
const
auto
*
y
=
types
[
2
].
as
<
TensorTypeNode
>
();
CHECK
(
condition
!=
nullptr
&&
x
!=
nullptr
&&
y
!=
nullptr
);
const
auto
&
cond_shape
=
condition
->
shape
;
const
auto
&
x_shape
=
x
->
shape
;
const
auto
&
y_shape
=
y
->
shape
;
CHECK
(
x_shape
.
size
()
==
y_shape
.
size
())
<<
"x and y must have the same size"
;
if
(
cond_shape
.
size
()
!=
x_shape
.
size
())
{
CHECK_EQ
(
cond_shape
.
size
(),
1
)
<<
"Shape of condition "
<<
condition
->
shape
<<
" must be either equal to x or has dimension of 1."
;
}
for
(
size_t
i
=
0
;
i
<
x_shape
.
size
();
i
++
)
{
CHECK
(
reporter
->
AssertEQ
(
x_shape
[
i
],
y_shape
[
i
]))
<<
"x and y must have the same shape: "
<<
x_shape
<<
" vs "
<<
y_shape
;
CHECK
(
reporter
->
AssertEQ
(
cond_shape
[
i
],
x_shape
[
i
]))
<<
"Shape of condition "
<<
condition
->
shape
<<
" must be either equal to x or has dimension of 1."
;
}
reporter
->
Assign
(
types
[
3
],
TensorTypeNode
::
make
(
x_shape
,
x
->
dtype
));
return
true
;
}
// Positional relay function to create where operator.
Expr
MakeWhere
(
const
Expr
&
condition
,
const
Expr
&
x
,
const
Expr
&
y
)
{
static
const
Op
&
op
=
Op
::
Get
(
"where"
);
return
CallNode
::
make
(
op
,
{
condition
,
x
,
y
});
}
TVM_REGISTER_API
(
"relay.op._make.where"
)
.
set_body
([](
const
TVMArgs
&
args
,
TVMRetValue
*
rv
)
{
runtime
::
detail
::
unpack_call
<
Expr
,
3
>
(
MakeWhere
,
args
,
rv
);
});
RELAY_REGISTER_OP
(
"where"
)
.
describe
(
R"code(
Return the elements, either from x or y, depending on the condition.
Given three ndarrays, condition, x, and y, return an ndarray with the elements
from x or y, depending on the elements from condition are true or false.
x and y must have the same shape. If condition has the same shape as x,
each element in the output array is from x if the corresponding element
in the condition is true, and from y if false.
If condition does not have the same shape as x, it must be a 1D array whose
size is the same as x’s first dimension size. Each row of the output array
is from x’s row if the corresponding element from condition is true, and
from y’s row if false.
Note that all non-zero values are interpreted as True in condition.
Examples::
x = [[1, 2], [3, 4]]
y = [[5, 6], [7, 8]]
cond = [[0, 1], [-1, 0]]
where(cond, x, y) = [[5, 2], [3, 8]]
cond = [1, 0]
where(cond, x, y) = [[1, 2], [7, 8]]
)code"
TVM_ADD_FILELINE
)
.
add_argument
(
"condition"
,
"Tensor"
,
"Condition array"
)
.
add_argument
(
"x"
,
"Tensor"
,
"First array to be selected"
)
.
add_argument
(
"y"
,
"Tensor"
,
"Second array to be selected"
)
.
set_num_inputs
(
3
)
.
set_support_level
(
4
)
.
add_type_rel
(
"Where"
,
WhereRel
);
}
// namespace relay
}
// namespace tvm
tests/python/relay/test_op_level4.py
View file @
493fc040
...
...
@@ -125,8 +125,22 @@ def test_binary_broadcast():
assert
ftype
.
ret_type
==
relay
.
TensorType
((
5
,
10
,
4
),
"int32"
)
def
test_where
():
ib
=
relay
.
ir_builder
.
IRBuilder
()
cond
=
ib
.
param
(
"cond"
,
relay
.
TensorType
((
3
,
4
),
"float32"
))
x
=
ib
.
param
(
"x"
,
relay
.
TensorType
((
3
,
4
),
"float32"
))
y
=
ib
.
param
(
"y"
,
relay
.
TensorType
((
3
,
4
),
"float32"
))
with
ib
.
function
(
cond
,
x
,
y
)
as
func
:
ib
.
ret
(
relay
.
where
(
cond
.
var
,
x
.
var
,
y
.
var
))
ib
.
ret
(
func
)
func
=
relay
.
ir_pass
.
infer_type
(
ib
.
env
,
func
.
to_func
())
ftype
=
func
.
checked_type
assert
ftype
.
ret_type
==
relay
.
TensorType
((
3
,
4
),
"float32"
)
if
__name__
==
"__main__"
:
test_cmp_type
()
test_binary_broadcast
()
test_binary_op
()
test_binary_broadcast_op
()
test_where
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment