Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
79e482bc
Commit
79e482bc
authored
Aug 13, 2017
by
Tianqi Chen
Committed by
GitHub
Aug 13, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[PASS] Memory barrier detection, storage access lower. (#317)
parent
afa20869
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
161 additions
and
283 deletions
+161
-283
include/tvm/ir_pass.h
+9
-0
include/tvm/target_info.h
+6
-0
python/tvm/build_module.py
+1
-1
src/api/api_pass.cc
+1
-0
src/pass/coproc_sync.cc
+0
-0
src/pass/lower_tvm_builtin.cc
+7
-76
src/pass/storage_access.cc
+110
-0
src/pass/storage_flatten.cc
+0
-1
src/pass/storage_sync.cc
+4
-197
tests/python/unittest/test_pass_storage_sync.py
+23
-8
No files found.
include/tvm/ir_pass.h
View file @
79e482bc
...
@@ -267,6 +267,15 @@ Stmt CoProcSync(Stmt stmt);
...
@@ -267,6 +267,15 @@ Stmt CoProcSync(Stmt stmt);
Stmt
LiftAttrScope
(
Stmt
stmt
,
std
::
string
attr_key
);
Stmt
LiftAttrScope
(
Stmt
stmt
,
std
::
string
attr_key
);
/*!
/*!
* \brief Lower attached storage access information.
* Do this pass after all storage access analysis finish.
*
* \param stmt The stmt to be trasnformed
* \return Transformed stmt.
*/
Stmt
LowerStorageAccessInfo
(
Stmt
stmt
);
/*!
* \brief Make an user callable API LoweredFunc.
* \brief Make an user callable API LoweredFunc.
*
*
* The main task of this function is to create code to :
* The main task of this function is to create code to :
...
...
include/tvm/target_info.h
View file @
79e482bc
...
@@ -23,11 +23,17 @@ struct MemoryInfoNode : public Node {
...
@@ -23,11 +23,17 @@ struct MemoryInfoNode : public Node {
int
max_num_bits
;
int
max_num_bits
;
/*! \brief maximum number of bits to be used in simd op */
/*! \brief maximum number of bits to be used in simd op */
int
max_simd_bits
;
int
max_simd_bits
;
/*!
* \brief head address of the buffer, if visible to CPU
* This address can be None.
*/
Expr
head_address
;
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"unit_bits"
,
&
unit_bits
);
v
->
Visit
(
"unit_bits"
,
&
unit_bits
);
v
->
Visit
(
"max_num_bits"
,
&
max_num_bits
);
v
->
Visit
(
"max_num_bits"
,
&
max_num_bits
);
v
->
Visit
(
"max_simd_bits"
,
&
max_simd_bits
);
v
->
Visit
(
"max_simd_bits"
,
&
max_simd_bits
);
v
->
Visit
(
"head_address"
,
&
head_address
);
}
}
static
constexpr
const
char
*
_type_key
=
"MemoryInfo"
;
static
constexpr
const
char
*
_type_key
=
"MemoryInfo"
;
...
...
python/tvm/build_module.py
View file @
79e482bc
...
@@ -197,7 +197,6 @@ def lower(sch,
...
@@ -197,7 +197,6 @@ def lower(sch,
stmt
=
ir_pass
.
VectorizeLoop
(
stmt
)
stmt
=
ir_pass
.
VectorizeLoop
(
stmt
)
stmt
=
ir_pass
.
InjectVirtualThread
(
stmt
)
stmt
=
ir_pass
.
InjectVirtualThread
(
stmt
)
stmt
=
ir_pass
.
StorageRewrite
(
stmt
)
stmt
=
ir_pass
.
StorageRewrite
(
stmt
)
stmt
=
ir_pass
.
CoProcSync
(
stmt
)
cfg
=
BuildConfig
.
current
cfg
=
BuildConfig
.
current
stmt
=
ir_pass
.
UnrollLoop
(
stmt
=
ir_pass
.
UnrollLoop
(
stmt
,
stmt
,
...
@@ -210,6 +209,7 @@ def lower(sch,
...
@@ -210,6 +209,7 @@ def lower(sch,
stmt
=
ir_pass
.
Simplify
(
stmt
)
stmt
=
ir_pass
.
Simplify
(
stmt
)
if
simple_mode
:
if
simple_mode
:
return
stmt
return
stmt
stmt
=
ir_pass
.
LowerStorageAccessInfo
(
stmt
)
return
ir_pass
.
MakeAPI
(
stmt
,
name
,
arg_list
,
0
,
cfg
.
restricted_func
)
return
ir_pass
.
MakeAPI
(
stmt
,
name
,
arg_list
,
0
,
cfg
.
restricted_func
)
...
...
src/api/api_pass.cc
View file @
79e482bc
...
@@ -95,6 +95,7 @@ REGISTER_PASS2(BindDeviceType);
...
@@ -95,6 +95,7 @@ REGISTER_PASS2(BindDeviceType);
REGISTER_PASS1
(
SplitHostDevice
);
REGISTER_PASS1
(
SplitHostDevice
);
REGISTER_PASS1
(
StorageRewrite
);
REGISTER_PASS1
(
StorageRewrite
);
REGISTER_PASS1
(
CoProcSync
);
REGISTER_PASS1
(
CoProcSync
);
REGISTER_PASS1
(
LowerStorageAccessInfo
);
REGISTER_PASS1
(
InjectVirtualThread
);
REGISTER_PASS1
(
InjectVirtualThread
);
REGISTER_PASS1
(
InjectPrefetch
);
REGISTER_PASS1
(
InjectPrefetch
);
REGISTER_PASS1
(
LoopPartition
);
REGISTER_PASS1
(
LoopPartition
);
...
...
src/pass/coproc_sync.cc
0 → 100644
View file @
79e482bc
This diff is collapsed.
Click to expand it.
src/pass/lower_tvm_builtin.cc
View file @
79e482bc
...
@@ -6,16 +6,13 @@
...
@@ -6,16 +6,13 @@
#include <tvm/ir.h>
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_pass.h>
#include <tvm/target_info.h>
#include <unordered_set>
#include <unordered_set>
#include "./ir_util.h"
#include "./ir_util.h"
#include "../arithmetic/compute_expr.h"
#include "../arithmetic/compute_expr.h"
#include "../runtime/thread_storage_scope.h"
namespace
tvm
{
namespace
tvm
{
namespace
ir
{
namespace
ir
{
using
runtime
::
StorageScope
;
inline
Expr
ConstInt32
(
size_t
index
)
{
inline
Expr
ConstInt32
(
size_t
index
)
{
CHECK_LE
(
index
,
std
::
numeric_limits
<
int
>::
max
());
CHECK_LE
(
index
,
std
::
numeric_limits
<
int
>::
max
());
return
make_const
(
Int
(
32
),
static_cast
<
int
>
(
index
));
return
make_const
(
Int
(
32
),
static_cast
<
int
>
(
index
));
...
@@ -69,14 +66,7 @@ class BuiltinLower : public IRMutator {
...
@@ -69,14 +66,7 @@ class BuiltinLower : public IRMutator {
// Lower allocate to device allocate when needed.
// Lower allocate to device allocate when needed.
Stmt
stmt
=
IRMutator
::
Mutate_
(
op
,
s
);
Stmt
stmt
=
IRMutator
::
Mutate_
(
op
,
s
);
op
=
stmt
.
as
<
Allocate
>
();
op
=
stmt
.
as
<
Allocate
>
();
// For special memory, remove allocate.
if
(
op
->
new_expr
.
defined
())
return
stmt
;
auto
it
=
storage_info_
.
find
(
op
->
buffer_var
.
get
());
if
(
it
!=
storage_info_
.
end
()
&&
it
->
second
.
scope
.
tag
.
length
()
!=
0
)
{
++
it
->
second
.
alloc_count
;
CHECK_LE
(
it
->
second
.
alloc_count
,
1
)
<<
"Double allocation of "
<<
it
->
second
.
scope
.
to_string
();
return
op
->
body
;
}
// Get constant allocation bound.
// Get constant allocation bound.
int64_t
dev_type
;
int64_t
dev_type
;
int64_t
nbytes
=
GetVectorBytes
(
op
->
type
);
int64_t
nbytes
=
GetVectorBytes
(
op
->
type
);
...
@@ -139,25 +129,12 @@ class BuiltinLower : public IRMutator {
...
@@ -139,25 +129,12 @@ class BuiltinLower : public IRMutator {
CHECK
(
!
device_type_
.
defined
());
CHECK
(
!
device_type_
.
defined
());
device_type_
=
op
->
value
;
device_type_
=
op
->
value
;
return
Mutate
(
op
->
body
);
return
Mutate
(
op
->
body
);
}
else
if
(
op
->
attr_key
==
attr
::
storage_scope
)
{
const
Variable
*
buf
=
op
->
node
.
as
<
Variable
>
();
StorageScope
scope
=
StorageScope
::
make
(
op
->
value
.
as
<
StringImm
>
()
->
value
);
StorageEntry
e
;
e
.
scope
=
scope
;
if
(
scope
.
tag
.
length
()
!=
0
)
{
e
.
info
=
GetMemoryInfo
(
op
->
value
.
as
<
StringImm
>
()
->
value
);
CHECK
(
e
.
info
.
defined
())
<<
"Cannot find memory info of "
<<
scope
.
to_string
();
}
storage_info_
[
buf
]
=
e
;
return
IRMutator
::
Mutate_
(
op
,
s
);
}
else
{
}
else
{
return
IRMutator
::
Mutate_
(
op
,
s
);
return
IRMutator
::
Mutate_
(
op
,
s
);
}
}
}
}
Expr
Mutate_
(
const
Call
*
op
,
const
Expr
&
e
)
final
{
Expr
Mutate_
(
const
Call
*
op
,
const
Expr
&
e
)
final
{
if
(
op
->
is_intrinsic
(
intrinsic
::
tvm_access_ptr
))
{
if
(
op
->
is_intrinsic
(
intrinsic
::
tvm_call_packed
))
{
return
MakeAccessPtr
(
op
,
e
);
}
else
if
(
op
->
is_intrinsic
(
intrinsic
::
tvm_call_packed
))
{
return
MakeCallPacked
(
op
,
e
);
return
MakeCallPacked
(
op
,
e
);
}
else
if
(
op
->
is_intrinsic
(
intrinsic
::
tvm_stack_make_shape
))
{
}
else
if
(
op
->
is_intrinsic
(
intrinsic
::
tvm_stack_make_shape
))
{
return
MakeShape
(
op
,
e
);
return
MakeShape
(
op
,
e
);
...
@@ -167,14 +144,6 @@ class BuiltinLower : public IRMutator {
...
@@ -167,14 +144,6 @@ class BuiltinLower : public IRMutator {
return
IRMutator
::
Mutate_
(
op
,
e
);
return
IRMutator
::
Mutate_
(
op
,
e
);
}
}
}
}
Expr
Convert
(
Type
t
,
Expr
e
)
{
if
(
e
.
type
()
!=
t
)
{
return
Cast
::
make
(
t
,
e
);
}
else
{
return
e
;
}
}
// call shape
// call shape
Expr
MakeShape
(
const
Call
*
op
,
const
Expr
&
e
)
{
Expr
MakeShape
(
const
Call
*
op
,
const
Expr
&
e
)
{
size_t
stack_begin
=
run_shape_stack_
;
size_t
stack_begin
=
run_shape_stack_
;
...
@@ -183,7 +152,7 @@ class BuiltinLower : public IRMutator {
...
@@ -183,7 +152,7 @@ class BuiltinLower : public IRMutator {
op
=
expr
.
as
<
Call
>
();
op
=
expr
.
as
<
Call
>
();
for
(
size_t
i
=
0
;
i
<
op
->
args
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
op
->
args
.
size
();
++
i
)
{
prep_seq_
.
emplace_back
(
prep_seq_
.
emplace_back
(
Store
::
make
(
stack_shape_
,
Conver
t
(
Int
(
64
),
op
->
args
[
i
]),
Store
::
make
(
stack_shape_
,
cas
t
(
Int
(
64
),
op
->
args
[
i
]),
ConstInt32
(
stack_begin
+
i
),
const_true
(
1
)));
ConstInt32
(
stack_begin
+
i
),
const_true
(
1
)));
}
}
return
AddressOffset
(
stack_shape_
,
Int
(
64
),
stack_begin
);
return
AddressOffset
(
stack_shape_
,
Int
(
64
),
stack_begin
);
...
@@ -224,15 +193,15 @@ class BuiltinLower : public IRMutator {
...
@@ -224,15 +193,15 @@ class BuiltinLower : public IRMutator {
}
}
prep_seq_
.
emplace_back
(
prep_seq_
.
emplace_back
(
TVMStructSet
(
stack_array_
,
idx
,
intrinsic
::
kArrByteOffset
,
TVMStructSet
(
stack_array_
,
idx
,
intrinsic
::
kArrByteOffset
,
Conver
t
(
UInt
(
64
),
byte_offset
)));
cas
t
(
UInt
(
64
),
byte_offset
)));
CHECK
(
device_type_
.
defined
())
<<
"Unknown device type in current IR"
;
CHECK
(
device_type_
.
defined
())
<<
"Unknown device type in current IR"
;
CHECK
(
device_id_
.
defined
())
<<
"Unknown device id in current IR"
;
CHECK
(
device_id_
.
defined
())
<<
"Unknown device id in current IR"
;
prep_seq_
.
emplace_back
(
prep_seq_
.
emplace_back
(
TVMStructSet
(
stack_array_
,
idx
,
intrinsic
::
kArrDeviceId
,
TVMStructSet
(
stack_array_
,
idx
,
intrinsic
::
kArrDeviceId
,
Conver
t
(
Int
(
32
),
device_id_
)));
cas
t
(
Int
(
32
),
device_id_
)));
prep_seq_
.
emplace_back
(
prep_seq_
.
emplace_back
(
TVMStructSet
(
stack_array_
,
idx
,
intrinsic
::
kArrDeviceType
,
TVMStructSet
(
stack_array_
,
idx
,
intrinsic
::
kArrDeviceType
,
Conver
t
(
Int
(
32
),
device_type_
)));
cas
t
(
Int
(
32
),
device_type_
)));
return
TVMStructGet
(
Handle
(),
stack_array_
,
idx
,
intrinsic
::
kArrAddr
);
return
TVMStructGet
(
Handle
(),
stack_array_
,
idx
,
intrinsic
::
kArrAddr
);
}
}
// call packled.
// call packled.
...
@@ -280,33 +249,6 @@ class BuiltinLower : public IRMutator {
...
@@ -280,33 +249,6 @@ class BuiltinLower : public IRMutator {
Int
(
32
),
intrinsic
::
tvm_call_packed_lowered
,
Int
(
32
),
intrinsic
::
tvm_call_packed_lowered
,
packed_args
,
Call
::
Intrinsic
);
packed_args
,
Call
::
Intrinsic
);
}
}
// tvm_access_ptr
Expr
MakeAccessPtr
(
const
Call
*
op
,
const
Expr
&
e
)
{
// Specially handle the buffer packed intrinsic
Expr
expr
=
IRMutator
::
Mutate_
(
op
,
e
);
op
=
expr
.
as
<
Call
>
();
CHECK_EQ
(
op
->
args
.
size
(),
5U
);
Type
dtype
=
op
->
args
[
0
].
type
();
const
Variable
*
buffer
=
op
->
args
[
1
].
as
<
Variable
>
();
Expr
offset
=
op
->
args
[
2
];
auto
it
=
storage_info_
.
find
(
buffer
);
if
(
it
!=
storage_info_
.
end
()
&&
it
->
second
.
scope
.
tag
.
length
()
!=
0
)
{
return
MakeTaggedAccessPtr
(
op
->
type
,
dtype
,
offset
,
it
->
second
.
info
.
defined
()
?
it
->
second
.
info
->
unit_bits
:
8
);
}
CHECK
(
op
->
type
.
is_handle
());
// Change to address_of
return
AddressOffset
(
Var
(
op
->
args
[
1
].
node_
),
dtype
,
offset
);
}
Expr
MakeTaggedAccessPtr
(
Type
ptr_type
,
Type
dtype
,
Expr
offset
,
int
unit_bits
)
{
int
dtype_bits
=
dtype
.
bits
()
*
dtype
.
lanes
();
CHECK_EQ
(
unit_bits
%
dtype_bits
,
0
);
return
Convert
(
ptr_type
,
ir
::
Simplify
(
offset
/
make_const
(
offset
.
type
(),
unit_bits
/
dtype_bits
)));
}
private
:
private
:
bool
IsArrayHandle
(
const
Expr
&
arg
)
{
bool
IsArrayHandle
(
const
Expr
&
arg
)
{
...
@@ -337,17 +279,6 @@ class BuiltinLower : public IRMutator {
...
@@ -337,17 +279,6 @@ class BuiltinLower : public IRMutator {
uint64_t
max_shape_stack_
{
0
};
uint64_t
max_shape_stack_
{
0
};
uint64_t
max_array_stack_
{
0
};
uint64_t
max_array_stack_
{
0
};
uint64_t
max_arg_stack_
{
0
};
uint64_t
max_arg_stack_
{
0
};
// The storage entry.
struct
StorageEntry
{
// Whether it is tagged memory.
StorageScope
scope
;
// The memory info if any.
MemoryInfo
info
;
// Allocation counter
int
alloc_count
{
0
};
};
// The storage scope of each buffer
std
::
unordered_map
<
const
Variable
*
,
StorageEntry
>
storage_info_
;
};
};
LoweredFunc
LowerTVMBuiltin
(
LoweredFunc
f
)
{
LoweredFunc
LowerTVMBuiltin
(
LoweredFunc
f
)
{
...
...
src/pass/storage_access.cc
View file @
79e482bc
...
@@ -2,7 +2,12 @@
...
@@ -2,7 +2,12 @@
* Copyright (c) 2017 by Contributors
* Copyright (c) 2017 by Contributors
* \file storage_access.cc
* \file storage_access.cc
*/
*/
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include <tvm/target_info.h>
#include "./ir_util.h"
#include "./storage_access.h"
#include "./storage_access.h"
#include "../arithmetic/compute_expr.h"
namespace
tvm
{
namespace
tvm
{
namespace
ir
{
namespace
ir
{
...
@@ -191,5 +196,110 @@ StorageScope StorageAccessVisitor::GetScope(const Variable* buf) const {
...
@@ -191,5 +196,110 @@ StorageScope StorageAccessVisitor::GetScope(const Variable* buf) const {
if
(
it
==
storage_scope_
.
end
())
return
s
;
if
(
it
==
storage_scope_
.
end
())
return
s
;
return
it
->
second
;
return
it
->
second
;
}
}
class
StorageAccessInfoLower
:
public
IRMutator
{
public
:
Stmt
Mutate_
(
const
Allocate
*
op
,
const
Stmt
&
s
)
final
{
// Lower allocate to device allocate when needed.
Stmt
stmt
=
IRMutator
::
Mutate_
(
op
,
s
);
op
=
stmt
.
as
<
Allocate
>
();
// For special memory, remove allocate, or use head expr
auto
it
=
storage_info_
.
find
(
op
->
buffer_var
.
get
());
if
(
it
!=
storage_info_
.
end
()
&&
it
->
second
.
info
.
defined
())
{
const
MemoryInfo
&
info
=
it
->
second
.
info
;
++
it
->
second
.
alloc_count
;
CHECK_LE
(
it
->
second
.
alloc_count
,
1
)
<<
"Double allocation of "
<<
it
->
second
.
scope
.
to_string
();
if
(
info
->
head_address
.
defined
())
{
return
Allocate
::
make
(
op
->
buffer_var
,
op
->
type
,
op
->
extents
,
op
->
condition
,
op
->
body
,
info
->
head_address
,
"nop"
);
}
return
op
->
body
;
}
else
{
return
stmt
;
}
}
Stmt
Mutate_
(
const
AttrStmt
*
op
,
const
Stmt
&
s
)
final
{
if
(
op
->
attr_key
==
attr
::
storage_scope
)
{
const
Variable
*
buf
=
op
->
node
.
as
<
Variable
>
();
StorageScope
scope
=
StorageScope
::
make
(
op
->
value
.
as
<
StringImm
>
()
->
value
);
StorageEntry
e
;
e
.
scope
=
scope
;
if
(
scope
.
tag
.
length
()
!=
0
)
{
e
.
info
=
GetMemoryInfo
(
op
->
value
.
as
<
StringImm
>
()
->
value
);
CHECK
(
e
.
info
.
defined
())
<<
"Cannot find memory info of "
<<
scope
.
to_string
();
}
storage_info_
[
buf
]
=
e
;
return
IRMutator
::
Mutate_
(
op
,
s
);
}
else
{
return
IRMutator
::
Mutate_
(
op
,
s
);
}
}
Expr
Mutate_
(
const
Call
*
op
,
const
Expr
&
e
)
final
{
if
(
op
->
is_intrinsic
(
intrinsic
::
tvm_access_ptr
))
{
return
MakeAccessPtr
(
op
,
e
);
}
else
{
return
IRMutator
::
Mutate_
(
op
,
e
);
}
}
private
:
// tvm_access_ptr
Expr
MakeAccessPtr
(
const
Call
*
op
,
const
Expr
&
e
)
{
// Specially handle the buffer packed intrinsic
Expr
expr
=
IRMutator
::
Mutate_
(
op
,
e
);
op
=
expr
.
as
<
Call
>
();
CHECK_EQ
(
op
->
args
.
size
(),
5U
);
Type
dtype
=
op
->
args
[
0
].
type
();
const
Variable
*
buffer
=
op
->
args
[
1
].
as
<
Variable
>
();
Var
buffer_var
(
op
->
args
[
1
].
node_
);
Expr
offset
=
op
->
args
[
2
];
auto
it
=
storage_info_
.
find
(
buffer
);
if
(
it
!=
storage_info_
.
end
()
&&
it
->
second
.
info
.
defined
())
{
return
MakeTaggedAccessPtr
(
op
->
type
,
buffer_var
,
dtype
,
offset
,
it
->
second
.
info
);
}
CHECK
(
op
->
type
.
is_handle
());
// Change to address_of
return
AddressOffset
(
buffer_var
,
dtype
,
offset
);
}
Expr
MakeTaggedAccessPtr
(
Type
ptr_type
,
Var
buffer_var
,
Type
dtype
,
Expr
offset
,
const
MemoryInfo
&
info
)
{
if
(
ptr_type
.
is_handle
())
{
CHECK
(
info
->
head_address
.
defined
())
<<
buffer_var
<<
" is not adddressable."
;
return
AddressOffset
(
buffer_var
,
dtype
,
offset
);
}
int
dtype_bits
=
dtype
.
bits
()
*
dtype
.
lanes
();
CHECK_EQ
(
info
->
unit_bits
%
dtype_bits
,
0
);
return
cast
(
ptr_type
,
ir
::
Simplify
(
offset
/
make_const
(
offset
.
type
(),
info
->
unit_bits
/
dtype_bits
)));
}
// The storage entry.
struct
StorageEntry
{
// Whether it is tagged memory.
StorageScope
scope
;
// The memory info if any.
MemoryInfo
info
;
// Allocation counter
int
alloc_count
{
0
};
};
// The storage scope of each buffer
std
::
unordered_map
<
const
Variable
*
,
StorageEntry
>
storage_info_
;
};
Stmt
LowerStorageAccessInfo
(
Stmt
stmt
)
{
return
StorageAccessInfoLower
().
Mutate
(
stmt
);
}
}
// namespace ir
}
// namespace ir
}
// namespace tvm
}
// namespace tvm
src/pass/storage_flatten.cc
View file @
79e482bc
...
@@ -86,7 +86,6 @@ class StorageFlattener : public IRMutator {
...
@@ -86,7 +86,6 @@ class StorageFlattener : public IRMutator {
return
this
->
Mutate
(
op
->
body
);
return
this
->
Mutate
(
op
->
body
);
}
else
{
}
else
{
// create a buffer entry
// create a buffer entry
// TODO(tqchen) allow permutation and inference of index dimension.
BufferEntry
e
;
BufferEntry
e
;
e
.
bounds
=
op
->
bounds
;
e
.
bounds
=
op
->
bounds
;
Array
<
Expr
>
shape
;
Array
<
Expr
>
shape
;
...
...
src/pass/storage_sync.cc
View file @
79e482bc
...
@@ -153,7 +153,6 @@ class ThreadSyncInserter : public IRMutator {
...
@@ -153,7 +153,6 @@ class ThreadSyncInserter : public IRMutator {
Stmt
Mutate
(
Stmt
stmt
)
final
{
Stmt
Mutate
(
Stmt
stmt
)
final
{
if
(
syncs_
.
size
()
==
0
)
return
stmt
;
if
(
syncs_
.
size
()
==
0
)
return
stmt
;
stmt
=
IRMutator
::
Mutate
(
stmt
);
if
(
syncs_
.
count
(
stmt
.
get
()))
{
if
(
syncs_
.
count
(
stmt
.
get
()))
{
Stmt
barrier
;
Stmt
barrier
;
if
(
sync_scope_
.
rank
==
0
)
{
if
(
sync_scope_
.
rank
==
0
)
{
...
@@ -164,7 +163,11 @@ class ThreadSyncInserter : public IRMutator {
...
@@ -164,7 +163,11 @@ class ThreadSyncInserter : public IRMutator {
{
StringImm
::
make
(
sync_scope_
.
to_string
())},
{
StringImm
::
make
(
sync_scope_
.
to_string
())},
Call
::
Intrinsic
));
Call
::
Intrinsic
));
}
}
// Mutate after query, to avoid stmt change.
stmt
=
IRMutator
::
Mutate
(
stmt
);
stmt
=
Block
::
make
(
barrier
,
stmt
);
stmt
=
Block
::
make
(
barrier
,
stmt
);
}
else
{
stmt
=
IRMutator
::
Mutate
(
stmt
);
}
}
return
stmt
;
return
stmt
;
}
}
...
@@ -296,201 +299,5 @@ LoweredFunc ThreadSync(LoweredFunc f, std::string storage_scope) {
...
@@ -296,201 +299,5 @@ LoweredFunc ThreadSync(LoweredFunc f, std::string storage_scope) {
return
LoweredFunc
(
n
);
return
LoweredFunc
(
n
);
}
}
// Visitor to find touched set by co-processor scope.
class
CoProcTouchedBuffer
:
public
IRVisitor
{
public
:
void
Visit_
(
const
Load
*
op
)
final
{
if
(
in_scope_
)
{
touched_
.
insert
(
op
->
buffer_var
.
get
());
}
IRVisitor
::
Visit_
(
op
);
}
void
Visit_
(
const
Store
*
op
)
final
{
if
(
in_scope_
)
{
touched_
.
insert
(
op
->
buffer_var
.
get
());
}
IRVisitor
::
Visit_
(
op
);
}
void
Visit_
(
const
Call
*
op
)
final
{
if
(
op
->
is_intrinsic
(
intrinsic
::
tvm_access_ptr
)
&&
in_scope_
)
{
const
Variable
*
buffer
=
op
->
args
[
1
].
as
<
Variable
>
();
touched_
.
insert
(
buffer
);
}
IRVisitor
::
Visit_
(
op
);
}
void
Visit_
(
const
AttrStmt
*
op
)
final
{
if
(
op
->
attr_key
==
attr
::
coproc_scope
&&
!
in_scope_
)
{
in_scope_
=
true
;
IterVar
iv
(
op
->
node
.
node_
);
coproc_
.
insert
(
iv
);
IRVisitor
::
Visit_
(
op
);
in_scope_
=
false
;
}
else
{
IRVisitor
::
Visit_
(
op
);
}
}
std
::
unordered_set
<
const
Variable
*>
touched_
;
std
::
unordered_set
<
IterVar
>
coproc_
;
private
:
bool
in_scope_
{
false
};
};
// Synchronization planning with co-processor.
class
CoProcSyncPlanner
:
public
StorageAccessVisitor
{
public
:
void
Plan
(
const
Stmt
&
stmt
)
{
CoProcTouchedBuffer
visitor
;
visitor
.
Visit
(
stmt
);
touched_
=
std
::
move
(
visitor
.
touched_
);
if
(
!
touched_
.
empty
())
{
this
->
Visit
(
stmt
);
PlanWriteSync
(
scope_
.
back
(),
nullptr
,
true
);
CHECK_EQ
(
visitor
.
coproc_
.
size
(),
1U
);
if
(
write_sync_
.
size
()
==
0
)
{
write_sync_
[
stmt
.
get
()]
=
GetWriteSync
(
(
*
visitor
.
coproc_
.
begin
())
->
var
->
name_hint
+
".coproc_sync"
);
}
}
}
// Write synchronization to be inserted before or after stmt.
std
::
unordered_map
<
const
Node
*
,
std
::
vector
<
Stmt
>
>
write_sync_
;
protected
:
bool
Enabled
(
const
Variable
*
buf
,
const
StorageScope
&
scope
)
const
final
{
return
touched_
.
count
(
buf
)
&&
scope
==
global_scope_
;
}
// Plan the sync
std
::
vector
<
AccessEntry
>
Summarize
(
std
::
vector
<
StmtEntry
>
seq
,
const
For
*
loop
)
final
{
return
PlanWriteSync
(
seq
,
loop
,
false
);
}
private
:
// Plan write synchronization if write is not coherent
std
::
vector
<
AccessEntry
>
PlanWriteSync
(
std
::
vector
<
StmtEntry
>
seq
,
const
For
*
loop
,
bool
force_sync_at_end
)
{
// detect write barriers
// access by the co-processor.
std
::
vector
<
AccessEntry
>
co_access
;
bool
contain_sync
=
false
;
auto
find_conflict
=
[
&
](
const
AccessEntry
&
acc
)
{
for
(
const
AccessEntry
&
x
:
co_access
)
{
if
(
x
.
buffer
.
same_as
(
acc
.
buffer
)
&&
((
acc
.
type
==
kRead
&&
x
.
type
==
kWrite
)
||
acc
.
type
==
kWrite
))
{
return
true
;
}
}
return
false
;
};
for
(
size_t
i
=
0
;
i
<
seq
.
size
();
++
i
)
{
const
StmtEntry
&
s
=
seq
[
i
];
bool
sync_write
=
false
;
for
(
const
AccessEntry
&
acc
:
s
.
access
)
{
if
(
acc
.
threads
.
size
()
==
0
&&
find_conflict
(
acc
))
{
sync_write
=
true
;
break
;
}
if
(
acc
.
type
==
kSync
)
{
co_access
.
clear
();
contain_sync
=
true
;
}
}
if
(
sync_write
)
{
CHECK_NE
(
i
,
0U
);
write_sync_
[
seq
[
i
-
1
].
stmt
]
=
GetWriteSync
(
co_access
);
co_access
.
clear
();
contain_sync
=
true
;
}
for
(
const
AccessEntry
&
acc
:
s
.
access
)
{
if
(
acc
.
threads
.
size
()
!=
0
)
{
co_access
.
push_back
(
acc
);
}
}
}
bool
sync_at_end
=
force_sync_at_end
;
if
(
loop
!=
nullptr
&&
!
sync_at_end
)
{
// loop carray dependency
for
(
size_t
i
=
0
;
i
<
seq
.
size
();
++
i
)
{
const
StmtEntry
&
s
=
seq
[
i
];
for
(
const
AccessEntry
&
acc
:
s
.
access
)
{
if
(
acc
.
threads
.
size
()
==
0
&&
find_conflict
(
acc
))
{
sync_at_end
=
true
;
break
;
}
}
if
(
write_sync_
.
count
(
s
.
stmt
)
||
sync_at_end
)
break
;
}
}
if
(
sync_at_end
&&
co_access
.
size
()
!=
0
)
{
CHECK_NE
(
seq
.
size
(),
0
);
contain_sync
=
true
;
write_sync_
[
seq
.
back
().
stmt
]
=
GetWriteSync
(
co_access
);
co_access
.
clear
();
}
if
(
contain_sync
)
{
AccessEntry
e
;
e
.
type
=
kSync
;
e
.
scope
=
global_scope_
;
co_access
.
insert
(
co_access
.
begin
(),
e
);
}
return
co_access
;
}
// Add write Synchronization
std
::
vector
<
Stmt
>
GetWriteSync
(
const
std
::
vector
<
AccessEntry
>&
co_access
)
{
// Does not consider memory coherence, need runtime.
CHECK_NE
(
co_access
.
size
(),
0U
);
CHECK_EQ
(
co_access
[
0
].
threads
.
size
(),
1U
);
return
GetWriteSync
(
co_access
[
0
].
threads
[
0
]
->
var
->
name_hint
+
".coproc_sync"
);
}
std
::
vector
<
Stmt
>
GetWriteSync
(
std
::
string
sync_name
)
{
std
::
vector
<
Stmt
>
stmts
;
stmts
.
emplace_back
(
Evaluate
::
make
(
Call
::
make
(
Int
(
32
),
sync_name
,
{},
Call
::
Intrinsic
)));
return
stmts
;
}
std
::
unordered_set
<
const
Variable
*>
touched_
;
StorageScope
global_scope_
=
StorageScope
::
make
(
"global"
);
};
class
CoProcSyncInserter
:
public
IRMutator
{
public
:
explicit
CoProcSyncInserter
(
const
std
::
unordered_map
<
const
Node
*
,
std
::
vector
<
Stmt
>
>&
write_sync
)
:
write_sync_
(
write_sync
)
{}
Stmt
Mutate
(
Stmt
stmt
)
final
{
stmt
=
IRMutator
::
Mutate
(
stmt
);
auto
it
=
write_sync_
.
find
(
stmt
.
get
());
if
(
it
!=
write_sync_
.
end
())
{
stmt
=
Block
::
make
(
stmt
,
MergeSeq
(
it
->
second
));
}
return
stmt
;
}
private
:
const
std
::
unordered_map
<
const
Node
*
,
std
::
vector
<
Stmt
>
>&
write_sync_
;
};
Stmt
CoProcSync
(
Stmt
stmt
)
{
CoProcSyncPlanner
planner
;
planner
.
Plan
(
stmt
);
if
(
planner
.
write_sync_
.
size
()
!=
0
)
{
return
CoProcSyncInserter
(
planner
.
write_sync_
).
Mutate
(
stmt
);
}
else
{
return
stmt
;
}
}
}
// namespace ir
}
// namespace ir
}
// namespace tvm
}
// namespace tvm
tests/python/unittest/test_pass_storage_sync.py
View file @
79e482bc
...
@@ -32,16 +32,31 @@ def test_coproc_sync():
...
@@ -32,16 +32,31 @@ def test_coproc_sync():
ib
=
tvm
.
ir_builder
.
create
()
ib
=
tvm
.
ir_builder
.
create
()
n
=
tvm
.
var
(
"n"
)
n
=
tvm
.
var
(
"n"
)
cp
=
tvm
.
thread_axis
((
0
,
1
),
"cop"
)
cp
=
tvm
.
thread_axis
((
0
,
1
),
"cop"
)
A
=
ib
.
allocate
(
"float32"
,
n
,
name
=
"A"
,
scope
=
"global"
)
@tvm.register_func
(
"tvm.info.mem.global.cache"
)
def
meminfo_cache
():
return
tvm
.
make
.
node
(
"MemoryInfo"
,
unit_bits
=
8
,
max_simd_bits
=
32
,
max_num_bits
=
128
,
head_address
=
tvm
.
call_extern
(
"handle"
,
"global_cache"
))
A
=
ib
.
allocate
(
"float32"
,
128
,
name
=
"A"
,
scope
=
"global.cache"
)
with
ib
.
for_range
(
0
,
n
,
name
=
"i"
)
as
i
:
with
ib
.
for_range
(
0
,
n
,
name
=
"i"
)
as
i
:
A
[
i
]
=
A
[
i
]
+
1
A
[
i
]
=
A
[
i
]
+
1
with
ib
.
for_range
(
0
,
10
,
name
=
"j"
)
as
j
:
with
ib
.
for_range
(
0
,
8
,
name
=
"k"
)
as
k
:
ib
.
scope_attr
(
cp
,
"coproc_scope"
,
1
)
with
ib
.
for_range
(
0
,
10
,
name
=
"j"
)
as
j
:
A
[
j
]
=
A
[
j
]
+
2
ib
.
scope_attr
(
cp
,
"coproc_scope"
,
1
)
body
=
ib
.
get
()
A
[
j
]
=
A
[
j
+
k
*
10
]
+
2
body
=
tvm
.
ir_pass
.
CoProcSync
(
body
)
stmt
=
ib
.
get
()
body
=
body
.
body
.
body
.
body
stmt
=
tvm
.
ir_pass
.
CoProcSync
(
stmt
)
assert
(
tvm
.
make
.
stmt_list
(
body
)[
-
1
]
.
value
.
name
==
"cop.coproc_sync"
)
body
=
stmt
.
body
.
body
.
body
blist
=
tvm
.
make
.
stmt_list
(
body
)
assert
(
blist
[
1
]
.
value
.
name
==
"cop.coproc_read_barrier"
)
assert
(
blist
[
1
]
.
value
.
args
[
3
]
.
value
==
80
)
assert
(
blist
[
-
2
]
.
value
.
name
==
"cop.coproc_sync"
)
assert
(
blist
[
-
1
]
.
value
.
name
==
"cop.coproc_write_barrier"
)
assert
(
blist
[
-
1
]
.
value
.
args
[
3
]
.
value
==
10
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment