Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
34e31c44
Commit
34e31c44
authored
Mar 14, 2018
by
Ding
Committed by
Tianqi Chen
Mar 13, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[PASS] Add VerifyMemory pass and test cases (#410) (#993)
parent
75b93d30
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
289 additions
and
2 deletions
+289
-2
include/tvm/ir_pass.h
+14
-0
python/tvm/build_module.py
+5
-1
src/api/api_pass.cc
+1
-0
src/codegen/build_module.cc
+5
-1
src/pass/verify_memory.cc
+168
-0
tests/python/unittest/test_pass_verify_memory.py
+96
-0
No files found.
include/tvm/ir_pass.h
View file @
34e31c44
...
...
@@ -440,6 +440,20 @@ LoweredFunc PointerValueTypeRewrite(LoweredFunc f);
* \return Transformed function.
*/
LoweredFunc
LowerIntrin
(
LoweredFunc
f
,
const
std
::
string
&
target
);
/*!
* \brief Verify if memory accesses are legal for a specific target device type.
*
* In the case that tgt is cuda, if not all workload is bound with
* threads, CPU code is generated that tries to access GPU memory,
* which is illegal. This pass performs verification for this case.
*
* \param func The function to be verified.
* \param device_type The target device type.
* \return Success of memory verification.
*/
bool
VerifyMemory
(
LoweredFunc
func
,
int
device_type
);
}
// namespace ir
}
// namespace tvm
...
...
python/tvm/build_module.py
View file @
34e31c44
...
...
@@ -424,10 +424,15 @@ def build(sch,
target
=
_target
.
current_target
()
if
target
is
None
else
target
target
=
_target
.
create
(
target
)
if
target
else
_target
.
create
(
"llvm"
)
device_type
=
ndarray
.
context
(
target
.
target_name
,
0
)
.
device_type
fhost
=
[]
fdevice
=
[]
for
func
in
flist
:
if
not
ir_pass
.
VerifyMemory
(
func
,
device_type
):
raise
ValueError
(
"Direct host side access to device memory is detected in
%
s. "
"Did you forget to bind?"
%
func
.
name
)
if
func
.
func_type
==
container
.
LoweredFunc
.
MixedFunc
:
if
BuildConfig
.
current
.
detect_global_barrier
:
func
=
ir_pass
.
ThreadSync
(
func
,
"global"
)
...
...
@@ -449,7 +454,6 @@ def build(sch,
warnings
.
warn
(
"Specified target
%
s, but cannot find device code, did you do bind?"
%
target
)
device_type
=
ndarray
.
context
(
target
.
target_name
,
0
)
.
device_type
fhost
=
[
ir_pass
.
BindDeviceType
(
x
,
device_type
)
for
x
in
fhost
]
fhost
=
[
ir_pass
.
LowerTVMBuiltin
(
x
)
for
x
in
fhost
]
...
...
src/api/api_pass.cc
View file @
34e31c44
...
...
@@ -128,5 +128,6 @@ REGISTER_PASS2(LowerThreadAllreduce);
REGISTER_PASS2
(
LowerIntrin
);
REGISTER_PASS1
(
LowerTVMBuiltin
);
REGISTER_PASS1
(
CombineContextCall
);
REGISTER_PASS2
(
VerifyMemory
);
}
// namespace ir
}
// namespace tvm
src/codegen/build_module.cc
View file @
34e31c44
...
...
@@ -269,7 +269,11 @@ runtime::Module build(const Array<LoweredFunc>& funcs,
Array
<
LoweredFunc
>
fhost
;
Array
<
LoweredFunc
>
fdevice
;
for
(
const
auto
&
x
:
funcs
)
{
for
(
const
auto
&
x
:
funcs
)
{
CHECK
(
ir
::
VerifyMemory
(
x
,
target
.
device_type
))
<<
"Direct host side access to device memory is detected in "
<<
x
->
func_name
()
<<
". Did you forget to bind?"
;
if
(
x
->
func_type
==
kMixedFunc
)
{
auto
func
=
x
;
if
(
config
->
detect_global_barrier
)
{
...
...
src/pass/verify_memory.cc
0 → 100644
View file @
34e31c44
/*!
* Copyright (c) 2018 by Contributors
* \file verify_memory.cc
* \brief Pass to check if memory accesses are legal.
*/
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_pass.h>
namespace
tvm
{
namespace
ir
{
namespace
{
/*!
* \brief Verify if memory accesses are legal.
*
* In the case that tgt is cuda, if workload is not bound with
* threads, CPU code is generated that tries to access GPU memory,
* which is illegal.
*
* This pass performs such verification by checking if all Producer/Consumer
* with memory accesses are bound with threads when device type is GPU.
*/
class
MemoryAccessVerifier
final
:
protected
IRVisitor
{
public
:
/// Special member functions
//@{
explicit
MemoryAccessVerifier
(
LoweredFunc
f
,
int
device_type
)
:
func_
(
f
),
dev_type_
(
device_type
)
{}
virtual
~
MemoryAccessVerifier
()
=
default
;
MemoryAccessVerifier
(
const
MemoryAccessVerifier
&
)
=
delete
;
MemoryAccessVerifier
(
MemoryAccessVerifier
&&
)
=
delete
;
MemoryAccessVerifier
&
operator
=
(
const
MemoryAccessVerifier
&
)
=
delete
;
MemoryAccessVerifier
&
operator
=
(
MemoryAccessVerifier
&&
)
=
delete
;
//@}
/// Interface to perform memory access verification
void
Run
()
{
if
(
!
IsGPUDevice
(
dev_type_
))
return
;
IRVisitor
::
Visit
(
func_
->
body
);
}
/// Verification result
bool
Failed
()
const
{
return
failure_
;
}
protected
:
/// Visitor implementation
//@{
void
Visit
(
const
NodeRef
&
n
)
final
{
if
(
Failed
())
return
;
IRVisitor
::
Visit
(
n
);
}
void
Visit_
(
const
LetStmt
*
op
)
final
{
// Book keep definitions
defs_
[
op
->
var
.
get
()]
=
op
->
value
;
return
IRVisitor
::
Visit_
(
op
);
}
void
Visit_
(
const
AttrStmt
*
op
)
final
{
if
(
!
InThreadEnv
()
&&
(
op
->
attr_key
==
attr
::
thread_extent
||
op
->
attr_key
==
attr
::
pipeline_exec_scope
))
{
EnterThreadEnv
();
IRVisitor
::
Visit_
(
op
);
ExitThreadEnv
();
}
else
{
IRVisitor
::
Visit_
(
op
);
}
}
void
Visit_
(
const
ProducerConsumer
*
op
)
final
{
EnterProducerConsumer
(
op
);
IRVisitor
::
Visit_
(
op
);
ExitProducerConsumer
();
}
void
Visit_
(
const
Load
*
op
)
final
{
HandleLoadStoreToVariable
(
op
->
buffer_var
);
return
IRVisitor
::
Visit_
(
op
);
}
void
Visit_
(
const
Store
*
op
)
final
{
HandleLoadStoreToVariable
(
op
->
buffer_var
);
return
IRVisitor
::
Visit_
(
op
);
}
//@}
/// Check if the value of a Variable comes from function argument.
bool
IsFromFunctionArgs
(
const
Variable
*
var
)
const
{
const
Variable
*
V
=
var
;
while
(
true
)
{
CHECK
(
V
)
<<
"Invalid Variable
\n
"
;
// Variable is from function args. Return true.
if
(
V
==
func_
->
args
[
0
].
node_
.
get
())
return
true
;
// The value is expected to come from a tvm_struct_get Call.
// Get the first argument of tvm_struct_get, and continue.
const
auto
&
iter
=
defs_
.
find
(
V
);
if
(
iter
==
defs_
.
end
())
return
false
;
const
Call
*
C
=
iter
->
second
.
as
<
const
Call
>
();
if
(
!
C
||
C
->
name
!=
intrinsic
::
tvm_struct_get
)
return
false
;
V
=
C
->
args
[
0
].
as
<
Variable
>
();
}
return
false
;
}
/// Handle memory access to a Variable
void
HandleLoadStoreToVariable
(
const
VarExpr
&
var
)
{
// We skip the access within thread env.
if
(
InThreadEnv
())
return
;
// We only check access within a producer/consumer.
// Because for load/store out side of producer/consumer,
// they don't have to be in thread env to stay legal (e.g. Load of args).
if
(
!
InProducerConsumer
())
return
;
// We only handle the variable from function argument.
// If it does not come from args, then it could be allocated internally,
// it may possibly be in host or device address space.
// We do not handle this case, and skip it conservatively.
if
(
!
IsFromFunctionArgs
(
var
.
get
()))
return
;
// The verification fails in this case.
SetFailure
();
}
/// Status getter/setter
//@{
bool
InThreadEnv
()
const
{
return
in_thread_env_
;
}
void
EnterThreadEnv
()
{
in_thread_env_
=
true
;
}
void
ExitThreadEnv
()
{
in_thread_env_
=
false
;
}
bool
InProducerConsumer
()
const
{
return
pc_
!=
nullptr
;
}
const
ProducerConsumer
*
GetCurrentProducerConsumer
()
const
{
return
pc_
;
}
void
EnterProducerConsumer
(
const
ProducerConsumer
*
pc
)
{
this
->
pc_
=
pc
;
}
void
ExitProducerConsumer
()
{
pc_
=
nullptr
;
}
void
SetFailure
()
{
failure_
=
true
;
}
//@}
/// Check if a given DLDeviceType/TVMDeviceExtType value denotes GPU device.
static
bool
IsGPUDevice
(
int
dev_type
)
{
return
kDLGPU
==
dev_type
||
kDLOpenCL
==
dev_type
||
kDLVulkan
==
dev_type
||
kDLMetal
==
dev_type
||
kDLROCM
==
dev_type
||
kOpenGL
==
dev_type
;
}
private
:
/// Status of visitor
//@{
bool
in_thread_env_
{
false
};
const
ProducerConsumer
*
pc_
{
nullptr
};
bool
failure_
{
false
};
///< If the verification fails (i.e. has illegal access)
//@}
LoweredFunc
func_
{
nullptr
};
///< Function to be verified.
int
dev_type_
{
kDLCPU
};
///< Device type
std
::
unordered_map
<
const
Variable
*
,
Expr
>
defs_
;
///< Variable definitions
};
}
// namespace
/// Interface of VerifyMemory pass
bool
VerifyMemory
(
LoweredFunc
func
,
int
device_type
)
{
MemoryAccessVerifier
v
(
func
,
device_type
);
v
.
Run
();
return
!
v
.
Failed
();
}
}
// namespace ir
}
// namespace tvm
tests/python/unittest/test_pass_verify_memory.py
0 → 100644
View file @
34e31c44
import
tvm
# The following DLDeviceType/TVMDeviceExtType values
# are originally defined in dlpack.h and c_runtime_api.h.
gpu_devices
=
[
2
,
4
,
7
,
8
,
10
,
11
]
other_devices
=
[
1
,
3
,
9
,
12
]
def
lower
(
sch
,
args
):
binds
=
{}
arg_list
=
[]
for
x
in
args
:
if
isinstance
(
x
,
tvm
.
tensor
.
Tensor
):
buf
=
tvm
.
decl_buffer
(
x
.
shape
,
dtype
=
x
.
dtype
,
name
=
x
.
name
)
assert
x
not
in
binds
binds
[
x
]
=
buf
arg_list
.
append
(
buf
)
else
:
raise
ValueError
(
"args must be Tensor, Buffer or Var"
)
sch
=
sch
.
normalize
()
bounds
=
tvm
.
schedule
.
InferBound
(
sch
)
stmt
=
tvm
.
schedule
.
ScheduleOps
(
sch
,
bounds
)
stmt
=
tvm
.
ir_pass
.
LoopPartition
(
stmt
,
False
)
stmt
=
tvm
.
ir_pass
.
StorageFlatten
(
stmt
,
binds
,
64
)
func
=
tvm
.
ir_pass
.
MakeAPI
(
stmt
,
"myadd"
,
arg_list
,
0
,
True
)
return
func
# All computations are bound.
# So VerifyMemory pass is expected to succeed.
#
def
test_verify_memory_all_bind
():
n
=
tvm
.
var
(
"n"
)
A
=
tvm
.
placeholder
((
n
,),
name
=
'A'
)
B
=
tvm
.
compute
(
A
.
shape
,
lambda
i
:
A
[
i
]
+
1.0
,
name
=
"B"
)
# B is bound to threads.
s
=
tvm
.
create_schedule
(
B
.
op
)
bx
,
tx
=
s
[
B
]
.
split
(
B
.
op
.
axis
[
0
],
factor
=
64
)
s
[
B
]
.
bind
(
bx
,
tvm
.
thread_axis
(
"blockIdx.x"
))
s
[
B
]
.
bind
(
tx
,
tvm
.
thread_axis
(
"threadIdx.x"
))
func
=
lower
(
s
,
[
A
,
B
])
for
dev_type
in
gpu_devices
+
other_devices
:
assert
tvm
.
ir_pass
.
VerifyMemory
(
func
,
dev_type
)
# Computations are not bound.
# So VerifyMemory pass fails when device type is GPU.
#
def
test_verify_memory_not_bind
():
n
=
tvm
.
var
(
"n"
)
A
=
tvm
.
placeholder
((
n
,),
name
=
'A'
)
B
=
tvm
.
compute
(
A
.
shape
,
lambda
i
:
A
[
i
]
+
1.0
,
name
=
"B"
)
# B is not bound to threads.
s
=
tvm
.
create_schedule
(
B
.
op
)
func
=
lower
(
s
,
[
A
,
B
])
for
dev_type
in
gpu_devices
:
assert
not
tvm
.
ir_pass
.
VerifyMemory
(
func
,
dev_type
)
for
dev_type
in
other_devices
:
assert
tvm
.
ir_pass
.
VerifyMemory
(
func
,
dev_type
)
# Computations are partially bound.
# So VerifyMemory pass fails when device type is GPU.
#
def
test_verify_memory_partially_bind
():
n
=
tvm
.
var
(
"n"
)
A
=
tvm
.
placeholder
((
n
,),
name
=
'A'
)
B
=
tvm
.
compute
(
A
.
shape
,
lambda
i
:
A
[
i
]
+
1.0
,
name
=
"B"
)
C
=
tvm
.
compute
(
B
.
shape
,
lambda
i
:
B
[
i
]
+
2.0
,
name
=
"C"
)
D
=
tvm
.
compute
(
C
.
shape
,
lambda
i
:
C
[
i
]
+
2.0
,
name
=
"D"
)
# C is bound to threads, but B and D are not.
s
=
tvm
.
create_schedule
([
B
.
op
,
C
.
op
,
D
.
op
])
bx
,
tx
=
s
[
C
]
.
split
(
C
.
op
.
axis
[
0
],
factor
=
64
)
s
[
C
]
.
bind
(
bx
,
tvm
.
thread_axis
(
"blockIdx.x"
))
s
[
C
]
.
bind
(
tx
,
tvm
.
thread_axis
(
"threadIdx.x"
))
func
=
lower
(
s
,
[
A
,
B
,
C
,
D
])
for
dev_type
in
gpu_devices
:
assert
not
tvm
.
ir_pass
.
VerifyMemory
(
func
,
dev_type
)
for
dev_type
in
other_devices
:
assert
tvm
.
ir_pass
.
VerifyMemory
(
func
,
dev_type
)
if
__name__
==
"__main__"
:
test_verify_memory_all_bind
()
test_verify_memory_not_bind
()
test_verify_memory_partially_bind
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment