Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
d3277874
Unverified
Commit
d3277874
authored
Apr 21, 2020
by
Tianqi Chen
Committed by
GitHub
Apr 21, 2020
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[PTYTHON] Migrate VTA TIR passes to the new pass manager. (#5397)
parent
72f2aea2
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
170 additions
and
193 deletions
+170
-193
include/tvm/target/target.h
+3
-2
python/tvm/autotvm/measure/measure_methods.py
+4
-4
python/tvm/driver/build_module.py
+5
-24
python/tvm/tir/function.py
+16
-0
src/target/target.cc
+2
-2
tests/python/relay/test_pass_fold_constant.py
+4
-2
tests/python/unittest/test_target_codegen_cuda.py
+7
-3
tests/python/unittest/test_target_codegen_llvm.py
+7
-4
tests/python/unittest/test_tir_pass_verify_gpu_code.py
+4
-4
tutorials/dev/low_level_custom_pass.py
+6
-5
vta/python/vta/build_module.py
+27
-25
vta/python/vta/transform.py
+85
-118
No files found.
include/tvm/target/target.h
View file @
d3277874
...
@@ -27,6 +27,7 @@
...
@@ -27,6 +27,7 @@
#include <tvm/support/with.h>
#include <tvm/support/with.h>
#include <tvm/node/container.h>
#include <tvm/node/container.h>
#include <tvm/ir/expr.h>
#include <tvm/ir/expr.h>
#include <tvm/ir/transform.h>
#include <string>
#include <string>
#include <vector>
#include <vector>
...
@@ -225,8 +226,8 @@ class BuildConfigNode : public Object {
...
@@ -225,8 +226,8 @@ class BuildConfigNode : public Object {
/*! \brief Whether to partition const loop */
/*! \brief Whether to partition const loop */
bool
partition_const_loop
=
false
;
bool
partition_const_loop
=
false
;
/*! \brief
Whether to dump the IR of each pass (only when building from python)
*/
/*! \brief
List of passes to be injected into the low-level pipeline.
*/
std
::
vector
<
std
::
pair
<
int
,
runtime
::
PackedFunc
>
>
add_lower_pass
;
std
::
vector
<
std
::
pair
<
int
,
transform
::
Pass
>
>
add_lower_pass
;
/*! \brief Whether to dump the IR of each pass (only when building from python) */
/*! \brief Whether to dump the IR of each pass (only when building from python) */
bool
dump_pass_ir
=
false
;
bool
dump_pass_ir
=
false
;
...
...
python/tvm/autotvm/measure/measure_methods.py
View file @
d3277874
...
@@ -615,9 +615,9 @@ def gpu_verify_pass(**kwargs):
...
@@ -615,9 +615,9 @@ def gpu_verify_pass(**kwargs):
"""Verify the validity of a gpu kernel.
"""Verify the validity of a gpu kernel.
This pass will check memory usage and number of threads per block.
This pass will check memory usage and number of threads per block.
"""
"""
def
verify_pass
(
stmt
):
def
verify_pass
(
f
,
*
_
):
valid
=
ir_pass
.
VerifyGPUCode
(
stmt
,
kwargs
)
valid
=
ir_pass
.
VerifyGPUCode
(
f
.
body
,
kwargs
)
if
not
valid
:
if
not
valid
:
raise
InstantiationError
(
"Skipped because of invalid gpu kernel"
)
raise
InstantiationError
(
"Skipped because of invalid gpu kernel"
)
return
stmt
return
f
return
verify_pass
return
tvm
.
tir
.
transform
.
prim_func_pass
(
verify_pass
,
opt_level
=
0
)
python/tvm/driver/build_module.py
View file @
d3277874
...
@@ -123,25 +123,6 @@ def form_irmodule(sch, args, name, binds):
...
@@ -123,25 +123,6 @@ def form_irmodule(sch, args, name, binds):
return
tvm
.
IRModule
({
name
:
func
})
return
tvm
.
IRModule
({
name
:
func
})
def
_wrap_as_prim_func_pass
(
flist
,
name
):
"""Wrap flist as a function pass.
This is an temporary adapter before we fully
migrate to the new pass manager.
"""
def
_transform
(
func
,
*
_
):
stmt
=
func
.
body
for
f
in
flist
:
stmt
=
f
(
stmt
)
# create a new function with updated body.
return
tvm
.
tir
.
PrimFunc
(
func
.
params
,
stmt
,
func
.
ret_type
,
func
.
buffer_map
,
func
.
attrs
)
return
tvm
.
tir
.
transform
.
prim_func_pass
(
_transform
,
opt_level
=
0
,
name
=
name
)
def
lower
(
sch
,
def
lower
(
sch
,
args
,
args
,
name
=
"main"
,
name
=
"main"
,
...
@@ -190,15 +171,15 @@ def lower(sch,
...
@@ -190,15 +171,15 @@ def lower(sch,
else
:
else
:
mod
=
sch
mod
=
sch
pass_list
=
lower_phase0
# Phase 1
# Phase 1
pass_list
=
[
pass_list
+=
[
_wrap_as_prim_func_pass
(
lower_phase0
,
"Custom-Phase0"
),
tvm
.
tir
.
transform
.
InjectPrefetch
(),
tvm
.
tir
.
transform
.
InjectPrefetch
(),
tvm
.
tir
.
transform
.
StorageFlatten
(
64
,
cfg
.
instrument_bound_checkers
),
tvm
.
tir
.
transform
.
StorageFlatten
(
64
,
cfg
.
instrument_bound_checkers
),
tvm
.
tir
.
transform
.
NarrowDataType
(
32
),
tvm
.
tir
.
transform
.
NarrowDataType
(
32
),
tvm
.
tir
.
transform
.
Simplify
(),
tvm
.
tir
.
transform
.
Simplify
(),
_wrap_as_prim_func_pass
(
lower_phase1
,
"Custom-Phase1"
),
]
]
pass_list
+=
lower_phase1
# Phase 2
# Phase 2
if
not
simple_mode
:
if
not
simple_mode
:
...
@@ -214,8 +195,8 @@ def lower(sch,
...
@@ -214,8 +195,8 @@ def lower(sch,
cfg
.
auto_unroll_max_depth
,
cfg
.
auto_unroll_max_depth
,
cfg
.
auto_unroll_max_extent
,
cfg
.
auto_unroll_max_extent
,
cfg
.
unroll_explicit
),
cfg
.
unroll_explicit
),
_wrap_as_prim_func_pass
(
lower_phase2
,
"Custom-Phase2"
),
]
]
pass_list
+=
lower_phase2
# Phase 3
# Phase 3
pass_list
+=
[
pass_list
+=
[
...
@@ -225,7 +206,7 @@ def lower(sch,
...
@@ -225,7 +206,7 @@ def lower(sch,
if
not
cfg
.
disable_select_rewriting
:
if
not
cfg
.
disable_select_rewriting
:
pass_list
+=
[
tvm
.
tir
.
transform
.
RewriteUnsafeSelect
()]
pass_list
+=
[
tvm
.
tir
.
transform
.
RewriteUnsafeSelect
()]
pass_list
+=
[
_wrap_as_prim_func_pass
(
lower_phase3
,
"Custom-Phase3"
)]
pass_list
+=
lower_phase3
# Instrument BoundCheckers
# Instrument BoundCheckers
if
cfg
.
instrument_bound_checkers
:
if
cfg
.
instrument_bound_checkers
:
...
...
python/tvm/tir/function.py
View file @
d3277874
...
@@ -67,3 +67,19 @@ class PrimFunc(BaseFunc):
...
@@ -67,3 +67,19 @@ class PrimFunc(BaseFunc):
self
.
__init_handle_by_constructor__
(
self
.
__init_handle_by_constructor__
(
_ffi_api
.
PrimFunc
,
param_list
,
body
,
ret_type
,
buffer_map
,
attrs
)
_ffi_api
.
PrimFunc
,
param_list
,
body
,
ret_type
,
buffer_map
,
attrs
)
def
with_body
(
self
,
new_body
):
"""Create a new PrimFunc with the same set signatures but a new body.
Parameters
----------
new_body : Stmt
The new body.
Returns
-------
new_func : PrimFunc
The created new function.
"""
return
PrimFunc
(
self
.
params
,
new_body
,
self
.
ret_type
,
self
.
buffer_map
,
self
.
attrs
)
src/target/target.cc
View file @
d3277874
...
@@ -434,12 +434,12 @@ TVM_REGISTER_GLOBAL("target.ExitBuildConfigScope")
...
@@ -434,12 +434,12 @@ TVM_REGISTER_GLOBAL("target.ExitBuildConfigScope")
TVM_REGISTER_GLOBAL
(
"target.BuildConfigSetAddLowerPass"
)
TVM_REGISTER_GLOBAL
(
"target.BuildConfigSetAddLowerPass"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
BuildConfig
cfg
=
args
[
0
];
BuildConfig
cfg
=
args
[
0
];
std
::
vector
<
std
::
pair
<
int
,
PackedFunc
>
>
add_lower_pass
;
std
::
vector
<
std
::
pair
<
int
,
transform
::
Pass
>
>
add_lower_pass
;
CHECK_EQ
(
args
.
size
()
%
2
,
1
);
CHECK_EQ
(
args
.
size
()
%
2
,
1
);
for
(
int
i
=
1
;
i
<
args
.
size
();
i
+=
2
)
{
for
(
int
i
=
1
;
i
<
args
.
size
();
i
+=
2
)
{
add_lower_pass
.
push_back
(
std
::
make_pair
(
add_lower_pass
.
push_back
(
std
::
make_pair
(
args
[
i
].
operator
int
(),
args
[
i
].
operator
int
(),
args
[
i
+
1
].
operator
t
vm
::
runtime
::
PackedFunc
()));
args
[
i
+
1
].
operator
t
ransform
::
Pass
()));
}
}
cfg
->
add_lower_pass
=
add_lower_pass
;
cfg
->
add_lower_pass
=
add_lower_pass
;
});
});
...
...
tests/python/relay/test_pass_fold_constant.py
View file @
d3277874
...
@@ -51,11 +51,13 @@ def test_fold_const():
...
@@ -51,11 +51,13 @@ def test_fold_const():
z
=
relay
.
add
(
y
,
relay
.
const
(
c_data
))
z
=
relay
.
add
(
y
,
relay
.
const
(
c_data
))
return
relay
.
Function
([
x
],
z
)
return
relay
.
Function
([
x
],
z
)
def
fail
(
x
):
def
FailPass
():
def
_transform
(
m
,
*
args
):
raise
RuntimeError
()
raise
RuntimeError
()
return
tvm
.
transform
.
module_pass
(
_transform
,
opt_level
=
0
)
# the fold constant should work on any context.
# the fold constant should work on any context.
with
tvm
.
target
.
build_config
(
add_lower_pass
=
[(
0
,
fail
)]):
with
tvm
.
target
.
build_config
(
add_lower_pass
=
[(
0
,
FailPass
()
)]):
with
tvm
.
target
.
create
(
"cuda"
):
with
tvm
.
target
.
create
(
"cuda"
):
zz
=
run_opt_pass
(
before
(),
transform
.
FoldConstant
())
zz
=
run_opt_pass
(
before
(),
transform
.
FoldConstant
())
zexpected
=
run_opt_pass
(
expected
(),
transform
.
InferType
())
zexpected
=
run_opt_pass
(
expected
(),
transform
.
InferType
())
...
...
tests/python/unittest/test_target_codegen_cuda.py
View file @
d3277874
...
@@ -182,7 +182,7 @@ def test_cuda_shuffle():
...
@@ -182,7 +182,7 @@ def test_cuda_shuffle():
sch
[
c
]
.
bind
(
xo
,
thrx
)
sch
[
c
]
.
bind
(
xo
,
thrx
)
sch
[
c
]
.
vectorize
(
xi
)
sch
[
c
]
.
vectorize
(
xi
)
def
my_vectorize
(
stmt
):
def
MyVectorize
(
):
def
vectorizer
(
op
):
def
vectorizer
(
op
):
if
op
.
for_type
==
tvm
.
tir
.
For
.
Vectorized
:
if
op
.
for_type
==
tvm
.
tir
.
For
.
Vectorized
:
four
=
tvm
.
tir
.
const
(
4
,
'int32'
)
four
=
tvm
.
tir
.
const
(
4
,
'int32'
)
...
@@ -198,9 +198,13 @@ def test_cuda_shuffle():
...
@@ -198,9 +198,13 @@ def test_cuda_shuffle():
new_b
=
tvm
.
tir
.
Shuffle
(
bs
,
ids
)
new_b
=
tvm
.
tir
.
Shuffle
(
bs
,
ids
)
return
tvm
.
tir
.
Store
(
store
.
buffer_var
,
new_a
+
new_b
,
idx
,
all_ones
)
return
tvm
.
tir
.
Store
(
store
.
buffer_var
,
new_a
+
new_b
,
idx
,
all_ones
)
return
None
return
None
return
tvm
.
tir
.
ir_pass
.
IRTransform
(
stmt
,
None
,
vectorizer
,
[
'For'
])
with
tvm
.
target
.
build_config
(
add_lower_pass
=
[(
1
,
my_vectorize
)]):
def
_transform
(
f
,
*
_
):
return
f
.
with_body
(
tvm
.
tir
.
ir_pass
.
IRTransform
(
f
.
body
,
None
,
vectorizer
,
[
'For'
]))
return
tvm
.
tir
.
transform
.
prim_func_pass
(
_transform
,
opt_level
=
0
,
name
=
"MyVectorize"
)
with
tvm
.
target
.
build_config
(
add_lower_pass
=
[(
1
,
MyVectorize
())]):
module
=
tvm
.
build
(
sch
,
[
a
,
b
,
c
],
target
=
'cuda'
)
module
=
tvm
.
build
(
sch
,
[
a
,
b
,
c
],
target
=
'cuda'
)
a_
=
np
.
array
(
list
(
range
(
64
)),
dtype
=
'int32'
)
a_
=
np
.
array
(
list
(
range
(
64
)),
dtype
=
'int32'
)
b_
=
np
.
array
((
list
(
range
(
4
))[::
-
1
])
*
16
,
dtype
=
'int32'
)
b_
=
np
.
array
((
list
(
range
(
4
))[::
-
1
])
*
16
,
dtype
=
'int32'
)
...
...
tests/python/unittest/test_target_codegen_llvm.py
View file @
d3277874
...
@@ -671,8 +671,7 @@ def test_llvm_shuffle():
...
@@ -671,8 +671,7 @@ def test_llvm_shuffle():
c
=
te
.
compute
((
8
,
),
lambda
x
:
a
[
x
]
+
b
[
7
-
x
])
c
=
te
.
compute
((
8
,
),
lambda
x
:
a
[
x
]
+
b
[
7
-
x
])
sch
=
te
.
create_schedule
(
c
.
op
)
sch
=
te
.
create_schedule
(
c
.
op
)
def
my_vectorize
(
stmt
):
def
my_vectorize
():
def
vectorizer
(
op
):
def
vectorizer
(
op
):
store
=
op
.
body
store
=
op
.
body
idx
=
tvm
.
tir
.
Ramp
(
tvm
.
tir
.
const
(
0
,
'int32'
),
tvm
.
tir
.
const
(
1
,
'int32'
),
8
)
idx
=
tvm
.
tir
.
Ramp
(
tvm
.
tir
.
const
(
0
,
'int32'
),
tvm
.
tir
.
const
(
1
,
'int32'
),
8
)
...
@@ -684,9 +683,13 @@ def test_llvm_shuffle():
...
@@ -684,9 +683,13 @@ def test_llvm_shuffle():
value
=
new_a
+
new_b
value
=
new_a
+
new_b
return
tvm
.
tir
.
Store
(
store
.
buffer_var
,
new_a
+
new_b
,
idx
,
all_ones
)
return
tvm
.
tir
.
Store
(
store
.
buffer_var
,
new_a
+
new_b
,
idx
,
all_ones
)
return
tvm
.
tir
.
ir_pass
.
IRTransform
(
stmt
,
None
,
vectorizer
,
[
'For'
])
def
_transform
(
f
,
*
_
):
return
f
.
with_body
(
tvm
.
tir
.
ir_pass
.
IRTransform
(
f
.
body
,
None
,
vectorizer
,
[
'For'
]))
return
tvm
.
tir
.
transform
.
prim_func_pass
(
_transform
,
opt_level
=
0
,
name
=
"my_vectorize"
)
with
tvm
.
target
.
build_config
(
add_lower_pass
=
[(
1
,
my_vectorize
)]):
with
tvm
.
target
.
build_config
(
add_lower_pass
=
[(
1
,
my_vectorize
()
)]):
ir
=
tvm
.
lower
(
sch
,
[
a
,
b
,
c
],
simple_mode
=
True
)
ir
=
tvm
.
lower
(
sch
,
[
a
,
b
,
c
],
simple_mode
=
True
)
module
=
tvm
.
build
(
sch
,
[
a
,
b
,
c
])
module
=
tvm
.
build
(
sch
,
[
a
,
b
,
c
])
a_
=
tvm
.
nd
.
array
(
np
.
arange
(
1
,
9
,
dtype
=
'int32'
))
a_
=
tvm
.
nd
.
array
(
np
.
arange
(
1
,
9
,
dtype
=
'int32'
))
...
...
tests/python/unittest/test_tir_pass_verify_gpu_code.py
View file @
d3277874
...
@@ -19,10 +19,10 @@ import tvm
...
@@ -19,10 +19,10 @@ import tvm
from
tvm
import
te
from
tvm
import
te
def
get_verify_pass
(
valid
,
**
kwargs
):
def
get_verify_pass
(
valid
,
**
kwargs
):
def
verify_pass
(
stmt
):
def
_fverify
(
f
,
*
_
):
valid
[
0
]
=
tvm
.
tir
.
ir_pass
.
VerifyGPUCode
(
stmt
,
kwargs
)
valid
[
0
]
=
tvm
.
tir
.
ir_pass
.
VerifyGPUCode
(
f
.
body
,
kwargs
)
return
stmt
return
f
return
verify_pass
return
tvm
.
tir
.
transform
.
prim_func_pass
(
_fverify
,
opt_level
=
0
)
def
test_shared_memory
():
def
test_shared_memory
():
def
check_shared_memory
(
dtype
):
def
check_shared_memory
(
dtype
):
...
...
tutorials/dev/low_level_custom_pass.py
View file @
d3277874
...
@@ -117,19 +117,20 @@ def vectorize8(op):
...
@@ -117,19 +117,20 @@ def vectorize8(op):
return
body
return
body
return
None
return
None
def
vectorize
(
stmt
):
@tvm.tir.transform.prim_func_pass
(
opt_level
=
0
)
def
vectorize
(
f
,
mod
,
ctx
):
global
loops
global
loops
tvm
.
tir
.
ir_pass
.
PostOrderVisit
(
stmt
,
find_width8
)
tvm
.
tir
.
ir_pass
.
PostOrderVisit
(
f
.
body
,
find_width8
)
if
not
loops
:
if
not
loops
:
return
s
tmt
return
s
f
# The last list arugment indicates what kinds of nodes will be transformed.
# The last list arugment indicates what kinds of nodes will be transformed.
# Thus, in this case only `For` nodes will call `vectorize8`
# Thus, in this case only `For` nodes will call `vectorize8`
stmt
=
tvm
.
tir
.
ir_pass
.
IRTransform
(
stmt
,
None
,
vectorize8
,
[
'For'
])
return
f
.
with_body
(
tvm
.
tir
.
ir_pass
.
IRTransform
(
f
.
body
,
None
,
vectorize8
,
[
'For'
]))
return
stmt
#####################################################################
#####################################################################
# Glue to Lowering
# Glue to Lowering
...
...
vta/python/vta/build_module.py
View file @
d3277874
...
@@ -14,25 +14,22 @@
...
@@ -14,25 +14,22 @@
# KIND, either express or implied. See the License for the
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# specific language governing permissions and limitations
# under the License.
# under the License.
# pylint: disable=unused-argument
# pylint: disable=unused-argument
, invalid-name
"""VTA specific buildin for runtime."""
"""VTA specific buildin for runtime."""
import
tvm
import
tvm
from
.
import
ir_pass
from
.
import
transform
from
.environment
import
get_env
from
.environment
import
get_env
def
lift_coproc_scope
(
x
):
def
EarlyRewrite
():
"""Lift coprocessings cope to the """
x
=
ir_pass
.
lift_alloc_to_scope_begin
(
x
)
x
=
tvm
.
tir
.
ir_pass
.
LiftAttrScope
(
x
,
"coproc_scope"
,
False
)
return
x
def
early_rewrite
(
stmt
):
"""Try to do storage rewrite in early pass."""
"""Try to do storage rewrite in early pass."""
def
_transform
(
mod
,
ctx
):
try
:
try
:
return
tvm
.
tir
.
ir_pass
.
StorageRewrite
(
stmt
)
return
tvm
.
tir
.
transform
.
StorageRewrite
()(
mod
)
except
tvm
.
error
.
TVMError
:
except
tvm
.
error
.
TVMError
:
return
stmt
return
mod
return
tvm
.
transform
.
module_pass
(
_transform
,
opt_level
=
0
,
name
=
"tir.vta.EarlyRewrite"
)
def
build_config
(
debug_flag
=
0
,
**
kwargs
):
def
build_config
(
debug_flag
=
0
,
**
kwargs
):
...
@@ -60,27 +57,32 @@ def build_config(debug_flag=0, **kwargs):
...
@@ -60,27 +57,32 @@ def build_config(debug_flag=0, **kwargs):
vta_module = tvm.build(s, ...)
vta_module = tvm.build(s, ...)
"""
"""
env
=
get_env
()
env
=
get_env
()
def
add_debug
(
stmt
):
@tvm.tir.transform.prim_func_pass
(
opt_level
=
0
)
def
add_debug
(
f
,
*
_
):
debug
=
tvm
.
tir
.
call_extern
(
debug
=
tvm
.
tir
.
call_extern
(
"int32"
,
"VTASetDebugMode"
,
"int32"
,
"VTASetDebugMode"
,
env
.
dev
.
command_handle
,
env
.
dev
.
command_handle
,
debug_flag
)
debug_flag
)
return
tvm
.
tir
.
stmt_seq
(
debug
,
stmt
)
return
f
.
with_body
(
tvm
.
tir
.
stmt_seq
(
debug
,
f
.
body
))
pass_list
=
[(
0
,
ir_pass
.
inject_conv2d_transpose_skip
),
(
1
,
ir_pass
.
inject_dma_intrin
),
(
1
,
ir_pass
.
inject_skip_copy
),
pass_list
=
[(
0
,
transform
.
InjectConv2DTransposeSkip
()),
(
1
,
ir_pass
.
annotate_alu_coproc_scope
),
(
1
,
transform
.
InjectDMAIntrin
()),
(
1
,
lambda
x
:
tvm
.
tir
.
ir_pass
.
LiftAttrScope
(
x
,
"coproc_uop_scope"
,
True
)),
(
1
,
transform
.
InjectSkipCopy
()),
(
1
,
lift_coproc_scope
),
(
1
,
transform
.
AnnotateALUCoProcScope
()),
(
1
,
ir_pass
.
inject_coproc_sync
),
(
1
,
tvm
.
tir
.
transform
.
LiftAttrScope
(
"coproc_uop_scope"
)),
(
1
,
early_rewrite
)]
(
1
,
transform
.
LiftAllocToScopeBegin
()),
(
1
,
tvm
.
tir
.
transform
.
LiftAttrScope
(
"coproc_scope"
)),
(
1
,
transform
.
InjectCoProcSync
()),
(
1
,
EarlyRewrite
())]
if
debug_flag
:
if
debug_flag
:
pass_list
.
append
((
1
,
add_debug
))
pass_list
.
append
((
1
,
add_debug
))
pass_list
.
append
((
2
,
ir_pass
.
inject_alu_intrin
))
pass_list
.
append
((
2
,
transform
.
InjectALUIntrin
()
))
pass_list
.
append
((
3
,
tvm
.
tir
.
ir_pass
.
LowerStorageAccessInfo
))
pass_list
.
append
((
3
,
tvm
.
tir
.
transform
.
LowerDeviceStorageAccessInfo
()
))
pass_list
.
append
((
3
,
ir_pass
.
fold_uop_loop
))
pass_list
.
append
((
3
,
transform
.
FoldUopLoop
()
))
pass_list
.
append
((
3
,
ir_pass
.
cpu_access_rewrite
))
pass_list
.
append
((
3
,
transform
.
CPUAccessRewrite
()
))
return
tvm
.
target
.
build_config
(
add_lower_pass
=
pass_list
,
**
kwargs
)
return
tvm
.
target
.
build_config
(
add_lower_pass
=
pass_list
,
**
kwargs
)
...
...
vta/python/vta/
ir_pass
.py
→
vta/python/vta/
transform
.py
View file @
d3277874
...
@@ -14,8 +14,8 @@
...
@@ -14,8 +14,8 @@
# KIND, either express or implied. See the License for the
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# specific language governing permissions and limitations
# under the License.
# under the License.
"""Additional
IR Pass
for VTA"""
"""Additional
Transformation Passes.
for VTA"""
# pylint: disable=len-as-condition, no-else-return
# pylint: disable=len-as-condition, no-else-return
, unused-argument, invalid-name
import
tvm
import
tvm
from
tvm
import
te
from
tvm
import
te
from
topi
import
util
from
topi
import
util
...
@@ -38,7 +38,7 @@ def _match_pragma(stmt, key):
...
@@ -38,7 +38,7 @@ def _match_pragma(stmt, key):
(
stmt
.
attr_key
==
"pragma_scope"
and
stmt
.
value
.
value
==
key
))
(
stmt
.
attr_key
==
"pragma_scope"
and
stmt
.
value
.
value
==
key
))
def
fold_uop_loop
(
stmt_in
):
def
FoldUopLoop
(
):
"""Detect and fold uop loop.
"""Detect and fold uop loop.
VTA support uop programming model
VTA support uop programming model
...
@@ -46,18 +46,11 @@ def fold_uop_loop(stmt_in):
...
@@ -46,18 +46,11 @@ def fold_uop_loop(stmt_in):
This pass detect the loop structure
This pass detect the loop structure
and extract that into uop loop AST.
and extract that into uop loop AST.
Parameters
----------
stmt_in : Stmt
Input statement
Returns
Returns
-------
-------
stmt_out : Stmt
fpass : tvm.transform.Pass
Output statement.
The pass
"""
"""
env
=
get_env
()
def
_fold_outermost_loop
(
body
):
def
_fold_outermost_loop
(
body
):
stmt
=
body
stmt
=
body
if
not
isinstance
(
stmt
,
tvm
.
tir
.
For
):
if
not
isinstance
(
stmt
,
tvm
.
tir
.
For
):
...
@@ -109,6 +102,7 @@ def fold_uop_loop(stmt_in):
...
@@ -109,6 +102,7 @@ def fold_uop_loop(stmt_in):
raise
ValueError
(
"Failed to fold the GEMM instructions.."
)
raise
ValueError
(
"Failed to fold the GEMM instructions.."
)
def
_do_fold
(
stmt
):
def
_do_fold
(
stmt
):
env
=
get_env
()
if
(
stmt
.
attr_key
==
"coproc_uop_scope"
and
if
(
stmt
.
attr_key
==
"coproc_uop_scope"
and
isinstance
(
stmt
.
value
,
tvm
.
tir
.
StringImm
)
and
isinstance
(
stmt
.
value
,
tvm
.
tir
.
StringImm
)
and
stmt
.
value
.
value
==
env
.
dev
.
vta_push_uop
.
value
):
stmt
.
value
.
value
==
env
.
dev
.
vta_push_uop
.
value
):
...
@@ -135,12 +129,16 @@ def fold_uop_loop(stmt_in):
...
@@ -135,12 +129,16 @@ def fold_uop_loop(stmt_in):
return
tvm
.
tir
.
AttrStmt
(
return
tvm
.
tir
.
AttrStmt
(
stmt
.
node
,
stmt
.
attr_key
,
stmt
.
value
,
body
)
stmt
.
node
,
stmt
.
attr_key
,
stmt
.
value
,
body
)
return
None
return
None
out
=
tvm
.
tir
.
ir_pass
.
IRTransform
(
stmt_in
,
_do_fold
,
None
,
[
"AttrStmt"
])
def
_ftransform
(
f
,
mod
,
ctx
):
return
out
return
f
.
with_body
(
tvm
.
tir
.
ir_pass
.
IRTransform
(
f
.
body
,
_do_fold
,
None
,
[
"AttrStmt"
]))
return
tvm
.
tir
.
transform
.
prim_func_pass
(
_ftransform
,
opt_level
=
0
,
name
=
"tir.vta.FoldUopLoop"
)
def
cpu_access_rewrite
(
stmt_in
):
def
CPUAccessRewrite
(
):
"""Detect CPU access to VTA buffer and get address correctly.
"""Detect CPU access to VTA buffer and get address correctly.
VTA's buffer is an opaque handle that do not
VTA's buffer is an opaque handle that do not
...
@@ -148,18 +146,14 @@ def cpu_access_rewrite(stmt_in):
...
@@ -148,18 +146,14 @@ def cpu_access_rewrite(stmt_in):
This pass detect CPU access and rewrite to use pointer
This pass detect CPU access and rewrite to use pointer
returned VTABufferCPUPtr for CPU access.
returned VTABufferCPUPtr for CPU access.
Parameters
----------
stmt_in : Stmt
Input statement
Returns
Returns
-------
-------
stmt_out : Stmt
fpass : tvm.transform.Pass
T
ransformed statement
T
he pass
"""
"""
env
=
get_env
()
def
_ftransform
(
f
,
mod
,
ctx
):
rw_info
=
{}
rw_info
=
{}
env
=
get_env
()
def
_post_order
(
op
):
def
_post_order
(
op
):
if
isinstance
(
op
,
tvm
.
tir
.
Allocate
):
if
isinstance
(
op
,
tvm
.
tir
.
Allocate
):
buffer_var
=
op
.
buffer_var
buffer_var
=
op
.
buffer_var
...
@@ -191,30 +185,31 @@ def cpu_access_rewrite(stmt_in):
...
@@ -191,30 +185,31 @@ def cpu_access_rewrite(stmt_in):
new_var
=
rw_info
[
buffer_var
]
new_var
=
rw_info
[
buffer_var
]
return
tvm
.
tir
.
Store
(
new_var
,
op
.
value
,
op
.
index
)
return
tvm
.
tir
.
Store
(
new_var
,
op
.
value
,
op
.
index
)
raise
RuntimeError
(
"not reached"
)
raise
RuntimeError
(
"not reached"
)
stmt_in
=
f
.
body
stmt
=
tvm
.
tir
.
ir_pass
.
IRTransform
(
stmt
=
tvm
.
tir
.
ir_pass
.
IRTransform
(
stmt_in
,
None
,
_post_order
,
[
"Allocate"
,
"Load"
,
"Store"
])
stmt_in
,
None
,
_post_order
,
[
"Allocate"
,
"Load"
,
"Store"
])
for
buffer_var
,
new_var
in
rw_info
.
items
():
for
buffer_var
,
new_var
in
rw_info
.
items
():
stmt
=
tvm
.
tir
.
LetStmt
(
stmt
=
tvm
.
tir
.
LetStmt
(
new_var
,
tvm
.
tir
.
call_extern
(
new_var
,
tvm
.
tir
.
call_extern
(
"handle"
,
"VTABufferCPUPtr"
,
"handle"
,
"VTABufferCPUPtr"
,
env
.
dev
.
command_handle
,
env
.
dev
.
command_handle
,
buffer_var
),
stmt
)
buffer_var
),
stmt
)
return
stmt
return
f
.
with_body
(
stmt
)
return
tvm
.
tir
.
transform
.
prim_func_pass
(
_ftransform
,
opt_level
=
0
,
name
=
"tir.vta.CPUAccessRewrite"
)
def
lift_alloc_to_scope_begin
(
stmt_in
):
def
LiftAllocToScopeBegin
(
):
"""Lift allocate to beginning of the current scope.
"""Lift allocate to beginning of the current scope.
Parameters
----------
stmt_in : Stmt
Input statement
Returns
Returns
-------
-------
stmt_out : Stmt
fpass : tvm.transform.Pass
T
ransformed statement
T
he pass
"""
"""
def
_ftransform
(
f
,
mod
,
ctx
):
lift_stmt
=
[[]]
lift_stmt
=
[[]]
def
_merge_block
(
slist
,
body
):
def
_merge_block
(
slist
,
body
):
for
op
in
slist
:
for
op
in
slist
:
...
@@ -257,46 +252,46 @@ def lift_alloc_to_scope_begin(stmt_in):
...
@@ -257,46 +252,46 @@ def lift_alloc_to_scope_begin(stmt_in):
if
isinstance
(
op
,
tvm
.
tir
.
For
):
if
isinstance
(
op
,
tvm
.
tir
.
For
):
return
_merge_block
(
lift_stmt
.
pop
()
+
[
op
],
op
.
body
)
return
_merge_block
(
lift_stmt
.
pop
()
+
[
op
],
op
.
body
)
raise
RuntimeError
(
"not reached"
)
raise
RuntimeError
(
"not reached"
)
stmt_in
=
f
.
body
stmt
=
tvm
.
tir
.
ir_pass
.
IRTransform
(
stmt
=
tvm
.
tir
.
ir_pass
.
IRTransform
(
stmt_in
,
_pre_order
,
_post_order
,
[
"Allocate"
,
"AttrStmt"
,
"For"
])
stmt_in
,
_pre_order
,
_post_order
,
[
"Allocate"
,
"AttrStmt"
,
"For"
])
assert
len
(
lift_stmt
)
==
1
assert
len
(
lift_stmt
)
==
1
return
_merge_block
(
lift_stmt
[
0
],
stmt
)
return
f
.
with_body
(
_merge_block
(
lift_stmt
[
0
],
stmt
)
)
return
tvm
.
tir
.
transform
.
prim_func_pass
(
_ftransform
,
opt_level
=
0
,
name
=
"tir.vta.LiftAllocToScopeBegin"
)
def
inject_skip_copy
(
stmt_in
):
"""Pass to inject skip copy stmt, used for debug purpose.
Parameters
def
InjectSkipCopy
():
----------
"""Pass to inject skip copy stmt, used for debug purpose.
stmt_in : Stmt
Input statement
Returns
Returns
-------
-------
stmt_out : Stmt
fpass : tvm.transform.Pass
T
ransformed statement
T
he pass
"""
"""
def
_do_fold
(
stmt
):
def
_do_fold
(
stmt
):
if
_match_pragma
(
stmt
,
"skip_dma_copy"
):
if
_match_pragma
(
stmt
,
"skip_dma_copy"
):
return
tvm
.
tir
.
Evaluate
(
0
)
return
tvm
.
tir
.
Evaluate
(
0
)
return
None
return
None
return
tvm
.
tir
.
ir_pass
.
IRTransform
(
stmt_in
,
_do_fold
,
None
,
[
"AttrStmt"
])
def
_ftransform
(
f
,
mod
,
ctx
):
return
f
.
with_body
(
tvm
.
tir
.
ir_pass
.
IRTransform
(
f
.
body
,
_do_fold
,
None
,
[
"AttrStmt"
]))
def
inject_coproc_sync
(
stmt_in
):
return
tvm
.
tir
.
transform
.
prim_func_pass
(
"""Pass to inject skip copy stmt, used in debug.
_ftransform
,
opt_level
=
0
,
name
=
"tir.vta.InjectSkipCopy"
)
Parameters
----------
def
InjectCoProcSync
():
stmt_in : Stmt
"""Pass inject coproc sync
Input statement
Returns
Returns
-------
-------
stmt_out : Stmt
fpass : tvm.transform.Pass
T
ransformed statement
T
he pass
"""
"""
def
_ftransform
(
f
,
*
_
):
success
=
[
False
]
success
=
[
False
]
def
_do_fold
(
stmt
):
def
_do_fold
(
stmt
):
if
_match_pragma
(
stmt
,
"coproc_sync"
):
if
_match_pragma
(
stmt
,
"coproc_sync"
):
...
@@ -311,26 +306,22 @@ def inject_coproc_sync(stmt_in):
...
@@ -311,26 +306,22 @@ def inject_coproc_sync(stmt_in):
op
.
loop_var
,
op
.
min
,
2
,
op
.
for_type
,
op
.
loop_var
,
op
.
min
,
2
,
op
.
for_type
,
op
.
device_api
,
op
.
body
)
op
.
device_api
,
op
.
body
)
return
None
return
None
stmt
=
tvm
.
tir
.
ir_pass
.
IRTransform
(
return
f
.
with_body
(
tvm
.
tir
.
ir_pass
.
IRTransform
(
stmt_in
,
None
,
_do_fold
,
[
"AttrStmt"
])
f
.
body
,
None
,
_do_fold
,
[
"AttrStmt"
]))
stmt
=
tvm
.
tir
.
ir_pass
.
CoProcSync
(
stmt
)
return
tvm
.
transform
.
Sequential
(
return
stmt
[
tvm
.
tir
.
transform
.
prim_func_pass
(
_ftransform
,
0
,
"tir.vta.InjectCoProcSync"
),
tvm
.
tir
.
transform
.
CoProcSync
()],
opt_level
=
0
,
name
=
"tir.vta.InjectCoProcSync"
)
def
inject_dma_intrin
(
stmt_in
):
def
InjectDMAIntrin
(
):
"""Pass to inject DMA copy intrinsics.
"""Pass to inject DMA copy intrinsics.
Parameters
----------
stmt_in : Stmt
Input statement
Returns
Returns
-------
-------
stmt_out : Stmt
fpass : tvm.transform.Pass
T
ransformed statement
T
he pass
"""
"""
env
=
get_env
()
idxd
=
tvm
.
tir
.
indexdiv
idxd
=
tvm
.
tir
.
indexdiv
idxm
=
tvm
.
tir
.
indexmod
idxm
=
tvm
.
tir
.
indexmod
...
@@ -474,6 +465,7 @@ def inject_dma_intrin(stmt_in):
...
@@ -474,6 +465,7 @@ def inject_dma_intrin(stmt_in):
def
_inject_copy
(
src
,
dst
,
pad_before
,
pad_after
,
pad_value
):
def
_inject_copy
(
src
,
dst
,
pad_before
,
pad_after
,
pad_value
):
# FIXME: pad_value is ignored...
# FIXME: pad_value is ignored...
env
=
get_env
()
_
=
pad_value
_
=
pad_value
if
dst
.
scope
==
"global"
:
if
dst
.
scope
==
"global"
:
# Store
# Store
...
@@ -576,7 +568,7 @@ def inject_dma_intrin(stmt_in):
...
@@ -576,7 +568,7 @@ def inject_dma_intrin(stmt_in):
else
:
else
:
raise
RuntimeError
(
"Do not support copy
%
s->
%
s"
%
(
src
.
scope
,
dst
.
scope
))
raise
RuntimeError
(
"Do not support copy
%
s->
%
s"
%
(
src
.
scope
,
dst
.
scope
))
return
tvm
.
tir
.
ir_pass
.
InjectCopyIntrin
(
stmt_in
,
"dma_copy"
,
_inject_copy
)
return
tvm
.
tir
.
transform
.
InjectCopyIntrin
(
"dma_copy"
,
_inject_copy
)
def
_get_gemm_intrin_buffer
():
def
_get_gemm_intrin_buffer
():
...
@@ -619,19 +611,15 @@ def _get_gemm_intrin_buffer():
...
@@ -619,19 +611,15 @@ def _get_gemm_intrin_buffer():
return
wgt_layout
,
inp_layout
,
out_layout
return
wgt_layout
,
inp_layout
,
out_layout
def
inject_conv2d_transpose_skip
(
stmt_in
):
def
InjectConv2DTransposeSkip
(
):
"""Pass to skip 0-weights in conv2d transpose with stride > 1.
"""Pass to skip 0-weights in conv2d transpose with stride > 1.
Parameters
----------
stmt_in : Stmt
Input statement
Returns
Returns
-------
-------
stmt_out : Stmt
fpass : tvm.transform.Pass
T
ransformed statement
T
he pass
"""
"""
def
_ftransform
(
func
,
mod
,
ctx
):
env
=
get_env
()
env
=
get_env
()
dwgt
,
dinp
,
dout
=
_get_gemm_intrin_buffer
()
dwgt
,
dinp
,
dout
=
_get_gemm_intrin_buffer
()
...
@@ -687,7 +675,8 @@ def inject_conv2d_transpose_skip(stmt_in):
...
@@ -687,7 +675,8 @@ def inject_conv2d_transpose_skip(stmt_in):
irb
=
tvm
.
tir
.
ir_builder
.
create
()
irb
=
tvm
.
tir
.
ir_builder
.
create
()
with
irb
.
if_scope
(
condition
):
with
irb
.
if_scope
(
condition
):
dev
=
env
.
dev
dev
=
env
.
dev
irb
.
scope_attr
(
dev
.
vta_axis
,
"coproc_scope"
,
dev
.
get_task_qid
(
dev
.
QID_COMPUTE
))
irb
.
scope_attr
(
dev
.
vta_axis
,
"coproc_scope"
,
dev
.
get_task_qid
(
dev
.
QID_COMPUTE
))
irb
.
scope_attr
(
dev
.
vta_axis
,
"coproc_uop_scope"
,
dev
.
vta_push_uop
)
irb
.
scope_attr
(
dev
.
vta_axis
,
"coproc_uop_scope"
,
dev
.
vta_push_uop
)
irb
.
emit
(
tvm
.
tir
.
call_extern
(
"int32"
,
"VTAUopPush"
,
irb
.
emit
(
tvm
.
tir
.
call_extern
(
"int32"
,
"VTAUopPush"
,
0
,
0
,
0
,
0
,
...
@@ -717,24 +706,22 @@ def inject_conv2d_transpose_skip(stmt_in):
...
@@ -717,24 +706,22 @@ def inject_conv2d_transpose_skip(stmt_in):
tvm
.
tir
.
call_intrin
(
'handle'
,
'tvm_tuple'
,
*
tpl
),
inner
)
tvm
.
tir
.
call_intrin
(
'handle'
,
'tvm_tuple'
,
*
tpl
),
inner
)
return
inner
return
inner
return
None
return
None
ret
=
tvm
.
tir
.
ir_pass
.
IRTransform
(
stmt_in
,
_do_fold
,
None
,
[
"AttrStmt"
])
return
ret
return
func
.
with_body
(
tvm
.
tir
.
ir_pass
.
IRTransform
(
func
.
body
,
_do_fold
,
None
,
[
"AttrStmt"
]))
return
tvm
.
tir
.
transform
.
prim_func_pass
(
_ftransform
,
opt_level
=
0
,
name
=
"tir.vta.InjectConv2DTrasnposeSkip"
)
def
annotate_alu_coproc_scope
(
stmt_in
):
"""Pass to insert ALU instruction.
Parameters
def
AnnotateALUCoProcScope
():
----------
"""Pass to insert ALU instruction.
stmt_in : Stmt
Input statement
Returns
Returns
-------
-------
stmt_out : Stmt
fpass : tvm.transform.Pass
T
ransformed statement
T
he pass
"""
"""
def
_ftransform
(
func
,
mod
,
ctx
):
env
=
get_env
()
env
=
get_env
()
def
_do_fold
(
stmt
):
def
_do_fold
(
stmt
):
if
_match_pragma
(
stmt
,
"alu"
):
if
_match_pragma
(
stmt
,
"alu"
):
...
@@ -749,25 +736,21 @@ def annotate_alu_coproc_scope(stmt_in):
...
@@ -749,25 +736,21 @@ def annotate_alu_coproc_scope(stmt_in):
return
tvm
.
tir
.
Evaluate
(
0
)
return
tvm
.
tir
.
Evaluate
(
0
)
return
stmt
return
stmt
stmt_out
=
tvm
.
tir
.
ir_pass
.
IRTransform
(
return
func
.
with_body
(
tvm
.
tir
.
ir_pass
.
IRTransform
(
stmt_in
,
None
,
_do_fold
,
[
"AttrStmt"
]
)
func
.
body
,
None
,
_do_fold
,
[
"AttrStmt"
])
)
return
tvm
.
tir
.
transform
.
prim_func_pass
(
return
stmt_out
_ftransform
,
opt_level
=
0
,
name
=
"tir.vta.AnnotateALUCoProcScope"
)
def
inject_alu_intrin
(
stmt_in
):
def
InjectALUIntrin
(
):
"""Pass to inject ALU micro-ops.
"""Pass to inject ALU micro-ops.
Parameters
----------
stmt_in : Stmt
Input statement
Returns
Returns
-------
-------
stmt_out : Stmt
fpass : tvm.transform.Pass
T
ransformed statement
T
he pass
"""
"""
def
_ftransform
(
func
,
mod
,
ctx
):
env
=
get_env
()
env
=
get_env
()
idxm
=
tvm
.
tir
.
indexmod
idxm
=
tvm
.
tir
.
indexmod
analyzer
=
tvm
.
arith
.
Analyzer
()
analyzer
=
tvm
.
arith
.
Analyzer
()
...
@@ -972,24 +955,8 @@ def inject_alu_intrin(stmt_in):
...
@@ -972,24 +955,8 @@ def inject_alu_intrin(stmt_in):
return
irb
.
get
()
return
irb
.
get
()
return
stmt
return
stmt
stmt_out
=
tvm
.
tir
.
ir_pass
.
IRTransform
(
return
func
.
with_body
(
tvm
.
tir
.
ir_pass
.
IRTransform
(
stmt_in
,
None
,
_do_fold
,
[
"AttrStmt"
])
func
.
body
,
None
,
_do_fold
,
[
"AttrStmt"
]))
return
stmt_out
def
debug_print
(
stmt
):
"""A debug pass that print the stmt
Parameters
----------
stmt : Stmt
The input statement
Returns
return
tvm
.
tir
.
transform
.
prim_func_pass
(
-------
_ftransform
,
opt_level
=
0
,
name
=
"tir.vta.InjectALUIntrin"
)
stmt : Stmt
The
"""
# pylint: disable=superfluous-parens
print
(
stmt
)
return
stmt
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment