Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
d3277874
Unverified
Commit
d3277874
authored
Apr 21, 2020
by
Tianqi Chen
Committed by
GitHub
Apr 21, 2020
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[PTYTHON] Migrate VTA TIR passes to the new pass manager. (#5397)
parent
72f2aea2
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
560 additions
and
583 deletions
+560
-583
include/tvm/target/target.h
+3
-2
python/tvm/autotvm/measure/measure_methods.py
+4
-4
python/tvm/driver/build_module.py
+5
-24
python/tvm/tir/function.py
+16
-0
src/target/target.cc
+2
-2
tests/python/relay/test_pass_fold_constant.py
+5
-3
tests/python/unittest/test_target_codegen_cuda.py
+7
-3
tests/python/unittest/test_target_codegen_llvm.py
+7
-4
tests/python/unittest/test_tir_pass_verify_gpu_code.py
+4
-4
tutorials/dev/low_level_custom_pass.py
+6
-5
vta/python/vta/build_module.py
+29
-27
vta/python/vta/transform.py
+472
-505
No files found.
include/tvm/target/target.h
View file @
d3277874
...
@@ -27,6 +27,7 @@
...
@@ -27,6 +27,7 @@
#include <tvm/support/with.h>
#include <tvm/support/with.h>
#include <tvm/node/container.h>
#include <tvm/node/container.h>
#include <tvm/ir/expr.h>
#include <tvm/ir/expr.h>
#include <tvm/ir/transform.h>
#include <string>
#include <string>
#include <vector>
#include <vector>
...
@@ -225,8 +226,8 @@ class BuildConfigNode : public Object {
...
@@ -225,8 +226,8 @@ class BuildConfigNode : public Object {
/*! \brief Whether to partition const loop */
/*! \brief Whether to partition const loop */
bool
partition_const_loop
=
false
;
bool
partition_const_loop
=
false
;
/*! \brief
Whether to dump the IR of each pass (only when building from python)
*/
/*! \brief
List of passes to be injected into the low-level pipeline.
*/
std
::
vector
<
std
::
pair
<
int
,
runtime
::
PackedFunc
>
>
add_lower_pass
;
std
::
vector
<
std
::
pair
<
int
,
transform
::
Pass
>
>
add_lower_pass
;
/*! \brief Whether to dump the IR of each pass (only when building from python) */
/*! \brief Whether to dump the IR of each pass (only when building from python) */
bool
dump_pass_ir
=
false
;
bool
dump_pass_ir
=
false
;
...
...
python/tvm/autotvm/measure/measure_methods.py
View file @
d3277874
...
@@ -615,9 +615,9 @@ def gpu_verify_pass(**kwargs):
...
@@ -615,9 +615,9 @@ def gpu_verify_pass(**kwargs):
"""Verify the validity of a gpu kernel.
"""Verify the validity of a gpu kernel.
This pass will check memory usage and number of threads per block.
This pass will check memory usage and number of threads per block.
"""
"""
def
verify_pass
(
stmt
):
def
verify_pass
(
f
,
*
_
):
valid
=
ir_pass
.
VerifyGPUCode
(
stmt
,
kwargs
)
valid
=
ir_pass
.
VerifyGPUCode
(
f
.
body
,
kwargs
)
if
not
valid
:
if
not
valid
:
raise
InstantiationError
(
"Skipped because of invalid gpu kernel"
)
raise
InstantiationError
(
"Skipped because of invalid gpu kernel"
)
return
stmt
return
f
return
verify_pass
return
tvm
.
tir
.
transform
.
prim_func_pass
(
verify_pass
,
opt_level
=
0
)
python/tvm/driver/build_module.py
View file @
d3277874
...
@@ -123,25 +123,6 @@ def form_irmodule(sch, args, name, binds):
...
@@ -123,25 +123,6 @@ def form_irmodule(sch, args, name, binds):
return
tvm
.
IRModule
({
name
:
func
})
return
tvm
.
IRModule
({
name
:
func
})
def
_wrap_as_prim_func_pass
(
flist
,
name
):
"""Wrap flist as a function pass.
This is an temporary adapter before we fully
migrate to the new pass manager.
"""
def
_transform
(
func
,
*
_
):
stmt
=
func
.
body
for
f
in
flist
:
stmt
=
f
(
stmt
)
# create a new function with updated body.
return
tvm
.
tir
.
PrimFunc
(
func
.
params
,
stmt
,
func
.
ret_type
,
func
.
buffer_map
,
func
.
attrs
)
return
tvm
.
tir
.
transform
.
prim_func_pass
(
_transform
,
opt_level
=
0
,
name
=
name
)
def
lower
(
sch
,
def
lower
(
sch
,
args
,
args
,
name
=
"main"
,
name
=
"main"
,
...
@@ -190,15 +171,15 @@ def lower(sch,
...
@@ -190,15 +171,15 @@ def lower(sch,
else
:
else
:
mod
=
sch
mod
=
sch
pass_list
=
lower_phase0
# Phase 1
# Phase 1
pass_list
=
[
pass_list
+=
[
_wrap_as_prim_func_pass
(
lower_phase0
,
"Custom-Phase0"
),
tvm
.
tir
.
transform
.
InjectPrefetch
(),
tvm
.
tir
.
transform
.
InjectPrefetch
(),
tvm
.
tir
.
transform
.
StorageFlatten
(
64
,
cfg
.
instrument_bound_checkers
),
tvm
.
tir
.
transform
.
StorageFlatten
(
64
,
cfg
.
instrument_bound_checkers
),
tvm
.
tir
.
transform
.
NarrowDataType
(
32
),
tvm
.
tir
.
transform
.
NarrowDataType
(
32
),
tvm
.
tir
.
transform
.
Simplify
(),
tvm
.
tir
.
transform
.
Simplify
(),
_wrap_as_prim_func_pass
(
lower_phase1
,
"Custom-Phase1"
),
]
]
pass_list
+=
lower_phase1
# Phase 2
# Phase 2
if
not
simple_mode
:
if
not
simple_mode
:
...
@@ -214,8 +195,8 @@ def lower(sch,
...
@@ -214,8 +195,8 @@ def lower(sch,
cfg
.
auto_unroll_max_depth
,
cfg
.
auto_unroll_max_depth
,
cfg
.
auto_unroll_max_extent
,
cfg
.
auto_unroll_max_extent
,
cfg
.
unroll_explicit
),
cfg
.
unroll_explicit
),
_wrap_as_prim_func_pass
(
lower_phase2
,
"Custom-Phase2"
),
]
]
pass_list
+=
lower_phase2
# Phase 3
# Phase 3
pass_list
+=
[
pass_list
+=
[
...
@@ -225,7 +206,7 @@ def lower(sch,
...
@@ -225,7 +206,7 @@ def lower(sch,
if
not
cfg
.
disable_select_rewriting
:
if
not
cfg
.
disable_select_rewriting
:
pass_list
+=
[
tvm
.
tir
.
transform
.
RewriteUnsafeSelect
()]
pass_list
+=
[
tvm
.
tir
.
transform
.
RewriteUnsafeSelect
()]
pass_list
+=
[
_wrap_as_prim_func_pass
(
lower_phase3
,
"Custom-Phase3"
)]
pass_list
+=
lower_phase3
# Instrument BoundCheckers
# Instrument BoundCheckers
if
cfg
.
instrument_bound_checkers
:
if
cfg
.
instrument_bound_checkers
:
...
...
python/tvm/tir/function.py
View file @
d3277874
...
@@ -67,3 +67,19 @@ class PrimFunc(BaseFunc):
...
@@ -67,3 +67,19 @@ class PrimFunc(BaseFunc):
self
.
__init_handle_by_constructor__
(
self
.
__init_handle_by_constructor__
(
_ffi_api
.
PrimFunc
,
param_list
,
body
,
ret_type
,
buffer_map
,
attrs
)
_ffi_api
.
PrimFunc
,
param_list
,
body
,
ret_type
,
buffer_map
,
attrs
)
def
with_body
(
self
,
new_body
):
"""Create a new PrimFunc with the same set signatures but a new body.
Parameters
----------
new_body : Stmt
The new body.
Returns
-------
new_func : PrimFunc
The created new function.
"""
return
PrimFunc
(
self
.
params
,
new_body
,
self
.
ret_type
,
self
.
buffer_map
,
self
.
attrs
)
src/target/target.cc
View file @
d3277874
...
@@ -434,12 +434,12 @@ TVM_REGISTER_GLOBAL("target.ExitBuildConfigScope")
...
@@ -434,12 +434,12 @@ TVM_REGISTER_GLOBAL("target.ExitBuildConfigScope")
TVM_REGISTER_GLOBAL
(
"target.BuildConfigSetAddLowerPass"
)
TVM_REGISTER_GLOBAL
(
"target.BuildConfigSetAddLowerPass"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
BuildConfig
cfg
=
args
[
0
];
BuildConfig
cfg
=
args
[
0
];
std
::
vector
<
std
::
pair
<
int
,
PackedFunc
>
>
add_lower_pass
;
std
::
vector
<
std
::
pair
<
int
,
transform
::
Pass
>
>
add_lower_pass
;
CHECK_EQ
(
args
.
size
()
%
2
,
1
);
CHECK_EQ
(
args
.
size
()
%
2
,
1
);
for
(
int
i
=
1
;
i
<
args
.
size
();
i
+=
2
)
{
for
(
int
i
=
1
;
i
<
args
.
size
();
i
+=
2
)
{
add_lower_pass
.
push_back
(
std
::
make_pair
(
add_lower_pass
.
push_back
(
std
::
make_pair
(
args
[
i
].
operator
int
(),
args
[
i
].
operator
int
(),
args
[
i
+
1
].
operator
t
vm
::
runtime
::
PackedFunc
()));
args
[
i
+
1
].
operator
t
ransform
::
Pass
()));
}
}
cfg
->
add_lower_pass
=
add_lower_pass
;
cfg
->
add_lower_pass
=
add_lower_pass
;
});
});
...
...
tests/python/relay/test_pass_fold_constant.py
View file @
d3277874
...
@@ -51,11 +51,13 @@ def test_fold_const():
...
@@ -51,11 +51,13 @@ def test_fold_const():
z
=
relay
.
add
(
y
,
relay
.
const
(
c_data
))
z
=
relay
.
add
(
y
,
relay
.
const
(
c_data
))
return
relay
.
Function
([
x
],
z
)
return
relay
.
Function
([
x
],
z
)
def
fail
(
x
):
def
FailPass
():
raise
RuntimeError
()
def
_transform
(
m
,
*
args
):
raise
RuntimeError
()
return
tvm
.
transform
.
module_pass
(
_transform
,
opt_level
=
0
)
# the fold constant should work on any context.
# the fold constant should work on any context.
with
tvm
.
target
.
build_config
(
add_lower_pass
=
[(
0
,
fail
)]):
with
tvm
.
target
.
build_config
(
add_lower_pass
=
[(
0
,
FailPass
()
)]):
with
tvm
.
target
.
create
(
"cuda"
):
with
tvm
.
target
.
create
(
"cuda"
):
zz
=
run_opt_pass
(
before
(),
transform
.
FoldConstant
())
zz
=
run_opt_pass
(
before
(),
transform
.
FoldConstant
())
zexpected
=
run_opt_pass
(
expected
(),
transform
.
InferType
())
zexpected
=
run_opt_pass
(
expected
(),
transform
.
InferType
())
...
...
tests/python/unittest/test_target_codegen_cuda.py
View file @
d3277874
...
@@ -182,7 +182,7 @@ def test_cuda_shuffle():
...
@@ -182,7 +182,7 @@ def test_cuda_shuffle():
sch
[
c
]
.
bind
(
xo
,
thrx
)
sch
[
c
]
.
bind
(
xo
,
thrx
)
sch
[
c
]
.
vectorize
(
xi
)
sch
[
c
]
.
vectorize
(
xi
)
def
my_vectorize
(
stmt
):
def
MyVectorize
(
):
def
vectorizer
(
op
):
def
vectorizer
(
op
):
if
op
.
for_type
==
tvm
.
tir
.
For
.
Vectorized
:
if
op
.
for_type
==
tvm
.
tir
.
For
.
Vectorized
:
four
=
tvm
.
tir
.
const
(
4
,
'int32'
)
four
=
tvm
.
tir
.
const
(
4
,
'int32'
)
...
@@ -198,9 +198,13 @@ def test_cuda_shuffle():
...
@@ -198,9 +198,13 @@ def test_cuda_shuffle():
new_b
=
tvm
.
tir
.
Shuffle
(
bs
,
ids
)
new_b
=
tvm
.
tir
.
Shuffle
(
bs
,
ids
)
return
tvm
.
tir
.
Store
(
store
.
buffer_var
,
new_a
+
new_b
,
idx
,
all_ones
)
return
tvm
.
tir
.
Store
(
store
.
buffer_var
,
new_a
+
new_b
,
idx
,
all_ones
)
return
None
return
None
return
tvm
.
tir
.
ir_pass
.
IRTransform
(
stmt
,
None
,
vectorizer
,
[
'For'
])
with
tvm
.
target
.
build_config
(
add_lower_pass
=
[(
1
,
my_vectorize
)]):
def
_transform
(
f
,
*
_
):
return
f
.
with_body
(
tvm
.
tir
.
ir_pass
.
IRTransform
(
f
.
body
,
None
,
vectorizer
,
[
'For'
]))
return
tvm
.
tir
.
transform
.
prim_func_pass
(
_transform
,
opt_level
=
0
,
name
=
"MyVectorize"
)
with
tvm
.
target
.
build_config
(
add_lower_pass
=
[(
1
,
MyVectorize
())]):
module
=
tvm
.
build
(
sch
,
[
a
,
b
,
c
],
target
=
'cuda'
)
module
=
tvm
.
build
(
sch
,
[
a
,
b
,
c
],
target
=
'cuda'
)
a_
=
np
.
array
(
list
(
range
(
64
)),
dtype
=
'int32'
)
a_
=
np
.
array
(
list
(
range
(
64
)),
dtype
=
'int32'
)
b_
=
np
.
array
((
list
(
range
(
4
))[::
-
1
])
*
16
,
dtype
=
'int32'
)
b_
=
np
.
array
((
list
(
range
(
4
))[::
-
1
])
*
16
,
dtype
=
'int32'
)
...
...
tests/python/unittest/test_target_codegen_llvm.py
View file @
d3277874
...
@@ -671,8 +671,7 @@ def test_llvm_shuffle():
...
@@ -671,8 +671,7 @@ def test_llvm_shuffle():
c
=
te
.
compute
((
8
,
),
lambda
x
:
a
[
x
]
+
b
[
7
-
x
])
c
=
te
.
compute
((
8
,
),
lambda
x
:
a
[
x
]
+
b
[
7
-
x
])
sch
=
te
.
create_schedule
(
c
.
op
)
sch
=
te
.
create_schedule
(
c
.
op
)
def
my_vectorize
(
stmt
):
def
my_vectorize
():
def
vectorizer
(
op
):
def
vectorizer
(
op
):
store
=
op
.
body
store
=
op
.
body
idx
=
tvm
.
tir
.
Ramp
(
tvm
.
tir
.
const
(
0
,
'int32'
),
tvm
.
tir
.
const
(
1
,
'int32'
),
8
)
idx
=
tvm
.
tir
.
Ramp
(
tvm
.
tir
.
const
(
0
,
'int32'
),
tvm
.
tir
.
const
(
1
,
'int32'
),
8
)
...
@@ -684,9 +683,13 @@ def test_llvm_shuffle():
...
@@ -684,9 +683,13 @@ def test_llvm_shuffle():
value
=
new_a
+
new_b
value
=
new_a
+
new_b
return
tvm
.
tir
.
Store
(
store
.
buffer_var
,
new_a
+
new_b
,
idx
,
all_ones
)
return
tvm
.
tir
.
Store
(
store
.
buffer_var
,
new_a
+
new_b
,
idx
,
all_ones
)
return
tvm
.
tir
.
ir_pass
.
IRTransform
(
stmt
,
None
,
vectorizer
,
[
'For'
])
def
_transform
(
f
,
*
_
):
return
f
.
with_body
(
tvm
.
tir
.
ir_pass
.
IRTransform
(
f
.
body
,
None
,
vectorizer
,
[
'For'
]))
return
tvm
.
tir
.
transform
.
prim_func_pass
(
_transform
,
opt_level
=
0
,
name
=
"my_vectorize"
)
with
tvm
.
target
.
build_config
(
add_lower_pass
=
[(
1
,
my_vectorize
)]):
with
tvm
.
target
.
build_config
(
add_lower_pass
=
[(
1
,
my_vectorize
()
)]):
ir
=
tvm
.
lower
(
sch
,
[
a
,
b
,
c
],
simple_mode
=
True
)
ir
=
tvm
.
lower
(
sch
,
[
a
,
b
,
c
],
simple_mode
=
True
)
module
=
tvm
.
build
(
sch
,
[
a
,
b
,
c
])
module
=
tvm
.
build
(
sch
,
[
a
,
b
,
c
])
a_
=
tvm
.
nd
.
array
(
np
.
arange
(
1
,
9
,
dtype
=
'int32'
))
a_
=
tvm
.
nd
.
array
(
np
.
arange
(
1
,
9
,
dtype
=
'int32'
))
...
...
tests/python/unittest/test_tir_pass_verify_gpu_code.py
View file @
d3277874
...
@@ -19,10 +19,10 @@ import tvm
...
@@ -19,10 +19,10 @@ import tvm
from
tvm
import
te
from
tvm
import
te
def
get_verify_pass
(
valid
,
**
kwargs
):
def
get_verify_pass
(
valid
,
**
kwargs
):
def
verify_pass
(
stmt
):
def
_fverify
(
f
,
*
_
):
valid
[
0
]
=
tvm
.
tir
.
ir_pass
.
VerifyGPUCode
(
stmt
,
kwargs
)
valid
[
0
]
=
tvm
.
tir
.
ir_pass
.
VerifyGPUCode
(
f
.
body
,
kwargs
)
return
stmt
return
f
return
verify_pass
return
tvm
.
tir
.
transform
.
prim_func_pass
(
_fverify
,
opt_level
=
0
)
def
test_shared_memory
():
def
test_shared_memory
():
def
check_shared_memory
(
dtype
):
def
check_shared_memory
(
dtype
):
...
...
tutorials/dev/low_level_custom_pass.py
View file @
d3277874
...
@@ -117,19 +117,20 @@ def vectorize8(op):
...
@@ -117,19 +117,20 @@ def vectorize8(op):
return
body
return
body
return
None
return
None
def
vectorize
(
stmt
):
@tvm.tir.transform.prim_func_pass
(
opt_level
=
0
)
def
vectorize
(
f
,
mod
,
ctx
):
global
loops
global
loops
tvm
.
tir
.
ir_pass
.
PostOrderVisit
(
stmt
,
find_width8
)
tvm
.
tir
.
ir_pass
.
PostOrderVisit
(
f
.
body
,
find_width8
)
if
not
loops
:
if
not
loops
:
return
s
tmt
return
s
f
# The last list arugment indicates what kinds of nodes will be transformed.
# The last list arugment indicates what kinds of nodes will be transformed.
# Thus, in this case only `For` nodes will call `vectorize8`
# Thus, in this case only `For` nodes will call `vectorize8`
stmt
=
tvm
.
tir
.
ir_pass
.
IRTransform
(
stmt
,
None
,
vectorize8
,
[
'For'
])
return
f
.
with_body
(
tvm
.
tir
.
ir_pass
.
IRTransform
(
f
.
body
,
None
,
vectorize8
,
[
'For'
]))
return
stmt
#####################################################################
#####################################################################
# Glue to Lowering
# Glue to Lowering
...
...
vta/python/vta/build_module.py
View file @
d3277874
...
@@ -14,25 +14,22 @@
...
@@ -14,25 +14,22 @@
# KIND, either express or implied. See the License for the
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# specific language governing permissions and limitations
# under the License.
# under the License.
# pylint: disable=unused-argument
# pylint: disable=unused-argument
, invalid-name
"""VTA specific buildin for runtime."""
"""VTA specific buildin for runtime."""
import
tvm
import
tvm
from
.
import
ir_pass
from
.
import
transform
from
.environment
import
get_env
from
.environment
import
get_env
def
lift_coproc_scope
(
x
):
def
EarlyRewrite
():
"""Lift coprocessings cope to the """
x
=
ir_pass
.
lift_alloc_to_scope_begin
(
x
)
x
=
tvm
.
tir
.
ir_pass
.
LiftAttrScope
(
x
,
"coproc_scope"
,
False
)
return
x
def
early_rewrite
(
stmt
):
"""Try to do storage rewrite in early pass."""
"""Try to do storage rewrite in early pass."""
try
:
def
_transform
(
mod
,
ctx
):
return
tvm
.
tir
.
ir_pass
.
StorageRewrite
(
stmt
)
try
:
except
tvm
.
error
.
TVMError
:
return
tvm
.
tir
.
transform
.
StorageRewrite
()(
mod
)
return
stmt
except
tvm
.
error
.
TVMError
:
return
mod
return
tvm
.
transform
.
module_pass
(
_transform
,
opt_level
=
0
,
name
=
"tir.vta.EarlyRewrite"
)
def
build_config
(
debug_flag
=
0
,
**
kwargs
):
def
build_config
(
debug_flag
=
0
,
**
kwargs
):
...
@@ -60,27 +57,32 @@ def build_config(debug_flag=0, **kwargs):
...
@@ -60,27 +57,32 @@ def build_config(debug_flag=0, **kwargs):
vta_module = tvm.build(s, ...)
vta_module = tvm.build(s, ...)
"""
"""
env
=
get_env
()
env
=
get_env
()
def
add_debug
(
stmt
):
@tvm.tir.transform.prim_func_pass
(
opt_level
=
0
)
def
add_debug
(
f
,
*
_
):
debug
=
tvm
.
tir
.
call_extern
(
debug
=
tvm
.
tir
.
call_extern
(
"int32"
,
"VTASetDebugMode"
,
"int32"
,
"VTASetDebugMode"
,
env
.
dev
.
command_handle
,
env
.
dev
.
command_handle
,
debug_flag
)
debug_flag
)
return
tvm
.
tir
.
stmt_seq
(
debug
,
stmt
)
return
f
.
with_body
(
tvm
.
tir
.
stmt_seq
(
debug
,
f
.
body
))
pass_list
=
[(
0
,
ir_pass
.
inject_conv2d_transpose_skip
),
(
1
,
ir_pass
.
inject_dma_intrin
),
(
1
,
ir_pass
.
inject_skip_copy
),
pass_list
=
[(
0
,
transform
.
InjectConv2DTransposeSkip
()),
(
1
,
ir_pass
.
annotate_alu_coproc_scope
),
(
1
,
transform
.
InjectDMAIntrin
()),
(
1
,
lambda
x
:
tvm
.
tir
.
ir_pass
.
LiftAttrScope
(
x
,
"coproc_uop_scope"
,
True
)),
(
1
,
transform
.
InjectSkipCopy
()),
(
1
,
lift_coproc_scope
),
(
1
,
transform
.
AnnotateALUCoProcScope
()),
(
1
,
ir_pass
.
inject_coproc_sync
),
(
1
,
tvm
.
tir
.
transform
.
LiftAttrScope
(
"coproc_uop_scope"
)),
(
1
,
early_rewrite
)]
(
1
,
transform
.
LiftAllocToScopeBegin
()),
(
1
,
tvm
.
tir
.
transform
.
LiftAttrScope
(
"coproc_scope"
)),
(
1
,
transform
.
InjectCoProcSync
()),
(
1
,
EarlyRewrite
())]
if
debug_flag
:
if
debug_flag
:
pass_list
.
append
((
1
,
add_debug
))
pass_list
.
append
((
1
,
add_debug
))
pass_list
.
append
((
2
,
ir_pass
.
inject_alu_intrin
))
pass_list
.
append
((
2
,
transform
.
InjectALUIntrin
()
))
pass_list
.
append
((
3
,
tvm
.
tir
.
ir_pass
.
LowerStorageAccessInfo
))
pass_list
.
append
((
3
,
tvm
.
tir
.
transform
.
LowerDeviceStorageAccessInfo
()
))
pass_list
.
append
((
3
,
ir_pass
.
fold_uop_loop
))
pass_list
.
append
((
3
,
transform
.
FoldUopLoop
()
))
pass_list
.
append
((
3
,
ir_pass
.
cpu_access_rewrite
))
pass_list
.
append
((
3
,
transform
.
CPUAccessRewrite
()
))
return
tvm
.
target
.
build_config
(
add_lower_pass
=
pass_list
,
**
kwargs
)
return
tvm
.
target
.
build_config
(
add_lower_pass
=
pass_list
,
**
kwargs
)
...
...
vta/python/vta/
ir_pass
.py
→
vta/python/vta/
transform
.py
View file @
d3277874
...
@@ -14,8 +14,8 @@
...
@@ -14,8 +14,8 @@
# KIND, either express or implied. See the License for the
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# specific language governing permissions and limitations
# under the License.
# under the License.
"""Additional
IR Pass
for VTA"""
"""Additional
Transformation Passes.
for VTA"""
# pylint: disable=len-as-condition, no-else-return
# pylint: disable=len-as-condition, no-else-return
, unused-argument, invalid-name
import
tvm
import
tvm
from
tvm
import
te
from
tvm
import
te
from
topi
import
util
from
topi
import
util
...
@@ -38,7 +38,7 @@ def _match_pragma(stmt, key):
...
@@ -38,7 +38,7 @@ def _match_pragma(stmt, key):
(
stmt
.
attr_key
==
"pragma_scope"
and
stmt
.
value
.
value
==
key
))
(
stmt
.
attr_key
==
"pragma_scope"
and
stmt
.
value
.
value
==
key
))
def
fold_uop_loop
(
stmt_in
):
def
FoldUopLoop
(
):
"""Detect and fold uop loop.
"""Detect and fold uop loop.
VTA support uop programming model
VTA support uop programming model
...
@@ -46,18 +46,11 @@ def fold_uop_loop(stmt_in):
...
@@ -46,18 +46,11 @@ def fold_uop_loop(stmt_in):
This pass detect the loop structure
This pass detect the loop structure
and extract that into uop loop AST.
and extract that into uop loop AST.
Parameters
----------
stmt_in : Stmt
Input statement
Returns
Returns
-------
-------
stmt_out : Stmt
fpass : tvm.transform.Pass
Output statement.
The pass
"""
"""
env
=
get_env
()
def
_fold_outermost_loop
(
body
):
def
_fold_outermost_loop
(
body
):
stmt
=
body
stmt
=
body
if
not
isinstance
(
stmt
,
tvm
.
tir
.
For
):
if
not
isinstance
(
stmt
,
tvm
.
tir
.
For
):
...
@@ -109,6 +102,7 @@ def fold_uop_loop(stmt_in):
...
@@ -109,6 +102,7 @@ def fold_uop_loop(stmt_in):
raise
ValueError
(
"Failed to fold the GEMM instructions.."
)
raise
ValueError
(
"Failed to fold the GEMM instructions.."
)
def
_do_fold
(
stmt
):
def
_do_fold
(
stmt
):
env
=
get_env
()
if
(
stmt
.
attr_key
==
"coproc_uop_scope"
and
if
(
stmt
.
attr_key
==
"coproc_uop_scope"
and
isinstance
(
stmt
.
value
,
tvm
.
tir
.
StringImm
)
and
isinstance
(
stmt
.
value
,
tvm
.
tir
.
StringImm
)
and
stmt
.
value
.
value
==
env
.
dev
.
vta_push_uop
.
value
):
stmt
.
value
.
value
==
env
.
dev
.
vta_push_uop
.
value
):
...
@@ -135,12 +129,16 @@ def fold_uop_loop(stmt_in):
...
@@ -135,12 +129,16 @@ def fold_uop_loop(stmt_in):
return
tvm
.
tir
.
AttrStmt
(
return
tvm
.
tir
.
AttrStmt
(
stmt
.
node
,
stmt
.
attr_key
,
stmt
.
value
,
body
)
stmt
.
node
,
stmt
.
attr_key
,
stmt
.
value
,
body
)
return
None
return
None
out
=
tvm
.
tir
.
ir_pass
.
IRTransform
(
stmt_in
,
_do_fold
,
None
,
[
"AttrStmt"
])
def
_ftransform
(
f
,
mod
,
ctx
):
return
out
return
f
.
with_body
(
tvm
.
tir
.
ir_pass
.
IRTransform
(
f
.
body
,
_do_fold
,
None
,
[
"AttrStmt"
]))
return
tvm
.
tir
.
transform
.
prim_func_pass
(
_ftransform
,
opt_level
=
0
,
name
=
"tir.vta.FoldUopLoop"
)
def
cpu_access_rewrite
(
stmt_in
):
def
CPUAccessRewrite
(
):
"""Detect CPU access to VTA buffer and get address correctly.
"""Detect CPU access to VTA buffer and get address correctly.
VTA's buffer is an opaque handle that do not
VTA's buffer is an opaque handle that do not
...
@@ -148,189 +146,182 @@ def cpu_access_rewrite(stmt_in):
...
@@ -148,189 +146,182 @@ def cpu_access_rewrite(stmt_in):
This pass detect CPU access and rewrite to use pointer
This pass detect CPU access and rewrite to use pointer
returned VTABufferCPUPtr for CPU access.
returned VTABufferCPUPtr for CPU access.
Parameters
----------
stmt_in : Stmt
Input statement
Returns
Returns
-------
-------
stmt_out : Stmt
fpass : tvm.transform.Pass
T
ransformed statement
T
he pass
"""
"""
env
=
get_env
()
def
_ftransform
(
f
,
mod
,
ctx
):
rw_info
=
{}
rw_info
=
{}
def
_post_order
(
op
):
env
=
get_env
()
if
isinstance
(
op
,
tvm
.
tir
.
Allocate
):
def
_post_order
(
op
):
buffer_var
=
op
.
buffer_var
if
isinstance
(
op
,
tvm
.
tir
.
Allocate
):
if
not
buffer_var
in
rw_info
:
buffer_var
=
op
.
buffer_var
return
None
if
not
buffer_var
in
rw_info
:
new_var
=
rw_info
[
buffer_var
]
return
None
let_stmt
=
tvm
.
tir
.
LetStmt
(
new_var
=
rw_info
[
buffer_var
]
let_stmt
=
tvm
.
tir
.
LetStmt
(
new_var
,
tvm
.
tir
.
call_extern
(
"handle"
,
"VTABufferCPUPtr"
,
env
.
dev
.
command_handle
,
buffer_var
),
op
.
body
)
alloc
=
tvm
.
tir
.
Allocate
(
buffer_var
,
op
.
dtype
,
op
.
extents
,
op
.
condition
,
let_stmt
)
del
rw_info
[
buffer_var
]
return
alloc
if
isinstance
(
op
,
tvm
.
tir
.
Load
):
buffer_var
=
op
.
buffer_var
if
not
buffer_var
in
rw_info
:
rw_info
[
buffer_var
]
=
te
.
var
(
buffer_var
.
name
+
"_ptr"
,
"handle"
)
new_var
=
rw_info
[
buffer_var
]
return
tvm
.
tir
.
Load
(
op
.
dtype
,
new_var
,
op
.
index
)
if
isinstance
(
op
,
tvm
.
tir
.
Store
):
buffer_var
=
op
.
buffer_var
if
not
buffer_var
in
rw_info
:
rw_info
[
buffer_var
]
=
te
.
var
(
buffer_var
.
name
+
"_ptr"
,
"handle"
)
new_var
=
rw_info
[
buffer_var
]
return
tvm
.
tir
.
Store
(
new_var
,
op
.
value
,
op
.
index
)
raise
RuntimeError
(
"not reached"
)
stmt_in
=
f
.
body
stmt
=
tvm
.
tir
.
ir_pass
.
IRTransform
(
stmt_in
,
None
,
_post_order
,
[
"Allocate"
,
"Load"
,
"Store"
])
for
buffer_var
,
new_var
in
rw_info
.
items
():
stmt
=
tvm
.
tir
.
LetStmt
(
new_var
,
tvm
.
tir
.
call_extern
(
new_var
,
tvm
.
tir
.
call_extern
(
"handle"
,
"VTABufferCPUPtr"
,
"handle"
,
"VTABufferCPUPtr"
,
env
.
dev
.
command_handle
,
env
.
dev
.
command_handle
,
buffer_var
),
op
.
body
)
buffer_var
),
stmt
)
alloc
=
tvm
.
tir
.
Allocate
(
return
f
.
with_body
(
stmt
)
buffer_var
,
op
.
dtype
,
op
.
extents
,
return
tvm
.
tir
.
transform
.
prim_func_pass
(
op
.
condition
,
let_stmt
)
_ftransform
,
opt_level
=
0
,
name
=
"tir.vta.CPUAccessRewrite"
)
del
rw_info
[
buffer_var
]
return
alloc
if
isinstance
(
op
,
tvm
.
tir
.
Load
):
buffer_var
=
op
.
buffer_var
if
not
buffer_var
in
rw_info
:
rw_info
[
buffer_var
]
=
te
.
var
(
buffer_var
.
name
+
"_ptr"
,
"handle"
)
new_var
=
rw_info
[
buffer_var
]
return
tvm
.
tir
.
Load
(
op
.
dtype
,
new_var
,
op
.
index
)
if
isinstance
(
op
,
tvm
.
tir
.
Store
):
buffer_var
=
op
.
buffer_var
if
not
buffer_var
in
rw_info
:
rw_info
[
buffer_var
]
=
te
.
var
(
buffer_var
.
name
+
"_ptr"
,
"handle"
)
new_var
=
rw_info
[
buffer_var
]
return
tvm
.
tir
.
Store
(
new_var
,
op
.
value
,
op
.
index
)
raise
RuntimeError
(
"not reached"
)
stmt
=
tvm
.
tir
.
ir_pass
.
IRTransform
(
stmt_in
,
None
,
_post_order
,
[
"Allocate"
,
"Load"
,
"Store"
])
for
buffer_var
,
new_var
in
rw_info
.
items
():
stmt
=
tvm
.
tir
.
LetStmt
(
new_var
,
tvm
.
tir
.
call_extern
(
"handle"
,
"VTABufferCPUPtr"
,
env
.
dev
.
command_handle
,
buffer_var
),
stmt
)
return
stmt
def
lift_alloc_to_scope_begin
(
stmt_in
):
def
LiftAllocToScopeBegin
(
):
"""Lift allocate to beginning of the current scope.
"""Lift allocate to beginning of the current scope.
Parameters
----------
stmt_in : Stmt
Input statement
Returns
Returns
-------
-------
stmt_out : Stmt
fpass : tvm.transform.Pass
T
ransformed statement
T
he pass
"""
"""
lift_stmt
=
[[]]
def
_ftransform
(
f
,
mod
,
ctx
):
def
_merge_block
(
slist
,
body
):
lift_stmt
=
[[]]
for
op
in
slist
:
def
_merge_block
(
slist
,
body
):
if
op
.
body
==
body
:
for
op
in
slist
:
body
=
op
if
op
.
body
==
body
:
elif
isinstance
(
op
,
tvm
.
tir
.
Allocate
):
body
=
op
body
=
tvm
.
tir
.
Allocate
(
elif
isinstance
(
op
,
tvm
.
tir
.
Allocate
):
op
.
buffer_var
,
op
.
dtype
,
body
=
tvm
.
tir
.
Allocate
(
op
.
extents
,
op
.
condition
,
body
)
op
.
buffer_var
,
op
.
dtype
,
elif
isinstance
(
op
,
tvm
.
tir
.
AttrStmt
):
op
.
extents
,
op
.
condition
,
body
)
body
=
tvm
.
tir
.
AttrStmt
(
elif
isinstance
(
op
,
tvm
.
tir
.
AttrStmt
):
op
.
node
,
op
.
attr_key
,
op
.
value
,
body
)
body
=
tvm
.
tir
.
AttrStmt
(
elif
isinstance
(
op
,
tvm
.
tir
.
For
):
op
.
node
,
op
.
attr_key
,
op
.
value
,
body
)
body
=
tvm
.
tir
.
For
(
elif
isinstance
(
op
,
tvm
.
tir
.
For
):
op
.
loop_var
,
op
.
min
,
op
.
extent
,
op
.
for_type
,
body
=
tvm
.
tir
.
For
(
op
.
device_api
,
body
)
op
.
loop_var
,
op
.
min
,
op
.
extent
,
op
.
for_type
,
else
:
op
.
device_api
,
body
)
raise
RuntimeError
(
"unexpected op"
)
else
:
del
slist
[:]
raise
RuntimeError
(
"unexpected op"
)
return
body
del
slist
[:]
return
body
def
_pre_order
(
op
):
if
isinstance
(
op
,
tvm
.
tir
.
For
):
def
_pre_order
(
op
):
lift_stmt
.
append
([])
if
isinstance
(
op
,
tvm
.
tir
.
For
):
elif
isinstance
(
op
,
tvm
.
tir
.
AttrStmt
):
if
op
.
attr_key
==
"virtual_thread"
:
lift_stmt
.
append
([])
lift_stmt
.
append
([])
elif
isinstance
(
op
,
tvm
.
tir
.
AttrStmt
):
if
op
.
attr_key
==
"virtual_thread"
:
lift_stmt
.
append
([])
def
_post_order
(
op
):
def
_post_order
(
op
):
if
isinstance
(
op
,
tvm
.
tir
.
Allocate
):
if
isinstance
(
op
,
tvm
.
tir
.
Allocate
):
lift_stmt
[
-
1
]
.
append
(
op
)
return
op
.
body
if
isinstance
(
op
,
tvm
.
tir
.
AttrStmt
):
if
op
.
attr_key
==
"storage_scope"
:
lift_stmt
[
-
1
]
.
append
(
op
)
lift_stmt
[
-
1
]
.
append
(
op
)
return
op
.
body
return
op
.
body
if
op
.
attr_key
==
"virtual_thread"
:
if
isinstance
(
op
,
tvm
.
tir
.
AttrStmt
):
if
op
.
attr_key
==
"storage_scope"
:
lift_stmt
[
-
1
]
.
append
(
op
)
return
op
.
body
if
op
.
attr_key
==
"virtual_thread"
:
return
_merge_block
(
lift_stmt
.
pop
()
+
[
op
],
op
.
body
)
return
op
if
isinstance
(
op
,
tvm
.
tir
.
For
):
return
_merge_block
(
lift_stmt
.
pop
()
+
[
op
],
op
.
body
)
return
_merge_block
(
lift_stmt
.
pop
()
+
[
op
],
op
.
body
)
return
op
raise
RuntimeError
(
"not reached"
)
if
isinstance
(
op
,
tvm
.
tir
.
For
):
stmt_in
=
f
.
body
return
_merge_block
(
lift_stmt
.
pop
()
+
[
op
],
op
.
body
)
stmt
=
tvm
.
tir
.
ir_pass
.
IRTransform
(
raise
RuntimeError
(
"not reached"
)
stmt_in
,
_pre_order
,
_post_order
,
[
"Allocate"
,
"AttrStmt"
,
"For"
])
stmt
=
tvm
.
tir
.
ir_pass
.
IRTransform
(
assert
len
(
lift_stmt
)
==
1
stmt_in
,
_pre_order
,
_post_order
,
[
"Allocate"
,
"AttrStmt"
,
"For"
])
return
f
.
with_body
(
_merge_block
(
lift_stmt
[
0
],
stmt
))
assert
len
(
lift_stmt
)
==
1
return
_merge_block
(
lift_stmt
[
0
],
stmt
)
return
tvm
.
tir
.
transform
.
prim_func_pass
(
_ftransform
,
opt_level
=
0
,
name
=
"tir.vta.LiftAllocToScopeBegin"
)
def
inject_skip_copy
(
stmt_in
):
"""Pass to inject skip copy stmt, used for debug purpose.
Parameters
def
InjectSkipCopy
():
----------
"""Pass to inject skip copy stmt, used for debug purpose.
stmt_in : Stmt
Input statement
Returns
Returns
-------
-------
stmt_out : Stmt
fpass : tvm.transform.Pass
T
ransformed statement
T
he pass
"""
"""
def
_do_fold
(
stmt
):
def
_do_fold
(
stmt
):
if
_match_pragma
(
stmt
,
"skip_dma_copy"
):
if
_match_pragma
(
stmt
,
"skip_dma_copy"
):
return
tvm
.
tir
.
Evaluate
(
0
)
return
tvm
.
tir
.
Evaluate
(
0
)
return
None
return
None
return
tvm
.
tir
.
ir_pass
.
IRTransform
(
stmt_in
,
_do_fold
,
None
,
[
"AttrStmt"
])
def
_ftransform
(
f
,
mod
,
ctx
):
return
f
.
with_body
(
tvm
.
tir
.
ir_pass
.
IRTransform
(
f
.
body
,
_do_fold
,
None
,
[
"AttrStmt"
]))
def
inject_coproc_sync
(
stmt_in
):
return
tvm
.
tir
.
transform
.
prim_func_pass
(
"""Pass to inject skip copy stmt, used in debug.
_ftransform
,
opt_level
=
0
,
name
=
"tir.vta.InjectSkipCopy"
)
Parameters
----------
def
InjectCoProcSync
():
stmt_in : Stmt
"""Pass inject coproc sync
Input statement
Returns
Returns
-------
-------
stmt_out : Stmt
fpass : tvm.transform.Pass
T
ransformed statement
T
he pass
"""
"""
success
=
[
False
]
def
_ftransform
(
f
,
*
_
):
def
_do_fold
(
stmt
):
success
=
[
False
]
if
_match_pragma
(
stmt
,
"coproc_sync"
):
def
_do_fold
(
stmt
):
success
[
0
]
=
True
if
_match_pragma
(
stmt
,
"coproc_sync"
):
sync
=
tvm
.
tir
.
Call
(
success
[
0
]
=
True
"int32"
,
"vta.coproc_sync"
,
[],
tvm
.
tir
.
Call
.
Intrinsic
,
None
,
0
)
sync
=
tvm
.
tir
.
Call
(
return
tvm
.
tir
.
SeqStmt
([
stmt
.
body
,
tvm
.
tir
.
Evaluate
(
sync
)])
"int32"
,
"vta.coproc_sync"
,
[],
tvm
.
tir
.
Call
.
Intrinsic
,
None
,
0
)
if
_match_pragma
(
stmt
,
"trim_loop"
):
return
tvm
.
tir
.
SeqStmt
([
stmt
.
body
,
tvm
.
tir
.
Evaluate
(
sync
)])
op
=
stmt
.
body
if
_match_pragma
(
stmt
,
"trim_loop"
):
assert
isinstance
(
op
,
tvm
.
tir
.
For
)
op
=
stmt
.
body
return
tvm
.
tir
.
For
(
assert
isinstance
(
op
,
tvm
.
tir
.
For
)
op
.
loop_var
,
op
.
min
,
2
,
op
.
for_type
,
return
tvm
.
tir
.
For
(
op
.
device_api
,
op
.
body
)
op
.
loop_var
,
op
.
min
,
2
,
op
.
for_type
,
return
None
op
.
device_api
,
op
.
body
)
stmt
=
tvm
.
tir
.
ir_pass
.
IRTransform
(
return
None
stmt_in
,
None
,
_do_fold
,
[
"AttrStmt"
])
return
f
.
with_body
(
tvm
.
tir
.
ir_pass
.
IRTransform
(
stmt
=
tvm
.
tir
.
ir_pass
.
CoProcSync
(
stmt
)
f
.
body
,
None
,
_do_fold
,
[
"AttrStmt"
]))
return
stmt
return
tvm
.
transform
.
Sequential
(
[
tvm
.
tir
.
transform
.
prim_func_pass
(
_ftransform
,
0
,
"tir.vta.InjectCoProcSync"
),
tvm
.
tir
.
transform
.
CoProcSync
()],
def
inject_dma_intrin
(
stmt_in
):
opt_level
=
0
,
name
=
"tir.vta.InjectCoProcSync"
)
def
InjectDMAIntrin
():
"""Pass to inject DMA copy intrinsics.
"""Pass to inject DMA copy intrinsics.
Parameters
----------
stmt_in : Stmt
Input statement
Returns
Returns
-------
-------
stmt_out : Stmt
fpass : tvm.transform.Pass
T
ransformed statement
T
he pass
"""
"""
env
=
get_env
()
idxd
=
tvm
.
tir
.
indexdiv
idxd
=
tvm
.
tir
.
indexdiv
idxm
=
tvm
.
tir
.
indexmod
idxm
=
tvm
.
tir
.
indexmod
...
@@ -474,6 +465,7 @@ def inject_dma_intrin(stmt_in):
...
@@ -474,6 +465,7 @@ def inject_dma_intrin(stmt_in):
def
_inject_copy
(
src
,
dst
,
pad_before
,
pad_after
,
pad_value
):
def
_inject_copy
(
src
,
dst
,
pad_before
,
pad_after
,
pad_value
):
# FIXME: pad_value is ignored...
# FIXME: pad_value is ignored...
env
=
get_env
()
_
=
pad_value
_
=
pad_value
if
dst
.
scope
==
"global"
:
if
dst
.
scope
==
"global"
:
# Store
# Store
...
@@ -576,7 +568,7 @@ def inject_dma_intrin(stmt_in):
...
@@ -576,7 +568,7 @@ def inject_dma_intrin(stmt_in):
else
:
else
:
raise
RuntimeError
(
"Do not support copy
%
s->
%
s"
%
(
src
.
scope
,
dst
.
scope
))
raise
RuntimeError
(
"Do not support copy
%
s->
%
s"
%
(
src
.
scope
,
dst
.
scope
))
return
tvm
.
tir
.
ir_pass
.
InjectCopyIntrin
(
stmt_in
,
"dma_copy"
,
_inject_copy
)
return
tvm
.
tir
.
transform
.
InjectCopyIntrin
(
"dma_copy"
,
_inject_copy
)
def
_get_gemm_intrin_buffer
():
def
_get_gemm_intrin_buffer
():
...
@@ -619,377 +611,352 @@ def _get_gemm_intrin_buffer():
...
@@ -619,377 +611,352 @@ def _get_gemm_intrin_buffer():
return
wgt_layout
,
inp_layout
,
out_layout
return
wgt_layout
,
inp_layout
,
out_layout
def
inject_conv2d_transpose_skip
(
stmt_in
):
def
InjectConv2DTransposeSkip
(
):
"""Pass to skip 0-weights in conv2d transpose with stride > 1.
"""Pass to skip 0-weights in conv2d transpose with stride > 1.
Parameters
----------
stmt_in : Stmt
Input statement
Returns
Returns
-------
-------
stmt_out : Stmt
fpass : tvm.transform.Pass
T
ransformed statement
T
he pass
"""
"""
env
=
get_env
()
def
_ftransform
(
func
,
mod
,
ctx
):
dwgt
,
dinp
,
dout
=
_get_gemm_intrin_buffer
()
env
=
get_env
()
dwgt
,
dinp
,
dout
=
_get_gemm_intrin_buffer
()
calls
=
[]
selects
=
[]
calls
=
[]
selects
=
[]
def
_find_basics
(
op
):
if
isinstance
(
op
,
tvm
.
tir
.
BufferLoad
):
def
_find_basics
(
op
):
calls
.
append
(
op
)
if
isinstance
(
op
,
tvm
.
tir
.
BufferLoad
):
elif
isinstance
(
op
,
tvm
.
tir
.
Select
):
calls
.
append
(
op
)
selects
.
append
(
op
)
elif
isinstance
(
op
,
tvm
.
tir
.
Select
):
selects
.
append
(
op
)
def
_do_fold
(
op
):
if
_match_pragma
(
op
,
"conv2d_transpose_gemm"
):
def
_do_fold
(
op
):
is_init
=
".init"
in
str
(
op
)
if
_match_pragma
(
op
,
"conv2d_transpose_gemm"
):
tvm
.
tir
.
ir_pass
.
PostOrderVisit
(
op
,
_find_basics
)
is_init
=
".init"
in
str
(
op
)
tvm
.
tir
.
ir_pass
.
PostOrderVisit
(
op
,
_find_basics
)
if
is_init
:
# create inner most block
if
is_init
:
irb
=
tvm
.
tir
.
ir_builder
.
create
()
# create inner most block
dev
=
env
.
dev
irb
=
tvm
.
tir
.
ir_builder
.
create
()
irb
.
scope_attr
(
dev
.
vta_axis
,
"coproc_scope"
,
dev
.
get_task_qid
(
dev
.
QID_COMPUTE
))
irb
.
scope_attr
(
dev
.
vta_axis
,
"coproc_uop_scope"
,
dev
.
vta_push_uop
)
irb
.
emit
(
tvm
.
tir
.
call_extern
(
"int32"
,
"VTAUopPush"
,
0
,
1
,
dout
.
access_ptr
(
"rw"
,
"int32"
),
0
,
0
,
0
,
0
,
0
))
inner
=
irb
.
get
()
# TODO(@tmoreau89): This is only a temporary fix, please take a look.
body
=
op
.
body
.
body
while
isinstance
(
body
,
tvm
.
tir
.
IfThenElse
):
body
=
body
.
then_case
args
=
body
.
indices
res_buffer
=
body
.
buffer
tpl
=
(
args
[
0
],
1
,
args
[
1
],
1
,
args
[
2
],
1
,
args
[
3
],
1
,
0
,
1
,
0
,
env
.
BLOCK_OUT
)
inner
=
tvm
.
tir
.
AttrStmt
(
[
dout
,
res_buffer
],
'buffer_bind_scope'
,
tvm
.
tir
.
call_intrin
(
'handle'
,
'tvm_tuple'
,
*
tpl
),
inner
)
return
inner
else
:
conv_call
,
data_call
,
kernel_call
=
calls
[
-
3
:]
pad_data_tensor
=
data_call
.
buffer
kernel_tensor
=
kernel_call
.
buffer
res_tensor
=
conv_call
.
buffer
if
selects
:
condition
=
selects
[
0
]
.
condition
else
:
condition
=
tvm
.
tir
.
const
(
1
,
'int'
)
# create inner most block
irb
=
tvm
.
tir
.
ir_builder
.
create
()
with
irb
.
if_scope
(
condition
):
dev
=
env
.
dev
dev
=
env
.
dev
irb
.
scope_attr
(
dev
.
vta_axis
,
"coproc_scope"
,
dev
.
get_task_qid
(
dev
.
QID_COMPUTE
))
irb
.
scope_attr
(
dev
.
vta_axis
,
"coproc_scope"
,
dev
.
get_task_qid
(
dev
.
QID_COMPUTE
))
irb
.
scope_attr
(
dev
.
vta_axis
,
"coproc_uop_scope"
,
dev
.
vta_push_uop
)
irb
.
scope_attr
(
dev
.
vta_axis
,
"coproc_uop_scope"
,
dev
.
vta_push_uop
)
irb
.
emit
(
tvm
.
tir
.
call_extern
(
"int32"
,
"VTAUopPush"
,
irb
.
emit
(
tvm
.
tir
.
call_extern
(
"int32"
,
"VTAUopPush"
,
0
,
0
,
0
,
1
,
dout
.
access_ptr
(
"rw"
,
"int32"
),
dout
.
access_ptr
(
"rw"
,
"int32"
),
dinp
.
access_ptr
(
"r"
,
"int32"
),
0
,
0
,
dwgt
.
access_ptr
(
"r"
,
"int32"
),
0
,
0
,
0
))
0
,
0
,
0
))
inner
=
irb
.
get
()
inner
=
irb
.
get
()
# TODO(@tmoreau89): This is only a temporary fix, please take a look.
args
=
conv_call
.
indices
body
=
op
.
body
.
body
tpl
=
(
args
[
0
],
1
,
args
[
1
],
1
,
args
[
2
],
1
,
args
[
3
],
while
isinstance
(
body
,
tvm
.
tir
.
IfThenElse
):
1
,
0
,
1
,
0
,
env
.
BLOCK_OUT
)
body
=
body
.
then_case
inner
=
tvm
.
tir
.
AttrStmt
(
args
=
body
.
indices
[
dout
,
res_tensor
],
'buffer_bind_scope'
,
res_buffer
=
body
.
buffer
tvm
.
tir
.
call_intrin
(
'handle'
,
'tvm_tuple'
,
*
tpl
),
inner
)
tpl
=
(
args
[
0
],
1
,
args
[
1
],
1
,
args
[
2
],
1
,
args
[
3
],
1
,
0
,
1
,
0
,
env
.
BLOCK_OUT
)
args
=
kernel_call
.
indices
inner
=
tvm
.
tir
.
AttrStmt
(
tpl
=
(
args
[
0
],
1
,
args
[
1
],
1
,
args
[
2
],
1
,
args
[
3
],
[
dout
,
res_buffer
],
'buffer_bind_scope'
,
1
,
0
,
env
.
BLOCK_OUT
,
0
,
env
.
BLOCK_IN
)
tvm
.
tir
.
call_intrin
(
'handle'
,
'tvm_tuple'
,
*
tpl
),
inner
)
inner
=
tvm
.
tir
.
AttrStmt
(
return
inner
[
dwgt
,
kernel_tensor
],
'buffer_bind_scope'
,
else
:
tvm
.
tir
.
call_intrin
(
'handle'
,
'tvm_tuple'
,
*
tpl
),
inner
)
conv_call
,
data_call
,
kernel_call
=
calls
[
-
3
:]
args
=
data_call
.
indices
pad_data_tensor
=
data_call
.
buffer
tpl
=
(
args
[
0
],
1
,
args
[
1
],
1
,
args
[
2
],
1
,
args
[
3
],
kernel_tensor
=
kernel_call
.
buffer
1
,
0
,
1
,
0
,
env
.
BLOCK_IN
)
res_tensor
=
conv_call
.
buffer
inner
=
tvm
.
tir
.
AttrStmt
(
[
dinp
,
pad_data_tensor
],
'buffer_bind_scope'
,
tvm
.
tir
.
call_intrin
(
'handle'
,
'tvm_tuple'
,
*
tpl
),
inner
)
return
inner
return
None
ret
=
tvm
.
tir
.
ir_pass
.
IRTransform
(
stmt_in
,
_do_fold
,
None
,
[
"AttrStmt"
])
return
ret
def
annotate_alu_coproc_scope
(
stmt_in
):
if
selects
:
condition
=
selects
[
0
]
.
condition
else
:
condition
=
tvm
.
tir
.
const
(
1
,
'int'
)
# create inner most block
irb
=
tvm
.
tir
.
ir_builder
.
create
()
with
irb
.
if_scope
(
condition
):
dev
=
env
.
dev
irb
.
scope_attr
(
dev
.
vta_axis
,
"coproc_scope"
,
dev
.
get_task_qid
(
dev
.
QID_COMPUTE
))
irb
.
scope_attr
(
dev
.
vta_axis
,
"coproc_uop_scope"
,
dev
.
vta_push_uop
)
irb
.
emit
(
tvm
.
tir
.
call_extern
(
"int32"
,
"VTAUopPush"
,
0
,
0
,
dout
.
access_ptr
(
"rw"
,
"int32"
),
dinp
.
access_ptr
(
"r"
,
"int32"
),
dwgt
.
access_ptr
(
"r"
,
"int32"
),
0
,
0
,
0
))
inner
=
irb
.
get
()
args
=
conv_call
.
indices
tpl
=
(
args
[
0
],
1
,
args
[
1
],
1
,
args
[
2
],
1
,
args
[
3
],
1
,
0
,
1
,
0
,
env
.
BLOCK_OUT
)
inner
=
tvm
.
tir
.
AttrStmt
(
[
dout
,
res_tensor
],
'buffer_bind_scope'
,
tvm
.
tir
.
call_intrin
(
'handle'
,
'tvm_tuple'
,
*
tpl
),
inner
)
args
=
kernel_call
.
indices
tpl
=
(
args
[
0
],
1
,
args
[
1
],
1
,
args
[
2
],
1
,
args
[
3
],
1
,
0
,
env
.
BLOCK_OUT
,
0
,
env
.
BLOCK_IN
)
inner
=
tvm
.
tir
.
AttrStmt
(
[
dwgt
,
kernel_tensor
],
'buffer_bind_scope'
,
tvm
.
tir
.
call_intrin
(
'handle'
,
'tvm_tuple'
,
*
tpl
),
inner
)
args
=
data_call
.
indices
tpl
=
(
args
[
0
],
1
,
args
[
1
],
1
,
args
[
2
],
1
,
args
[
3
],
1
,
0
,
1
,
0
,
env
.
BLOCK_IN
)
inner
=
tvm
.
tir
.
AttrStmt
(
[
dinp
,
pad_data_tensor
],
'buffer_bind_scope'
,
tvm
.
tir
.
call_intrin
(
'handle'
,
'tvm_tuple'
,
*
tpl
),
inner
)
return
inner
return
None
return
func
.
with_body
(
tvm
.
tir
.
ir_pass
.
IRTransform
(
func
.
body
,
_do_fold
,
None
,
[
"AttrStmt"
]))
return
tvm
.
tir
.
transform
.
prim_func_pass
(
_ftransform
,
opt_level
=
0
,
name
=
"tir.vta.InjectConv2DTrasnposeSkip"
)
def
AnnotateALUCoProcScope
():
"""Pass to insert ALU instruction.
"""Pass to insert ALU instruction.
Parameters
----------
stmt_in : Stmt
Input statement
Returns
Returns
-------
-------
stmt_out : Stmt
fpass : tvm.transform.Pass
T
ransformed statement
T
he pass
"""
"""
env
=
get_env
()
def
_ftransform
(
func
,
mod
,
ctx
):
def
_do_fold
(
stmt
):
env
=
get_env
()
if
_match_pragma
(
stmt
,
"alu"
):
def
_do_fold
(
stmt
):
irb
=
tvm
.
tir
.
ir_builder
.
create
()
if
_match_pragma
(
stmt
,
"alu"
):
irb
.
scope_attr
(
env
.
dev
.
vta_axis
,
"coproc_scope"
,
irb
=
tvm
.
tir
.
ir_builder
.
create
()
env
.
dev
.
get_task_qid
(
env
.
dev
.
QID_COMPUTE
))
irb
.
scope_attr
(
env
.
dev
.
vta_axis
,
"coproc_scope"
,
irb
.
scope_attr
(
env
.
dev
.
vta_axis
,
"coproc_uop_scope"
,
env
.
dev
.
get_task_qid
(
env
.
dev
.
QID_COMPUTE
))
tvm
.
tir
.
StringImm
(
"VTAPushALUOp"
))
irb
.
scope_attr
(
env
.
dev
.
vta_axis
,
"coproc_uop_scope"
,
irb
.
emit
(
stmt
)
tvm
.
tir
.
StringImm
(
"VTAPushALUOp"
))
return
irb
.
get
()
irb
.
emit
(
stmt
)
if
_match_pragma
(
stmt
,
"skip_alu"
):
return
irb
.
get
()
return
tvm
.
tir
.
Evaluate
(
0
)
if
_match_pragma
(
stmt
,
"skip_alu"
):
return
stmt
return
tvm
.
tir
.
Evaluate
(
0
)
return
stmt
stmt_out
=
tvm
.
tir
.
ir_pass
.
IRTransform
(
stmt_in
,
None
,
_do_fold
,
[
"AttrStmt"
])
return
func
.
with_body
(
tvm
.
tir
.
ir_pass
.
IRTransform
(
func
.
body
,
None
,
_do_fold
,
[
"AttrStmt"
]))
return
stmt_out
return
tvm
.
tir
.
transform
.
prim_func_pass
(
_ftransform
,
opt_level
=
0
,
name
=
"tir.vta.AnnotateALUCoProcScope"
)
def
inject_alu_intrin
(
stmt_in
):
def
InjectALUIntrin
():
"""Pass to inject ALU micro-ops.
"""Pass to inject ALU micro-ops.
Parameters
----------
stmt_in : Stmt
Input statement
Returns
Returns
-------
-------
stmt_out : Stmt
fpass : tvm.transform.Pass
T
ransformed statement
T
he pass
"""
"""
env
=
get_env
()
def
_ftransform
(
func
,
mod
,
ctx
):
idxm
=
tvm
.
tir
.
indexmod
env
=
get_env
()
analyzer
=
tvm
.
arith
.
Analyzer
()
idxm
=
tvm
.
tir
.
indexmod
analyzer
=
tvm
.
arith
.
Analyzer
()
def
_do_fold
(
stmt
):
def
_do_fold
(
stmt
):
def
_equal
(
x
,
y
):
def
_equal
(
x
,
y
):
return
tvm
.
ir
.
structural_equal
(
analyzer
.
simplify
(
x
-
y
),
0
)
return
tvm
.
ir
.
structural_equal
(
analyzer
.
simplify
(
x
-
y
),
0
)
def
_flatten_loop
(
src_coeff
,
dst_coeff
,
extents
):
def
_flatten_loop
(
src_coeff
,
dst_coeff
,
extents
):
src_coeff
=
list
(
src_coeff
)
src_coeff
=
list
(
src_coeff
)
dst_coeff
=
list
(
dst_coeff
)
dst_coeff
=
list
(
dst_coeff
)
extents
=
list
(
extents
)
extents
=
list
(
extents
)
rev_src_coeff
=
[
src_coeff
.
pop
()]
rev_src_coeff
=
[
src_coeff
.
pop
()]
rev_dst_coeff
=
[
dst_coeff
.
pop
()]
rev_dst_coeff
=
[
dst_coeff
.
pop
()]
rev_extents
=
[]
rev_extents
=
[]
assert
src_coeff
assert
src_coeff
vsrc
=
src_coeff
.
pop
()
vsrc
=
src_coeff
.
pop
()
vdst
=
dst_coeff
.
pop
()
vdst
=
dst_coeff
.
pop
()
vext
=
extents
.
pop
()
vext
=
extents
.
pop
()
while
src_coeff
:
while
src_coeff
:
next_src
=
src_coeff
.
pop
()
next_src
=
src_coeff
.
pop
()
next_dst
=
dst_coeff
.
pop
()
next_dst
=
dst_coeff
.
pop
()
next_ext
=
extents
.
pop
()
next_ext
=
extents
.
pop
()
if
_equal
(
next_src
,
vsrc
*
vext
)
and
_equal
(
next_dst
,
vdst
*
vext
):
if
_equal
(
next_src
,
vsrc
*
vext
)
and
_equal
(
next_dst
,
vdst
*
vext
):
vext
=
analyzer
.
simplify
(
vext
*
next_ext
)
vext
=
analyzer
.
simplify
(
vext
*
next_ext
)
else
:
else
:
rev_src_coeff
.
append
(
vsrc
)
rev_src_coeff
.
append
(
vsrc
)
rev_dst_coeff
.
append
(
vdst
)
rev_dst_coeff
.
append
(
vdst
)
rev_extents
.
append
(
vext
)
rev_extents
.
append
(
vext
)
vsrc
=
next_src
vsrc
=
next_src
vdst
=
next_dst
vdst
=
next_dst
vext
=
next_ext
vext
=
next_ext
rev_src_coeff
.
append
(
vsrc
)
rev_src_coeff
.
append
(
vsrc
)
rev_dst_coeff
.
append
(
vdst
)
rev_dst_coeff
.
append
(
vdst
)
rev_extents
.
append
(
vext
)
rev_extents
.
append
(
vext
)
rev_src_coeff
.
reverse
()
rev_src_coeff
.
reverse
()
rev_dst_coeff
.
reverse
()
rev_dst_coeff
.
reverse
()
rev_extents
.
reverse
()
rev_extents
.
reverse
()
return
rev_src_coeff
,
rev_dst_coeff
,
rev_extents
return
rev_src_coeff
,
rev_dst_coeff
,
rev_extents
if
_match_pragma
(
stmt
,
"alu"
):
if
_match_pragma
(
stmt
,
"alu"
):
# Get to the innermost loop body
# Get to the innermost loop body
loop_body
=
stmt
.
body
loop_body
=
stmt
.
body
nest_size
=
0
nest_size
=
0
while
isinstance
(
loop_body
,
tvm
.
tir
.
For
):
while
isinstance
(
loop_body
,
tvm
.
tir
.
For
):
loop_body
=
loop_body
.
body
loop_body
=
loop_body
.
body
nest_size
+=
1
nest_size
+=
1
# Get the src/dst arguments
# Get the src/dst arguments
dst_var
=
loop_body
.
buffer_var
dst_var
=
loop_body
.
buffer_var
dst_idx
=
loop_body
.
index
dst_idx
=
loop_body
.
index
# Derive loop variables and extents
# Derive loop variables and extents
tmp_body
=
stmt
.
body
tmp_body
=
stmt
.
body
indices
=
[]
indices
=
[]
extents
=
[]
extents
=
[]
for
_
in
range
(
nest_size
):
for
_
in
range
(
nest_size
):
indices
.
append
(
tmp_body
.
loop_var
)
indices
.
append
(
tmp_body
.
loop_var
)
extents
.
append
(
tmp_body
.
extent
)
extents
.
append
(
tmp_body
.
extent
)
tmp_body
=
tmp_body
.
body
tmp_body
=
tmp_body
.
body
# Derive opcode
# Derive opcode
if
isinstance
(
loop_body
.
value
,
tvm
.
tir
.
Add
):
if
isinstance
(
loop_body
.
value
,
tvm
.
tir
.
Add
):
alu_opcode
=
env
.
dev
.
ALU_OPCODE_ADD
alu_opcode
=
env
.
dev
.
ALU_OPCODE_ADD
lhs
=
loop_body
.
value
.
a
lhs
=
loop_body
.
value
.
a
rhs
=
loop_body
.
value
.
b
rhs
=
loop_body
.
value
.
b
elif
isinstance
(
loop_body
.
value
,
tvm
.
tir
.
Sub
):
elif
isinstance
(
loop_body
.
value
,
tvm
.
tir
.
Sub
):
alu_opcode
=
env
.
dev
.
ALU_OPCODE_SUB
alu_opcode
=
env
.
dev
.
ALU_OPCODE_SUB
lhs
=
loop_body
.
value
.
a
lhs
=
loop_body
.
value
.
a
rhs
=
loop_body
.
value
.
b
rhs
=
loop_body
.
value
.
b
elif
isinstance
(
loop_body
.
value
,
tvm
.
tir
.
Mul
):
elif
isinstance
(
loop_body
.
value
,
tvm
.
tir
.
Mul
):
alu_opcode
=
env
.
dev
.
ALU_OPCODE_MUL
alu_opcode
=
env
.
dev
.
ALU_OPCODE_MUL
lhs
=
loop_body
.
value
.
a
lhs
=
loop_body
.
value
.
a
rhs
=
loop_body
.
value
.
b
rhs
=
loop_body
.
value
.
b
elif
isinstance
(
loop_body
.
value
,
tvm
.
tir
.
Min
):
elif
isinstance
(
loop_body
.
value
,
tvm
.
tir
.
Min
):
alu_opcode
=
env
.
dev
.
ALU_OPCODE_MIN
alu_opcode
=
env
.
dev
.
ALU_OPCODE_MIN
lhs
=
loop_body
.
value
.
a
lhs
=
loop_body
.
value
.
a
rhs
=
loop_body
.
value
.
b
rhs
=
loop_body
.
value
.
b
elif
isinstance
(
loop_body
.
value
,
tvm
.
tir
.
Max
):
elif
isinstance
(
loop_body
.
value
,
tvm
.
tir
.
Max
):
alu_opcode
=
env
.
dev
.
ALU_OPCODE_MAX
alu_opcode
=
env
.
dev
.
ALU_OPCODE_MAX
lhs
=
loop_body
.
value
.
a
lhs
=
loop_body
.
value
.
a
rhs
=
loop_body
.
value
.
b
rhs
=
loop_body
.
value
.
b
elif
isinstance
(
loop_body
.
value
,
tvm
.
tir
.
Call
):
elif
isinstance
(
loop_body
.
value
,
tvm
.
tir
.
Call
):
if
loop_body
.
value
.
name
==
'shift_left'
:
if
loop_body
.
value
.
name
==
'shift_left'
:
alu_opcode
=
env
.
dev
.
ALU_OPCODE_SHR
alu_opcode
=
env
.
dev
.
ALU_OPCODE_SHR
lhs
=
loop_body
.
value
.
args
[
0
]
lhs
=
loop_body
.
value
.
args
[
0
]
rhs
=
analyzer
.
simplify
(
-
loop_body
.
value
.
args
[
1
])
rhs
=
analyzer
.
simplify
(
-
loop_body
.
value
.
args
[
1
])
elif
loop_body
.
value
.
name
==
'shift_right'
:
elif
loop_body
.
value
.
name
==
'shift_right'
:
alu_opcode
=
env
.
dev
.
ALU_OPCODE_SHR
lhs
=
loop_body
.
value
.
args
[
0
]
rhs
=
loop_body
.
value
.
args
[
1
]
else
:
raise
RuntimeError
(
"Function call not recognized
%
s"
%
(
loop_body
.
value
.
name
))
elif
isinstance
(
loop_body
.
value
,
tvm
.
tir
.
Load
):
alu_opcode
=
env
.
dev
.
ALU_OPCODE_SHR
alu_opcode
=
env
.
dev
.
ALU_OPCODE_SHR
lhs
=
loop_body
.
value
.
args
[
0
]
lhs
=
loop_body
.
value
rhs
=
loop_body
.
value
.
args
[
1
]
rhs
=
tvm
.
tir
.
const
(
0
,
"int32"
)
else
:
else
:
raise
RuntimeError
(
raise
RuntimeError
(
"Function call not recognized
%
s"
%
(
loop_body
.
value
.
name
))
"Expression not recognized
%
s,
%
s,
%
s"
%
(
elif
isinstance
(
loop_body
.
value
,
tvm
.
tir
.
Load
):
type
(
loop_body
.
value
),
str
(
loop_body
.
value
),
str
(
stmt
)))
alu_opcode
=
env
.
dev
.
ALU_OPCODE_SHR
lhs
=
loop_body
.
value
# Derive array index coefficients
rhs
=
tvm
.
tir
.
const
(
0
,
"int32"
)
dst_coeff
=
tvm
.
arith
.
detect_linear_equation
(
dst_idx
,
indices
)
else
:
# Check if lhs/rhs is immediate
raise
RuntimeError
(
use_imm
=
False
"Expression not recognized
%
s,
%
s,
%
s"
%
(
imm_val
=
None
type
(
loop_body
.
value
),
str
(
loop_body
.
value
),
str
(
stmt
)))
if
isinstance
(
rhs
,
tvm
.
tir
.
IntImm
):
assert
lhs
.
buffer_var
.
same_as
(
dst_var
)
# Derive array index coefficients
src_coeff
=
tvm
.
arith
.
detect_linear_equation
(
lhs
.
index
,
indices
)
dst_coeff
=
tvm
.
arith
.
detect_linear_equation
(
dst_idx
,
indices
)
use_imm
=
True
# Check if lhs/rhs is immediate
imm_val
=
rhs
use_imm
=
False
if
isinstance
(
lhs
,
tvm
.
tir
.
IntImm
):
imm_val
=
None
assert
rhs
.
buffer_var
.
same_as
(
dst_var
)
if
isinstance
(
rhs
,
tvm
.
tir
.
IntImm
):
src_coeff
=
tvm
.
arith
.
detect_linear_equation
(
rhs
.
index
,
indices
)
assert
lhs
.
buffer_var
.
same_as
(
dst_var
)
use_imm
=
True
src_coeff
=
tvm
.
arith
.
detect_linear_equation
(
lhs
.
index
,
indices
)
imm_val
=
lhs
use_imm
=
True
if
imm_val
is
None
:
imm_val
=
rhs
imm_val
=
0
if
isinstance
(
lhs
,
tvm
.
tir
.
IntImm
):
assert
lhs
.
buffer_var
.
same_as
(
dst_var
)
and
rhs
.
buffer_var
.
same_as
(
dst_var
)
assert
rhs
.
buffer_var
.
same_as
(
dst_var
)
src_lhs_coeff
=
tvm
.
arith
.
detect_linear_equation
(
lhs
.
index
,
indices
)
src_coeff
=
tvm
.
arith
.
detect_linear_equation
(
rhs
.
index
,
indices
)
src_rhs_coeff
=
tvm
.
arith
.
detect_linear_equation
(
rhs
.
index
,
indices
)
use_imm
=
True
# Determine which side has the same coefficients
imm_val
=
lhs
lhs_equal
=
True
if
imm_val
is
None
:
rhs_equal
=
True
imm_val
=
0
for
i
,
coef
in
enumerate
(
dst_coeff
):
assert
lhs
.
buffer_var
.
same_as
(
dst_var
)
and
rhs
.
buffer_var
.
same_as
(
dst_var
)
if
not
tvm
.
ir
.
structural_equal
(
coef
,
src_lhs_coeff
[
i
]):
src_lhs_coeff
=
tvm
.
arith
.
detect_linear_equation
(
lhs
.
index
,
indices
)
lhs_equal
=
False
src_rhs_coeff
=
tvm
.
arith
.
detect_linear_equation
(
rhs
.
index
,
indices
)
if
not
tvm
.
ir
.
structural_equal
(
coef
,
src_rhs_coeff
[
i
]):
# Determine which side has the same coefficients
rhs_equal
=
False
lhs_equal
=
True
# Make sure at least one of the source is identical to the
rhs_equal
=
True
# destination (in-place computation)
for
i
,
coef
in
enumerate
(
dst_coeff
):
assert
lhs_equal
or
rhs_equal
if
not
tvm
.
ir
.
structural_equal
(
coef
,
src_lhs_coeff
[
i
]):
# Assign the source coefficients
lhs_equal
=
False
if
lhs_equal
:
if
not
tvm
.
ir
.
structural_equal
(
coef
,
src_rhs_coeff
[
i
]):
src_coeff
=
src_rhs_coeff
rhs_equal
=
False
else
:
# Make sure at least one of the source is identical to the
src_coeff
=
src_lhs_coeff
# destination (in-place computation)
assert
lhs_equal
or
rhs_equal
# Ensure that we have the proper tensor dimensions in the
# Assign the source coefficients
# innermost loop (pattern match)
if
lhs_equal
:
src_coeff
=
list
(
src_coeff
)
src_coeff
=
src_rhs_coeff
dst_coeff
=
list
(
dst_coeff
)
extents
=
list
(
extents
)
assert
len
(
src_coeff
)
>
1
assert
len
(
dst_coeff
)
>
1
assert
len
(
extents
)
!=
0
assert
tvm
.
ir
.
structural_equal
(
analyzer
.
simplify
(
idxm
(
src_coeff
[
-
1
],
env
.
BATCH
*
env
.
BLOCK_OUT
)),
0
)
assert
tvm
.
ir
.
structural_equal
(
analyzer
.
simplify
(
idxm
(
dst_coeff
[
-
1
],
env
.
BATCH
*
env
.
BLOCK_OUT
)),
0
)
assert
tvm
.
ir
.
structural_equal
(
src_coeff
[
-
2
],
1
)
assert
tvm
.
ir
.
structural_equal
(
dst_coeff
[
-
2
],
1
)
if
env
.
BATCH
>
1
:
assert
len
(
src_coeff
)
>
2
assert
len
(
dst_coeff
)
>
2
assert
len
(
extents
)
>
1
assert
tvm
.
ir
.
structural_equal
(
src_coeff
[
-
3
],
env
.
BLOCK_OUT
)
assert
tvm
.
ir
.
structural_equal
(
dst_coeff
[
-
3
],
env
.
BLOCK_OUT
)
# Apply tensorization of the loop coefficients
src_offset
=
src_coeff
[
-
1
]
dst_offset
=
dst_coeff
[
-
1
]
if
env
.
BATCH
==
1
:
src_coeff
=
src_coeff
[:
-
2
]
dst_coeff
=
dst_coeff
[:
-
2
]
extents
=
extents
[:
-
1
]
else
:
else
:
src_coeff
=
src_lhs_coeff
src_coeff
=
src_coeff
[:
-
3
]
dst_coeff
=
dst_coeff
[:
-
3
]
# Ensure that we have the proper tensor dimensions in the
extents
=
extents
[:
-
2
]
# innermost loop (pattern match)
src_coeff
.
append
(
src_offset
)
src_coeff
=
list
(
src_coeff
)
dst_coeff
.
append
(
dst_offset
)
dst_coeff
=
list
(
dst_coeff
)
src_coeff
=
[
extents
=
list
(
extents
)
analyzer
.
simplify
(
c
//
(
env
.
BATCH
*
env
.
BLOCK_OUT
))
for
c
in
src_coeff
]
assert
len
(
src_coeff
)
>
1
dst_coeff
=
[
assert
len
(
dst_coeff
)
>
1
analyzer
.
simplify
(
c
//
(
env
.
BATCH
*
env
.
BLOCK_OUT
))
for
c
in
dst_coeff
]
assert
len
(
extents
)
!=
0
assert
tvm
.
ir
.
structural_equal
(
# Flatten the outer loops
analyzer
.
simplify
(
if
extents
:
idxm
(
src_coeff
[
-
1
],
env
.
BATCH
*
env
.
BLOCK_OUT
)),
0
)
src_coeff
,
dst_coeff
,
extents
=
_flatten_loop
(
src_coeff
,
dst_coeff
,
extents
)
assert
tvm
.
ir
.
structural_equal
(
analyzer
.
simplify
(
# Insert ALU micro-ops
idxm
(
dst_coeff
[
-
1
],
env
.
BATCH
*
env
.
BLOCK_OUT
)),
0
)
irb
=
tvm
.
tir
.
ir_builder
.
create
()
assert
tvm
.
ir
.
structural_equal
(
src_coeff
[
-
2
],
1
)
for
idx
,
extent
in
enumerate
(
extents
):
assert
tvm
.
ir
.
structural_equal
(
dst_coeff
[
-
2
],
1
)
irb
.
emit
(
tvm
.
tir
.
call_extern
(
if
env
.
BATCH
>
1
:
"int32"
,
"VTAUopLoopBegin"
,
assert
len
(
src_coeff
)
>
2
extent
,
dst_coeff
[
idx
],
src_coeff
[
idx
],
0
))
assert
len
(
dst_coeff
)
>
2
use_imm
=
int
(
use_imm
)
assert
len
(
extents
)
>
1
assert
tvm
.
ir
.
structural_equal
(
src_coeff
[
-
3
],
env
.
BLOCK_OUT
)
assert
tvm
.
ir
.
structural_equal
(
dst_coeff
[
-
3
],
env
.
BLOCK_OUT
)
# Apply tensorization of the loop coefficients
src_offset
=
src_coeff
[
-
1
]
dst_offset
=
dst_coeff
[
-
1
]
if
env
.
BATCH
==
1
:
src_coeff
=
src_coeff
[:
-
2
]
dst_coeff
=
dst_coeff
[:
-
2
]
extents
=
extents
[:
-
1
]
else
:
src_coeff
=
src_coeff
[:
-
3
]
dst_coeff
=
dst_coeff
[:
-
3
]
extents
=
extents
[:
-
2
]
src_coeff
.
append
(
src_offset
)
dst_coeff
.
append
(
dst_offset
)
src_coeff
=
[
analyzer
.
simplify
(
c
//
(
env
.
BATCH
*
env
.
BLOCK_OUT
))
for
c
in
src_coeff
]
dst_coeff
=
[
analyzer
.
simplify
(
c
//
(
env
.
BATCH
*
env
.
BLOCK_OUT
))
for
c
in
dst_coeff
]
# Flatten the outer loops
if
extents
:
src_coeff
,
dst_coeff
,
extents
=
_flatten_loop
(
src_coeff
,
dst_coeff
,
extents
)
# Insert ALU micro-ops
irb
=
tvm
.
tir
.
ir_builder
.
create
()
for
idx
,
extent
in
enumerate
(
extents
):
irb
.
emit
(
tvm
.
tir
.
call_extern
(
"int32"
,
"VTAUopLoopBegin"
,
extent
,
dst_coeff
[
idx
],
src_coeff
[
idx
],
0
))
use_imm
=
int
(
use_imm
)
irb
.
emit
(
tvm
.
tir
.
call_extern
(
"int32"
,
"VTAUopPush"
,
1
,
0
,
dst_coeff
[
len
(
dst_coeff
)
-
1
],
src_coeff
[
len
(
src_coeff
)
-
1
],
0
,
alu_opcode
,
use_imm
,
imm_val
))
for
extent
in
extents
:
irb
.
emit
(
tvm
.
tir
.
call_extern
(
irb
.
emit
(
tvm
.
tir
.
call_extern
(
"int32"
,
"VTAUopLoopEnd"
))
"int32"
,
"VTAUopPush"
,
return
irb
.
get
()
1
,
0
,
return
stmt
dst_coeff
[
len
(
dst_coeff
)
-
1
],
src_coeff
[
len
(
src_coeff
)
-
1
],
stmt_out
=
tvm
.
tir
.
ir_pass
.
IRTransform
(
0
,
stmt_in
,
None
,
_do_fold
,
[
"AttrStmt"
])
alu_opcode
,
use_imm
,
imm_val
))
return
stmt_out
for
extent
in
extents
:
irb
.
emit
(
tvm
.
tir
.
call_extern
(
"int32"
,
"VTAUopLoopEnd"
))
def
debug_print
(
stmt
):
return
irb
.
get
()
"""A debug pass that print the stmt
return
stmt
Parameters
return
func
.
with_body
(
tvm
.
tir
.
ir_pass
.
IRTransform
(
----------
func
.
body
,
None
,
_do_fold
,
[
"AttrStmt"
]))
stmt : Stmt
The input statement
return
tvm
.
tir
.
transform
.
prim_func_pass
(
_ftransform
,
opt_level
=
0
,
name
=
"tir.vta.InjectALUIntrin"
)
Returns
-------
stmt : Stmt
The
"""
# pylint: disable=superfluous-parens
print
(
stmt
)
return
stmt
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment