Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
d3277874
Unverified
Commit
d3277874
authored
Apr 21, 2020
by
Tianqi Chen
Committed by
GitHub
Apr 21, 2020
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[PTYTHON] Migrate VTA TIR passes to the new pass manager. (#5397)
parent
72f2aea2
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
560 additions
and
583 deletions
+560
-583
include/tvm/target/target.h
+3
-2
python/tvm/autotvm/measure/measure_methods.py
+4
-4
python/tvm/driver/build_module.py
+5
-24
python/tvm/tir/function.py
+16
-0
src/target/target.cc
+2
-2
tests/python/relay/test_pass_fold_constant.py
+5
-3
tests/python/unittest/test_target_codegen_cuda.py
+7
-3
tests/python/unittest/test_target_codegen_llvm.py
+7
-4
tests/python/unittest/test_tir_pass_verify_gpu_code.py
+4
-4
tutorials/dev/low_level_custom_pass.py
+6
-5
vta/python/vta/build_module.py
+29
-27
vta/python/vta/transform.py
+472
-505
No files found.
include/tvm/target/target.h
View file @
d3277874
...
...
@@ -27,6 +27,7 @@
#include <tvm/support/with.h>
#include <tvm/node/container.h>
#include <tvm/ir/expr.h>
#include <tvm/ir/transform.h>
#include <string>
#include <vector>
...
...
@@ -225,8 +226,8 @@ class BuildConfigNode : public Object {
/*! \brief Whether to partition const loop */
bool
partition_const_loop
=
false
;
/*! \brief
Whether to dump the IR of each pass (only when building from python)
*/
std
::
vector
<
std
::
pair
<
int
,
runtime
::
PackedFunc
>
>
add_lower_pass
;
/*! \brief
List of passes to be injected into the low-level pipeline.
*/
std
::
vector
<
std
::
pair
<
int
,
transform
::
Pass
>
>
add_lower_pass
;
/*! \brief Whether to dump the IR of each pass (only when building from python) */
bool
dump_pass_ir
=
false
;
...
...
python/tvm/autotvm/measure/measure_methods.py
View file @
d3277874
...
...
@@ -615,9 +615,9 @@ def gpu_verify_pass(**kwargs):
"""Verify the validity of a gpu kernel.
This pass will check memory usage and number of threads per block.
"""
def
verify_pass
(
stmt
):
valid
=
ir_pass
.
VerifyGPUCode
(
stmt
,
kwargs
)
def
verify_pass
(
f
,
*
_
):
valid
=
ir_pass
.
VerifyGPUCode
(
f
.
body
,
kwargs
)
if
not
valid
:
raise
InstantiationError
(
"Skipped because of invalid gpu kernel"
)
return
stmt
return
verify_pass
return
f
return
tvm
.
tir
.
transform
.
prim_func_pass
(
verify_pass
,
opt_level
=
0
)
python/tvm/driver/build_module.py
View file @
d3277874
...
...
@@ -123,25 +123,6 @@ def form_irmodule(sch, args, name, binds):
return
tvm
.
IRModule
({
name
:
func
})
def
_wrap_as_prim_func_pass
(
flist
,
name
):
"""Wrap flist as a function pass.
This is an temporary adapter before we fully
migrate to the new pass manager.
"""
def
_transform
(
func
,
*
_
):
stmt
=
func
.
body
for
f
in
flist
:
stmt
=
f
(
stmt
)
# create a new function with updated body.
return
tvm
.
tir
.
PrimFunc
(
func
.
params
,
stmt
,
func
.
ret_type
,
func
.
buffer_map
,
func
.
attrs
)
return
tvm
.
tir
.
transform
.
prim_func_pass
(
_transform
,
opt_level
=
0
,
name
=
name
)
def
lower
(
sch
,
args
,
name
=
"main"
,
...
...
@@ -190,15 +171,15 @@ def lower(sch,
else
:
mod
=
sch
pass_list
=
lower_phase0
# Phase 1
pass_list
=
[
_wrap_as_prim_func_pass
(
lower_phase0
,
"Custom-Phase0"
),
pass_list
+=
[
tvm
.
tir
.
transform
.
InjectPrefetch
(),
tvm
.
tir
.
transform
.
StorageFlatten
(
64
,
cfg
.
instrument_bound_checkers
),
tvm
.
tir
.
transform
.
NarrowDataType
(
32
),
tvm
.
tir
.
transform
.
Simplify
(),
_wrap_as_prim_func_pass
(
lower_phase1
,
"Custom-Phase1"
),
]
pass_list
+=
lower_phase1
# Phase 2
if
not
simple_mode
:
...
...
@@ -214,8 +195,8 @@ def lower(sch,
cfg
.
auto_unroll_max_depth
,
cfg
.
auto_unroll_max_extent
,
cfg
.
unroll_explicit
),
_wrap_as_prim_func_pass
(
lower_phase2
,
"Custom-Phase2"
),
]
pass_list
+=
lower_phase2
# Phase 3
pass_list
+=
[
...
...
@@ -225,7 +206,7 @@ def lower(sch,
if
not
cfg
.
disable_select_rewriting
:
pass_list
+=
[
tvm
.
tir
.
transform
.
RewriteUnsafeSelect
()]
pass_list
+=
[
_wrap_as_prim_func_pass
(
lower_phase3
,
"Custom-Phase3"
)]
pass_list
+=
lower_phase3
# Instrument BoundCheckers
if
cfg
.
instrument_bound_checkers
:
...
...
python/tvm/tir/function.py
View file @
d3277874
...
...
@@ -67,3 +67,19 @@ class PrimFunc(BaseFunc):
self
.
__init_handle_by_constructor__
(
_ffi_api
.
PrimFunc
,
param_list
,
body
,
ret_type
,
buffer_map
,
attrs
)
def
with_body
(
self
,
new_body
):
"""Create a new PrimFunc with the same set signatures but a new body.
Parameters
----------
new_body : Stmt
The new body.
Returns
-------
new_func : PrimFunc
The created new function.
"""
return
PrimFunc
(
self
.
params
,
new_body
,
self
.
ret_type
,
self
.
buffer_map
,
self
.
attrs
)
src/target/target.cc
View file @
d3277874
...
...
@@ -434,12 +434,12 @@ TVM_REGISTER_GLOBAL("target.ExitBuildConfigScope")
TVM_REGISTER_GLOBAL
(
"target.BuildConfigSetAddLowerPass"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
BuildConfig
cfg
=
args
[
0
];
std
::
vector
<
std
::
pair
<
int
,
PackedFunc
>
>
add_lower_pass
;
std
::
vector
<
std
::
pair
<
int
,
transform
::
Pass
>
>
add_lower_pass
;
CHECK_EQ
(
args
.
size
()
%
2
,
1
);
for
(
int
i
=
1
;
i
<
args
.
size
();
i
+=
2
)
{
add_lower_pass
.
push_back
(
std
::
make_pair
(
args
[
i
].
operator
int
(),
args
[
i
+
1
].
operator
t
vm
::
runtime
::
PackedFunc
()));
args
[
i
+
1
].
operator
t
ransform
::
Pass
()));
}
cfg
->
add_lower_pass
=
add_lower_pass
;
});
...
...
tests/python/relay/test_pass_fold_constant.py
View file @
d3277874
...
...
@@ -51,11 +51,13 @@ def test_fold_const():
z
=
relay
.
add
(
y
,
relay
.
const
(
c_data
))
return
relay
.
Function
([
x
],
z
)
def
fail
(
x
):
raise
RuntimeError
()
def
FailPass
():
def
_transform
(
m
,
*
args
):
raise
RuntimeError
()
return
tvm
.
transform
.
module_pass
(
_transform
,
opt_level
=
0
)
# the fold constant should work on any context.
with
tvm
.
target
.
build_config
(
add_lower_pass
=
[(
0
,
fail
)]):
with
tvm
.
target
.
build_config
(
add_lower_pass
=
[(
0
,
FailPass
()
)]):
with
tvm
.
target
.
create
(
"cuda"
):
zz
=
run_opt_pass
(
before
(),
transform
.
FoldConstant
())
zexpected
=
run_opt_pass
(
expected
(),
transform
.
InferType
())
...
...
tests/python/unittest/test_target_codegen_cuda.py
View file @
d3277874
...
...
@@ -182,7 +182,7 @@ def test_cuda_shuffle():
sch
[
c
]
.
bind
(
xo
,
thrx
)
sch
[
c
]
.
vectorize
(
xi
)
def
my_vectorize
(
stmt
):
def
MyVectorize
(
):
def
vectorizer
(
op
):
if
op
.
for_type
==
tvm
.
tir
.
For
.
Vectorized
:
four
=
tvm
.
tir
.
const
(
4
,
'int32'
)
...
...
@@ -198,9 +198,13 @@ def test_cuda_shuffle():
new_b
=
tvm
.
tir
.
Shuffle
(
bs
,
ids
)
return
tvm
.
tir
.
Store
(
store
.
buffer_var
,
new_a
+
new_b
,
idx
,
all_ones
)
return
None
return
tvm
.
tir
.
ir_pass
.
IRTransform
(
stmt
,
None
,
vectorizer
,
[
'For'
])
with
tvm
.
target
.
build_config
(
add_lower_pass
=
[(
1
,
my_vectorize
)]):
def
_transform
(
f
,
*
_
):
return
f
.
with_body
(
tvm
.
tir
.
ir_pass
.
IRTransform
(
f
.
body
,
None
,
vectorizer
,
[
'For'
]))
return
tvm
.
tir
.
transform
.
prim_func_pass
(
_transform
,
opt_level
=
0
,
name
=
"MyVectorize"
)
with
tvm
.
target
.
build_config
(
add_lower_pass
=
[(
1
,
MyVectorize
())]):
module
=
tvm
.
build
(
sch
,
[
a
,
b
,
c
],
target
=
'cuda'
)
a_
=
np
.
array
(
list
(
range
(
64
)),
dtype
=
'int32'
)
b_
=
np
.
array
((
list
(
range
(
4
))[::
-
1
])
*
16
,
dtype
=
'int32'
)
...
...
tests/python/unittest/test_target_codegen_llvm.py
View file @
d3277874
...
...
@@ -671,8 +671,7 @@ def test_llvm_shuffle():
c
=
te
.
compute
((
8
,
),
lambda
x
:
a
[
x
]
+
b
[
7
-
x
])
sch
=
te
.
create_schedule
(
c
.
op
)
def
my_vectorize
(
stmt
):
def
my_vectorize
():
def
vectorizer
(
op
):
store
=
op
.
body
idx
=
tvm
.
tir
.
Ramp
(
tvm
.
tir
.
const
(
0
,
'int32'
),
tvm
.
tir
.
const
(
1
,
'int32'
),
8
)
...
...
@@ -684,9 +683,13 @@ def test_llvm_shuffle():
value
=
new_a
+
new_b
return
tvm
.
tir
.
Store
(
store
.
buffer_var
,
new_a
+
new_b
,
idx
,
all_ones
)
return
tvm
.
tir
.
ir_pass
.
IRTransform
(
stmt
,
None
,
vectorizer
,
[
'For'
])
def
_transform
(
f
,
*
_
):
return
f
.
with_body
(
tvm
.
tir
.
ir_pass
.
IRTransform
(
f
.
body
,
None
,
vectorizer
,
[
'For'
]))
return
tvm
.
tir
.
transform
.
prim_func_pass
(
_transform
,
opt_level
=
0
,
name
=
"my_vectorize"
)
with
tvm
.
target
.
build_config
(
add_lower_pass
=
[(
1
,
my_vectorize
)]):
with
tvm
.
target
.
build_config
(
add_lower_pass
=
[(
1
,
my_vectorize
()
)]):
ir
=
tvm
.
lower
(
sch
,
[
a
,
b
,
c
],
simple_mode
=
True
)
module
=
tvm
.
build
(
sch
,
[
a
,
b
,
c
])
a_
=
tvm
.
nd
.
array
(
np
.
arange
(
1
,
9
,
dtype
=
'int32'
))
...
...
tests/python/unittest/test_tir_pass_verify_gpu_code.py
View file @
d3277874
...
...
@@ -19,10 +19,10 @@ import tvm
from
tvm
import
te
def
get_verify_pass
(
valid
,
**
kwargs
):
def
verify_pass
(
stmt
):
valid
[
0
]
=
tvm
.
tir
.
ir_pass
.
VerifyGPUCode
(
stmt
,
kwargs
)
return
stmt
return
verify_pass
def
_fverify
(
f
,
*
_
):
valid
[
0
]
=
tvm
.
tir
.
ir_pass
.
VerifyGPUCode
(
f
.
body
,
kwargs
)
return
f
return
tvm
.
tir
.
transform
.
prim_func_pass
(
_fverify
,
opt_level
=
0
)
def
test_shared_memory
():
def
check_shared_memory
(
dtype
):
...
...
tutorials/dev/low_level_custom_pass.py
View file @
d3277874
...
...
@@ -117,19 +117,20 @@ def vectorize8(op):
return
body
return
None
def
vectorize
(
stmt
):
@tvm.tir.transform.prim_func_pass
(
opt_level
=
0
)
def
vectorize
(
f
,
mod
,
ctx
):
global
loops
tvm
.
tir
.
ir_pass
.
PostOrderVisit
(
stmt
,
find_width8
)
tvm
.
tir
.
ir_pass
.
PostOrderVisit
(
f
.
body
,
find_width8
)
if
not
loops
:
return
s
tmt
return
s
f
# The last list arugment indicates what kinds of nodes will be transformed.
# Thus, in this case only `For` nodes will call `vectorize8`
stmt
=
tvm
.
tir
.
ir_pass
.
IRTransform
(
stmt
,
None
,
vectorize8
,
[
'For'
])
return
f
.
with_body
(
tvm
.
tir
.
ir_pass
.
IRTransform
(
f
.
body
,
None
,
vectorize8
,
[
'For'
]))
return
stmt
#####################################################################
# Glue to Lowering
...
...
vta/python/vta/build_module.py
View file @
d3277874
...
...
@@ -14,25 +14,22 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-argument
# pylint: disable=unused-argument
, invalid-name
"""VTA specific buildin for runtime."""
import
tvm
from
.
import
ir_pass
from
.
import
transform
from
.environment
import
get_env
def
lift_coproc_scope
(
x
):
"""Lift coprocessings cope to the """
x
=
ir_pass
.
lift_alloc_to_scope_begin
(
x
)
x
=
tvm
.
tir
.
ir_pass
.
LiftAttrScope
(
x
,
"coproc_scope"
,
False
)
return
x
def
early_rewrite
(
stmt
):
def
EarlyRewrite
():
"""Try to do storage rewrite in early pass."""
try
:
return
tvm
.
tir
.
ir_pass
.
StorageRewrite
(
stmt
)
except
tvm
.
error
.
TVMError
:
return
stmt
def
_transform
(
mod
,
ctx
):
try
:
return
tvm
.
tir
.
transform
.
StorageRewrite
()(
mod
)
except
tvm
.
error
.
TVMError
:
return
mod
return
tvm
.
transform
.
module_pass
(
_transform
,
opt_level
=
0
,
name
=
"tir.vta.EarlyRewrite"
)
def
build_config
(
debug_flag
=
0
,
**
kwargs
):
...
...
@@ -60,27 +57,32 @@ def build_config(debug_flag=0, **kwargs):
vta_module = tvm.build(s, ...)
"""
env
=
get_env
()
def
add_debug
(
stmt
):
@tvm.tir.transform.prim_func_pass
(
opt_level
=
0
)
def
add_debug
(
f
,
*
_
):
debug
=
tvm
.
tir
.
call_extern
(
"int32"
,
"VTASetDebugMode"
,
env
.
dev
.
command_handle
,
debug_flag
)
return
tvm
.
tir
.
stmt_seq
(
debug
,
stmt
)
pass_list
=
[(
0
,
ir_pass
.
inject_conv2d_transpose_skip
),
(
1
,
ir_pass
.
inject_dma_intrin
),
(
1
,
ir_pass
.
inject_skip_copy
),
(
1
,
ir_pass
.
annotate_alu_coproc_scope
),
(
1
,
lambda
x
:
tvm
.
tir
.
ir_pass
.
LiftAttrScope
(
x
,
"coproc_uop_scope"
,
True
)),
(
1
,
lift_coproc_scope
),
(
1
,
ir_pass
.
inject_coproc_sync
),
(
1
,
early_rewrite
)]
return
f
.
with_body
(
tvm
.
tir
.
stmt_seq
(
debug
,
f
.
body
))
pass_list
=
[(
0
,
transform
.
InjectConv2DTransposeSkip
()),
(
1
,
transform
.
InjectDMAIntrin
()),
(
1
,
transform
.
InjectSkipCopy
()),
(
1
,
transform
.
AnnotateALUCoProcScope
()),
(
1
,
tvm
.
tir
.
transform
.
LiftAttrScope
(
"coproc_uop_scope"
)),
(
1
,
transform
.
LiftAllocToScopeBegin
()),
(
1
,
tvm
.
tir
.
transform
.
LiftAttrScope
(
"coproc_scope"
)),
(
1
,
transform
.
InjectCoProcSync
()),
(
1
,
EarlyRewrite
())]
if
debug_flag
:
pass_list
.
append
((
1
,
add_debug
))
pass_list
.
append
((
2
,
ir_pass
.
inject_alu_intrin
))
pass_list
.
append
((
3
,
tvm
.
tir
.
ir_pass
.
LowerStorageAccessInfo
))
pass_list
.
append
((
3
,
ir_pass
.
fold_uop_loop
))
pass_list
.
append
((
3
,
ir_pass
.
cpu_access_rewrite
))
pass_list
.
append
((
2
,
transform
.
InjectALUIntrin
()
))
pass_list
.
append
((
3
,
tvm
.
tir
.
transform
.
LowerDeviceStorageAccessInfo
()
))
pass_list
.
append
((
3
,
transform
.
FoldUopLoop
()
))
pass_list
.
append
((
3
,
transform
.
CPUAccessRewrite
()
))
return
tvm
.
target
.
build_config
(
add_lower_pass
=
pass_list
,
**
kwargs
)
...
...
vta/python/vta/
ir_pass
.py
→
vta/python/vta/
transform
.py
View file @
d3277874
...
...
@@ -14,8 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Additional
IR Pass
for VTA"""
# pylint: disable=len-as-condition, no-else-return
"""Additional
Transformation Passes.
for VTA"""
# pylint: disable=len-as-condition, no-else-return
, unused-argument, invalid-name
import
tvm
from
tvm
import
te
from
topi
import
util
...
...
@@ -38,7 +38,7 @@ def _match_pragma(stmt, key):
(
stmt
.
attr_key
==
"pragma_scope"
and
stmt
.
value
.
value
==
key
))
def
fold_uop_loop
(
stmt_in
):
def
FoldUopLoop
(
):
"""Detect and fold uop loop.
VTA support uop programming model
...
...
@@ -46,18 +46,11 @@ def fold_uop_loop(stmt_in):
This pass detect the loop structure
and extract that into uop loop AST.
Parameters
----------
stmt_in : Stmt
Input statement
Returns
-------
stmt_out : Stmt
Output statement.
fpass : tvm.transform.Pass
The pass
"""
env
=
get_env
()
def
_fold_outermost_loop
(
body
):
stmt
=
body
if
not
isinstance
(
stmt
,
tvm
.
tir
.
For
):
...
...
@@ -109,6 +102,7 @@ def fold_uop_loop(stmt_in):
raise
ValueError
(
"Failed to fold the GEMM instructions.."
)
def
_do_fold
(
stmt
):
env
=
get_env
()
if
(
stmt
.
attr_key
==
"coproc_uop_scope"
and
isinstance
(
stmt
.
value
,
tvm
.
tir
.
StringImm
)
and
stmt
.
value
.
value
==
env
.
dev
.
vta_push_uop
.
value
):
...
...
@@ -135,12 +129,16 @@ def fold_uop_loop(stmt_in):
return
tvm
.
tir
.
AttrStmt
(
stmt
.
node
,
stmt
.
attr_key
,
stmt
.
value
,
body
)
return
None
out
=
tvm
.
tir
.
ir_pass
.
IRTransform
(
stmt_in
,
_do_fold
,
None
,
[
"AttrStmt"
])
return
out
def
_ftransform
(
f
,
mod
,
ctx
):
return
f
.
with_body
(
tvm
.
tir
.
ir_pass
.
IRTransform
(
f
.
body
,
_do_fold
,
None
,
[
"AttrStmt"
]))
return
tvm
.
tir
.
transform
.
prim_func_pass
(
_ftransform
,
opt_level
=
0
,
name
=
"tir.vta.FoldUopLoop"
)
def
cpu_access_rewrite
(
stmt_in
):
def
CPUAccessRewrite
(
):
"""Detect CPU access to VTA buffer and get address correctly.
VTA's buffer is an opaque handle that do not
...
...
@@ -148,189 +146,182 @@ def cpu_access_rewrite(stmt_in):
This pass detect CPU access and rewrite to use pointer
returned VTABufferCPUPtr for CPU access.
Parameters
----------
stmt_in : Stmt
Input statement
Returns
-------
stmt_out : Stmt
T
ransformed statement
fpass : tvm.transform.Pass
T
he pass
"""
env
=
get_env
()
rw_info
=
{}
def
_post_order
(
op
):
if
isinstance
(
op
,
tvm
.
tir
.
Allocate
):
buffer_var
=
op
.
buffer_var
if
not
buffer_var
in
rw_info
:
return
None
new_var
=
rw_info
[
buffer_var
]
let_stmt
=
tvm
.
tir
.
LetStmt
(
def
_ftransform
(
f
,
mod
,
ctx
):
rw_info
=
{}
env
=
get_env
()
def
_post_order
(
op
):
if
isinstance
(
op
,
tvm
.
tir
.
Allocate
):
buffer_var
=
op
.
buffer_var
if
not
buffer_var
in
rw_info
:
return
None
new_var
=
rw_info
[
buffer_var
]
let_stmt
=
tvm
.
tir
.
LetStmt
(
new_var
,
tvm
.
tir
.
call_extern
(
"handle"
,
"VTABufferCPUPtr"
,
env
.
dev
.
command_handle
,
buffer_var
),
op
.
body
)
alloc
=
tvm
.
tir
.
Allocate
(
buffer_var
,
op
.
dtype
,
op
.
extents
,
op
.
condition
,
let_stmt
)
del
rw_info
[
buffer_var
]
return
alloc
if
isinstance
(
op
,
tvm
.
tir
.
Load
):
buffer_var
=
op
.
buffer_var
if
not
buffer_var
in
rw_info
:
rw_info
[
buffer_var
]
=
te
.
var
(
buffer_var
.
name
+
"_ptr"
,
"handle"
)
new_var
=
rw_info
[
buffer_var
]
return
tvm
.
tir
.
Load
(
op
.
dtype
,
new_var
,
op
.
index
)
if
isinstance
(
op
,
tvm
.
tir
.
Store
):
buffer_var
=
op
.
buffer_var
if
not
buffer_var
in
rw_info
:
rw_info
[
buffer_var
]
=
te
.
var
(
buffer_var
.
name
+
"_ptr"
,
"handle"
)
new_var
=
rw_info
[
buffer_var
]
return
tvm
.
tir
.
Store
(
new_var
,
op
.
value
,
op
.
index
)
raise
RuntimeError
(
"not reached"
)
stmt_in
=
f
.
body
stmt
=
tvm
.
tir
.
ir_pass
.
IRTransform
(
stmt_in
,
None
,
_post_order
,
[
"Allocate"
,
"Load"
,
"Store"
])
for
buffer_var
,
new_var
in
rw_info
.
items
():
stmt
=
tvm
.
tir
.
LetStmt
(
new_var
,
tvm
.
tir
.
call_extern
(
"handle"
,
"VTABufferCPUPtr"
,
env
.
dev
.
command_handle
,
buffer_var
),
op
.
body
)
alloc
=
tvm
.
tir
.
Allocate
(
buffer_var
,
op
.
dtype
,
op
.
extents
,
op
.
condition
,
let_stmt
)
del
rw_info
[
buffer_var
]
return
alloc
if
isinstance
(
op
,
tvm
.
tir
.
Load
):
buffer_var
=
op
.
buffer_var
if
not
buffer_var
in
rw_info
:
rw_info
[
buffer_var
]
=
te
.
var
(
buffer_var
.
name
+
"_ptr"
,
"handle"
)
new_var
=
rw_info
[
buffer_var
]
return
tvm
.
tir
.
Load
(
op
.
dtype
,
new_var
,
op
.
index
)
if
isinstance
(
op
,
tvm
.
tir
.
Store
):
buffer_var
=
op
.
buffer_var
if
not
buffer_var
in
rw_info
:
rw_info
[
buffer_var
]
=
te
.
var
(
buffer_var
.
name
+
"_ptr"
,
"handle"
)
new_var
=
rw_info
[
buffer_var
]
return
tvm
.
tir
.
Store
(
new_var
,
op
.
value
,
op
.
index
)
raise
RuntimeError
(
"not reached"
)
stmt
=
tvm
.
tir
.
ir_pass
.
IRTransform
(
stmt_in
,
None
,
_post_order
,
[
"Allocate"
,
"Load"
,
"Store"
])
for
buffer_var
,
new_var
in
rw_info
.
items
():
stmt
=
tvm
.
tir
.
LetStmt
(
new_var
,
tvm
.
tir
.
call_extern
(
"handle"
,
"VTABufferCPUPtr"
,
env
.
dev
.
command_handle
,
buffer_var
),
stmt
)
return
stmt
buffer_var
),
stmt
)
return
f
.
with_body
(
stmt
)
return
tvm
.
tir
.
transform
.
prim_func_pass
(
_ftransform
,
opt_level
=
0
,
name
=
"tir.vta.CPUAccessRewrite"
)
def
lift_alloc_to_scope_begin
(
stmt_in
):
def
LiftAllocToScopeBegin
(
):
"""Lift allocate to beginning of the current scope.
Parameters
----------
stmt_in : Stmt
Input statement
Returns
-------
stmt_out : Stmt
T
ransformed statement
fpass : tvm.transform.Pass
T
he pass
"""
lift_stmt
=
[[]]
def
_merge_block
(
slist
,
body
):
for
op
in
slist
:
if
op
.
body
==
body
:
body
=
op
elif
isinstance
(
op
,
tvm
.
tir
.
Allocate
):
body
=
tvm
.
tir
.
Allocate
(
op
.
buffer_var
,
op
.
dtype
,
op
.
extents
,
op
.
condition
,
body
)
elif
isinstance
(
op
,
tvm
.
tir
.
AttrStmt
):
body
=
tvm
.
tir
.
AttrStmt
(
op
.
node
,
op
.
attr_key
,
op
.
value
,
body
)
elif
isinstance
(
op
,
tvm
.
tir
.
For
):
body
=
tvm
.
tir
.
For
(
op
.
loop_var
,
op
.
min
,
op
.
extent
,
op
.
for_type
,
op
.
device_api
,
body
)
else
:
raise
RuntimeError
(
"unexpected op"
)
del
slist
[:]
return
body
def
_pre_order
(
op
):
if
isinstance
(
op
,
tvm
.
tir
.
For
):
lift_stmt
.
append
([])
elif
isinstance
(
op
,
tvm
.
tir
.
AttrStmt
):
if
op
.
attr_key
==
"virtual_thread"
:
def
_ftransform
(
f
,
mod
,
ctx
):
lift_stmt
=
[[]]
def
_merge_block
(
slist
,
body
):
for
op
in
slist
:
if
op
.
body
==
body
:
body
=
op
elif
isinstance
(
op
,
tvm
.
tir
.
Allocate
):
body
=
tvm
.
tir
.
Allocate
(
op
.
buffer_var
,
op
.
dtype
,
op
.
extents
,
op
.
condition
,
body
)
elif
isinstance
(
op
,
tvm
.
tir
.
AttrStmt
):
body
=
tvm
.
tir
.
AttrStmt
(
op
.
node
,
op
.
attr_key
,
op
.
value
,
body
)
elif
isinstance
(
op
,
tvm
.
tir
.
For
):
body
=
tvm
.
tir
.
For
(
op
.
loop_var
,
op
.
min
,
op
.
extent
,
op
.
for_type
,
op
.
device_api
,
body
)
else
:
raise
RuntimeError
(
"unexpected op"
)
del
slist
[:]
return
body
def
_pre_order
(
op
):
if
isinstance
(
op
,
tvm
.
tir
.
For
):
lift_stmt
.
append
([])
elif
isinstance
(
op
,
tvm
.
tir
.
AttrStmt
):
if
op
.
attr_key
==
"virtual_thread"
:
lift_stmt
.
append
([])
def
_post_order
(
op
):
if
isinstance
(
op
,
tvm
.
tir
.
Allocate
):
lift_stmt
[
-
1
]
.
append
(
op
)
return
op
.
body
if
isinstance
(
op
,
tvm
.
tir
.
AttrStmt
):
if
op
.
attr_key
==
"storage_scope"
:
def
_post_order
(
op
):
if
isinstance
(
op
,
tvm
.
tir
.
Allocate
):
lift_stmt
[
-
1
]
.
append
(
op
)
return
op
.
body
if
op
.
attr_key
==
"virtual_thread"
:
if
isinstance
(
op
,
tvm
.
tir
.
AttrStmt
):
if
op
.
attr_key
==
"storage_scope"
:
lift_stmt
[
-
1
]
.
append
(
op
)
return
op
.
body
if
op
.
attr_key
==
"virtual_thread"
:
return
_merge_block
(
lift_stmt
.
pop
()
+
[
op
],
op
.
body
)
return
op
if
isinstance
(
op
,
tvm
.
tir
.
For
):
return
_merge_block
(
lift_stmt
.
pop
()
+
[
op
],
op
.
body
)
return
op
if
isinstance
(
op
,
tvm
.
tir
.
For
):
return
_merge_block
(
lift_stmt
.
pop
()
+
[
op
],
op
.
body
)
raise
RuntimeError
(
"not reached"
)
stmt
=
tvm
.
tir
.
ir_pass
.
IRTransform
(
stmt_in
,
_pre_order
,
_post_order
,
[
"Allocate"
,
"AttrStmt"
,
"For"
])
assert
len
(
lift_stmt
)
==
1
return
_merge_block
(
lift_stmt
[
0
],
stmt
)
raise
RuntimeError
(
"not reached"
)
stmt_in
=
f
.
body
stmt
=
tvm
.
tir
.
ir_pass
.
IRTransform
(
stmt_in
,
_pre_order
,
_post_order
,
[
"Allocate"
,
"AttrStmt"
,
"For"
])
assert
len
(
lift_stmt
)
==
1
return
f
.
with_body
(
_merge_block
(
lift_stmt
[
0
],
stmt
))
return
tvm
.
tir
.
transform
.
prim_func_pass
(
_ftransform
,
opt_level
=
0
,
name
=
"tir.vta.LiftAllocToScopeBegin"
)
def
inject_skip_copy
(
stmt_in
):
"""Pass to inject skip copy stmt, used for debug purpose.
Parameters
----------
stmt_in : Stmt
Input statement
def
InjectSkipCopy
():
"""Pass to inject skip copy stmt, used for debug purpose.
Returns
-------
stmt_out : Stmt
T
ransformed statement
fpass : tvm.transform.Pass
T
he pass
"""
def
_do_fold
(
stmt
):
if
_match_pragma
(
stmt
,
"skip_dma_copy"
):
return
tvm
.
tir
.
Evaluate
(
0
)
return
None
return
tvm
.
tir
.
ir_pass
.
IRTransform
(
stmt_in
,
_do_fold
,
None
,
[
"AttrStmt"
])
def
_ftransform
(
f
,
mod
,
ctx
):
return
f
.
with_body
(
tvm
.
tir
.
ir_pass
.
IRTransform
(
f
.
body
,
_do_fold
,
None
,
[
"AttrStmt"
]))
def
inject_coproc_sync
(
stmt_in
):
"""Pass to inject skip copy stmt, used in debug.
return
tvm
.
tir
.
transform
.
prim_func_pass
(
_ftransform
,
opt_level
=
0
,
name
=
"tir.vta.InjectSkipCopy"
)
Parameters
----------
stmt_in : Stmt
Input statement
def
InjectCoProcSync
():
"""Pass inject coproc sync
Returns
-------
stmt_out : Stmt
T
ransformed statement
fpass : tvm.transform.Pass
T
he pass
"""
success
=
[
False
]
def
_do_fold
(
stmt
):
if
_match_pragma
(
stmt
,
"coproc_sync"
):
success
[
0
]
=
True
sync
=
tvm
.
tir
.
Call
(
"int32"
,
"vta.coproc_sync"
,
[],
tvm
.
tir
.
Call
.
Intrinsic
,
None
,
0
)
return
tvm
.
tir
.
SeqStmt
([
stmt
.
body
,
tvm
.
tir
.
Evaluate
(
sync
)])
if
_match_pragma
(
stmt
,
"trim_loop"
):
op
=
stmt
.
body
assert
isinstance
(
op
,
tvm
.
tir
.
For
)
return
tvm
.
tir
.
For
(
op
.
loop_var
,
op
.
min
,
2
,
op
.
for_type
,
op
.
device_api
,
op
.
body
)
return
None
stmt
=
tvm
.
tir
.
ir_pass
.
IRTransform
(
stmt_in
,
None
,
_do_fold
,
[
"AttrStmt"
])
stmt
=
tvm
.
tir
.
ir_pass
.
CoProcSync
(
stmt
)
return
stmt
def
inject_dma_intrin
(
stmt_in
):
def
_ftransform
(
f
,
*
_
):
success
=
[
False
]
def
_do_fold
(
stmt
):
if
_match_pragma
(
stmt
,
"coproc_sync"
):
success
[
0
]
=
True
sync
=
tvm
.
tir
.
Call
(
"int32"
,
"vta.coproc_sync"
,
[],
tvm
.
tir
.
Call
.
Intrinsic
,
None
,
0
)
return
tvm
.
tir
.
SeqStmt
([
stmt
.
body
,
tvm
.
tir
.
Evaluate
(
sync
)])
if
_match_pragma
(
stmt
,
"trim_loop"
):
op
=
stmt
.
body
assert
isinstance
(
op
,
tvm
.
tir
.
For
)
return
tvm
.
tir
.
For
(
op
.
loop_var
,
op
.
min
,
2
,
op
.
for_type
,
op
.
device_api
,
op
.
body
)
return
None
return
f
.
with_body
(
tvm
.
tir
.
ir_pass
.
IRTransform
(
f
.
body
,
None
,
_do_fold
,
[
"AttrStmt"
]))
return
tvm
.
transform
.
Sequential
(
[
tvm
.
tir
.
transform
.
prim_func_pass
(
_ftransform
,
0
,
"tir.vta.InjectCoProcSync"
),
tvm
.
tir
.
transform
.
CoProcSync
()],
opt_level
=
0
,
name
=
"tir.vta.InjectCoProcSync"
)
def
InjectDMAIntrin
():
"""Pass to inject DMA copy intrinsics.
Parameters
----------
stmt_in : Stmt
Input statement
Returns
-------
stmt_out : Stmt
T
ransformed statement
fpass : tvm.transform.Pass
T
he pass
"""
env
=
get_env
()
idxd
=
tvm
.
tir
.
indexdiv
idxm
=
tvm
.
tir
.
indexmod
...
...
@@ -474,6 +465,7 @@ def inject_dma_intrin(stmt_in):
def
_inject_copy
(
src
,
dst
,
pad_before
,
pad_after
,
pad_value
):
# FIXME: pad_value is ignored...
env
=
get_env
()
_
=
pad_value
if
dst
.
scope
==
"global"
:
# Store
...
...
@@ -576,7 +568,7 @@ def inject_dma_intrin(stmt_in):
else
:
raise
RuntimeError
(
"Do not support copy
%
s->
%
s"
%
(
src
.
scope
,
dst
.
scope
))
return
tvm
.
tir
.
ir_pass
.
InjectCopyIntrin
(
stmt_in
,
"dma_copy"
,
_inject_copy
)
return
tvm
.
tir
.
transform
.
InjectCopyIntrin
(
"dma_copy"
,
_inject_copy
)
def
_get_gemm_intrin_buffer
():
...
...
@@ -619,377 +611,352 @@ def _get_gemm_intrin_buffer():
return
wgt_layout
,
inp_layout
,
out_layout
def
inject_conv2d_transpose_skip
(
stmt_in
):
def
InjectConv2DTransposeSkip
(
):
"""Pass to skip 0-weights in conv2d transpose with stride > 1.
Parameters
----------
stmt_in : Stmt
Input statement
Returns
-------
stmt_out : Stmt
T
ransformed statement
fpass : tvm.transform.Pass
T
he pass
"""
env
=
get_env
()
dwgt
,
dinp
,
dout
=
_get_gemm_intrin_buffer
()
calls
=
[]
selects
=
[]
def
_find_basics
(
op
):
if
isinstance
(
op
,
tvm
.
tir
.
BufferLoad
):
calls
.
append
(
op
)
elif
isinstance
(
op
,
tvm
.
tir
.
Select
):
selects
.
append
(
op
)
def
_do_fold
(
op
):
if
_match_pragma
(
op
,
"conv2d_transpose_gemm"
):
is_init
=
".init"
in
str
(
op
)
tvm
.
tir
.
ir_pass
.
PostOrderVisit
(
op
,
_find_basics
)
if
is_init
:
# create inner most block
irb
=
tvm
.
tir
.
ir_builder
.
create
()
dev
=
env
.
dev
irb
.
scope_attr
(
dev
.
vta_axis
,
"coproc_scope"
,
dev
.
get_task_qid
(
dev
.
QID_COMPUTE
))
irb
.
scope_attr
(
dev
.
vta_axis
,
"coproc_uop_scope"
,
dev
.
vta_push_uop
)
irb
.
emit
(
tvm
.
tir
.
call_extern
(
"int32"
,
"VTAUopPush"
,
0
,
1
,
dout
.
access_ptr
(
"rw"
,
"int32"
),
0
,
0
,
0
,
0
,
0
))
inner
=
irb
.
get
()
# TODO(@tmoreau89): This is only a temporary fix, please take a look.
body
=
op
.
body
.
body
while
isinstance
(
body
,
tvm
.
tir
.
IfThenElse
):
body
=
body
.
then_case
args
=
body
.
indices
res_buffer
=
body
.
buffer
tpl
=
(
args
[
0
],
1
,
args
[
1
],
1
,
args
[
2
],
1
,
args
[
3
],
1
,
0
,
1
,
0
,
env
.
BLOCK_OUT
)
inner
=
tvm
.
tir
.
AttrStmt
(
[
dout
,
res_buffer
],
'buffer_bind_scope'
,
tvm
.
tir
.
call_intrin
(
'handle'
,
'tvm_tuple'
,
*
tpl
),
inner
)
return
inner
else
:
conv_call
,
data_call
,
kernel_call
=
calls
[
-
3
:]
pad_data_tensor
=
data_call
.
buffer
kernel_tensor
=
kernel_call
.
buffer
res_tensor
=
conv_call
.
buffer
if
selects
:
condition
=
selects
[
0
]
.
condition
else
:
condition
=
tvm
.
tir
.
const
(
1
,
'int'
)
# create inner most block
irb
=
tvm
.
tir
.
ir_builder
.
create
()
with
irb
.
if_scope
(
condition
):
def
_ftransform
(
func
,
mod
,
ctx
):
env
=
get_env
()
dwgt
,
dinp
,
dout
=
_get_gemm_intrin_buffer
()
calls
=
[]
selects
=
[]
def
_find_basics
(
op
):
if
isinstance
(
op
,
tvm
.
tir
.
BufferLoad
):
calls
.
append
(
op
)
elif
isinstance
(
op
,
tvm
.
tir
.
Select
):
selects
.
append
(
op
)
def
_do_fold
(
op
):
if
_match_pragma
(
op
,
"conv2d_transpose_gemm"
):
is_init
=
".init"
in
str
(
op
)
tvm
.
tir
.
ir_pass
.
PostOrderVisit
(
op
,
_find_basics
)
if
is_init
:
# create inner most block
irb
=
tvm
.
tir
.
ir_builder
.
create
()
dev
=
env
.
dev
irb
.
scope_attr
(
dev
.
vta_axis
,
"coproc_scope"
,
dev
.
get_task_qid
(
dev
.
QID_COMPUTE
))
irb
.
scope_attr
(
dev
.
vta_axis
,
"coproc_uop_scope"
,
dev
.
vta_push_uop
)
irb
.
emit
(
tvm
.
tir
.
call_extern
(
"int32"
,
"VTAUopPush"
,
0
,
0
,
0
,
1
,
dout
.
access_ptr
(
"rw"
,
"int32"
),
dinp
.
access_ptr
(
"r"
,
"int32"
),
dwgt
.
access_ptr
(
"r"
,
"int32"
),
0
,
0
,
0
,
0
,
0
))
inner
=
irb
.
get
()
args
=
conv_call
.
indices
tpl
=
(
args
[
0
],
1
,
args
[
1
],
1
,
args
[
2
],
1
,
args
[
3
],
1
,
0
,
1
,
0
,
env
.
BLOCK_OUT
)
inner
=
tvm
.
tir
.
AttrStmt
(
[
dout
,
res_tensor
],
'buffer_bind_scope'
,
tvm
.
tir
.
call_intrin
(
'handle'
,
'tvm_tuple'
,
*
tpl
),
inner
)
args
=
kernel_call
.
indices
tpl
=
(
args
[
0
],
1
,
args
[
1
],
1
,
args
[
2
],
1
,
args
[
3
],
1
,
0
,
env
.
BLOCK_OUT
,
0
,
env
.
BLOCK_IN
)
inner
=
tvm
.
tir
.
AttrStmt
(
[
dwgt
,
kernel_tensor
],
'buffer_bind_scope'
,
tvm
.
tir
.
call_intrin
(
'handle'
,
'tvm_tuple'
,
*
tpl
),
inner
)
args
=
data_call
.
indices
tpl
=
(
args
[
0
],
1
,
args
[
1
],
1
,
args
[
2
],
1
,
args
[
3
],
1
,
0
,
1
,
0
,
env
.
BLOCK_IN
)
inner
=
tvm
.
tir
.
AttrStmt
(
[
dinp
,
pad_data_tensor
],
'buffer_bind_scope'
,
tvm
.
tir
.
call_intrin
(
'handle'
,
'tvm_tuple'
,
*
tpl
),
inner
)
return
inner
return
None
ret
=
tvm
.
tir
.
ir_pass
.
IRTransform
(
stmt_in
,
_do_fold
,
None
,
[
"AttrStmt"
])
return
ret
inner
=
irb
.
get
()
# TODO(@tmoreau89): This is only a temporary fix, please take a look.
body
=
op
.
body
.
body
while
isinstance
(
body
,
tvm
.
tir
.
IfThenElse
):
body
=
body
.
then_case
args
=
body
.
indices
res_buffer
=
body
.
buffer
tpl
=
(
args
[
0
],
1
,
args
[
1
],
1
,
args
[
2
],
1
,
args
[
3
],
1
,
0
,
1
,
0
,
env
.
BLOCK_OUT
)
inner
=
tvm
.
tir
.
AttrStmt
(
[
dout
,
res_buffer
],
'buffer_bind_scope'
,
tvm
.
tir
.
call_intrin
(
'handle'
,
'tvm_tuple'
,
*
tpl
),
inner
)
return
inner
else
:
conv_call
,
data_call
,
kernel_call
=
calls
[
-
3
:]
pad_data_tensor
=
data_call
.
buffer
kernel_tensor
=
kernel_call
.
buffer
res_tensor
=
conv_call
.
buffer
def
annotate_alu_coproc_scope
(
stmt_in
):
if
selects
:
condition
=
selects
[
0
]
.
condition
else
:
condition
=
tvm
.
tir
.
const
(
1
,
'int'
)
# create inner most block
irb
=
tvm
.
tir
.
ir_builder
.
create
()
with
irb
.
if_scope
(
condition
):
dev
=
env
.
dev
irb
.
scope_attr
(
dev
.
vta_axis
,
"coproc_scope"
,
dev
.
get_task_qid
(
dev
.
QID_COMPUTE
))
irb
.
scope_attr
(
dev
.
vta_axis
,
"coproc_uop_scope"
,
dev
.
vta_push_uop
)
irb
.
emit
(
tvm
.
tir
.
call_extern
(
"int32"
,
"VTAUopPush"
,
0
,
0
,
dout
.
access_ptr
(
"rw"
,
"int32"
),
dinp
.
access_ptr
(
"r"
,
"int32"
),
dwgt
.
access_ptr
(
"r"
,
"int32"
),
0
,
0
,
0
))
inner
=
irb
.
get
()
args
=
conv_call
.
indices
tpl
=
(
args
[
0
],
1
,
args
[
1
],
1
,
args
[
2
],
1
,
args
[
3
],
1
,
0
,
1
,
0
,
env
.
BLOCK_OUT
)
inner
=
tvm
.
tir
.
AttrStmt
(
[
dout
,
res_tensor
],
'buffer_bind_scope'
,
tvm
.
tir
.
call_intrin
(
'handle'
,
'tvm_tuple'
,
*
tpl
),
inner
)
args
=
kernel_call
.
indices
tpl
=
(
args
[
0
],
1
,
args
[
1
],
1
,
args
[
2
],
1
,
args
[
3
],
1
,
0
,
env
.
BLOCK_OUT
,
0
,
env
.
BLOCK_IN
)
inner
=
tvm
.
tir
.
AttrStmt
(
[
dwgt
,
kernel_tensor
],
'buffer_bind_scope'
,
tvm
.
tir
.
call_intrin
(
'handle'
,
'tvm_tuple'
,
*
tpl
),
inner
)
args
=
data_call
.
indices
tpl
=
(
args
[
0
],
1
,
args
[
1
],
1
,
args
[
2
],
1
,
args
[
3
],
1
,
0
,
1
,
0
,
env
.
BLOCK_IN
)
inner
=
tvm
.
tir
.
AttrStmt
(
[
dinp
,
pad_data_tensor
],
'buffer_bind_scope'
,
tvm
.
tir
.
call_intrin
(
'handle'
,
'tvm_tuple'
,
*
tpl
),
inner
)
return
inner
return
None
return
func
.
with_body
(
tvm
.
tir
.
ir_pass
.
IRTransform
(
func
.
body
,
_do_fold
,
None
,
[
"AttrStmt"
]))
return
tvm
.
tir
.
transform
.
prim_func_pass
(
_ftransform
,
opt_level
=
0
,
name
=
"tir.vta.InjectConv2DTrasnposeSkip"
)
def
AnnotateALUCoProcScope
():
"""Pass to insert ALU instruction.
Parameters
----------
stmt_in : Stmt
Input statement
Returns
-------
stmt_out : Stmt
T
ransformed statement
fpass : tvm.transform.Pass
T
he pass
"""
env
=
get_env
()
def
_do_fold
(
stmt
):
if
_match_pragma
(
stmt
,
"alu"
):
irb
=
tvm
.
tir
.
ir_builder
.
create
()
irb
.
scope_attr
(
env
.
dev
.
vta_axis
,
"coproc_scope"
,
env
.
dev
.
get_task_qid
(
env
.
dev
.
QID_COMPUTE
))
irb
.
scope_attr
(
env
.
dev
.
vta_axis
,
"coproc_uop_scope"
,
tvm
.
tir
.
StringImm
(
"VTAPushALUOp"
))
irb
.
emit
(
stmt
)
return
irb
.
get
()
if
_match_pragma
(
stmt
,
"skip_alu"
):
return
tvm
.
tir
.
Evaluate
(
0
)
return
stmt
stmt_out
=
tvm
.
tir
.
ir_pass
.
IRTransform
(
stmt_in
,
None
,
_do_fold
,
[
"AttrStmt"
])
return
stmt_out
def
inject_alu_intrin
(
stmt_in
):
def
_ftransform
(
func
,
mod
,
ctx
):
env
=
get_env
()
def
_do_fold
(
stmt
):
if
_match_pragma
(
stmt
,
"alu"
):
irb
=
tvm
.
tir
.
ir_builder
.
create
()
irb
.
scope_attr
(
env
.
dev
.
vta_axis
,
"coproc_scope"
,
env
.
dev
.
get_task_qid
(
env
.
dev
.
QID_COMPUTE
))
irb
.
scope_attr
(
env
.
dev
.
vta_axis
,
"coproc_uop_scope"
,
tvm
.
tir
.
StringImm
(
"VTAPushALUOp"
))
irb
.
emit
(
stmt
)
return
irb
.
get
()
if
_match_pragma
(
stmt
,
"skip_alu"
):
return
tvm
.
tir
.
Evaluate
(
0
)
return
stmt
return
func
.
with_body
(
tvm
.
tir
.
ir_pass
.
IRTransform
(
func
.
body
,
None
,
_do_fold
,
[
"AttrStmt"
]))
return
tvm
.
tir
.
transform
.
prim_func_pass
(
_ftransform
,
opt_level
=
0
,
name
=
"tir.vta.AnnotateALUCoProcScope"
)
def
InjectALUIntrin
():
"""Pass to inject ALU micro-ops.
Parameters
----------
stmt_in : Stmt
Input statement
Returns
-------
stmt_out : Stmt
T
ransformed statement
fpass : tvm.transform.Pass
T
he pass
"""
env
=
get_env
()
idxm
=
tvm
.
tir
.
indexmod
analyzer
=
tvm
.
arith
.
Analyzer
()
def
_ftransform
(
func
,
mod
,
ctx
):
env
=
get_env
()
idxm
=
tvm
.
tir
.
indexmod
analyzer
=
tvm
.
arith
.
Analyzer
()
def
_do_fold
(
stmt
):
def
_equal
(
x
,
y
):
return
tvm
.
ir
.
structural_equal
(
analyzer
.
simplify
(
x
-
y
),
0
)
def
_flatten_loop
(
src_coeff
,
dst_coeff
,
extents
):
src_coeff
=
list
(
src_coeff
)
dst_coeff
=
list
(
dst_coeff
)
extents
=
list
(
extents
)
rev_src_coeff
=
[
src_coeff
.
pop
()]
rev_dst_coeff
=
[
dst_coeff
.
pop
()]
rev_extents
=
[]
assert
src_coeff
vsrc
=
src_coeff
.
pop
()
vdst
=
dst_coeff
.
pop
()
vext
=
extents
.
pop
()
while
src_coeff
:
next_src
=
src_coeff
.
pop
()
next_dst
=
dst_coeff
.
pop
()
next_ext
=
extents
.
pop
()
if
_equal
(
next_src
,
vsrc
*
vext
)
and
_equal
(
next_dst
,
vdst
*
vext
):
vext
=
analyzer
.
simplify
(
vext
*
next_ext
)
else
:
rev_src_coeff
.
append
(
vsrc
)
rev_dst_coeff
.
append
(
vdst
)
rev_extents
.
append
(
vext
)
vsrc
=
next_src
vdst
=
next_dst
vext
=
next_ext
rev_src_coeff
.
append
(
vsrc
)
rev_dst_coeff
.
append
(
vdst
)
rev_extents
.
append
(
vext
)
rev_src_coeff
.
reverse
()
rev_dst_coeff
.
reverse
()
rev_extents
.
reverse
()
return
rev_src_coeff
,
rev_dst_coeff
,
rev_extents
if
_match_pragma
(
stmt
,
"alu"
):
# Get to the innermost loop body
loop_body
=
stmt
.
body
nest_size
=
0
while
isinstance
(
loop_body
,
tvm
.
tir
.
For
):
loop_body
=
loop_body
.
body
nest_size
+=
1
# Get the src/dst arguments
dst_var
=
loop_body
.
buffer_var
dst_idx
=
loop_body
.
index
# Derive loop variables and extents
tmp_body
=
stmt
.
body
indices
=
[]
extents
=
[]
for
_
in
range
(
nest_size
):
indices
.
append
(
tmp_body
.
loop_var
)
extents
.
append
(
tmp_body
.
extent
)
tmp_body
=
tmp_body
.
body
# Derive opcode
if
isinstance
(
loop_body
.
value
,
tvm
.
tir
.
Add
):
alu_opcode
=
env
.
dev
.
ALU_OPCODE_ADD
lhs
=
loop_body
.
value
.
a
rhs
=
loop_body
.
value
.
b
elif
isinstance
(
loop_body
.
value
,
tvm
.
tir
.
Sub
):
alu_opcode
=
env
.
dev
.
ALU_OPCODE_SUB
lhs
=
loop_body
.
value
.
a
rhs
=
loop_body
.
value
.
b
elif
isinstance
(
loop_body
.
value
,
tvm
.
tir
.
Mul
):
alu_opcode
=
env
.
dev
.
ALU_OPCODE_MUL
lhs
=
loop_body
.
value
.
a
rhs
=
loop_body
.
value
.
b
elif
isinstance
(
loop_body
.
value
,
tvm
.
tir
.
Min
):
alu_opcode
=
env
.
dev
.
ALU_OPCODE_MIN
lhs
=
loop_body
.
value
.
a
rhs
=
loop_body
.
value
.
b
elif
isinstance
(
loop_body
.
value
,
tvm
.
tir
.
Max
):
alu_opcode
=
env
.
dev
.
ALU_OPCODE_MAX
lhs
=
loop_body
.
value
.
a
rhs
=
loop_body
.
value
.
b
elif
isinstance
(
loop_body
.
value
,
tvm
.
tir
.
Call
):
if
loop_body
.
value
.
name
==
'shift_left'
:
alu_opcode
=
env
.
dev
.
ALU_OPCODE_SHR
lhs
=
loop_body
.
value
.
args
[
0
]
rhs
=
analyzer
.
simplify
(
-
loop_body
.
value
.
args
[
1
])
elif
loop_body
.
value
.
name
==
'shift_right'
:
def
_do_fold
(
stmt
):
def
_equal
(
x
,
y
):
return
tvm
.
ir
.
structural_equal
(
analyzer
.
simplify
(
x
-
y
),
0
)
def
_flatten_loop
(
src_coeff
,
dst_coeff
,
extents
):
src_coeff
=
list
(
src_coeff
)
dst_coeff
=
list
(
dst_coeff
)
extents
=
list
(
extents
)
rev_src_coeff
=
[
src_coeff
.
pop
()]
rev_dst_coeff
=
[
dst_coeff
.
pop
()]
rev_extents
=
[]
assert
src_coeff
vsrc
=
src_coeff
.
pop
()
vdst
=
dst_coeff
.
pop
()
vext
=
extents
.
pop
()
while
src_coeff
:
next_src
=
src_coeff
.
pop
()
next_dst
=
dst_coeff
.
pop
()
next_ext
=
extents
.
pop
()
if
_equal
(
next_src
,
vsrc
*
vext
)
and
_equal
(
next_dst
,
vdst
*
vext
):
vext
=
analyzer
.
simplify
(
vext
*
next_ext
)
else
:
rev_src_coeff
.
append
(
vsrc
)
rev_dst_coeff
.
append
(
vdst
)
rev_extents
.
append
(
vext
)
vsrc
=
next_src
vdst
=
next_dst
vext
=
next_ext
rev_src_coeff
.
append
(
vsrc
)
rev_dst_coeff
.
append
(
vdst
)
rev_extents
.
append
(
vext
)
rev_src_coeff
.
reverse
()
rev_dst_coeff
.
reverse
()
rev_extents
.
reverse
()
return
rev_src_coeff
,
rev_dst_coeff
,
rev_extents
if
_match_pragma
(
stmt
,
"alu"
):
# Get to the innermost loop body
loop_body
=
stmt
.
body
nest_size
=
0
while
isinstance
(
loop_body
,
tvm
.
tir
.
For
):
loop_body
=
loop_body
.
body
nest_size
+=
1
# Get the src/dst arguments
dst_var
=
loop_body
.
buffer_var
dst_idx
=
loop_body
.
index
# Derive loop variables and extents
tmp_body
=
stmt
.
body
indices
=
[]
extents
=
[]
for
_
in
range
(
nest_size
):
indices
.
append
(
tmp_body
.
loop_var
)
extents
.
append
(
tmp_body
.
extent
)
tmp_body
=
tmp_body
.
body
# Derive opcode
if
isinstance
(
loop_body
.
value
,
tvm
.
tir
.
Add
):
alu_opcode
=
env
.
dev
.
ALU_OPCODE_ADD
lhs
=
loop_body
.
value
.
a
rhs
=
loop_body
.
value
.
b
elif
isinstance
(
loop_body
.
value
,
tvm
.
tir
.
Sub
):
alu_opcode
=
env
.
dev
.
ALU_OPCODE_SUB
lhs
=
loop_body
.
value
.
a
rhs
=
loop_body
.
value
.
b
elif
isinstance
(
loop_body
.
value
,
tvm
.
tir
.
Mul
):
alu_opcode
=
env
.
dev
.
ALU_OPCODE_MUL
lhs
=
loop_body
.
value
.
a
rhs
=
loop_body
.
value
.
b
elif
isinstance
(
loop_body
.
value
,
tvm
.
tir
.
Min
):
alu_opcode
=
env
.
dev
.
ALU_OPCODE_MIN
lhs
=
loop_body
.
value
.
a
rhs
=
loop_body
.
value
.
b
elif
isinstance
(
loop_body
.
value
,
tvm
.
tir
.
Max
):
alu_opcode
=
env
.
dev
.
ALU_OPCODE_MAX
lhs
=
loop_body
.
value
.
a
rhs
=
loop_body
.
value
.
b
elif
isinstance
(
loop_body
.
value
,
tvm
.
tir
.
Call
):
if
loop_body
.
value
.
name
==
'shift_left'
:
alu_opcode
=
env
.
dev
.
ALU_OPCODE_SHR
lhs
=
loop_body
.
value
.
args
[
0
]
rhs
=
analyzer
.
simplify
(
-
loop_body
.
value
.
args
[
1
])
elif
loop_body
.
value
.
name
==
'shift_right'
:
alu_opcode
=
env
.
dev
.
ALU_OPCODE_SHR
lhs
=
loop_body
.
value
.
args
[
0
]
rhs
=
loop_body
.
value
.
args
[
1
]
else
:
raise
RuntimeError
(
"Function call not recognized
%
s"
%
(
loop_body
.
value
.
name
))
elif
isinstance
(
loop_body
.
value
,
tvm
.
tir
.
Load
):
alu_opcode
=
env
.
dev
.
ALU_OPCODE_SHR
lhs
=
loop_body
.
value
.
args
[
0
]
rhs
=
loop_body
.
value
.
args
[
1
]
lhs
=
loop_body
.
value
rhs
=
tvm
.
tir
.
const
(
0
,
"int32"
)
else
:
raise
RuntimeError
(
"Function call not recognized
%
s"
%
(
loop_body
.
value
.
name
))
elif
isinstance
(
loop_body
.
value
,
tvm
.
tir
.
Load
):
alu_opcode
=
env
.
dev
.
ALU_OPCODE_SHR
lhs
=
loop_body
.
value
rhs
=
tvm
.
tir
.
const
(
0
,
"int32"
)
else
:
raise
RuntimeError
(
"Expression not recognized
%
s,
%
s,
%
s"
%
(
type
(
loop_body
.
value
),
str
(
loop_body
.
value
),
str
(
stmt
)))
# Derive array index coefficients
dst_coeff
=
tvm
.
arith
.
detect_linear_equation
(
dst_idx
,
indices
)
# Check if lhs/rhs is immediate
use_imm
=
False
imm_val
=
None
if
isinstance
(
rhs
,
tvm
.
tir
.
IntImm
):
assert
lhs
.
buffer_var
.
same_as
(
dst_var
)
src_coeff
=
tvm
.
arith
.
detect_linear_equation
(
lhs
.
index
,
indices
)
use_imm
=
True
imm_val
=
rhs
if
isinstance
(
lhs
,
tvm
.
tir
.
IntImm
):
assert
rhs
.
buffer_var
.
same_as
(
dst_var
)
src_coeff
=
tvm
.
arith
.
detect_linear_equation
(
rhs
.
index
,
indices
)
use_imm
=
True
imm_val
=
lhs
if
imm_val
is
None
:
imm_val
=
0
assert
lhs
.
buffer_var
.
same_as
(
dst_var
)
and
rhs
.
buffer_var
.
same_as
(
dst_var
)
src_lhs_coeff
=
tvm
.
arith
.
detect_linear_equation
(
lhs
.
index
,
indices
)
src_rhs_coeff
=
tvm
.
arith
.
detect_linear_equation
(
rhs
.
index
,
indices
)
# Determine which side has the same coefficients
lhs_equal
=
True
rhs_equal
=
True
for
i
,
coef
in
enumerate
(
dst_coeff
):
if
not
tvm
.
ir
.
structural_equal
(
coef
,
src_lhs_coeff
[
i
]):
lhs_equal
=
False
if
not
tvm
.
ir
.
structural_equal
(
coef
,
src_rhs_coeff
[
i
]):
rhs_equal
=
False
# Make sure at least one of the source is identical to the
# destination (in-place computation)
assert
lhs_equal
or
rhs_equal
# Assign the source coefficients
if
lhs_equal
:
src_coeff
=
src_rhs_coeff
"Expression not recognized
%
s,
%
s,
%
s"
%
(
type
(
loop_body
.
value
),
str
(
loop_body
.
value
),
str
(
stmt
)))
# Derive array index coefficients
dst_coeff
=
tvm
.
arith
.
detect_linear_equation
(
dst_idx
,
indices
)
# Check if lhs/rhs is immediate
use_imm
=
False
imm_val
=
None
if
isinstance
(
rhs
,
tvm
.
tir
.
IntImm
):
assert
lhs
.
buffer_var
.
same_as
(
dst_var
)
src_coeff
=
tvm
.
arith
.
detect_linear_equation
(
lhs
.
index
,
indices
)
use_imm
=
True
imm_val
=
rhs
if
isinstance
(
lhs
,
tvm
.
tir
.
IntImm
):
assert
rhs
.
buffer_var
.
same_as
(
dst_var
)
src_coeff
=
tvm
.
arith
.
detect_linear_equation
(
rhs
.
index
,
indices
)
use_imm
=
True
imm_val
=
lhs
if
imm_val
is
None
:
imm_val
=
0
assert
lhs
.
buffer_var
.
same_as
(
dst_var
)
and
rhs
.
buffer_var
.
same_as
(
dst_var
)
src_lhs_coeff
=
tvm
.
arith
.
detect_linear_equation
(
lhs
.
index
,
indices
)
src_rhs_coeff
=
tvm
.
arith
.
detect_linear_equation
(
rhs
.
index
,
indices
)
# Determine which side has the same coefficients
lhs_equal
=
True
rhs_equal
=
True
for
i
,
coef
in
enumerate
(
dst_coeff
):
if
not
tvm
.
ir
.
structural_equal
(
coef
,
src_lhs_coeff
[
i
]):
lhs_equal
=
False
if
not
tvm
.
ir
.
structural_equal
(
coef
,
src_rhs_coeff
[
i
]):
rhs_equal
=
False
# Make sure at least one of the source is identical to the
# destination (in-place computation)
assert
lhs_equal
or
rhs_equal
# Assign the source coefficients
if
lhs_equal
:
src_coeff
=
src_rhs_coeff
else
:
src_coeff
=
src_lhs_coeff
# Ensure that we have the proper tensor dimensions in the
# innermost loop (pattern match)
src_coeff
=
list
(
src_coeff
)
dst_coeff
=
list
(
dst_coeff
)
extents
=
list
(
extents
)
assert
len
(
src_coeff
)
>
1
assert
len
(
dst_coeff
)
>
1
assert
len
(
extents
)
!=
0
assert
tvm
.
ir
.
structural_equal
(
analyzer
.
simplify
(
idxm
(
src_coeff
[
-
1
],
env
.
BATCH
*
env
.
BLOCK_OUT
)),
0
)
assert
tvm
.
ir
.
structural_equal
(
analyzer
.
simplify
(
idxm
(
dst_coeff
[
-
1
],
env
.
BATCH
*
env
.
BLOCK_OUT
)),
0
)
assert
tvm
.
ir
.
structural_equal
(
src_coeff
[
-
2
],
1
)
assert
tvm
.
ir
.
structural_equal
(
dst_coeff
[
-
2
],
1
)
if
env
.
BATCH
>
1
:
assert
len
(
src_coeff
)
>
2
assert
len
(
dst_coeff
)
>
2
assert
len
(
extents
)
>
1
assert
tvm
.
ir
.
structural_equal
(
src_coeff
[
-
3
],
env
.
BLOCK_OUT
)
assert
tvm
.
ir
.
structural_equal
(
dst_coeff
[
-
3
],
env
.
BLOCK_OUT
)
# Apply tensorization of the loop coefficients
src_offset
=
src_coeff
[
-
1
]
dst_offset
=
dst_coeff
[
-
1
]
if
env
.
BATCH
==
1
:
src_coeff
=
src_coeff
[:
-
2
]
dst_coeff
=
dst_coeff
[:
-
2
]
extents
=
extents
[:
-
1
]
else
:
src_coeff
=
src_lhs_coeff
# Ensure that we have the proper tensor dimensions in the
# innermost loop (pattern match)
src_coeff
=
list
(
src_coeff
)
dst_coeff
=
list
(
dst_coeff
)
extents
=
list
(
extents
)
assert
len
(
src_coeff
)
>
1
assert
len
(
dst_coeff
)
>
1
assert
len
(
extents
)
!=
0
assert
tvm
.
ir
.
structural_equal
(
analyzer
.
simplify
(
idxm
(
src_coeff
[
-
1
],
env
.
BATCH
*
env
.
BLOCK_OUT
)),
0
)
assert
tvm
.
ir
.
structural_equal
(
analyzer
.
simplify
(
idxm
(
dst_coeff
[
-
1
],
env
.
BATCH
*
env
.
BLOCK_OUT
)),
0
)
assert
tvm
.
ir
.
structural_equal
(
src_coeff
[
-
2
],
1
)
assert
tvm
.
ir
.
structural_equal
(
dst_coeff
[
-
2
],
1
)
if
env
.
BATCH
>
1
:
assert
len
(
src_coeff
)
>
2
assert
len
(
dst_coeff
)
>
2
assert
len
(
extents
)
>
1
assert
tvm
.
ir
.
structural_equal
(
src_coeff
[
-
3
],
env
.
BLOCK_OUT
)
assert
tvm
.
ir
.
structural_equal
(
dst_coeff
[
-
3
],
env
.
BLOCK_OUT
)
# Apply tensorization of the loop coefficients
src_offset
=
src_coeff
[
-
1
]
dst_offset
=
dst_coeff
[
-
1
]
if
env
.
BATCH
==
1
:
src_coeff
=
src_coeff
[:
-
2
]
dst_coeff
=
dst_coeff
[:
-
2
]
extents
=
extents
[:
-
1
]
else
:
src_coeff
=
src_coeff
[:
-
3
]
dst_coeff
=
dst_coeff
[:
-
3
]
extents
=
extents
[:
-
2
]
src_coeff
.
append
(
src_offset
)
dst_coeff
.
append
(
dst_offset
)
src_coeff
=
[
analyzer
.
simplify
(
c
//
(
env
.
BATCH
*
env
.
BLOCK_OUT
))
for
c
in
src_coeff
]
dst_coeff
=
[
analyzer
.
simplify
(
c
//
(
env
.
BATCH
*
env
.
BLOCK_OUT
))
for
c
in
dst_coeff
]
# Flatten the outer loops
if
extents
:
src_coeff
,
dst_coeff
,
extents
=
_flatten_loop
(
src_coeff
,
dst_coeff
,
extents
)
# Insert ALU micro-ops
irb
=
tvm
.
tir
.
ir_builder
.
create
()
for
idx
,
extent
in
enumerate
(
extents
):
irb
.
emit
(
tvm
.
tir
.
call_extern
(
"int32"
,
"VTAUopLoopBegin"
,
extent
,
dst_coeff
[
idx
],
src_coeff
[
idx
],
0
))
use_imm
=
int
(
use_imm
)
irb
.
emit
(
tvm
.
tir
.
call_extern
(
"int32"
,
"VTAUopPush"
,
1
,
0
,
dst_coeff
[
len
(
dst_coeff
)
-
1
],
src_coeff
[
len
(
src_coeff
)
-
1
],
0
,
alu_opcode
,
use_imm
,
imm_val
))
for
extent
in
extents
:
src_coeff
=
src_coeff
[:
-
3
]
dst_coeff
=
dst_coeff
[:
-
3
]
extents
=
extents
[:
-
2
]
src_coeff
.
append
(
src_offset
)
dst_coeff
.
append
(
dst_offset
)
src_coeff
=
[
analyzer
.
simplify
(
c
//
(
env
.
BATCH
*
env
.
BLOCK_OUT
))
for
c
in
src_coeff
]
dst_coeff
=
[
analyzer
.
simplify
(
c
//
(
env
.
BATCH
*
env
.
BLOCK_OUT
))
for
c
in
dst_coeff
]
# Flatten the outer loops
if
extents
:
src_coeff
,
dst_coeff
,
extents
=
_flatten_loop
(
src_coeff
,
dst_coeff
,
extents
)
# Insert ALU micro-ops
irb
=
tvm
.
tir
.
ir_builder
.
create
()
for
idx
,
extent
in
enumerate
(
extents
):
irb
.
emit
(
tvm
.
tir
.
call_extern
(
"int32"
,
"VTAUopLoopBegin"
,
extent
,
dst_coeff
[
idx
],
src_coeff
[
idx
],
0
))
use_imm
=
int
(
use_imm
)
irb
.
emit
(
tvm
.
tir
.
call_extern
(
"int32"
,
"VTAUopLoopEnd"
))
return
irb
.
get
()
return
stmt
stmt_out
=
tvm
.
tir
.
ir_pass
.
IRTransform
(
stmt_in
,
None
,
_do_fold
,
[
"AttrStmt"
])
return
stmt_out
def
debug_print
(
stmt
):
"""A debug pass that print the stmt
Parameters
----------
stmt : Stmt
The input statement
Returns
-------
stmt : Stmt
The
"""
# pylint: disable=superfluous-parens
print
(
stmt
)
return
stmt
"int32"
,
"VTAUopPush"
,
1
,
0
,
dst_coeff
[
len
(
dst_coeff
)
-
1
],
src_coeff
[
len
(
src_coeff
)
-
1
],
0
,
alu_opcode
,
use_imm
,
imm_val
))
for
extent
in
extents
:
irb
.
emit
(
tvm
.
tir
.
call_extern
(
"int32"
,
"VTAUopLoopEnd"
))
return
irb
.
get
()
return
stmt
return
func
.
with_body
(
tvm
.
tir
.
ir_pass
.
IRTransform
(
func
.
body
,
None
,
_do_fold
,
[
"AttrStmt"
]))
return
tvm
.
tir
.
transform
.
prim_func_pass
(
_ftransform
,
opt_level
=
0
,
name
=
"tir.vta.InjectALUIntrin"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment