Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
d3277874
Unverified
Commit
d3277874
authored
Apr 21, 2020
by
Tianqi Chen
Committed by
GitHub
Apr 21, 2020
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[PTYTHON] Migrate VTA TIR passes to the new pass manager. (#5397)
parent
72f2aea2
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
88 additions
and
78 deletions
+88
-78
include/tvm/target/target.h
+3
-2
python/tvm/autotvm/measure/measure_methods.py
+4
-4
python/tvm/driver/build_module.py
+5
-24
python/tvm/tir/function.py
+16
-0
src/target/target.cc
+2
-2
tests/python/relay/test_pass_fold_constant.py
+5
-3
tests/python/unittest/test_target_codegen_cuda.py
+7
-3
tests/python/unittest/test_target_codegen_llvm.py
+7
-4
tests/python/unittest/test_tir_pass_verify_gpu_code.py
+4
-4
tutorials/dev/low_level_custom_pass.py
+6
-5
vta/python/vta/build_module.py
+29
-27
vta/python/vta/transform.py
+0
-0
No files found.
include/tvm/target/target.h
View file @
d3277874
...
...
@@ -27,6 +27,7 @@
#include <tvm/support/with.h>
#include <tvm/node/container.h>
#include <tvm/ir/expr.h>
#include <tvm/ir/transform.h>
#include <string>
#include <vector>
...
...
@@ -225,8 +226,8 @@ class BuildConfigNode : public Object {
/*! \brief Whether to partition const loop */
bool
partition_const_loop
=
false
;
/*! \brief
Whether to dump the IR of each pass (only when building from python)
*/
std
::
vector
<
std
::
pair
<
int
,
runtime
::
PackedFunc
>
>
add_lower_pass
;
/*! \brief
List of passes to be injected into the low-level pipeline.
*/
std
::
vector
<
std
::
pair
<
int
,
transform
::
Pass
>
>
add_lower_pass
;
/*! \brief Whether to dump the IR of each pass (only when building from python) */
bool
dump_pass_ir
=
false
;
...
...
python/tvm/autotvm/measure/measure_methods.py
View file @
d3277874
...
...
@@ -615,9 +615,9 @@ def gpu_verify_pass(**kwargs):
"""Verify the validity of a gpu kernel.
This pass will check memory usage and number of threads per block.
"""
def
verify_pass
(
stmt
):
valid
=
ir_pass
.
VerifyGPUCode
(
stmt
,
kwargs
)
def
verify_pass
(
f
,
*
_
):
valid
=
ir_pass
.
VerifyGPUCode
(
f
.
body
,
kwargs
)
if
not
valid
:
raise
InstantiationError
(
"Skipped because of invalid gpu kernel"
)
return
stmt
return
verify_pass
return
f
return
tvm
.
tir
.
transform
.
prim_func_pass
(
verify_pass
,
opt_level
=
0
)
python/tvm/driver/build_module.py
View file @
d3277874
...
...
@@ -123,25 +123,6 @@ def form_irmodule(sch, args, name, binds):
return
tvm
.
IRModule
({
name
:
func
})
def
_wrap_as_prim_func_pass
(
flist
,
name
):
"""Wrap flist as a function pass.
This is an temporary adapter before we fully
migrate to the new pass manager.
"""
def
_transform
(
func
,
*
_
):
stmt
=
func
.
body
for
f
in
flist
:
stmt
=
f
(
stmt
)
# create a new function with updated body.
return
tvm
.
tir
.
PrimFunc
(
func
.
params
,
stmt
,
func
.
ret_type
,
func
.
buffer_map
,
func
.
attrs
)
return
tvm
.
tir
.
transform
.
prim_func_pass
(
_transform
,
opt_level
=
0
,
name
=
name
)
def
lower
(
sch
,
args
,
name
=
"main"
,
...
...
@@ -190,15 +171,15 @@ def lower(sch,
else
:
mod
=
sch
pass_list
=
lower_phase0
# Phase 1
pass_list
=
[
_wrap_as_prim_func_pass
(
lower_phase0
,
"Custom-Phase0"
),
pass_list
+=
[
tvm
.
tir
.
transform
.
InjectPrefetch
(),
tvm
.
tir
.
transform
.
StorageFlatten
(
64
,
cfg
.
instrument_bound_checkers
),
tvm
.
tir
.
transform
.
NarrowDataType
(
32
),
tvm
.
tir
.
transform
.
Simplify
(),
_wrap_as_prim_func_pass
(
lower_phase1
,
"Custom-Phase1"
),
]
pass_list
+=
lower_phase1
# Phase 2
if
not
simple_mode
:
...
...
@@ -214,8 +195,8 @@ def lower(sch,
cfg
.
auto_unroll_max_depth
,
cfg
.
auto_unroll_max_extent
,
cfg
.
unroll_explicit
),
_wrap_as_prim_func_pass
(
lower_phase2
,
"Custom-Phase2"
),
]
pass_list
+=
lower_phase2
# Phase 3
pass_list
+=
[
...
...
@@ -225,7 +206,7 @@ def lower(sch,
if
not
cfg
.
disable_select_rewriting
:
pass_list
+=
[
tvm
.
tir
.
transform
.
RewriteUnsafeSelect
()]
pass_list
+=
[
_wrap_as_prim_func_pass
(
lower_phase3
,
"Custom-Phase3"
)]
pass_list
+=
lower_phase3
# Instrument BoundCheckers
if
cfg
.
instrument_bound_checkers
:
...
...
python/tvm/tir/function.py
View file @
d3277874
...
...
@@ -67,3 +67,19 @@ class PrimFunc(BaseFunc):
self
.
__init_handle_by_constructor__
(
_ffi_api
.
PrimFunc
,
param_list
,
body
,
ret_type
,
buffer_map
,
attrs
)
def
with_body
(
self
,
new_body
):
"""Create a new PrimFunc with the same set signatures but a new body.
Parameters
----------
new_body : Stmt
The new body.
Returns
-------
new_func : PrimFunc
The created new function.
"""
return
PrimFunc
(
self
.
params
,
new_body
,
self
.
ret_type
,
self
.
buffer_map
,
self
.
attrs
)
src/target/target.cc
View file @
d3277874
...
...
@@ -434,12 +434,12 @@ TVM_REGISTER_GLOBAL("target.ExitBuildConfigScope")
TVM_REGISTER_GLOBAL
(
"target.BuildConfigSetAddLowerPass"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
BuildConfig
cfg
=
args
[
0
];
std
::
vector
<
std
::
pair
<
int
,
PackedFunc
>
>
add_lower_pass
;
std
::
vector
<
std
::
pair
<
int
,
transform
::
Pass
>
>
add_lower_pass
;
CHECK_EQ
(
args
.
size
()
%
2
,
1
);
for
(
int
i
=
1
;
i
<
args
.
size
();
i
+=
2
)
{
add_lower_pass
.
push_back
(
std
::
make_pair
(
args
[
i
].
operator
int
(),
args
[
i
+
1
].
operator
t
vm
::
runtime
::
PackedFunc
()));
args
[
i
+
1
].
operator
t
ransform
::
Pass
()));
}
cfg
->
add_lower_pass
=
add_lower_pass
;
});
...
...
tests/python/relay/test_pass_fold_constant.py
View file @
d3277874
...
...
@@ -51,11 +51,13 @@ def test_fold_const():
z
=
relay
.
add
(
y
,
relay
.
const
(
c_data
))
return
relay
.
Function
([
x
],
z
)
def
fail
(
x
):
raise
RuntimeError
()
def
FailPass
():
def
_transform
(
m
,
*
args
):
raise
RuntimeError
()
return
tvm
.
transform
.
module_pass
(
_transform
,
opt_level
=
0
)
# the fold constant should work on any context.
with
tvm
.
target
.
build_config
(
add_lower_pass
=
[(
0
,
fail
)]):
with
tvm
.
target
.
build_config
(
add_lower_pass
=
[(
0
,
FailPass
()
)]):
with
tvm
.
target
.
create
(
"cuda"
):
zz
=
run_opt_pass
(
before
(),
transform
.
FoldConstant
())
zexpected
=
run_opt_pass
(
expected
(),
transform
.
InferType
())
...
...
tests/python/unittest/test_target_codegen_cuda.py
View file @
d3277874
...
...
@@ -182,7 +182,7 @@ def test_cuda_shuffle():
sch
[
c
]
.
bind
(
xo
,
thrx
)
sch
[
c
]
.
vectorize
(
xi
)
def
my_vectorize
(
stmt
):
def
MyVectorize
(
):
def
vectorizer
(
op
):
if
op
.
for_type
==
tvm
.
tir
.
For
.
Vectorized
:
four
=
tvm
.
tir
.
const
(
4
,
'int32'
)
...
...
@@ -198,9 +198,13 @@ def test_cuda_shuffle():
new_b
=
tvm
.
tir
.
Shuffle
(
bs
,
ids
)
return
tvm
.
tir
.
Store
(
store
.
buffer_var
,
new_a
+
new_b
,
idx
,
all_ones
)
return
None
return
tvm
.
tir
.
ir_pass
.
IRTransform
(
stmt
,
None
,
vectorizer
,
[
'For'
])
with
tvm
.
target
.
build_config
(
add_lower_pass
=
[(
1
,
my_vectorize
)]):
def
_transform
(
f
,
*
_
):
return
f
.
with_body
(
tvm
.
tir
.
ir_pass
.
IRTransform
(
f
.
body
,
None
,
vectorizer
,
[
'For'
]))
return
tvm
.
tir
.
transform
.
prim_func_pass
(
_transform
,
opt_level
=
0
,
name
=
"MyVectorize"
)
with
tvm
.
target
.
build_config
(
add_lower_pass
=
[(
1
,
MyVectorize
())]):
module
=
tvm
.
build
(
sch
,
[
a
,
b
,
c
],
target
=
'cuda'
)
a_
=
np
.
array
(
list
(
range
(
64
)),
dtype
=
'int32'
)
b_
=
np
.
array
((
list
(
range
(
4
))[::
-
1
])
*
16
,
dtype
=
'int32'
)
...
...
tests/python/unittest/test_target_codegen_llvm.py
View file @
d3277874
...
...
@@ -671,8 +671,7 @@ def test_llvm_shuffle():
c
=
te
.
compute
((
8
,
),
lambda
x
:
a
[
x
]
+
b
[
7
-
x
])
sch
=
te
.
create_schedule
(
c
.
op
)
def
my_vectorize
(
stmt
):
def
my_vectorize
():
def
vectorizer
(
op
):
store
=
op
.
body
idx
=
tvm
.
tir
.
Ramp
(
tvm
.
tir
.
const
(
0
,
'int32'
),
tvm
.
tir
.
const
(
1
,
'int32'
),
8
)
...
...
@@ -684,9 +683,13 @@ def test_llvm_shuffle():
value
=
new_a
+
new_b
return
tvm
.
tir
.
Store
(
store
.
buffer_var
,
new_a
+
new_b
,
idx
,
all_ones
)
return
tvm
.
tir
.
ir_pass
.
IRTransform
(
stmt
,
None
,
vectorizer
,
[
'For'
])
def
_transform
(
f
,
*
_
):
return
f
.
with_body
(
tvm
.
tir
.
ir_pass
.
IRTransform
(
f
.
body
,
None
,
vectorizer
,
[
'For'
]))
return
tvm
.
tir
.
transform
.
prim_func_pass
(
_transform
,
opt_level
=
0
,
name
=
"my_vectorize"
)
with
tvm
.
target
.
build_config
(
add_lower_pass
=
[(
1
,
my_vectorize
)]):
with
tvm
.
target
.
build_config
(
add_lower_pass
=
[(
1
,
my_vectorize
()
)]):
ir
=
tvm
.
lower
(
sch
,
[
a
,
b
,
c
],
simple_mode
=
True
)
module
=
tvm
.
build
(
sch
,
[
a
,
b
,
c
])
a_
=
tvm
.
nd
.
array
(
np
.
arange
(
1
,
9
,
dtype
=
'int32'
))
...
...
tests/python/unittest/test_tir_pass_verify_gpu_code.py
View file @
d3277874
...
...
@@ -19,10 +19,10 @@ import tvm
from
tvm
import
te
def
get_verify_pass
(
valid
,
**
kwargs
):
def
verify_pass
(
stmt
):
valid
[
0
]
=
tvm
.
tir
.
ir_pass
.
VerifyGPUCode
(
stmt
,
kwargs
)
return
stmt
return
verify_pass
def
_fverify
(
f
,
*
_
):
valid
[
0
]
=
tvm
.
tir
.
ir_pass
.
VerifyGPUCode
(
f
.
body
,
kwargs
)
return
f
return
tvm
.
tir
.
transform
.
prim_func_pass
(
_fverify
,
opt_level
=
0
)
def
test_shared_memory
():
def
check_shared_memory
(
dtype
):
...
...
tutorials/dev/low_level_custom_pass.py
View file @
d3277874
...
...
@@ -117,19 +117,20 @@ def vectorize8(op):
return
body
return
None
def
vectorize
(
stmt
):
@tvm.tir.transform.prim_func_pass
(
opt_level
=
0
)
def
vectorize
(
f
,
mod
,
ctx
):
global
loops
tvm
.
tir
.
ir_pass
.
PostOrderVisit
(
stmt
,
find_width8
)
tvm
.
tir
.
ir_pass
.
PostOrderVisit
(
f
.
body
,
find_width8
)
if
not
loops
:
return
s
tmt
return
s
f
# The last list arugment indicates what kinds of nodes will be transformed.
# Thus, in this case only `For` nodes will call `vectorize8`
stmt
=
tvm
.
tir
.
ir_pass
.
IRTransform
(
stmt
,
None
,
vectorize8
,
[
'For'
])
return
f
.
with_body
(
tvm
.
tir
.
ir_pass
.
IRTransform
(
f
.
body
,
None
,
vectorize8
,
[
'For'
]))
return
stmt
#####################################################################
# Glue to Lowering
...
...
vta/python/vta/build_module.py
View file @
d3277874
...
...
@@ -14,25 +14,22 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-argument
# pylint: disable=unused-argument
, invalid-name
"""VTA specific buildin for runtime."""
import
tvm
from
.
import
ir_pass
from
.
import
transform
from
.environment
import
get_env
def
lift_coproc_scope
(
x
):
"""Lift coprocessings cope to the """
x
=
ir_pass
.
lift_alloc_to_scope_begin
(
x
)
x
=
tvm
.
tir
.
ir_pass
.
LiftAttrScope
(
x
,
"coproc_scope"
,
False
)
return
x
def
early_rewrite
(
stmt
):
def
EarlyRewrite
():
"""Try to do storage rewrite in early pass."""
try
:
return
tvm
.
tir
.
ir_pass
.
StorageRewrite
(
stmt
)
except
tvm
.
error
.
TVMError
:
return
stmt
def
_transform
(
mod
,
ctx
):
try
:
return
tvm
.
tir
.
transform
.
StorageRewrite
()(
mod
)
except
tvm
.
error
.
TVMError
:
return
mod
return
tvm
.
transform
.
module_pass
(
_transform
,
opt_level
=
0
,
name
=
"tir.vta.EarlyRewrite"
)
def
build_config
(
debug_flag
=
0
,
**
kwargs
):
...
...
@@ -60,27 +57,32 @@ def build_config(debug_flag=0, **kwargs):
vta_module = tvm.build(s, ...)
"""
env
=
get_env
()
def
add_debug
(
stmt
):
@tvm.tir.transform.prim_func_pass
(
opt_level
=
0
)
def
add_debug
(
f
,
*
_
):
debug
=
tvm
.
tir
.
call_extern
(
"int32"
,
"VTASetDebugMode"
,
env
.
dev
.
command_handle
,
debug_flag
)
return
tvm
.
tir
.
stmt_seq
(
debug
,
stmt
)
pass_list
=
[(
0
,
ir_pass
.
inject_conv2d_transpose_skip
),
(
1
,
ir_pass
.
inject_dma_intrin
),
(
1
,
ir_pass
.
inject_skip_copy
),
(
1
,
ir_pass
.
annotate_alu_coproc_scope
),
(
1
,
lambda
x
:
tvm
.
tir
.
ir_pass
.
LiftAttrScope
(
x
,
"coproc_uop_scope"
,
True
)),
(
1
,
lift_coproc_scope
),
(
1
,
ir_pass
.
inject_coproc_sync
),
(
1
,
early_rewrite
)]
return
f
.
with_body
(
tvm
.
tir
.
stmt_seq
(
debug
,
f
.
body
))
pass_list
=
[(
0
,
transform
.
InjectConv2DTransposeSkip
()),
(
1
,
transform
.
InjectDMAIntrin
()),
(
1
,
transform
.
InjectSkipCopy
()),
(
1
,
transform
.
AnnotateALUCoProcScope
()),
(
1
,
tvm
.
tir
.
transform
.
LiftAttrScope
(
"coproc_uop_scope"
)),
(
1
,
transform
.
LiftAllocToScopeBegin
()),
(
1
,
tvm
.
tir
.
transform
.
LiftAttrScope
(
"coproc_scope"
)),
(
1
,
transform
.
InjectCoProcSync
()),
(
1
,
EarlyRewrite
())]
if
debug_flag
:
pass_list
.
append
((
1
,
add_debug
))
pass_list
.
append
((
2
,
ir_pass
.
inject_alu_intrin
))
pass_list
.
append
((
3
,
tvm
.
tir
.
ir_pass
.
LowerStorageAccessInfo
))
pass_list
.
append
((
3
,
ir_pass
.
fold_uop_loop
))
pass_list
.
append
((
3
,
ir_pass
.
cpu_access_rewrite
))
pass_list
.
append
((
2
,
transform
.
InjectALUIntrin
()
))
pass_list
.
append
((
3
,
tvm
.
tir
.
transform
.
LowerDeviceStorageAccessInfo
()
))
pass_list
.
append
((
3
,
transform
.
FoldUopLoop
()
))
pass_list
.
append
((
3
,
transform
.
CPUAccessRewrite
()
))
return
tvm
.
target
.
build_config
(
add_lower_pass
=
pass_list
,
**
kwargs
)
...
...
vta/python/vta/
ir_pass
.py
→
vta/python/vta/
transform
.py
View file @
d3277874
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment