Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
e8afa1b4
Commit
e8afa1b4
authored
Feb 23, 2018
by
xqdan
Committed by
Tianqi Chen
Feb 22, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[PASS] Support buffer reuse for different types (#891)
[PASS] Support buffer reuse for different types
parent
61cdf903
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
92 additions
and
12 deletions
+92
-12
src/pass/storage_rewrite.cc
+18
-10
tests/python/unittest/test_pass_storage_rewrite.py
+74
-2
No files found.
src/pass/storage_rewrite.cc
View file @
e8afa1b4
...
@@ -502,7 +502,6 @@ class StoragePlanRewriter : public IRMutator {
...
@@ -502,7 +502,6 @@ class StoragePlanRewriter : public IRMutator {
}
}
// Remap the index
// Remap the index
Expr
RemapIndex
(
Type
dtype
,
Expr
index
,
StorageEntry
*
e
)
{
Expr
RemapIndex
(
Type
dtype
,
Expr
index
,
StorageEntry
*
e
)
{
CHECK_EQ
(
dtype
.
element_of
(),
e
->
elem_type
);
if
(
e
->
bits_offset
==
0
)
return
index
;
if
(
e
->
bits_offset
==
0
)
return
index
;
uint64_t
elem_bits
=
dtype
.
bits
()
*
dtype
.
lanes
();
uint64_t
elem_bits
=
dtype
.
bits
()
*
dtype
.
lanes
();
CHECK_EQ
(
e
->
bits_offset
%
elem_bits
,
0U
);
CHECK_EQ
(
e
->
bits_offset
%
elem_bits
,
0U
);
...
@@ -564,16 +563,21 @@ class StoragePlanRewriter : public IRMutator {
...
@@ -564,16 +563,21 @@ class StoragePlanRewriter : public IRMutator {
Expr
combo_size
;
Expr
combo_size
;
for
(
const
Allocate
*
op
:
e
->
allocs
)
{
for
(
const
Allocate
*
op
:
e
->
allocs
)
{
Expr
sz
=
arith
::
ComputeReduce
<
Mul
>
(
op
->
extents
,
make_const
(
Int
(
32
),
1
));
Expr
sz
=
arith
::
ComputeReduce
<
Mul
>
(
op
->
extents
,
make_const
(
Int
(
32
),
1
));
if
(
alloc_type
.
lanes
()
!=
op
->
type
.
lanes
())
{
// transform to bits
sz
=
(
sz
*
make_const
(
sz
.
type
(),
op
->
type
.
lanes
())
+
auto
sz_nbits
=
sz
*
(
op
->
type
.
bits
()
*
op
->
type
.
lanes
());
make_const
(
sz
.
type
(),
alloc_type
.
lanes
()
-
1
))
/
make_const
(
sz
.
type
(),
alloc_type
.
lanes
());
}
if
(
combo_size
.
defined
())
{
if
(
combo_size
.
defined
())
{
combo_size
=
max
(
combo_size
,
sz
);
combo_size
=
max
(
combo_size
,
sz
_nbits
);
}
else
{
}
else
{
combo_size
=
sz
;
combo_size
=
sz_nbits
;
}
}
}
// transform to alloc bytes
auto
type_bits
=
alloc_type
.
bits
()
*
alloc_type
.
lanes
();
bool
divided
=
can_prove
(
combo_size
%
type_bits
==
0
);
combo_size
=
combo_size
/
type_bits
;
// round up for can not divided
if
(
!
divided
)
{
combo_size
+=
make_const
(
Int
(
32
),
1
);
}
}
combo_size
=
ir
::
Simplify
(
combo_size
);
combo_size
=
ir
::
Simplify
(
combo_size
);
e
->
new_alloc
=
Allocate
::
make
(
e
->
new_alloc
=
Allocate
::
make
(
...
@@ -784,8 +788,9 @@ class StoragePlanRewriter : public IRMutator {
...
@@ -784,8 +788,9 @@ class StoragePlanRewriter : public IRMutator {
// skip plan for local variable,
// skip plan for local variable,
// compiler can do a better job with register allocation.
// compiler can do a better job with register allocation.
const
uint64_t
match_range
=
16
;
const
uint64_t
match_range
=
16
;
uint64_t
op_elem_bits
=
op
->
type
.
bits
()
*
op
->
type
.
lanes
();
uint64_t
const_nbits
=
static_cast
<
uint64_t
>
(
uint64_t
const_nbits
=
static_cast
<
uint64_t
>
(
op
->
constant_allocation_size
()
*
op
->
type
.
bits
()
*
op
->
type
.
lanes
()
);
op
->
constant_allocation_size
()
*
op
_elem_bits
);
// disable reuse of small arrays, they will be lowered to registers in LLVM
// disable reuse of small arrays, they will be lowered to registers in LLVM
// This rules only apply if we are using non special memory
// This rules only apply if we are using non special memory
if
(
scope
.
tag
.
length
()
==
0
)
{
if
(
scope
.
tag
.
length
()
==
0
)
{
...
@@ -801,15 +806,18 @@ class StoragePlanRewriter : public IRMutator {
...
@@ -801,15 +806,18 @@ class StoragePlanRewriter : public IRMutator {
auto
begin
=
const_free_map_
.
lower_bound
(
const_nbits
/
match_range
);
auto
begin
=
const_free_map_
.
lower_bound
(
const_nbits
/
match_range
);
auto
mid
=
const_free_map_
.
lower_bound
(
const_nbits
);
auto
mid
=
const_free_map_
.
lower_bound
(
const_nbits
);
auto
end
=
const_free_map_
.
upper_bound
(
const_nbits
*
match_range
);
auto
end
=
const_free_map_
.
upper_bound
(
const_nbits
*
match_range
);
// start looking at the buffer that is bigger than the required size first
for
(
auto
it
=
mid
;
it
!=
end
;
++
it
)
{
for
(
auto
it
=
mid
;
it
!=
end
;
++
it
)
{
StorageEntry
*
e
=
it
->
second
;
StorageEntry
*
e
=
it
->
second
;
if
(
e
->
attach_scope_
!=
attach_scope
)
continue
;
if
(
e
->
attach_scope_
!=
attach_scope
)
continue
;
if
(
e
->
scope
!=
scope
)
continue
;
if
(
e
->
scope
!=
scope
)
continue
;
if
(
e
->
elem_type
!=
op
->
type
.
element_of
())
continue
;
// when not divided, no reuse, eg, float4 vs float3
if
(
e
->
bits_offset
%
op_elem_bits
!=
0
)
continue
;
e
->
const_nbits
=
std
::
max
(
const_nbits
,
e
->
const_nbits
);
e
->
const_nbits
=
std
::
max
(
const_nbits
,
e
->
const_nbits
);
const_free_map_
.
erase
(
it
);
const_free_map_
.
erase
(
it
);
return
e
;
return
e
;
}
}
// then start looking at smaller buffers.
for
(
auto
it
=
mid
;
it
!=
begin
;)
{
for
(
auto
it
=
mid
;
it
!=
begin
;)
{
--
it
;
--
it
;
StorageEntry
*
e
=
it
->
second
;
StorageEntry
*
e
=
it
->
second
;
...
...
tests/python/unittest/test_pass_storage_rewrite.py
View file @
e8afa1b4
...
@@ -54,10 +54,27 @@ def test_alloc_different_dtypes():
...
@@ -54,10 +54,27 @@ def test_alloc_different_dtypes():
ib
=
tvm
.
ir_builder
.
create
()
ib
=
tvm
.
ir_builder
.
create
()
base_dtype
=
dtype_list
[
0
]
base_dtype
=
dtype_list
[
0
]
global_a
=
tvm
.
placeholder
((
length
,),
name
=
"global_a"
,
dtype
=
base_dtype
)
global_a
=
tvm
.
placeholder
((
length
,),
name
=
"global_a"
,
dtype
=
base_dtype
)
for
index
,
dtype
in
enumerate
(
dtype_list
):
assert
len
(
dtype_list
)
==
4
with
ib
.
for_range
(
0
,
length
,
name
=
"j"
)
as
j
:
with
ib
.
for_range
(
0
,
length
,
name
=
"j"
)
as
j
:
A
=
ib
.
allocate
(
dtype
,
length
,
name
=
"A_"
+
str
(
index
),
scope
=
"local.L0A"
)
dtype
=
dtype_list
[
0
]
A
=
ib
.
allocate
(
dtype
,
length
,
name
=
"A"
,
scope
=
"local.L0A"
)
A
[
j
]
=
tvm
.
const
(
1
,
dtype
=
dtype
)
A
[
j
]
=
tvm
.
const
(
1
,
dtype
=
dtype
)
with
ib
.
for_range
(
0
,
length
,
name
=
"j"
)
as
j
:
dtype
=
dtype_list
[
1
]
B
=
ib
.
allocate
(
dtype
,
length
,
name
=
"B"
,
scope
=
"local.L0A"
)
B
[
j
]
=
tvm
.
const
(
1
,
dtype
=
dtype
)
with
ib
.
for_range
(
0
,
length
,
name
=
"j"
)
as
j
:
dtype
=
dtype_list
[
2
]
C
=
ib
.
allocate
(
dtype
,
length
,
name
=
"C"
,
scope
=
"local.L0A"
)
C
[
j
]
=
tvm
.
const
(
1
,
dtype
=
dtype
)
with
ib
.
for_range
(
0
,
length
,
name
=
"j"
)
as
j
:
dtype
=
dtype_list
[
3
]
D
=
ib
.
allocate
(
dtype
,
length
,
name
=
"D"
,
scope
=
"local.L0A"
)
D
[
j
]
=
tvm
.
const
(
1
,
dtype
=
dtype
)
with
ib
.
for_range
(
0
,
length
,
name
=
"j"
)
as
j
:
dtype
=
"int8"
E
=
ib
.
allocate
(
dtype
,
length
,
name
=
"E"
,
scope
=
"local.L0A"
)
E
[
j
]
=
A
[
j
]
.
astype
(
dtype
)
+
B
[
j
]
.
astype
(
dtype
)
+
C
[
j
]
.
astype
(
dtype
)
+
D
[
j
]
.
astype
(
dtype
)
return
ib
.
get
()
return
ib
.
get
()
def
dtype_bit_len
(
dtype
):
def
dtype_bit_len
(
dtype
):
...
@@ -342,6 +359,58 @@ def test_inplace_rule3():
...
@@ -342,6 +359,58 @@ def test_inplace_rule3():
assert
n
.
extents
[
0
]
.
value
==
70
assert
n
.
extents
[
0
]
.
value
==
70
tvm
.
ir_pass
.
PostOrderVisit
(
stmt
,
verify
)
tvm
.
ir_pass
.
PostOrderVisit
(
stmt
,
verify
)
def
test_alloc_seq_type
():
ib
=
tvm
.
ir_builder
.
create
()
n
=
tvm
.
var
(
"n"
)
with
ib
.
for_range
(
0
,
n
,
name
=
"i"
)
as
i
:
with
ib
.
for_range
(
0
,
10
,
name
=
"j"
)
as
j
:
A
=
ib
.
allocate
(
"float32"
,
200
,
name
=
"A"
,
scope
=
"local.L0A"
)
A1
=
ib
.
allocate
(
"float32"
,
200
,
name
=
"A1"
,
scope
=
"local.L0A"
)
A
[
j
]
=
1.2
A1
[
j
]
=
1.3
B
=
ib
.
allocate
(
"int16"
,
200
,
name
=
"B"
,
scope
=
"local.L0A"
)
B
[
j
]
=
tvm
.
const
(
1
,
"int16"
)
C
=
ib
.
allocate
(
"int16"
,
200
,
name
=
"C"
,
scope
=
"local.L0A"
)
C
[
j
]
=
tvm
.
const
(
1
,
"int16"
)
D
=
ib
.
allocate
(
"int16"
,
200
,
name
=
"D"
,
scope
=
"local.L0A"
)
D
[
j
]
=
B
[
j
]
+
C
[
j
]
A2
=
ib
.
allocate
(
"float32"
,
200
,
name
=
"A2"
,
scope
=
"local.L0A"
)
A2
[
j
]
=
A
[
j
]
body
=
ib
.
get
()
body
=
tvm
.
ir_pass
.
StorageRewrite
(
body
)
num_alloc
=
[
0
]
def
verify
(
n
):
if
isinstance
(
n
,
tvm
.
stmt
.
Allocate
):
num_alloc
[
0
]
+=
1
assert
n
.
extents
[
0
]
.
value
==
500
tvm
.
ir_pass
.
PostOrderVisit
(
body
,
verify
)
assert
num_alloc
[
0
]
==
1
def
test_alloc_seq_type2
():
ib
=
tvm
.
ir_builder
.
create
()
n
=
tvm
.
var
(
"n"
)
with
ib
.
for_range
(
0
,
n
,
name
=
"i"
)
as
i
:
with
ib
.
for_range
(
0
,
10
,
name
=
"j"
)
as
j
:
A
=
ib
.
allocate
(
"float32"
,
200
,
name
=
"A"
,
scope
=
"local.L0A"
)
A
[
j
]
=
1.2
with
ib
.
for_range
(
0
,
20
,
name
=
"j"
)
as
j
:
B
=
ib
.
allocate
(
"int16"
,
400
,
name
=
"B"
,
scope
=
"local.L0A"
)
B
[
j
]
=
tvm
.
const
(
1
,
"int16"
)
with
ib
.
for_range
(
0
,
10
,
name
=
"j"
)
as
j
:
C
=
ib
.
allocate
(
"float32"
,
200
,
name
=
"C"
,
scope
=
"local.L0A"
)
C
[
j
]
=
1.2
body
=
ib
.
get
()
body
=
tvm
.
ir_pass
.
StorageRewrite
(
body
)
num_alloc
=
[
0
]
def
verify
(
n
):
if
isinstance
(
n
,
tvm
.
stmt
.
Allocate
):
num_alloc
[
0
]
+=
1
assert
n
.
extents
[
0
]
.
value
==
200
tvm
.
ir_pass
.
PostOrderVisit
(
body
,
verify
)
assert
num_alloc
[
0
]
==
1
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_alloc_seq
()
test_alloc_seq
()
test_alloc_different_dtypes
()
test_alloc_different_dtypes
()
...
@@ -352,3 +421,6 @@ if __name__ == "__main__":
...
@@ -352,3 +421,6 @@ if __name__ == "__main__":
test_storage_share_gpu
()
test_storage_share_gpu
()
test_inplace_rule2
()
test_inplace_rule2
()
test_inplace_rule3
()
test_inplace_rule3
()
test_alloc_seq_type
()
test_alloc_seq_type2
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment