Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
46b4a914
Commit
46b4a914
authored
Jun 02, 2017
by
Tianqi Chen
Committed by
GitHub
Jun 02, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[PASS] Refactor build config, allow implicit unroll pragma (#167)
parent
86e56824
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
171 additions
and
63 deletions
+171
-63
docs/api/python/build.rst
+2
-0
examples/operator/gemm/cuda_gemm_square.py
+5
-5
examples/operator/rnn/lstm.py
+7
-3
examples/operator/rnn/matexp.py
+7
-6
include/tvm/ir_pass.h
+7
-3
include/tvm/schedule.h
+1
-1
python/tvm/__init__.py
+1
-1
python/tvm/build.py
+82
-19
src/api/api_pass.cc
+7
-1
src/codegen/codegen_cuda.cc
+1
-3
src/codegen/codegen_cuda.h
+0
-3
src/pass/unroll_loop.cc
+45
-12
tests/python/integration/test_gemm.py
+1
-3
tests/python/unittest/test_pass_unroll.py
+5
-3
No files found.
docs/api/python/build.rst
View file @
46b4a914
...
@@ -3,3 +3,5 @@ tvm.build
...
@@ -3,3 +3,5 @@ tvm.build
.. autofunction:: tvm.lower
.. autofunction:: tvm.lower
.. autofunction:: tvm.build
.. autofunction:: tvm.build
.. autofunction:: tvm.build_config
examples/operator/gemm/cuda_gemm_square.py
View file @
46b4a914
...
@@ -95,15 +95,12 @@ def test_gemm():
...
@@ -95,15 +95,12 @@ def test_gemm():
s
[
BB
]
.
bind
(
ty
,
thread_y
)
s
[
BB
]
.
bind
(
ty
,
thread_y
)
s
[
BB
]
.
bind
(
tx
,
thread_x
)
s
[
BB
]
.
bind
(
tx
,
thread_x
)
s
[
BB
]
.
vectorize
(
xi
)
s
[
BB
]
.
vectorize
(
xi
)
max_auto_unroll_step
=
8
# correctness
# correctness
def
check_device
(
device
):
def
check_device
(
device
):
if
not
tvm
.
module
.
enabled
(
device
):
if
not
tvm
.
module
.
enabled
(
device
):
print
(
"Skip because
%
s is not enabled"
%
device
)
print
(
"Skip because
%
s is not enabled"
%
device
)
return
return
f
=
tvm
.
build
(
s
,
[
A
,
B
,
C
],
device
,
f
=
tvm
.
build
(
s
,
[
A
,
B
,
C
],
device
)
max_auto_unroll_step
=
max_auto_unroll_step
)
ctx
=
tvm
.
gpu
(
0
)
if
device
==
"cuda"
else
tvm
.
cl
(
0
)
ctx
=
tvm
.
gpu
(
0
)
if
device
==
"cuda"
else
tvm
.
cl
(
0
)
# launch the kernel.
# launch the kernel.
n
,
m
,
l
=
nn
,
nn
,
nn
n
,
m
,
l
=
nn
,
nn
,
nn
...
@@ -117,7 +114,10 @@ def test_gemm():
...
@@ -117,7 +114,10 @@ def test_gemm():
np
.
testing
.
assert_allclose
(
np
.
testing
.
assert_allclose
(
c
.
asnumpy
(),
np
.
dot
(
b_np
.
T
,
a_np
),
rtol
=
1e-5
)
c
.
asnumpy
(),
np
.
dot
(
b_np
.
T
,
a_np
),
rtol
=
1e-5
)
check_device
(
"cuda"
)
with
tvm
.
build_config
(
auto_unroll_max_step
=
32
,
auto_unroll_min_depth
=
0
,
unroll_explicit
=
False
):
check_device
(
"cuda"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_gemm
()
test_gemm
()
examples/operator/rnn/lstm.py
View file @
46b4a914
...
@@ -147,8 +147,7 @@ def lstm():
...
@@ -147,8 +147,7 @@ def lstm():
def
check_device
(
target
):
def
check_device
(
target
):
num_step
=
n_num_step
num_step
=
n_num_step
flstm
=
tvm
.
build
(
s
,
[
Xi2h
,
Wh2h
,
scan_h
,
scan_c
],
flstm
=
tvm
.
build
(
s
,
[
Xi2h
,
Wh2h
,
scan_h
,
scan_c
],
target
,
target
)
detect_global_barrier
=
DETECT_GLOBAL_BARRIER
)
ctx
=
tvm
.
gpu
(
0
)
if
target
==
"cuda"
else
tvm
.
cl
(
0
)
ctx
=
tvm
.
gpu
(
0
)
if
target
==
"cuda"
else
tvm
.
cl
(
0
)
# launch the kernel.
# launch the kernel.
scan_h_np
=
np
.
zeros
(
scan_h_np
=
np
.
zeros
(
...
@@ -172,7 +171,12 @@ def lstm():
...
@@ -172,7 +171,12 @@ def lstm():
tgap
=
time
.
time
()
-
tstart
tgap
=
time
.
time
()
-
tstart
print
(
"Time cost=
%
g"
%
tgap
)
print
(
"Time cost=
%
g"
%
tgap
)
check_device
(
"cuda"
)
# set unroll_explicit for more readable code.
with
tvm
.
build_config
(
detect_global_barrier
=
DETECT_GLOBAL_BARRIER
,
auto_unroll_max_step
=
128
,
unroll_explicit
=
False
):
check_device
(
"cuda"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
lstm
()
lstm
()
examples/operator/rnn/matexp.py
View file @
46b4a914
...
@@ -15,7 +15,7 @@ from tvm.contrib import nvcc_compiler
...
@@ -15,7 +15,7 @@ from tvm.contrib import nvcc_compiler
import
numpy
as
np
import
numpy
as
np
# Quick knobs
# Quick knobs
TASK
=
"
rnn_
matexp"
TASK
=
"matexp"
USE_MANUAL_CODE
=
False
USE_MANUAL_CODE
=
False
PERSIST_KERNEL
=
True
PERSIST_KERNEL
=
True
DETECT_GLOBAL_BARRIER
=
PERSIST_KERNEL
DETECT_GLOBAL_BARRIER
=
PERSIST_KERNEL
...
@@ -44,7 +44,6 @@ def rnn_matexp():
...
@@ -44,7 +44,6 @@ def rnn_matexp():
n_num_step
=
128
n_num_step
=
128
n_num_hidden
=
1152
n_num_hidden
=
1152
n_batch_size
=
4
n_batch_size
=
4
max_auto_unroll_step
=
0
detect_global_barrier
=
DETECT_GLOBAL_BARRIER
detect_global_barrier
=
DETECT_GLOBAL_BARRIER
num_step
=
tvm
.
var
(
"num_step"
)
num_step
=
tvm
.
var
(
"num_step"
)
...
@@ -111,10 +110,12 @@ def rnn_matexp():
...
@@ -111,10 +110,12 @@ def rnn_matexp():
s
[
SS
]
.
bind
(
tx
,
thread_x
)
s
[
SS
]
.
bind
(
tx
,
thread_x
)
def
check_device
(
target
):
def
check_device
(
target
):
f
=
tvm
.
build
(
s
,
[
s_scan
,
Whh
],
with
tvm
.
build_config
(
target
,
detect_global_barrier
=
detect_global_barrier
,
max_auto_unroll_step
=
max_auto_unroll_step
,
auto_unroll_min_depth
=
2
,
detect_global_barrier
=
detect_global_barrier
)
auto_unroll_max_step
=
128
,
unroll_explicit
=
False
):
f
=
tvm
.
build
(
s
,
[
s_scan
,
Whh
],
target
)
ctx
=
tvm
.
gpu
(
0
)
if
target
==
"cuda"
else
tvm
.
cl
(
0
)
ctx
=
tvm
.
gpu
(
0
)
if
target
==
"cuda"
else
tvm
.
cl
(
0
)
# launch the kernel.
# launch the kernel.
res_np
=
np
.
zeros
(
res_np
=
np
.
zeros
(
...
...
include/tvm/ir_pass.h
View file @
46b4a914
...
@@ -144,12 +144,16 @@ Stmt SplitPipeline(Stmt stmt, bool split_load);
...
@@ -144,12 +144,16 @@ Stmt SplitPipeline(Stmt stmt, bool split_load);
Stmt
NarrowChannelAccess
(
Stmt
stmt
);
Stmt
NarrowChannelAccess
(
Stmt
stmt
);
/*!
/*!
* \brief unroll the constant loops
* \brief unroll the constant loop marked by unroll.
* This pass also automatically attach pragma unroll tag to loops which meets the standard.
*
* \param stmt The statment to be unrolled.
* \param stmt The statment to be unrolled.
* \param max_auto_step The maximum step to stop performing automatic unrolling.
* \param auto_max_step The maximum step before stop attach automatic unroll
* \param auto_min_depth The minimum depth before we can start automatic unroll
* \param explicit_unroll Whether explicitly unroll the loop, or leave unroll annotation to codegen.
* \return Transformed stmt.
* \return Transformed stmt.
*/
*/
Stmt
UnrollLoop
(
Stmt
stmt
,
int
max_auto_step
);
Stmt
UnrollLoop
(
Stmt
stmt
,
int
auto_max_step
,
int
auto_min_depth
,
bool
explicit_unroll
);
/*!
/*!
* \brief vectorize the constant loops
* \brief vectorize the constant loops
...
...
include/tvm/schedule.h
View file @
46b4a914
...
@@ -161,7 +161,7 @@ class Stage : public NodeRef {
...
@@ -161,7 +161,7 @@ class Stage : public NodeRef {
Stage
&
vectorize
(
IterVar
var
);
// NOLINT(*)
Stage
&
vectorize
(
IterVar
var
);
// NOLINT(*)
/*!
/*!
* \brief Unroll iteration.
* \brief Unroll iteration.
* \param var The axis to be
vectoriz
ed.
* \param var The axis to be
unroll
ed.
* \return reference to self.
* \return reference to self.
*/
*/
Stage
&
unroll
(
IterVar
var
);
// NOLINT(*)
Stage
&
unroll
(
IterVar
var
);
// NOLINT(*)
...
...
python/tvm/__init__.py
View file @
46b4a914
...
@@ -26,4 +26,4 @@ from .intrin import *
...
@@ -26,4 +26,4 @@ from .intrin import *
from
.node
import
register_node
from
.node
import
register_node
from
.ndarray
import
register_extension
from
.ndarray
import
register_extension
from
.schedule
import
create_schedule
from
.schedule
import
create_schedule
from
.build
import
build
,
lower
from
.build
import
build
,
lower
,
build_config
python/tvm/build.py
View file @
46b4a914
...
@@ -13,6 +13,77 @@ from . import collections
...
@@ -13,6 +13,77 @@ from . import collections
from
.
import
module
from
.
import
module
from
.
import
codegen
from
.
import
codegen
class
BuildConfig
(
object
):
"""Configuration scope to set a build config option.
Parameters
----------
kwargs
Keyword arguments of configurations to set.
"""
current
=
None
defaults
=
{
'auto_unroll_max_step'
:
0
,
'auto_unroll_min_depth'
:
1
,
'unroll_explicit'
:
True
,
'detect_global_barrier'
:
True
}
def
__init__
(
self
,
**
kwargs
):
self
.
_old_scope
=
None
for
k
,
_
in
kwargs
.
items
():
if
k
not
in
BuildConfig
.
defaults
:
raise
ValueError
(
"invalid argument
%
s, candidates are
%
s"
%
(
k
,
BuildConfig
.
defaults
.
keys
()))
self
.
_attr
=
kwargs
def
__getattr__
(
self
,
name
):
if
name
not
in
self
.
_attr
:
return
BuildConfig
.
defaults
[
name
]
return
self
.
_attr
[
name
]
def
__enter__
(
self
):
# pylint: disable=protected-access
self
.
_old_scope
=
BuildConfig
.
current
attr
=
BuildConfig
.
current
.
_attr
.
copy
()
attr
.
update
(
self
.
_attr
)
self
.
_attr
=
attr
BuildConfig
.
current
=
self
return
self
def
__exit__
(
self
,
ptype
,
value
,
trace
):
assert
self
.
_old_scope
BuildConfig
.
current
=
self
.
_old_scope
BuildConfig
.
current
=
BuildConfig
()
def
build_config
(
**
kwargs
):
"""Configure the build behavior by setting config variables.
Parameters
----------
auto_unroll_max_step: int, default=0
Threshold of loop extent to be automatically unrolled.
auto_unroll_min_depth: int, default=1
The minimum loop nest level before the loop can be automatically unrolled.
unroll_explicit: bool, default=True
Whether explicitly unroll the loop, if set false, the unroll hint will
be passed to the CodeGen phase, which may generate pragma unroll hint.
Set this to be true if CodeGen support unroll pragma and
when we want to be more readable.
detect_global_barrier: bool, default=True
Whether detect global barrier.
Returns
-------
config: BuildConfig
The build configuration
"""
return
BuildConfig
(
**
kwargs
)
def
get_binds
(
args
,
binds
=
None
):
def
get_binds
(
args
,
binds
=
None
):
"""Internal function to get binds and arg_list given arguments.
"""Internal function to get binds and arg_list given arguments.
...
@@ -49,12 +120,12 @@ def get_binds(args, binds=None):
...
@@ -49,12 +120,12 @@ def get_binds(args, binds=None):
raise
ValueError
(
"args must be Tensor, Buffer or Var"
)
raise
ValueError
(
"args must be Tensor, Buffer or Var"
)
return
binds
,
arg_list
return
binds
,
arg_list
def
lower
(
sch
,
def
lower
(
sch
,
args
,
args
,
name
=
"default_function"
,
name
=
"default_function"
,
binds
=
None
,
binds
=
None
,
simple_mode
=
False
,
simple_mode
=
False
):
max_auto_unroll_step
=
0
):
"""Lowering step before build into target.
"""Lowering step before build into target.
Parameters
Parameters
...
@@ -76,9 +147,6 @@ def lower(sch,
...
@@ -76,9 +147,6 @@ def lower(sch,
Whether only output simple and compact statement, this will skip
Whether only output simple and compact statement, this will skip
LoopPartition, api wrapper generation and Unrolling.
LoopPartition, api wrapper generation and Unrolling.
max_auto_unroll_step: int, optional
Maximum step to perform automatic unrolling
Returns
Returns
-------
-------
f : LoweredFunc or Stmt
f : LoweredFunc or Stmt
...
@@ -97,8 +165,12 @@ def lower(sch,
...
@@ -97,8 +165,12 @@ def lower(sch,
stmt
=
ir_pass
.
VectorizeLoop
(
stmt
)
stmt
=
ir_pass
.
VectorizeLoop
(
stmt
)
stmt
=
ir_pass
.
InjectVirtualThread
(
stmt
)
stmt
=
ir_pass
.
InjectVirtualThread
(
stmt
)
stmt
=
ir_pass
.
StorageRewrite
(
stmt
)
stmt
=
ir_pass
.
StorageRewrite
(
stmt
)
if
not
simple_mode
:
cfg
=
BuildConfig
.
current
stmt
=
ir_pass
.
UnrollLoop
(
stmt
,
max_auto_unroll_step
)
stmt
=
ir_pass
.
UnrollLoop
(
stmt
,
cfg
.
auto_unroll_max_step
,
cfg
.
auto_unroll_min_depth
,
cfg
.
unroll_explicit
)
stmt
=
ir_pass
.
Simplify
(
stmt
)
stmt
=
ir_pass
.
Simplify
(
stmt
)
if
simple_mode
:
if
simple_mode
:
return
stmt
return
stmt
...
@@ -110,9 +182,7 @@ def build(sch,
...
@@ -110,9 +182,7 @@ def build(sch,
target
=
"llvm"
,
target
=
"llvm"
,
target_host
=
None
,
target_host
=
None
,
name
=
"default_function"
,
name
=
"default_function"
,
binds
=
None
,
binds
=
None
):
max_auto_unroll_step
=
0
,
detect_global_barrier
=
True
):
"""Build a function with arguments as signiture.
"""Build a function with arguments as signiture.
Parameters
Parameters
...
@@ -142,12 +212,6 @@ def build(sch,
...
@@ -142,12 +212,6 @@ def build(sch,
Dictionary that maps the binding of symbolic buffer to Tensor.
Dictionary that maps the binding of symbolic buffer to Tensor.
By default, a new buffer is created for each tensor in the argument.
By default, a new buffer is created for each tensor in the argument.
max_auto_unroll_step: int, optional
Maximum step to perform automatic unrolling
detect_global_barrier: boolean, optional
Whether detect and inser global barrier
Returns
Returns
-------
-------
f : Function, or pair of functions
f : Function, or pair of functions
...
@@ -158,8 +222,7 @@ def build(sch,
...
@@ -158,8 +222,7 @@ def build(sch,
raise
ValueError
(
"args must be given for build from schedule"
)
raise
ValueError
(
"args must be given for build from schedule"
)
fapi
=
lower
(
sch
,
args
,
fapi
=
lower
(
sch
,
args
,
name
=
name
,
name
=
name
,
binds
=
binds
,
binds
=
binds
)
max_auto_unroll_step
=
max_auto_unroll_step
)
elif
isinstance
(
sch
,
collections
.
LoweredFunc
):
elif
isinstance
(
sch
,
collections
.
LoweredFunc
):
if
args
:
if
args
:
raise
ValueError
(
"args must be done when build from LoweredFunc"
)
raise
ValueError
(
"args must be done when build from LoweredFunc"
)
...
@@ -167,7 +230,7 @@ def build(sch,
...
@@ -167,7 +230,7 @@ def build(sch,
else
:
else
:
raise
ValueError
(
"sch have to be Schedule or LoweredFunc"
)
raise
ValueError
(
"sch have to be Schedule or LoweredFunc"
)
# device related lowering
# device related lowering
if
detect_global_barrier
:
if
BuildConfig
.
current
.
detect_global_barrier
:
fapi
=
ir_pass
.
StorageSync
(
fapi
,
"global"
)
fapi
=
ir_pass
.
StorageSync
(
fapi
,
"global"
)
fapi
=
ir_pass
.
StorageSync
(
fapi
,
"shared"
)
fapi
=
ir_pass
.
StorageSync
(
fapi
,
"shared"
)
warp_size
=
32
if
target
==
"cuda"
else
1
warp_size
=
32
if
target
==
"cuda"
else
1
...
...
src/api/api_pass.cc
View file @
46b4a914
...
@@ -51,6 +51,12 @@ TVM_REGISTER_API("ir_pass.PostOrderVisit")
...
@@ -51,6 +51,12 @@ TVM_REGISTER_API("ir_pass.PostOrderVisit")
*ret = PassName(args[0], args[1]); \
*ret = PassName(args[0], args[1]); \
}) \
}) \
#define REGISTER_PASS3(PassName) \
TVM_REGISTER_API("ir_pass."#PassName) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = PassName(args[0], args[1], args[2]); \
}) \
#define REGISTER_PASS4(PassName) \
#define REGISTER_PASS4(PassName) \
TVM_REGISTER_API("ir_pass."#PassName) \
TVM_REGISTER_API("ir_pass."#PassName) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
...
@@ -64,7 +70,7 @@ REGISTER_PASS4(Inline);
...
@@ -64,7 +70,7 @@ REGISTER_PASS4(Inline);
REGISTER_PASS2
(
StorageFlatten
);
REGISTER_PASS2
(
StorageFlatten
);
REGISTER_PASS1
(
VectorizeLoop
);
REGISTER_PASS1
(
VectorizeLoop
);
REGISTER_PASS2
(
ExprUseVar
);
REGISTER_PASS2
(
ExprUseVar
);
REGISTER_PASS
2
(
UnrollLoop
);
REGISTER_PASS
4
(
UnrollLoop
);
REGISTER_PASS2
(
StorageSync
);
REGISTER_PASS2
(
StorageSync
);
REGISTER_PASS4
(
MakeAPI
);
REGISTER_PASS4
(
MakeAPI
);
REGISTER_PASS1
(
SplitHostDevice
);
REGISTER_PASS1
(
SplitHostDevice
);
...
...
src/codegen/codegen_cuda.cc
View file @
46b4a914
...
@@ -27,10 +27,8 @@ void CodeGenCUDA::AddFunction(LoweredFunc f) {
...
@@ -27,10 +27,8 @@ void CodeGenCUDA::AddFunction(LoweredFunc f) {
}
}
void
CodeGenCUDA
::
VisitStmt_
(
const
ir
::
For
*
op
)
{
void
CodeGenCUDA
::
VisitStmt_
(
const
ir
::
For
*
op
)
{
int
ext
;
CHECK
(
is_zero
(
op
->
min
));
CHECK
(
is_zero
(
op
->
min
));
if
(
arith
::
GetConstInt
(
op
->
extent
,
&
ext
)
&&
if
(
op
->
for_type
==
ir
::
ForType
::
Unrolled
)
{
ext
<=
max_auto_unroll_
)
{
PrintIndent
();
PrintIndent
();
stream
<<
"#pragma unroll
\n
"
;
stream
<<
"#pragma unroll
\n
"
;
}
}
...
...
src/codegen/codegen_cuda.h
View file @
46b4a914
...
@@ -36,9 +36,6 @@ class CodeGenCUDA final : public CodeGenC {
...
@@ -36,9 +36,6 @@ class CodeGenCUDA final : public CodeGenC {
void
VisitStmt_
(
const
Evaluate
*
op
)
final
;
void
VisitStmt_
(
const
Evaluate
*
op
)
final
;
private
:
private
:
// magic number to add pragma unroll to it.
// used to generate code that is compact but still unrolls.
int
max_auto_unroll_
{
32
};
// Whether global barrier is needed.
// Whether global barrier is needed.
bool
need_global_barrier_
{
false
};
bool
need_global_barrier_
{
false
};
// Global barrier state
// Global barrier state
...
...
src/pass/unroll_loop.cc
View file @
46b4a914
...
@@ -16,8 +16,12 @@ namespace ir {
...
@@ -16,8 +16,12 @@ namespace ir {
class
LoopUnroller
:
public
IRMutator
{
class
LoopUnroller
:
public
IRMutator
{
public
:
public
:
explicit
LoopUnroller
(
int
max_auto_step
)
explicit
LoopUnroller
(
int
auto_max_step
,
:
max_auto_step_
(
max_auto_step
)
{
int
auto_min_depth
,
bool
explicit_unroll
)
:
auto_max_step_
(
auto_max_step
),
auto_min_depth_
(
auto_min_depth
),
explicit_unroll_
(
explicit_unroll
)
{
}
}
Stmt
Mutate_
(
const
For
*
op
,
const
Stmt
&
s
)
{
Stmt
Mutate_
(
const
For
*
op
,
const
Stmt
&
s
)
{
...
@@ -33,15 +37,16 @@ class LoopUnroller : public IRMutator {
...
@@ -33,15 +37,16 @@ class LoopUnroller : public IRMutator {
if
(
v2
!=
nullptr
)
{
if
(
v2
!=
nullptr
)
{
value
=
static_cast
<
int
>
(
v2
->
value
);
value
=
static_cast
<
int
>
(
v2
->
value
);
}
}
bool
allow_unroll
=
(
op
->
for_type
==
ForType
::
Serial
&&
bool
auto_unroll
=
(
op
->
for_type
==
ForType
::
Serial
&&
value
>=
0
&&
value
<=
max_auto_step_
);
value
>=
0
&&
value
<=
auto_max_step_
&&
loop_depth_
>=
auto_min_depth_
);
if
(
op
->
for_type
==
ForType
::
Unrolled
)
{
if
(
op
->
for_type
==
ForType
::
Unrolled
)
{
CHECK_GE
(
value
,
0
)
CHECK_GE
(
value
,
0
)
<<
"Cannot unroll non-constant loop"
;
<<
"Cannot unroll non-constant loop"
;
a
llow
_unroll
=
true
;
a
uto
_unroll
=
true
;
}
}
if
(
a
llow_unroll
)
{
if
(
a
uto_unroll
&&
explicit_unroll_
)
{
using
arith
::
ComputeExpr
;
using
arith
::
ComputeExpr
;
if
(
value
==
0
)
return
Evaluate
::
make
(
0
);
if
(
value
==
0
)
return
Evaluate
::
make
(
0
);
Stmt
body
=
op
->
body
;
Stmt
body
=
op
->
body
;
...
@@ -59,20 +64,48 @@ class LoopUnroller : public IRMutator {
...
@@ -59,20 +64,48 @@ class LoopUnroller : public IRMutator {
unrolled
=
step
;
unrolled
=
step
;
}
}
}
}
return
this
->
Mutate
(
unrolled
);
++
loop_depth_
;
Stmt
ret
=
this
->
Mutate
(
unrolled
);
--
loop_depth_
;
return
ret
;
}
else
{
}
else
{
return
IRMutator
::
Mutate_
(
op
,
stmt
);
++
loop_depth_
;
Stmt
ret
=
IRMutator
::
Mutate_
(
op
,
stmt
);
if
(
auto_unroll
)
{
op
=
ret
.
as
<
For
>
();
if
(
op
->
for_type
!=
ForType
::
Unrolled
)
{
ret
=
For
::
make
(
op
->
loop_var
,
op
->
min
,
op
->
extent
,
ForType
::
Unrolled
,
op
->
device_api
,
op
->
body
);
}
}
--
loop_depth_
;
return
ret
;
}
}
}
}
private
:
private
:
int
max_auto_step_
;
// maximum number of step to perform auto unroll.
int
auto_max_step_
;
int
auto_min_depth_
;
bool
explicit_unroll_
;
int
loop_depth_
{
0
};
};
};
Stmt
UnrollLoop
(
Stmt
stmt
,
int
max_auto_step
)
{
Stmt
UnrollLoop
(
Stmt
stmt
,
Stmt
ret
=
LoopUnroller
(
max_auto_step
).
Mutate
(
stmt
);
int
auto_max_step
,
return
ConvertSSA
(
ret
);
int
auto_min_depth
,
bool
explicit_unroll
)
{
Stmt
ret
=
LoopUnroller
(
auto_max_step
,
auto_min_depth
,
explicit_unroll
).
Mutate
(
stmt
);
if
(
!
ret
.
same_as
(
stmt
))
{
return
ConvertSSA
(
ret
);
}
else
{
return
ret
;
}
}
}
}
// namespace ir
}
// namespace ir
...
...
tests/python/integration/test_gemm.py
View file @
46b4a914
...
@@ -58,7 +58,6 @@ def test_gemm():
...
@@ -58,7 +58,6 @@ def test_gemm():
s
[
BB
]
.
bind
(
ty
,
thread_y
)
s
[
BB
]
.
bind
(
ty
,
thread_y
)
s
[
BB
]
.
bind
(
tx
,
thread_x
)
s
[
BB
]
.
bind
(
tx
,
thread_x
)
max_auto_unroll_step
=
0
# lowering test
# lowering test
s
=
s
.
normalize
()
s
=
s
.
normalize
()
...
@@ -68,8 +67,7 @@ def test_gemm():
...
@@ -68,8 +67,7 @@ def test_gemm():
print
(
"skip because
%
s is not enabled.."
%
device
)
print
(
"skip because
%
s is not enabled.."
%
device
)
return
return
f
=
tvm
.
build
(
s
,
[
A
,
B
,
C
],
device
,
f
=
tvm
.
build
(
s
,
[
A
,
B
,
C
],
device
)
max_auto_unroll_step
=
max_auto_unroll_step
)
ctx
=
tvm
.
context
(
device
,
0
)
ctx
=
tvm
.
context
(
device
,
0
)
# launch the kernel.
# launch the kernel.
n
=
nn
n
=
nn
...
...
tests/python/unittest/test_pass_unroll.py
View file @
46b4a914
...
@@ -14,9 +14,11 @@ def test_unroll_loop():
...
@@ -14,9 +14,11 @@ def test_unroll_loop():
tvm
.
make
.
Load
(
dtype
,
Ab
.
data
,
i
)
+
1
,
tvm
.
make
.
Load
(
dtype
,
Ab
.
data
,
i
)
+
1
,
j
+
1
)))
j
+
1
)))
assert
isinstance
(
stmt
,
tvm
.
stmt
.
For
)
assert
isinstance
(
stmt
,
tvm
.
stmt
.
For
)
stmt
=
tvm
.
ir_pass
.
UnrollLoop
(
stmt
,
4
)
ret
=
tvm
.
ir_pass
.
UnrollLoop
(
stmt
,
2
,
0
,
True
)
assert
not
isinstance
(
stmt
,
tvm
.
stmt
.
For
)
assert
not
isinstance
(
ret
,
tvm
.
stmt
.
For
)
print
(
stmt
)
ret
=
tvm
.
ir_pass
.
UnrollLoop
(
stmt
,
4
,
0
,
False
)
assert
isinstance
(
ret
,
tvm
.
stmt
.
For
)
assert
ret
.
for_type
==
tvm
.
stmt
.
For
.
Unrolled
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_unroll_loop
()
test_unroll_loop
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment