Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
e286e637
Commit
e286e637
authored
Oct 30, 2018
by
Siju
Committed by
Tianqi Chen
Oct 29, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RELAY]prelu op support (#2016)
parent
2fb1cc6e
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
130 additions
and
6 deletions
+130
-6
docs/langref/relay_op.rst
+2
-0
include/tvm/relay/attrs/nn.h
+11
-0
include/tvm/relay/type.h
+1
-0
python/tvm/relay/op/nn/nn.py
+27
-0
src/relay/op/nn/nn.cc
+56
-0
tests/python/relay/test_op_level3.py
+33
-6
No files found.
docs/langref/relay_op.rst
View file @
e286e637
...
...
@@ -74,6 +74,7 @@ This level enables additional math and transform operators.
tvm.relay.zeros
tvm.relay.nn.leaky_relu
tvm.relay.nn.prelu
tvm.relay.zeros_like
tvm.relay.ones
tvm.relay.ones_like
...
...
@@ -183,6 +184,7 @@ Level 2 Definitions
Level 3 Definitions
-------------------
.. autofunction:: tvm.relay.nn.leaky_relu
.. autofunction:: tvm.relay.nn.prelu
.. autofunction:: tvm.relay.floor
.. autofunction:: tvm.relay.ceil
.. autofunction:: tvm.relay.trunc
...
...
include/tvm/relay/attrs/nn.h
View file @
e286e637
...
...
@@ -278,6 +278,17 @@ struct LeakyReluAttrs : public tvm::AttrsNode<LeakyReluAttrs> {
};
/*! \brief Attributes for prelu operator */
struct
PReluAttrs
:
public
tvm
::
AttrsNode
<
PReluAttrs
>
{
int
axis
;
TVM_DECLARE_ATTRS
(
PReluAttrs
,
"relay.attrs.PReluAttrs"
)
{
TVM_ATTR_FIELD
(
axis
).
set_default
(
1
)
.
describe
(
"Specify which shape axis the channel is specified."
);
}
};
/*! \brief Attributes used in dropout operator */
struct
DropoutAttrs
:
public
tvm
::
AttrsNode
<
DropoutAttrs
>
{
double
rate
;
...
...
include/tvm/relay/type.h
View file @
e286e637
...
...
@@ -280,6 +280,7 @@ class TypeReporterNode : public Node {
TVM_DLL
virtual
void
Assign
(
const
Type
&
dst
,
const
Type
&
src
)
=
0
;
/*!
* \brief assert shape expression comparison.
* \note Use assert only if any of the condition input is symbolic.
* \param cond The condition of operation.
* \return false if assertation can be proven to have failed
* true if solver can still proceed.
...
...
python/tvm/relay/op/nn/nn.py
View file @
e286e637
...
...
@@ -528,6 +528,33 @@ def leaky_relu(data, alpha):
return
_make
.
leaky_relu
(
data
,
alpha
)
def
prelu
(
data
,
alpha
,
axis
=
1
):
"""This operator takes data as input and does Leaky version
of a Rectified Linear Unit.
.. math::
`y = x > 0 ? x : alpha * x`
Parameters
----------
data : tvm.relay.Expr
The input data to the operator.
alpha : tvm.relay.Expr
Slope coefficient for the negative half axis.
axis : int, optional
Specify which shape axis the channel is specified.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
return
_make
.
prelu
(
data
,
alpha
,
axis
)
def
pad
(
data
,
pad_width
,
pad_value
=
0.0
):
...
...
src/relay/op/nn/nn.cc
View file @
e286e637
...
...
@@ -171,6 +171,62 @@ RELAY_REGISTER_OP("nn.leaky_relu")
.
add_type_rel
(
"Identity"
,
IdentityRel
);
TVM_REGISTER_NODE_TYPE
(
PReluAttrs
);
bool
PReluRel
(
const
Array
<
Type
>&
types
,
int
num_inputs
,
const
Attrs
&
attrs
,
const
TypeReporter
&
reporter
)
{
CHECK_EQ
(
types
.
size
(),
3
);
const
auto
*
data
=
types
[
0
].
as
<
TensorTypeNode
>
();
if
(
data
==
nullptr
)
return
false
;
const
PReluAttrs
*
param
=
attrs
.
as
<
PReluAttrs
>
();
CHECK
(
param
!=
nullptr
);
CHECK
(
param
->
axis
<
static_cast
<
int
>
(
data
->
shape
.
size
()))
<<
"Wrong axis ("
<<
param
->
axis
<<
")value."
;
// assign alpha type
Array
<
IndexExpr
>
alpha_shape
({
data
->
shape
[
param
->
axis
]});
reporter
->
Assign
(
types
[
1
],
TensorTypeNode
::
make
(
alpha_shape
,
data
->
dtype
));
// assign output type
reporter
->
Assign
(
types
[
2
],
TensorTypeNode
::
make
(
data
->
shape
,
data
->
dtype
));
return
true
;
}
// Positional relay function to create prelu operator used by frontend FFI.
Expr
MakePRelu
(
Expr
data
,
Expr
alpha
,
int
axis
)
{
auto
attrs
=
make_node
<
PReluAttrs
>
();
attrs
->
axis
=
axis
;
static
const
Op
&
op
=
Op
::
Get
(
"nn.prelu"
);
return
CallNode
::
make
(
op
,
{
data
,
alpha
},
Attrs
(
attrs
),
{});
}
TVM_REGISTER_API
(
"relay.op.nn._make.prelu"
)
.
set_body
([](
const
TVMArgs
&
args
,
TVMRetValue
*
rv
)
{
runtime
::
detail
::
unpack_call
<
Expr
,
3
>
(
MakePRelu
,
args
,
rv
);
});
RELAY_REGISTER_OP
(
"nn.prelu"
)
.
describe
(
R"code(Parametric version of a Rectified Linear Unit.
It accepts two arguments: an input ``x`` and a channelwise slope ``alpha``
and computes the output as :math:`PReLU(x) y = x > 0 ? x : alpha * x`,
where :math:`*` is an channelwise multiplication for each sample in the batch.
)code"
TVM_ADD_FILELINE
)
.
set_attrs_type_key
(
"relay.attrs.PReluAttrs"
)
.
set_num_inputs
(
2
)
.
add_argument
(
"data"
,
"Tensor"
,
"Input data."
)
.
add_argument
(
"alpha"
,
"Tensor"
,
"Input channelwise alpha."
)
.
set_support_level
(
3
)
.
add_type_rel
(
"PRelu"
,
PReluRel
);
TVM_REGISTER_API
(
"relay.op.nn._make.softmax"
)
.
set_body
([](
const
TVMArgs
&
args
,
TVMRetValue
*
rv
)
{
auto
make_func
=
[](
Expr
data
,
int
axis
)
{
...
...
tests/python/relay/test_op_level3.py
View file @
e286e637
...
...
@@ -188,13 +188,39 @@ def test_full_like():
assert
yy
.
checked_type
==
relay
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
)
def
test_infer_type_leaky_relu
():
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
),
tvm
.
var
(
"h"
),
tvm
.
var
(
"w"
)
x
=
relay
.
var
(
"x"
,
relay
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
y
=
relay
.
nn
.
leaky_relu
(
x
,
alpha
=
0.1
)
"alpha=0.1"
in
y
.
astext
()
yy
=
relay
.
ir_pass
.
infer_type
(
y
)
assert
yy
.
checked_type
==
relay
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
)
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
),
tvm
.
var
(
"h"
),
tvm
.
var
(
"w"
)
x
=
relay
.
var
(
"x"
,
relay
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
))
y
=
relay
.
nn
.
leaky_relu
(
x
,
alpha
=
0.1
)
"alpha=0.1"
in
y
.
astext
()
yy
=
relay
.
ir_pass
.
infer_type
(
y
)
assert
yy
.
checked_type
==
relay
.
TensorType
((
n
,
c
,
h
,
w
),
"float32"
)
def
verify_infer_type_prelu
(
data
,
alpha
,
axis
,
output
,
dtype
=
"float32"
):
x
=
relay
.
var
(
"data"
,
relay
.
TensorType
(
data
,
dtype
))
if
alpha
:
y
=
relay
.
var
(
"alpha"
,
relay
.
TensorType
(
alpha
,
dtype
))
else
:
y
=
relay
.
var
(
"alpha"
,
relay
.
IncompleteType
())
z
=
relay
.
nn
.
prelu
(
x
,
y
,
axis
=
axis
)
zz
=
relay
.
ir_pass
.
infer_type
(
z
)
if
axis
!=
1
:
assert
"axis"
in
z
.
astext
()
assert
zz
.
checked_type
==
relay
.
ty
.
TensorType
(
output
,
dtype
)
if
not
alpha
:
axis
=
axis
if
axis
else
1
alpha_shape
=
(
data
[
axis
],)
assert
zz
.
args
[
1
]
.
checked_type
==
relay
.
TensorType
(
alpha_shape
,
"float32"
)
def
test_infer_type_prelu
():
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
),
tvm
.
var
(
"h"
),
tvm
.
var
(
"w"
)
verify_infer_type_prelu
((
n
,
c
,
h
,
w
),
(
c
,),
1
,
(
n
,
c
,
h
,
w
))
verify_infer_type_prelu
((
n
,
h
,
w
,
c
),
(
c
,),
3
,
(
n
,
h
,
w
,
c
))
verify_infer_type_prelu
((
n
,
c
,
h
,
w
),
None
,
1
,
(
n
,
c
,
h
,
w
))
verify_infer_type_prelu
((
n
,
h
,
w
,
c
),
None
,
3
,
(
n
,
h
,
w
,
c
))
verify_infer_type_prelu
((
1
,
3
,
2
,
2
),
(
3
,),
1
,
(
1
,
3
,
2
,
2
))
verify_infer_type_prelu
((
1
,
2
,
2
,
3
),
(
3
,),
3
,
(
1
,
2
,
2
,
3
))
verify_infer_type_prelu
((
1
,
3
,
2
,
2
),
None
,
1
,
(
1
,
3
,
2
,
2
))
verify_infer_type_prelu
((
1
,
2
,
2
,
3
),
None
,
3
,
(
1
,
2
,
2
,
3
))
if
__name__
==
"__main__"
:
test_cast
()
...
...
@@ -208,6 +234,7 @@ if __name__ == "__main__":
test_full
()
test_full_like
()
test_infer_type_leaky_relu
()
test_infer_type_prelu
()
test_squeeze_infer_type
()
test_squeeze_bad_axes_infer_type
()
test_split_infer_type
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment