Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
c2b36154
Commit
c2b36154
authored
Oct 22, 2018
by
雾雨魔理沙
Committed by
Tianqi Chen
Oct 22, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay][Op]BroadcastToLike CollapseSumLike (#1886)
parent
c51268c3
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
141 additions
and
0 deletions
+141
-0
docs/langref/relay_op.rst
+18
-0
python/tvm/relay/op/transform.py
+38
-0
src/relay/op/tensor/transform.cc
+61
-0
tests/python/relay/test_op_level10.py
+23
-0
tests/python/relay/test_pass_alpha_equal.py
+1
-0
No files found.
docs/langref/relay_op.rst
View file @
c2b36154
...
...
@@ -123,6 +123,17 @@ This level enables additional math and transform operators.
tvm.relay.image.resize
**Level 10: Temporary Operators**
This level support backpropagation of broadcast operators. It is temporary.
.. autosummary::
:nosignatures:
tvm.relay.broadcast_to_like
tvm.relay.collapse_sum_like
Level 1 Definitions
-------------------
.. autofunction:: tvm.relay.log
...
...
@@ -199,6 +210,13 @@ Level 4 Definitions
.. autofunction:: tvm.relay.prod
Level 5 Definitions
-------------------
.. autofunction:: tvm.relay.image.resize
Level 10 Definitions
--------------------
.. autofunction:: tvm.relay.broadcast_to_like
.. autofunction:: tvm.relay.collapse_sum_like
python/tvm/relay/op/transform.py
View file @
c2b36154
...
...
@@ -242,3 +242,41 @@ def where(condition, x, y):
Note that the shape of condition, x, and y needs to be the same.
"""
return
_make
.
where
(
condition
,
x
,
y
)
def
broadcast_to_like
(
data
,
broadcast_type
):
"""Return an scalar value array with the same shape and type as the input array.
Parameters
----------
data : relay.Expr
The input tensor.
broadcast_type : relay.Expr
Provide the type to broadcast to.
Returns
-------
result : relay.Expr
The resulting tensor.
"""
return
_make
.
broadcast_to_like
(
data
,
broadcast_type
)
def
collapse_sum_like
(
data
,
collapse_type
):
"""Return an scalar value array with the same shape and type as the input array.
Parameters
----------
data : relay.Expr
The input tensor.
collapse_type : relay.Expr
Provide the type to collapse to.
Returns
-------
result : relay.Expr
The resulting tensor.
"""
return
_make
.
collapse_sum_like
(
data
,
collapse_type
)
src/relay/op/tensor/transform.cc
View file @
c2b36154
...
...
@@ -718,5 +718,66 @@ RELAY_REGISTER_OP("squeeze")
.
set_support_level
(
3
)
.
add_type_rel
(
"Squeeze"
,
SqueezeRel
);
// Have no idea how to assert the constraint.
// CollapseSumLike: <A, B> -> B where BroadCast(A, B) = A
bool
CollapseSumLikeRel
(
const
Array
<
Type
>&
types
,
int
num_inputs
,
const
Attrs
&
attrs
,
const
TypeReporter
&
reporter
)
{
CHECK_EQ
(
types
.
size
(),
3
);
reporter
->
Assign
(
types
[
2
],
types
[
1
]);
return
true
;
}
Expr
MakeCollapseSumLike
(
Expr
data
,
Expr
collapse_type
)
{
static
const
Op
&
op
=
Op
::
Get
(
"collapse_sum_like"
);
return
CallNode
::
make
(
op
,
{
data
,
collapse_type
},
Attrs
(),
{});
}
TVM_REGISTER_API
(
"relay.op._make.collapse_sum_like"
)
.
set_body
([](
const
TVMArgs
&
args
,
TVMRetValue
*
rv
)
{
runtime
::
detail
::
unpack_call
<
Expr
,
2
>
(
MakeCollapseSumLike
,
args
,
rv
);
});
RELAY_REGISTER_OP
(
"collapse_sum_like"
)
.
describe
(
R"code(Collapse the first input to match the shape of the second input.
)code"
TVM_ADD_FILELINE
)
.
set_num_inputs
(
2
)
.
add_argument
(
"data"
,
"Tensor"
,
"The input tensor."
)
.
add_argument
(
"collapse_type"
,
"Tensor"
,
"Provide the type to collapse to."
)
.
set_support_level
(
10
)
.
add_type_rel
(
"CollapseSumLike"
,
CollapseSumLikeRel
);
// BroadCastToLike: <A, B> -> B where BroadCast(A, B) = B
bool
BroadCastToLikeRel
(
const
Array
<
Type
>&
types
,
int
num_inputs
,
const
Attrs
&
attrs
,
const
TypeReporter
&
reporter
)
{
CHECK_EQ
(
types
.
size
(),
3
);
reporter
->
Assign
(
types
[
2
],
types
[
1
]);
return
true
;
}
Expr
MakeBroadCastToLike
(
Expr
data
,
Expr
broadcast_type
)
{
static
const
Op
&
op
=
Op
::
Get
(
"broadcast_to_like"
);
return
CallNode
::
make
(
op
,
{
data
,
broadcast_type
},
Attrs
(),
{});
}
TVM_REGISTER_API
(
"relay.op._make.broadcast_to_like"
)
.
set_body
([](
const
TVMArgs
&
args
,
TVMRetValue
*
rv
)
{
runtime
::
detail
::
unpack_call
<
Expr
,
2
>
(
MakeBroadCastToLike
,
args
,
rv
);
});
RELAY_REGISTER_OP
(
"broadcast_to_like"
)
.
describe
(
R"code(Broadcast the first input to match the shape of the second input.
)code"
TVM_ADD_FILELINE
)
.
set_num_inputs
(
2
)
.
add_argument
(
"data"
,
"Tensor"
,
"The input tensor."
)
.
add_argument
(
"broadcast_type"
,
"Tensor"
,
"Provide the type to broadcast to."
)
.
set_support_level
(
10
)
.
add_type_rel
(
"BroadCastToLike"
,
BroadCastToLikeRel
);
}
// namespace relay
}
// namespace tvm
tests/python/relay/test_op_level10.py
0 → 100644
View file @
c2b36154
""" Support level10 operator test cases.
"""
import
tvm
from
tvm
import
relay
def
test_collapse_sum_like
():
x
=
relay
.
Var
(
"x"
,
relay
.
ty
.
TensorType
((
3
,
4
,
5
,
6
),
"int8"
))
y
=
relay
.
Var
(
"y"
,
relay
.
ty
.
TensorType
((
4
,
1
,
6
),
"int8"
))
z
=
relay
.
collapse_sum_like
(
x
,
y
)
zz
=
relay
.
ir_pass
.
infer_type
(
z
)
assert
zz
.
checked_type
==
relay
.
ty
.
TensorType
((
4
,
1
,
6
),
"int8"
)
def
test_broadcast_to_like
():
x
=
relay
.
Var
(
"x"
,
relay
.
ty
.
TensorType
((
3
,
4
,
5
,
6
),
"int8"
))
y
=
relay
.
Var
(
"y"
,
relay
.
ty
.
TensorType
((
4
,
1
,
6
),
"int8"
))
z
=
relay
.
broadcast_to_like
(
y
,
x
)
zz
=
relay
.
ir_pass
.
infer_type
(
z
)
assert
zz
.
checked_type
==
relay
.
ty
.
TensorType
((
3
,
4
,
5
,
6
),
"int8"
)
if
__name__
==
"__main__"
:
test_collapse_sum_like
()
test_broadcast_to_like
()
tests/python/relay/test_pass_alpha_equal.py
View file @
c2b36154
...
...
@@ -461,3 +461,4 @@ if __name__ == "__main__":
test_let_alpha_equal
()
test_if_alpha_equal
()
test_op_alpha_equal
()
test_var_alpha_equal
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment