Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
bdfcec0e
Commit
bdfcec0e
authored
Aug 05, 2018
by
masahi
Committed by
Tianqi Chen
Aug 04, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
update topi schedules (#1556)
parent
7b59b8ef
Hide whitespace changes
Inline
Side-by-side
Showing
23 changed files
with
147 additions
and
33 deletions
+147
-33
topi/python/topi/arm_cpu/bitserial_conv2d.py
+4
-1
topi/python/topi/arm_cpu/conv2d.py
+4
-1
topi/python/topi/arm_cpu/depthwise_conv2d.py
+5
-1
topi/python/topi/cuda/conv2d_hwcn.py
+5
-1
topi/python/topi/cuda/conv2d_nchw.py
+5
-1
topi/python/topi/cuda/conv2d_transpose_nchw.py
+5
-1
topi/python/topi/cuda/dense.py
+5
-1
topi/python/topi/cuda/depthwise_conv2d.py
+10
-2
topi/python/topi/cuda/pooling.py
+10
-2
topi/python/topi/cuda/reduction.py
+13
-4
topi/python/topi/cuda/vision.py
+5
-1
topi/python/topi/intel_graphics/conv2d.py
+8
-2
topi/python/topi/mali/conv2d.py
+4
-1
topi/python/topi/mali/dense.py
+5
-1
topi/python/topi/mali/depthwise_conv2d.py
+5
-1
topi/python/topi/opengl/conv2d_nchw.py
+5
-1
topi/python/topi/opengl/dense.py
+5
-1
topi/python/topi/opengl/pooling.py
+10
-2
topi/python/topi/x86/binary_dense.py
+4
-1
topi/python/topi/x86/bitserial_conv2d.py
+3
-1
topi/python/topi/x86/conv2d.py
+12
-3
topi/python/topi/x86/nn.py
+4
-1
topi/python/topi/x86/pooling.py
+11
-2
No files found.
topi/python/topi/arm_cpu/bitserial_conv2d.py
View file @
bdfcec0e
...
...
@@ -327,6 +327,8 @@ def _schedule_spatial_conv2d_nhwc(s, data, data_q, data_pad, data_vec,
def
schedule_bitserial_conv2d_nhwc
(
outs
):
"""Raspverry pi schedule for bitserial conv2d"""
s
=
tvm
.
create_schedule
([
x
.
op
for
x
in
outs
])
scheduled_ops
=
[]
def
traverse
(
op
):
"""Traverse operators from computation graph"""
# inline all one-to-one-mapping operators except the last stage (output)
...
...
@@ -334,7 +336,7 @@ def schedule_bitserial_conv2d_nhwc(outs):
if
op
not
in
s
.
outputs
:
s
[
op
]
.
compute_inline
()
for
tensor
in
op
.
input_tensors
:
if
tensor
.
op
.
input_tensors
:
if
tensor
.
op
.
input_tensors
and
tensor
.
op
not
in
scheduled_ops
:
traverse
(
tensor
.
op
)
if
'spatial_bitserial_conv_nhwc'
in
op
.
tag
:
...
...
@@ -360,6 +362,7 @@ def schedule_bitserial_conv2d_nhwc(outs):
_schedule_spatial_conv2d_nhwc
(
s
,
data
,
data_q
,
data_pad
,
data_vec
,
kernel
,
kernel_q
,
kernel_vec
,
conv_out
,
output
,
outs
[
0
])
scheduled_ops
.
append
(
op
)
traverse
(
outs
[
0
]
.
op
)
return
s
topi/python/topi/arm_cpu/conv2d.py
View file @
bdfcec0e
...
...
@@ -39,10 +39,11 @@ def decl_spatial_pack(cfg, data, kernel, strides, padding, layout, out_dtype):
def
schedule_conv2d_nchw_arm_cpu
(
cfg
,
outs
):
"""TOPI schedule callback"""
s
=
tvm
.
create_schedule
([
x
.
op
for
x
in
outs
])
scheduled_ops
=
[]
def
_callback
(
op
):
# schedule conv2d
if
'spatial_conv_output'
in
op
.
tag
:
if
'spatial_conv_output'
in
op
.
tag
and
op
not
in
scheduled_ops
:
output
=
op
.
output
(
0
)
conv
=
op
.
input_tensors
[
0
]
...
...
@@ -64,6 +65,8 @@ def schedule_conv2d_nchw_arm_cpu(cfg, outs):
output
=
op
.
output
(
0
)
_schedule_winograd
(
cfg
,
s
,
output
,
outs
[
0
])
scheduled_ops
.
append
(
op
)
traverse_inline
(
s
,
outs
[
0
]
.
op
,
_callback
)
return
s
...
...
topi/python/topi/arm_cpu/depthwise_conv2d.py
View file @
bdfcec0e
...
...
@@ -79,8 +79,10 @@ def schedule_depthwise_conv2d_nchw_(cfg, outs):
return
s
scheduled_ops
=
[]
def
_callback
(
op
):
if
op
.
tag
==
'depthwise_conv2d_nchw'
:
if
op
.
tag
==
'depthwise_conv2d_nchw'
and
op
not
in
scheduled_ops
:
output
=
op
.
output
(
0
)
kernel
=
op
.
input_tensors
[
1
]
data
=
op
.
input_tensors
[
0
]
...
...
@@ -90,5 +92,7 @@ def schedule_depthwise_conv2d_nchw_(cfg, outs):
data
=
data_pad
.
op
.
input_tensors
[
0
]
_schedule
(
cfg
,
s
,
data
,
data_pad
,
kernel
,
output
)
scheduled_ops
.
append
(
op
)
traverse_inline
(
s
,
outs
[
0
]
.
op
,
_callback
)
return
s
topi/python/topi/cuda/conv2d_hwcn.py
View file @
bdfcec0e
...
...
@@ -99,13 +99,15 @@ def schedule_conv2d_hwcn(outs):
sch
[
WW
]
.
bind
(
tx
,
thread_x
)
sch
[
WW
]
.
vectorize
(
fi
)
scheduled_ops
=
[]
def
traverse
(
operator
):
"""Traverse operators from computation graph"""
if
tag
.
is_broadcast
(
operator
.
tag
):
if
operator
not
in
sch
.
outputs
:
sch
[
operator
]
.
compute_inline
()
for
tensor
in
operator
.
input_tensors
:
if
tensor
.
op
.
input_tensors
:
if
tensor
.
op
.
input_tensors
and
tensor
.
op
not
in
scheduled_ops
:
traverse
(
tensor
.
op
)
elif
operator
.
tag
==
'conv2d_hwcn'
:
Apad
=
operator
.
input_tensors
[
0
]
...
...
@@ -117,5 +119,7 @@ def schedule_conv2d_hwcn(outs):
else
:
raise
RuntimeError
(
"Unsupported operator:
%
s"
%
operator
.
tag
)
scheduled_ops
.
append
(
operator
)
traverse
(
outs
[
0
]
.
op
)
return
sch
topi/python/topi/cuda/conv2d_nchw.py
View file @
bdfcec0e
...
...
@@ -492,6 +492,8 @@ def schedule_conv2d_small_batch(outs):
else
:
conv2d_56_64_64
(
s
,
Filter
,
temp_S
,
Filter_S
,
Out
,
Out_L
)
scheduled_ops
=
[]
def
traverse
(
OP
):
"""Traverse operators from computation graph"""
# inline all one-to-one-mapping operators except the last stage (output)
...
...
@@ -499,7 +501,7 @@ def schedule_conv2d_small_batch(outs):
if
OP
not
in
s
.
outputs
:
s
[
OP
]
.
compute_inline
()
for
tensor
in
OP
.
input_tensors
:
if
tensor
.
op
.
input_tensors
:
if
tensor
.
op
.
input_tensors
and
tensor
.
op
not
in
scheduled_ops
:
traverse
(
tensor
.
op
)
# schedule conv2d
if
'conv2d_nchw'
in
OP
.
tag
:
...
...
@@ -510,6 +512,8 @@ def schedule_conv2d_small_batch(outs):
Output
=
OP
.
output
(
0
)
schedule
(
temp
,
Filter
,
Output
)
scheduled_ops
.
append
(
OP
)
traverse
(
outs
[
0
]
.
op
)
return
s
...
...
topi/python/topi/cuda/conv2d_transpose_nchw.py
View file @
bdfcec0e
...
...
@@ -73,6 +73,8 @@ def schedule_conv2d_transpose_small_batch(outs):
else
:
conv2d_56_64_64
(
s
,
Filter
,
temp_S
,
Filter_S
,
Out
,
Out_L
)
scheduled_ops
=
[]
def
traverse
(
OP
):
"""Internal travserse function"""
# inline all one-to-one-mapping operators except the last stage (output)
...
...
@@ -80,7 +82,7 @@ def schedule_conv2d_transpose_small_batch(outs):
if
OP
not
in
s
.
outputs
:
s
[
OP
]
.
compute_inline
()
for
tensor
in
OP
.
input_tensors
:
if
tensor
.
op
.
input_tensors
:
if
tensor
.
op
.
input_tensors
and
tensor
.
op
not
in
scheduled_ops
:
traverse
(
tensor
.
op
)
# schedule conv2d_transpose_nchw
if
'conv2d_transpose_nchw'
in
OP
.
tag
:
...
...
@@ -91,6 +93,8 @@ def schedule_conv2d_transpose_small_batch(outs):
Output
=
OP
.
output
(
0
)
schedule
(
temp
,
Filter
,
Output
)
scheduled_ops
.
append
(
OP
)
traverse
(
outs
[
0
]
.
op
)
return
s
...
...
topi/python/topi/cuda/dense.py
View file @
bdfcec0e
...
...
@@ -86,6 +86,8 @@ def schedule_dense(outs):
s
[
Dense
]
.
set_store_predicate
(
thread_x
.
var
.
equal
(
0
))
s
[
Out
]
.
set_store_predicate
(
thread_x
.
var
.
equal
(
0
))
scheduled_ops
=
[]
def
traverse
(
OP
):
"""Internal travserse function"""
# inline all one-to-one-mapping operators except the last stage (output)
...
...
@@ -93,7 +95,7 @@ def schedule_dense(outs):
if
OP
not
in
s
.
outputs
:
s
[
OP
]
.
compute_inline
()
for
tensor
in
OP
.
input_tensors
:
if
tensor
.
op
.
input_tensors
:
if
tensor
.
op
.
input_tensors
and
tensor
.
op
not
in
scheduled_ops
:
traverse
(
tensor
.
op
)
# schedule dense
elif
OP
.
tag
==
'dense'
:
...
...
@@ -102,5 +104,7 @@ def schedule_dense(outs):
else
:
raise
RuntimeError
(
"Unsupported operator:
%
s"
%
OP
.
tag
)
scheduled_ops
.
append
(
OP
)
traverse
(
outs
[
0
]
.
op
)
return
s
topi/python/topi/cuda/depthwise_conv2d.py
View file @
bdfcec0e
...
...
@@ -101,6 +101,8 @@ def schedule_depthwise_conv2d_nchw(outs):
s
[
FS
]
.
bind
(
ty
,
thread_y
)
s
[
FS
]
.
bind
(
tx
,
thread_x
)
scheduled_ops
=
[]
def
traverse
(
OP
):
"""Internal travserse function"""
# inline all one-to-one-mapping operators except the last stage (output)
...
...
@@ -108,7 +110,7 @@ def schedule_depthwise_conv2d_nchw(outs):
if
OP
not
in
s
.
outputs
:
s
[
OP
]
.
compute_inline
()
for
tensor
in
OP
.
input_tensors
:
if
tensor
.
op
.
input_tensors
:
if
tensor
.
op
.
input_tensors
and
tensor
.
op
not
in
scheduled_ops
:
traverse
(
tensor
.
op
)
# schedule depthwise_conv2d
if
OP
.
tag
==
'depthwise_conv2d_nchw'
:
...
...
@@ -119,6 +121,8 @@ def schedule_depthwise_conv2d_nchw(outs):
DepthwiseConv2d
=
OP
.
output
(
0
)
_schedule
(
PaddedInput
,
Filter
,
DepthwiseConv2d
)
scheduled_ops
.
append
(
OP
)
traverse
(
outs
[
0
]
.
op
)
return
s
...
...
@@ -180,6 +184,8 @@ def schedule_depthwise_conv2d_nhwc(outs):
fused
=
s
[
FS
]
.
fuse
(
fi
,
ci
)
s
[
FS
]
.
bind
(
fused
,
thread_x
)
scheduled_ops
=
[]
def
traverse
(
OP
):
"""Internal travserse function"""
# inline all one-to-one-mapping operators except the last stage (output)
...
...
@@ -187,7 +193,7 @@ def schedule_depthwise_conv2d_nhwc(outs):
if
OP
not
in
s
.
outputs
:
s
[
OP
]
.
compute_inline
()
for
tensor
in
OP
.
input_tensors
:
if
tensor
.
op
.
input_tensors
:
if
tensor
.
op
.
input_tensors
and
tensor
.
op
not
in
scheduled_ops
:
traverse
(
tensor
.
op
)
# schedule depthwise_conv2d
if
OP
.
tag
==
'depthwise_conv2d_nhwc'
:
...
...
@@ -198,6 +204,8 @@ def schedule_depthwise_conv2d_nhwc(outs):
DepthwiseConv2d
=
OP
.
output
(
0
)
_schedule
(
PaddedInput
,
Filter
,
DepthwiseConv2d
)
scheduled_ops
.
append
(
OP
)
traverse
(
outs
[
0
]
.
op
)
return
s
...
...
topi/python/topi/cuda/pooling.py
View file @
bdfcec0e
...
...
@@ -45,6 +45,8 @@ def schedule_global_pool(outs):
else
:
s
[
Pool
]
.
compute_at
(
s
[
Out
],
tx
)
scheduled_ops
=
[]
def
traverse
(
OP
):
"""Internal travserse function"""
# inline all one-to-one-mapping operators except the last stage (output)
...
...
@@ -52,7 +54,7 @@ def schedule_global_pool(outs):
if
OP
not
in
s
.
outputs
:
s
[
OP
]
.
compute_inline
()
for
tensor
in
OP
.
input_tensors
:
if
tensor
.
op
.
input_tensors
:
if
tensor
.
op
.
input_tensors
and
tensor
.
op
not
in
scheduled_ops
:
traverse
(
tensor
.
op
)
# schedule global_pool
elif
OP
.
tag
.
startswith
(
'global_pool'
):
...
...
@@ -61,6 +63,8 @@ def schedule_global_pool(outs):
else
:
raise
RuntimeError
(
"Unsupported operator:
%
s"
%
OP
.
tag
)
scheduled_ops
.
append
(
OP
)
traverse
(
outs
[
0
]
.
op
)
return
s
...
...
@@ -101,6 +105,8 @@ def schedule_pool(outs):
else
:
s
[
Pool
]
.
compute_at
(
s
[
Out
],
tx
)
scheduled_ops
=
[]
def
traverse
(
OP
):
"""Internal travserse function"""
# inline all one-to-one-mapping operators except the last stage (output)
...
...
@@ -108,7 +114,7 @@ def schedule_pool(outs):
if
OP
not
in
s
.
outputs
:
s
[
OP
]
.
compute_inline
()
for
tensor
in
OP
.
input_tensors
:
if
tensor
.
op
.
input_tensors
:
if
tensor
.
op
.
input_tensors
and
tensor
.
op
not
in
scheduled_ops
:
traverse
(
tensor
.
op
)
# schedule pool
elif
OP
.
tag
.
startswith
(
'pool'
):
...
...
@@ -118,5 +124,7 @@ def schedule_pool(outs):
else
:
raise
RuntimeError
(
"Unsupported operator:
%
s"
%
OP
.
tag
)
scheduled_ops
.
append
(
OP
)
traverse
(
outs
[
0
]
.
op
)
return
s
topi/python/topi/cuda/reduction.py
View file @
bdfcec0e
...
...
@@ -88,6 +88,7 @@ def schedule_reduce(outs):
"""
outs
=
[
outs
]
if
isinstance
(
outs
,
tvm
.
tensor
.
Tensor
)
else
outs
sch
=
tvm
.
create_schedule
([
x
.
op
for
x
in
outs
])
scheduled_ops
=
[]
def
traverse_before_reduce
(
operator
):
"""Internal travserse function"""
...
...
@@ -96,10 +97,13 @@ def schedule_reduce(outs):
elif
tag
.
is_injective
(
operator
.
tag
):
sch
[
operator
]
.
compute_inline
()
for
tensor
in
operator
.
input_tensors
:
traverse_before_reduce
(
tensor
.
op
)
if
tensor
.
op
not
in
scheduled_ops
:
traverse_before_reduce
(
tensor
.
op
)
else
:
raise
RuntimeError
(
"Unsupported operator:
%
s"
%
operator
.
tag
)
scheduled_ops
.
append
(
operator
)
def
traverse_after_reduce
(
operator
):
"""Internal travserse function"""
if
tag
.
is_broadcast
(
operator
.
tag
):
...
...
@@ -107,13 +111,18 @@ def schedule_reduce(outs):
elif
operator
.
tag
==
'comm_reduce'
:
_schedule_reduce
(
operator
,
sch
,
is_idx_reduce
=
False
)
for
tensor
in
operator
.
input_tensors
:
traverse_before_reduce
(
tensor
.
op
)
if
tensor
.
op
not
in
scheduled_ops
:
traverse_before_reduce
(
tensor
.
op
)
elif
operator
.
tag
==
'comm_reduce_idx'
:
_schedule_reduce
(
operator
,
sch
,
is_idx_reduce
=
True
)
for
tensor
in
operator
.
input_tensors
[
0
]
.
op
.
input_tensors
:
traverse_before_reduce
(
tensor
.
op
)
input_tensors
=
operator
.
input_tensors
[
0
]
.
op
.
input_tensors
for
tensor
in
input_tensors
:
if
tensor
.
op
not
in
scheduled_ops
:
traverse_before_reduce
(
tensor
.
op
)
else
:
raise
RuntimeError
(
"Unsupported operator:
%
s"
%
operator
.
tag
)
scheduled_ops
.
append
(
operator
)
traverse_after_reduce
(
outs
[
0
]
.
op
)
return
sch
topi/python/topi/cuda/vision.py
View file @
bdfcec0e
...
...
@@ -11,6 +11,8 @@ def _default_schedule(outs):
target
=
tvm
.
target
.
current_target
()
outs
=
[
outs
]
if
isinstance
(
outs
,
tvm
.
tensor
.
Tensor
)
else
outs
s
=
tvm
.
create_schedule
([
x
.
op
for
x
in
outs
])
scheduled_ops
=
[]
def
traverse
(
op
):
"""inline all one-to-one-mapping operators except the last stage (output)"""
if
"nms"
in
op
.
tag
:
...
...
@@ -32,9 +34,11 @@ def _default_schedule(outs):
s
[
x
]
.
bind
(
bx
,
tvm
.
thread_axis
(
"blockIdx.x"
))
s
[
x
]
.
bind
(
tx
,
tvm
.
thread_axis
(
"threadIdx.x"
))
for
tensor
in
op
.
input_tensors
:
if
tensor
.
op
.
input_tensors
:
if
tensor
.
op
.
input_tensors
and
tensor
.
op
not
in
scheduled_ops
:
traverse
(
tensor
.
op
)
scheduled_ops
.
append
(
op
)
traverse
(
outs
[
0
]
.
op
)
return
s
...
...
topi/python/topi/intel_graphics/conv2d.py
View file @
bdfcec0e
...
...
@@ -113,6 +113,7 @@ def schedule_conv2d_NCHWc(num_filter, kernel_size, stride, padding, layout, out_
"""
outs
=
[
outs
]
if
isinstance
(
outs
,
tvm
.
tensor
.
Tensor
)
else
outs
s
=
tvm
.
create_schedule
([
x
.
op
for
x
in
outs
])
scheduled_ops
=
[]
def
traverse
(
op
):
"""inline all one-to-one-mapping operators except the last stage (output)"""
...
...
@@ -120,12 +121,14 @@ def schedule_conv2d_NCHWc(num_filter, kernel_size, stride, padding, layout, out_
if
op
not
in
s
.
outputs
:
s
[
op
]
.
compute_inline
()
for
tensor
in
op
.
input_tensors
:
if
tensor
.
op
.
input_tensors
:
if
tensor
.
op
.
input_tensors
and
tensor
.
op
not
in
scheduled_ops
:
traverse
(
tensor
.
op
)
if
"4_5"
in
op
.
tag
or
"4_4"
in
op
.
tag
or
"2_7"
in
op
.
tag
or
"2_14"
in
op
.
tag
\
or
"1_16"
in
op
.
tag
:
_schedule_cl_spatialpack_NCHWc
(
s
,
op
)
scheduled_ops
.
append
(
op
)
traverse
(
outs
[
0
]
.
op
)
return
s
...
...
@@ -360,6 +363,7 @@ def schedule_conv2d_nchw(outs):
"""
outs
=
[
outs
]
if
isinstance
(
outs
,
tvm
.
tensor
.
Tensor
)
else
outs
s
=
tvm
.
create_schedule
([
x
.
op
for
x
in
outs
])
scheduled_ops
=
[]
def
traverse
(
op
):
"""inline all one-to-one-mapping operators except the last stage (output)"""
...
...
@@ -367,12 +371,14 @@ def schedule_conv2d_nchw(outs):
if
op
not
in
s
.
outputs
:
s
[
op
]
.
compute_inline
()
for
tensor
in
op
.
input_tensors
:
if
tensor
.
op
.
input_tensors
:
if
tensor
.
op
.
input_tensors
and
tensor
.
op
not
in
scheduled_ops
:
traverse
(
tensor
.
op
)
if
"4_5"
in
op
.
tag
or
"4_4"
in
op
.
tag
or
"2_7"
in
op
.
tag
or
"2_14"
in
op
.
tag
\
or
"1_16"
in
op
.
tag
:
_schedule_cl_spatialpack
(
s
,
op
)
scheduled_ops
.
append
(
op
)
traverse
(
outs
[
0
]
.
op
)
return
s
...
...
topi/python/topi/mali/conv2d.py
View file @
bdfcec0e
...
...
@@ -144,6 +144,7 @@ def schedule_conv2d_nchw(outs):
"""
outs
=
[
outs
]
if
isinstance
(
outs
,
tvm
.
tensor
.
Tensor
)
else
outs
s
=
tvm
.
create_schedule
([
x
.
op
for
x
in
outs
])
scheduled_ops
=
[]
def
traverse
(
op
):
"""inline all one-to-one-mapping operators except the last stage (output)"""
...
...
@@ -151,7 +152,7 @@ def schedule_conv2d_nchw(outs):
if
op
not
in
s
.
outputs
:
s
[
op
]
.
compute_inline
()
for
tensor
in
op
.
input_tensors
:
if
tensor
.
op
.
input_tensors
:
if
tensor
.
op
.
input_tensors
and
tensor
.
op
not
in
scheduled_ops
:
traverse
(
tensor
.
op
)
if
'im2col_conv_output'
in
op
.
tag
:
...
...
@@ -163,6 +164,8 @@ def schedule_conv2d_nchw(outs):
if
'winograd_conv_output'
in
op
.
tag
:
_schedule_winograd
(
s
,
op
)
scheduled_ops
.
append
(
op
)
traverse
(
outs
[
0
]
.
op
)
return
s
...
...
topi/python/topi/mali/dense.py
View file @
bdfcec0e
...
...
@@ -81,6 +81,8 @@ def schedule_dense(outs):
# bias = s[outs[0]].op.input_tensors[1]
# print(tvm.lower(s, [data, weight, bias, outs[0]], simple_mode=True))
scheduled_ops
=
[]
def
traverse
(
OP
):
"""Internal travserse function"""
# inline all one-to-one-mapping operators except the last stage (output)
...
...
@@ -88,7 +90,7 @@ def schedule_dense(outs):
if
OP
not
in
s
.
outputs
:
s
[
OP
]
.
compute_inline
()
for
tensor
in
OP
.
input_tensors
:
if
tensor
.
op
.
input_tensors
:
if
tensor
.
op
.
input_tensors
and
tensor
.
op
not
in
scheduled_ops
:
traverse
(
tensor
.
op
)
# schedule dense
elif
OP
.
tag
==
'dense'
:
...
...
@@ -97,5 +99,7 @@ def schedule_dense(outs):
else
:
raise
RuntimeError
(
"Unsupported operator:
%
s"
%
OP
.
tag
)
scheduled_ops
.
append
(
OP
)
traverse
(
outs
[
0
]
.
op
)
return
s
topi/python/topi/mali/depthwise_conv2d.py
View file @
bdfcec0e
...
...
@@ -86,6 +86,8 @@ def schedule_depthwise_conv2d_nchw(outs):
s
[
conv
]
.
vectorize
(
xi
)
s
[
conv
]
.
compute_at
(
s
[
output
],
ji
)
scheduled_ops
=
[]
def
traverse
(
op
):
"""Internal travserse function"""
# inline all one-to-one-mapping operators except the last stage (output)
...
...
@@ -93,7 +95,7 @@ def schedule_depthwise_conv2d_nchw(outs):
if
op
not
in
s
.
outputs
:
s
[
op
]
.
compute_inline
()
for
tensor
in
op
.
input_tensors
:
if
tensor
.
op
.
input_tensors
:
if
tensor
.
op
.
input_tensors
and
tensor
.
op
not
in
scheduled_ops
:
traverse
(
tensor
.
op
)
# schedule depthwise_conv2d
...
...
@@ -105,5 +107,7 @@ def schedule_depthwise_conv2d_nchw(outs):
conv
=
op
.
output
(
0
)
_schedule
(
pad_data
,
kernel
,
conv
)
scheduled_ops
.
append
(
op
)
traverse
(
outs
[
0
]
.
op
)
return
s
topi/python/topi/opengl/conv2d_nchw.py
View file @
bdfcec0e
...
...
@@ -21,6 +21,8 @@ def schedule_conv2d_nchw(outs):
"""
outs
=
[
outs
]
if
isinstance
(
outs
,
tvm
.
tensor
.
Tensor
)
else
outs
s
=
tvm
.
create_schedule
([
x
.
op
for
x
in
outs
])
scheduled_ops
=
[]
def
_schedule
(
conv2d
,
data
):
if
conv2d
.
op
in
s
.
outputs
:
Out
=
conv2d
...
...
@@ -37,7 +39,7 @@ def schedule_conv2d_nchw(outs):
if
OP
not
in
s
.
outputs
:
s
[
OP
]
.
opengl
()
for
tensor
in
OP
.
input_tensors
:
if
tensor
.
op
.
input_tensors
:
if
tensor
.
op
.
input_tensors
and
tensor
.
op
not
in
scheduled_ops
:
traverse
(
tensor
.
op
)
# schedule conv2d_nchw
elif
OP
.
tag
.
startswith
(
'conv2d_nchw'
):
...
...
@@ -50,5 +52,7 @@ def schedule_conv2d_nchw(outs):
else
:
raise
RuntimeError
(
"Unsupported operator:
%
s"
%
OP
.
tag
)
scheduled_ops
.
append
(
OP
)
traverse
(
outs
[
0
]
.
op
)
return
s
topi/python/topi/opengl/dense.py
View file @
bdfcec0e
...
...
@@ -22,6 +22,8 @@ def schedule_dense(outs):
"""
outs
=
[
outs
]
if
isinstance
(
outs
,
tvm
.
tensor
.
Tensor
)
else
outs
s
=
tvm
.
create_schedule
([
x
.
op
for
x
in
outs
])
scheduled_ops
=
[]
def
_schedule
(
Dense
):
if
Dense
.
op
in
s
.
outputs
:
Out
=
Dense
...
...
@@ -37,7 +39,7 @@ def schedule_dense(outs):
if
OP
not
in
s
.
outputs
:
s
[
OP
]
.
compute_inline
()
for
tensor
in
OP
.
input_tensors
:
if
tensor
.
op
.
input_tensors
:
if
tensor
.
op
.
input_tensors
and
tensor
.
op
not
in
scheduled_ops
:
traverse
(
tensor
.
op
)
# schedule dense
elif
OP
.
tag
==
'dense'
:
...
...
@@ -46,5 +48,7 @@ def schedule_dense(outs):
else
:
raise
RuntimeError
(
"Unsupported operator:
%
s"
%
OP
.
tag
)
scheduled_ops
.
append
(
OP
)
traverse
(
outs
[
0
]
.
op
)
return
s
topi/python/topi/opengl/pooling.py
View file @
bdfcec0e
...
...
@@ -21,6 +21,8 @@ def schedule_global_pool(outs):
"""
outs
=
[
outs
]
if
isinstance
(
outs
,
tvm
.
tensor
.
Tensor
)
else
outs
s
=
tvm
.
create_schedule
([
x
.
op
for
x
in
outs
])
scheduled_ops
=
[]
def
_schedule
(
Pool
):
if
Pool
.
op
in
s
.
outputs
:
Out
=
Pool
...
...
@@ -36,7 +38,7 @@ def schedule_global_pool(outs):
if
OP
not
in
s
.
outputs
:
s
[
OP
]
.
opengl
()
for
tensor
in
OP
.
input_tensors
:
if
tensor
.
op
.
input_tensors
:
if
tensor
.
op
.
input_tensors
and
tensor
.
op
not
in
scheduled_ops
:
traverse
(
tensor
.
op
)
# schedule global_pool
elif
OP
.
tag
.
startswith
(
'global_pool'
):
...
...
@@ -45,6 +47,8 @@ def schedule_global_pool(outs):
else
:
raise
RuntimeError
(
"Unsupported operator:
%
s"
%
OP
.
tag
)
scheduled_ops
.
append
(
OP
)
traverse
(
outs
[
0
]
.
op
)
return
s
...
...
@@ -66,6 +70,8 @@ def schedule_pool(outs):
"""
outs
=
[
outs
]
if
isinstance
(
outs
,
tvm
.
tensor
.
Tensor
)
else
outs
s
=
tvm
.
create_schedule
([
x
.
op
for
x
in
outs
])
scheduled_ops
=
[]
def
_schedule
(
PaddedInput
,
Pool
):
if
isinstance
(
PaddedInput
.
op
,
tvm
.
tensor
.
ComputeOp
):
s
[
PaddedInput
]
.
opengl
()
...
...
@@ -82,7 +88,7 @@ def schedule_pool(outs):
if
tag
.
is_broadcast
(
OP
.
tag
):
if
OP
not
in
s
.
outputs
:
s
[
OP
]
.
compute_inline
()
for
tensor
in
OP
.
input_tensors
:
for
tensor
in
OP
.
input_tensors
and
tensor
.
op
not
in
scheduled_ops
:
if
tensor
.
op
.
input_tensors
:
traverse
(
tensor
.
op
)
# schedule pool
...
...
@@ -93,5 +99,7 @@ def schedule_pool(outs):
else
:
raise
RuntimeError
(
"Unsupported operator:
%
s"
%
OP
.
tag
)
scheduled_ops
.
append
(
OP
)
traverse
(
outs
[
0
]
.
op
)
return
s
topi/python/topi/x86/binary_dense.py
View file @
bdfcec0e
...
...
@@ -23,6 +23,7 @@ def schedule_binary_dense(outs):
"""
outs
=
[
outs
]
if
isinstance
(
outs
,
tvm
.
tensor
.
Tensor
)
else
outs
s
=
tvm
.
create_schedule
([
x
.
op
for
x
in
outs
])
scheduled_ops
=
[]
def
_schedule
(
A
,
B
,
C
):
s
[
C
]
.
split
(
s
[
C
]
.
op
.
reduce_axis
[
0
],
factor
=
8
)
...
...
@@ -41,7 +42,7 @@ def schedule_binary_dense(outs):
if
OP
not
in
s
.
outputs
:
s
[
OP
]
.
compute_inline
()
for
tensor
in
OP
.
input_tensors
:
if
tensor
.
op
.
input_tensors
:
if
tensor
.
op
.
input_tensors
and
tensor
.
op
not
in
scheduled_ops
:
traverse
(
tensor
.
op
)
# schedule binary_dense
elif
OP
.
tag
==
'binary_dense'
:
...
...
@@ -52,5 +53,7 @@ def schedule_binary_dense(outs):
else
:
raise
RuntimeError
(
"Unsupported operator:
%
s"
%
OP
.
tag
)
scheduled_ops
.
append
(
OP
)
traverse
(
outs
[
0
]
.
op
)
return
s
topi/python/topi/x86/bitserial_conv2d.py
View file @
bdfcec0e
...
...
@@ -71,6 +71,7 @@ def _declaration_bitserial_conv2d(data, kernel, stride, padding, activation_bits
def
schedule_bitserial_conv2d
(
outs
):
"""CPU schedule for bitserial convolutions NCHW and NHWC"""
s
=
tvm
.
create_schedule
([
x
.
op
for
x
in
outs
])
scheduled_ops
=
[]
def
traverse
(
op
):
"""Traverse operators from computation graph"""
...
...
@@ -79,7 +80,7 @@ def schedule_bitserial_conv2d(outs):
if
tag
.
is_broadcast
(
op
.
tag
)
or
'elemwise'
in
op
.
tag
:
if
op
not
in
s
.
outputs
:
s
[
op
]
.
compute_inline
()
for
tensor
in
op
.
input_tensors
:
for
tensor
in
op
.
input_tensors
and
tensor
.
op
not
in
scheduled_ops
:
if
tensor
.
op
.
input_tensors
:
traverse
(
tensor
.
op
)
...
...
@@ -111,6 +112,7 @@ def schedule_bitserial_conv2d(outs):
_schedule_spatial_conv2d_nhwc
(
s
,
data
,
data_q
,
data_pad
,
data_vec
,
kernel
,
kernel_q
,
kernel_vec
,
conv_out
,
output
,
outs
[
0
])
scheduled_ops
.
append
(
op
)
traverse
(
outs
[
0
]
.
op
)
return
s
...
...
topi/python/topi/x86/conv2d.py
View file @
bdfcec0e
...
...
@@ -188,6 +188,7 @@ def schedule_conv2d(outs):
}
s
=
tvm
.
create_schedule
([
x
.
op
for
x
in
outs
])
target
=
tvm
.
target
.
current_target
(
allow_none
=
False
)
scheduled_ops
=
[]
def
traverse
(
op
):
"""Traverse operators from computation graph"""
...
...
@@ -196,7 +197,7 @@ def schedule_conv2d(outs):
if
op
not
in
s
.
outputs
:
s
[
op
]
.
compute_inline
()
for
tensor
in
op
.
input_tensors
:
if
tensor
.
op
.
input_tensors
:
if
tensor
.
op
.
input_tensors
and
tensor
.
op
not
in
scheduled_ops
:
traverse
(
tensor
.
op
)
if
'conv2d_nchw'
in
op
.
tag
:
...
...
@@ -223,6 +224,8 @@ def schedule_conv2d(outs):
_AVX_SCH_TO_SCH_FUNC
[
type
(
sch
)](
s
,
data
,
data_pad
,
data_vec
,
kernel
,
kernel_vec
,
conv_out
,
output
,
outs
[
0
])
scheduled_ops
.
append
(
op
)
traverse
(
outs
[
0
]
.
op
)
return
s
...
...
@@ -232,6 +235,7 @@ def schedule_conv2d_nhwc(outs):
"""Create schedule for tensors"""
s
=
tvm
.
create_schedule
([
x
.
op
for
x
in
outs
])
output_op
=
outs
[
0
]
.
op
scheduled_ops
=
[]
def
traverse
(
op
):
"""Traverse operators from computation graph"""
...
...
@@ -246,7 +250,7 @@ def schedule_conv2d_nhwc(outs):
s
[
op
]
.
parallel
(
fused
)
s
[
op
]
.
vectorize
(
c
)
for
tensor
in
op
.
input_tensors
:
if
tensor
.
op
.
input_tensors
:
if
tensor
.
op
.
input_tensors
and
tensor
.
op
not
in
scheduled_ops
:
traverse
(
tensor
.
op
)
if
'conv2d_nhwc'
in
op
.
tag
:
...
...
@@ -275,6 +279,8 @@ def schedule_conv2d_nhwc(outs):
fused
=
s
[
C
]
.
fuse
(
n
,
h
,
w
)
s
[
C
]
.
parallel
(
fused
)
scheduled_ops
.
append
(
op
)
traverse
(
output_op
)
return
s
...
...
@@ -288,6 +294,7 @@ def schedule_conv2d_NCHWc(num_filter, kernel_size, stride, padding,
AVXConv1x1Fwd
:
conv2d_avx_1x1
.
_schedule_conv_NCHWc
}
s
=
tvm
.
create_schedule
([
x
.
op
for
x
in
outs
])
scheduled_ops
=
[]
def
traverse
(
op
):
"""Traverse operators from computation graph"""
...
...
@@ -296,7 +303,7 @@ def schedule_conv2d_NCHWc(num_filter, kernel_size, stride, padding,
if
op
not
in
s
.
outputs
:
s
[
op
]
.
compute_inline
()
for
tensor
in
op
.
input_tensors
:
if
tensor
.
op
.
input_tensors
:
if
tensor
.
op
.
input_tensors
and
tensor
.
op
not
in
scheduled_ops
:
traverse
(
tensor
.
op
)
if
'conv2d_NCHWc'
in
op
.
tag
:
...
...
@@ -322,5 +329,7 @@ def schedule_conv2d_NCHWc(num_filter, kernel_size, stride, padding,
_AVX_SCH_TO_SCH_FUNC
[
type
(
sch
)](
s
,
wkl
,
sch
,
data_vec
,
kernel
,
conv_out
,
outs
[
0
])
scheduled_ops
.
append
(
op
)
traverse
(
outs
[
0
]
.
op
)
return
s
topi/python/topi/x86/nn.py
View file @
bdfcec0e
...
...
@@ -53,6 +53,7 @@ def schedule_dense(outs):
outs
=
[
outs
]
if
isinstance
(
outs
,
tvm
.
tensor
.
Tensor
)
else
outs
s
=
tvm
.
create_schedule
([
x
.
op
for
x
in
outs
])
scheduled_ops
=
[]
def
traverse
(
op
):
"""Traverse operators from computation graph"""
...
...
@@ -61,7 +62,7 @@ def schedule_dense(outs):
if
op
not
in
s
.
outputs
:
s
[
op
]
.
compute_inline
()
for
tensor
in
op
.
input_tensors
:
if
tensor
.
op
.
input_tensors
:
if
tensor
.
op
.
input_tensors
and
tensor
.
op
not
in
scheduled_ops
:
traverse
(
tensor
.
op
)
if
'dense'
in
op
.
tag
:
...
...
@@ -89,5 +90,7 @@ def schedule_dense(outs):
# Parallelization
s
[
C
]
.
parallel
(
yo
)
scheduled_ops
.
append
(
op
)
traverse
(
outs
[
0
]
.
op
)
return
s
topi/python/topi/x86/pooling.py
View file @
bdfcec0e
...
...
@@ -32,6 +32,7 @@ def schedule_pool(outs):
"""
outs
=
[
outs
]
if
isinstance
(
outs
,
tvm
.
tensor
.
Tensor
)
else
outs
s
=
tvm
.
create_schedule
([
x
.
op
for
x
in
outs
])
scheduled_ops
=
[]
def
_schedule
(
PaddedInput
,
Pool
):
if
isinstance
(
PaddedInput
.
op
,
tvm
.
tensor
.
ComputeOp
):
...
...
@@ -45,7 +46,7 @@ def schedule_pool(outs):
if
OP
not
in
s
.
outputs
:
s
[
OP
]
.
compute_inline
()
for
tensor
in
OP
.
input_tensors
:
if
tensor
.
op
.
input_tensors
:
if
tensor
.
op
.
input_tensors
and
tensor
.
op
not
in
scheduled_ops
:
traverse
(
tensor
.
op
)
# schedule pool
elif
OP
.
tag
.
startswith
(
'pool'
):
...
...
@@ -54,6 +55,9 @@ def schedule_pool(outs):
_schedule
(
PaddedInput
,
Pool
)
else
:
raise
RuntimeError
(
"Unsupported operator:
%
s"
%
OP
.
tag
)
scheduled_ops
.
append
(
OP
)
traverse
(
outs
[
0
]
.
op
)
return
s
...
...
@@ -75,6 +79,8 @@ def schedule_global_pool(outs):
"""
outs
=
[
outs
]
if
isinstance
(
outs
,
tvm
.
tensor
.
Tensor
)
else
outs
s
=
tvm
.
create_schedule
([
x
.
op
for
x
in
outs
])
scheduled_ops
=
[]
def
traverse
(
OP
):
"""Internal travserse function"""
# inline all one-to-one-mapping operators except the last stage (output)
...
...
@@ -82,7 +88,7 @@ def schedule_global_pool(outs):
if
OP
not
in
s
.
outputs
:
s
[
OP
]
.
compute_inline
()
for
tensor
in
OP
.
input_tensors
:
if
tensor
.
op
.
input_tensors
:
if
tensor
.
op
.
input_tensors
and
tensor
.
op
not
in
scheduled_ops
:
traverse
(
tensor
.
op
)
# schedule pool
elif
OP
.
tag
.
startswith
(
'global_pool'
):
...
...
@@ -90,5 +96,8 @@ def schedule_global_pool(outs):
_parallel_sch
(
s
[
Pool
])
else
:
raise
RuntimeError
(
"Unsupported operator:
%
s"
%
OP
.
tag
)
scheduled_ops
.
append
(
OP
)
traverse
(
outs
[
0
]
.
op
)
return
s
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment