Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
cb06a184
Commit
cb06a184
authored
Dec 02, 2018
by
Siva
Committed by
Tianqi Chen
Dec 01, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RELAY][OP] end to end support for pad op. (#2213)
parent
285e8d54
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
48 additions
and
1 deletions
+48
-1
python/tvm/relay/op/nn/_nn.py
+3
-0
src/relay/op/nn/pad.cc
+29
-1
tests/python/relay/test_op_level2.py
+16
-0
No files found.
python/tvm/relay/op/nn/_nn.py
View file @
cb06a184
...
...
@@ -251,3 +251,6 @@ def schedule_upsampling(_, outs, target):
return
topi
.
generic
.
schedule_injective
(
outs
)
reg
.
register_pattern
(
"nn.upsampling"
,
OpPattern
.
INJECTIVE
)
# pad
reg
.
register_schedule
(
"nn.pad"
,
schedule_broadcast
)
src/relay/op/nn/pad.cc
View file @
cb06a184
...
...
@@ -6,8 +6,10 @@
#include <tvm/ir_operator.h>
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h>
#include <topi/nn.h>
#include <vector>
#include "../layout.h"
#include "../op_common.h"
namespace
tvm
{
namespace
relay
{
...
...
@@ -60,6 +62,30 @@ bool PadRel(const Array<Type>& types,
return
true
;
}
Array
<
Tensor
>
PadCompute
(
const
Attrs
&
attrs
,
const
Array
<
Tensor
>&
inputs
,
const
Type
&
out_type
,
const
Target
&
target
)
{
const
auto
*
param
=
attrs
.
as
<
PadAttrs
>
();
CHECK
(
param
!=
nullptr
);
auto
pad_width
=
param
->
pad_width
;
CHECK
(
pad_width
.
size
()
==
inputs
[
0
].
ndim
()
&&
pad_width
[
0
].
size
()
==
2
)
<<
"Illegal pad_width"
;
Array
<
IndexExpr
>
pad_before
;
for
(
size_t
i
=
0
;
i
<
pad_width
.
size
();
++
i
)
{
pad_before
.
push_back
(
pad_width
[
i
][
0
]);
}
Array
<
IndexExpr
>
pad_after
;
for
(
size_t
i
=
0
;
i
<
pad_width
.
size
();
++
i
)
{
pad_after
.
push_back
(
pad_width
[
i
][
1
]);
}
const
auto
*
out_ttype
=
out_type
.
as
<
TensorTypeNode
>
();
return
Array
<
Tensor
>
{
topi
::
pad
(
inputs
[
0
],
pad_before
,
pad_after
,
tvm
::
make_const
(
out_ttype
->
dtype
,
param
->
pad_value
))
};
}
// Handler to create a call to the padding op used by front-end FFI
Expr
MakePad
(
Expr
data
,
Array
<
Array
<
IndexExpr
>
>
pad_width
,
double
pad_value
)
{
auto
attrs
=
make_node
<
PadAttrs
>
();
...
...
@@ -82,7 +108,9 @@ RELAY_REGISTER_OP("nn.pad")
.
set_num_inputs
(
1
)
.
add_argument
(
"data"
,
"Tensor"
,
"The input tensor."
)
.
set_support_level
(
2
)
.
add_type_rel
(
"Pad"
,
PadRel
);
.
add_type_rel
(
"Pad"
,
PadRel
)
.
set_attr
<
TOpPattern
>
(
"TOpPattern"
,
kInjective
)
.
set_attr
<
FTVMCompute
>
(
"FTVMCompute"
,
PadCompute
);
}
// namespace relay
}
// namespace tvm
tests/python/relay/test_op_level2.py
View file @
cb06a184
...
...
@@ -330,6 +330,21 @@ def test_pad_infer_type():
yy
=
relay
.
ir_pass
.
infer_type
(
y
)
assert
yy
.
checked_type
==
relay
.
TensorType
((
n
+
2
,
6
,
9
,
w
+
8
),
"float32"
)
def
test_pad_run
():
def
_test_run
(
dtype
):
dshape
=
(
4
,
10
,
7
,
7
)
x
=
relay
.
var
(
"x"
,
shape
=
dshape
)
y
=
relay
.
nn
.
pad
(
x
,
((
1
,
1
),
(
2
,
2
),
(
3
,
3
),
(
4
,
4
)))
func
=
relay
.
Function
([
x
],
y
)
data
=
np
.
random
.
uniform
(
size
=
dshape
)
.
astype
(
dtype
)
ref_res
=
np
.
pad
(
data
,
((
1
,
1
),
(
2
,
2
),
(
3
,
3
),
(
4
,
4
)),
'constant'
)
for
target
,
ctx
in
ctx_list
():
intrp1
=
relay
.
create_executor
(
"graph"
,
ctx
=
ctx
,
target
=
target
)
op_res1
=
intrp1
.
evaluate
(
func
)(
data
)
tvm
.
testing
.
assert_allclose
(
op_res1
.
asnumpy
(),
ref_res
,
rtol
=
1e-5
,
atol
=
1e-5
)
_test_run
(
'float32'
)
_test_run
(
'int32'
)
def
test_lrn
():
n
,
c
,
h
,
w
=
tvm
.
var
(
"n"
),
tvm
.
var
(
"c"
),
tvm
.
var
(
"h"
),
tvm
.
var
(
"w"
)
...
...
@@ -457,6 +472,7 @@ if __name__ == "__main__":
test_upsampling_infer_type
()
test_flatten_infer_type
()
test_pad_infer_type
()
test_pad_run
()
test_conv2d_transpose_infer_type
()
test_conv2d_transpose_run
()
test_conv2d_run
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment