Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
531bb7c4
Commit
531bb7c4
authored
Jun 24, 2018
by
Lianmin Zheng
Committed by
Tianqi Chen
Jun 23, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[PASS] Add GPU IR verifier (#1296)
parent
f216b25e
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
410 additions
and
3 deletions
+410
-3
include/tvm/ir_pass.h
+24
-0
include/tvm/runtime/device_api.h
+2
-1
python/tvm/_ffi/runtime_ctypes.py
+13
-0
src/api/api_pass.cc
+1
-0
src/pass/verify_gpu_code.cc
+166
-0
src/runtime/cuda/cuda_device_api.cc
+17
-1
src/runtime/metal/metal_device_api.mm
+1
-0
src/runtime/opencl/opencl_device_api.cc
+14
-1
src/runtime/opengl/opengl_device_api.cc
+1
-0
src/runtime/rocm/rocm_device_api.cc
+1
-0
src/runtime/vulkan/vulkan_device_api.cc
+1
-0
tests/python/unittest/test_pass_verify_gpu_code.py
+169
-0
No files found.
include/tvm/ir_pass.h
View file @
531bb7c4
...
@@ -477,6 +477,30 @@ LoweredFunc LowerIntrin(LoweredFunc f, const std::string& target);
...
@@ -477,6 +477,30 @@ LoweredFunc LowerIntrin(LoweredFunc f, const std::string& target);
*/
*/
bool
VerifyMemory
(
LoweredFunc
func
,
int
device_type
);
bool
VerifyMemory
(
LoweredFunc
func
,
int
device_type
);
/*!
* \brief Verify the correctness of a GPU code
* It will check the whether the amount of memory usage or the number of threads
* in a block exceeds the limit
* \param stmt The statement to be checked
* \param constraints The dict to specify constraints to check.
* Possible keys are
*
* "max_local_memory_per_block": Total amount of local memory per block (in bytes).
* "max_shared_memory_per_block": Total amount of shared memory per block (in bytes).
* "max_thread_per_block": Maximum number of threads per block.
* "max_thread_x": Maximum length of threadIdx.x.
* "max_thread_y": Maximum length of threadIdx.y.
* "max_thread_z": Maximum length of threadIdx.z.
*
* If one key is missing in this argument, the pass won't check for that item.
* \return valid Whether it is a valid GPU code
*
*/
bool
VerifyGPUCode
(
Stmt
stmt
,
Map
<
std
::
string
,
Expr
>
constraints
);
}
// namespace ir
}
// namespace ir
}
// namespace tvm
}
// namespace tvm
...
...
include/tvm/runtime/device_api.h
View file @
531bb7c4
...
@@ -23,7 +23,8 @@ enum DeviceAttrKind : int {
...
@@ -23,7 +23,8 @@ enum DeviceAttrKind : int {
kComputeVersion
=
4
,
kComputeVersion
=
4
,
kDeviceName
=
5
,
kDeviceName
=
5
,
kMaxClockRate
=
6
,
kMaxClockRate
=
6
,
kMultiProcessorCount
=
7
kMultiProcessorCount
=
7
,
kMaxThreadDimensions
=
8
};
};
/*! \brief Number of bytes each allocation must align to */
/*! \brief Number of bytes each allocation must align to */
...
...
python/tvm/_ffi/runtime_ctypes.py
View file @
531bb7c4
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
from
__future__
import
absolute_import
from
__future__
import
absolute_import
import
ctypes
import
ctypes
import
json
import
numpy
as
np
import
numpy
as
np
from
.base
import
_LIB
,
check_call
from
.base
import
_LIB
,
check_call
from
..
import
_api_internal
from
..
import
_api_internal
...
@@ -178,6 +179,18 @@ class TVMContext(ctypes.Structure):
...
@@ -178,6 +179,18 @@ class TVMContext(ctypes.Structure):
return
_api_internal
.
_GetDeviceAttr
(
return
_api_internal
.
_GetDeviceAttr
(
self
.
device_type
,
self
.
device_id
,
7
)
self
.
device_type
,
self
.
device_id
,
7
)
@property
def
max_thread_dimensions
(
self
):
"""Return the maximum size of each thread axis
Returns
-------
dims: List of int
The maximum length of threadIdx.x, threadIdx.y, threadIdx.z
"""
return
json
.
loads
(
_api_internal
.
_GetDeviceAttr
(
self
.
device_type
,
self
.
device_id
,
8
))
def
sync
(
self
):
def
sync
(
self
):
"""Synchronize until jobs finished at the context."""
"""Synchronize until jobs finished at the context."""
check_call
(
_LIB
.
TVMSynchronize
(
self
.
device_type
,
self
.
device_id
,
None
))
check_call
(
_LIB
.
TVMSynchronize
(
self
.
device_type
,
self
.
device_id
,
None
))
...
...
src/api/api_pass.cc
View file @
531bb7c4
...
@@ -131,5 +131,6 @@ REGISTER_PASS2(LowerIntrin);
...
@@ -131,5 +131,6 @@ REGISTER_PASS2(LowerIntrin);
REGISTER_PASS1
(
LowerTVMBuiltin
);
REGISTER_PASS1
(
LowerTVMBuiltin
);
REGISTER_PASS1
(
CombineContextCall
);
REGISTER_PASS1
(
CombineContextCall
);
REGISTER_PASS2
(
VerifyMemory
);
REGISTER_PASS2
(
VerifyMemory
);
REGISTER_PASS2
(
VerifyGPUCode
);
}
// namespace ir
}
// namespace ir
}
// namespace tvm
}
// namespace tvm
src/pass/verify_gpu_code.cc
0 → 100644
View file @
531bb7c4
/*!
* Copyright (c) 2018 by Contributors
* \file verify_gpu_code.cc
* \brief Verify the correctness of a GPU IR.
* It will check the whether the amount of memory usage or the number of threads
* in a block exceeds the limit
*/
#include <tvm/api_registry.h>
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
namespace
tvm
{
namespace
ir
{
class
GPUCodeVerifier
:
public
IRVisitor
{
public
:
bool
Verify
(
tvm
::
Stmt
stmt
,
int64_t
max_local_memory_per_block
,
int64_t
max_shared_memory_per_block
,
int64_t
max_thread_per_block
,
int64_t
max_thread_x
,
int64_t
max_thread_y
,
int64_t
max_thread_z
)
{
max_local_memory_per_block_
=
static_cast
<
size_t
>
(
max_local_memory_per_block
);
max_shared_memory_per_block_
=
static_cast
<
size_t
>
(
max_shared_memory_per_block
);
max_thread_per_block_
=
static_cast
<
size_t
>
(
max_thread_per_block
);
max_thread_x_
=
static_cast
<
size_t
>
(
max_thread_x
);
max_thread_y_
=
static_cast
<
size_t
>
(
max_thread_y
);
max_thread_z_
=
static_cast
<
size_t
>
(
max_thread_z
);
Reset_
();
this
->
Visit
(
stmt
);
return
valid_
;
}
void
Visit_
(
const
ProducerConsumer
*
op
)
{
if
(
nest_level_
==
0
)
{
// enter a new kernel, reset statistics
Reset_
();
}
if
(
op
->
is_producer
)
{
nest_level_
++
;
IRVisitor
::
Visit_
(
op
);
nest_level_
--
;
}
else
{
IRVisitor
::
Visit_
(
op
);
}
if
(
nest_level_
==
0
)
{
// exit a kernel, check the validity
valid_
&=
thread_per_block_
<=
max_thread_per_block_
;
valid_
&=
local_memory_per_block_
<=
max_local_memory_per_block_
;
valid_
&=
shared_memory_per_block_
<=
max_shared_memory_per_block_
;
}
}
void
Visit_
(
const
Allocate
*
op
)
{
IRVisitor
::
Visit_
(
op
);
// visit an allocation of a buffer in shared memory, record its size
if
(
visited_local_buffers_
.
count
(
op
->
buffer_var
.
get
())
!=
0
)
{
size_t
size
=
static_cast
<
size_t
>
(
op
->
constant_allocation_size
());
local_memory_per_block_
+=
size
*
op
->
type
.
bytes
();
}
else
if
(
visited_shared_buffers_
.
count
(
op
->
buffer_var
.
get
())
!=
0
)
{
size_t
size
=
static_cast
<
size_t
>
(
op
->
constant_allocation_size
());
shared_memory_per_block_
+=
size
*
op
->
type
.
bytes
();
}
}
void
Visit_
(
const
AttrStmt
*
op
)
{
if
(
op
->
attr_key
==
attr
::
storage_scope
)
{
if
(
op
->
value
.
as
<
StringImm
>
()
->
value
==
"local"
)
{
visited_local_buffers_
.
insert
(
op
->
node
.
as
<
tvm
::
Variable
>
());
}
else
if
(
op
->
value
.
as
<
StringImm
>
()
->
value
==
"shared"
)
{
visited_shared_buffers_
.
insert
(
op
->
node
.
as
<
tvm
::
Variable
>
());
}
}
else
if
(
op
->
attr_key
==
attr
::
thread_extent
)
{
VarExpr
var
=
op
->
node
.
as
<
tvm
::
IterVarNode
>
()
->
var
;
const
auto
*
extent
=
op
->
value
.
as
<
IntImm
>
();
CHECK
(
extent
);
// record the number of threads in a block
std
::
string
name
=
var
.
get
()
->
name_hint
;
if
(
name
==
"threadIdx.x"
||
name
==
"threadIdx.y"
||
name
==
"threadIdx.z"
)
{
if
(
!
visited_threads_
.
count
(
name
))
{
visited_threads_
.
insert
(
name
);
size_t
length
=
static_cast
<
size_t
>
(
extent
->
value
);
thread_per_block_
*=
length
;
if
(
name
==
"threadIdx.x"
)
{
valid_
&=
length
<=
max_thread_x_
;
}
else
if
(
name
==
"threadIdx.y"
)
{
valid_
&=
length
<=
max_thread_y_
;
}
else
if
(
name
==
"threadIdx.z"
)
{
valid_
&=
length
<=
max_thread_z_
;
}
}
}
}
IRVisitor
::
Visit_
(
op
);
}
private
:
int
nest_level_
{
0
};
std
::
unordered_set
<
const
tvm
::
Variable
*>
visited_local_buffers_
;
std
::
unordered_set
<
const
tvm
::
Variable
*>
visited_shared_buffers_
;
std
::
unordered_set
<
std
::
string
>
visited_threads_
;
size_t
local_memory_per_block_
;
size_t
shared_memory_per_block_
;
size_t
thread_per_block_
;
size_t
max_local_memory_per_block_
;
size_t
max_shared_memory_per_block_
;
size_t
max_thread_per_block_
;
size_t
max_thread_x_
,
max_thread_y_
,
max_thread_z_
;
bool
valid_
{
true
};
void
Reset_
()
{
visited_local_buffers_
.
clear
();
visited_shared_buffers_
.
clear
();
local_memory_per_block_
=
0
;
shared_memory_per_block_
=
0
;
visited_threads_
.
clear
();
thread_per_block_
=
1
;
}
};
bool
VerifyGPUCode
(
Stmt
stmt
,
Map
<
std
::
string
,
Expr
>
constraints
)
{
GPUCodeVerifier
verifier
;
auto
get_int
=
[
&
constraints
](
std
::
string
key
,
int64_t
def
)
{
auto
iter
=
constraints
.
find
(
key
);
if
(
iter
!=
constraints
.
end
())
{
return
((
*
iter
).
second
).
as
<
IntImm
>
()
->
value
;
}
else
{
return
def
;
}
};
int64_t
max_local_memory_per_block
=
get_int
(
"max_local_memory_per_block"
,
INT64_MAX
);
int64_t
max_shared_memory_per_block
=
get_int
(
"max_shared_memory_per_block"
,
INT64_MAX
);
int64_t
max_thread_per_block
=
get_int
(
"max_thread_per_block"
,
INT64_MAX
);
int64_t
max_thread_x
=
get_int
(
"max_thread_x"
,
INT64_MAX
);
int64_t
max_thread_y
=
get_int
(
"max_thread_y"
,
INT64_MAX
);
int64_t
max_thread_z
=
get_int
(
"max_thread_z"
,
INT64_MAX
);
return
verifier
.
Verify
(
stmt
,
max_local_memory_per_block
,
max_shared_memory_per_block
,
max_thread_per_block
,
max_thread_x
,
max_thread_y
,
max_thread_z
);
}
}
// namespace ir
}
// namespace tvm
src/runtime/cuda/cuda_device_api.cc
View file @
531bb7c4
...
@@ -5,10 +5,12 @@
...
@@ -5,10 +5,12 @@
*/
*/
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/device_api.h>
#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <dmlc/thread_local.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/registry.h>
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#include <tvm/container.h>
#include <tvm/ir.h>
#include <tvm/packed_func_ext.h>
#include "./cuda_common.h"
#include "./cuda_common.h"
namespace
tvm
{
namespace
tvm
{
...
@@ -70,6 +72,20 @@ class CUDADeviceAPI final : public DeviceAPI {
...
@@ -70,6 +72,20 @@ class CUDADeviceAPI final : public DeviceAPI {
&
value
,
cudaDevAttrMultiProcessorCount
,
ctx
.
device_id
));
&
value
,
cudaDevAttrMultiProcessorCount
,
ctx
.
device_id
));
break
;
break
;
}
}
case
kMaxThreadDimensions
:
{
int
dims
[
3
];
CUDA_CALL
(
cudaDeviceGetAttribute
(
&
dims
[
0
],
cudaDevAttrMaxBlockDimX
,
ctx
.
device_id
));
CUDA_CALL
(
cudaDeviceGetAttribute
(
&
dims
[
1
],
cudaDevAttrMaxBlockDimY
,
ctx
.
device_id
));
CUDA_CALL
(
cudaDeviceGetAttribute
(
&
dims
[
2
],
cudaDevAttrMaxBlockDimZ
,
ctx
.
device_id
));
std
::
stringstream
ss
;
// use json string to return multiple int values;
ss
<<
"["
<<
dims
[
0
]
<<
", "
<<
dims
[
1
]
<<
", "
<<
dims
[
2
]
<<
"]"
;
*
rv
=
ss
.
str
();
return
;
}
}
}
*
rv
=
value
;
*
rv
=
value
;
}
}
...
...
src/runtime/metal/metal_device_api.mm
View file @
531bb7c4
...
@@ -42,6 +42,7 @@ void MetalWorkspace::GetAttr(
...
@@ -42,6 +42,7 @@ void MetalWorkspace::GetAttr(
case kDeviceName: return;
case kDeviceName: return;
case kMaxClockRate: return;
case kMaxClockRate: return;
case kMultiProcessorCount: return;
case kMultiProcessorCount: return;
case kMaxThreadDimensions: return;
case kExist: break;
case kExist: break;
}
}
}
}
...
...
src/runtime/opencl/opencl_device_api.cc
View file @
531bb7c4
...
@@ -4,6 +4,9 @@
...
@@ -4,6 +4,9 @@
*/
*/
#include <tvm/runtime/registry.h>
#include <tvm/runtime/registry.h>
#include <dmlc/thread_local.h>
#include <dmlc/thread_local.h>
#include <tvm/container.h>
#include <tvm/ir.h>
#include <tvm/packed_func_ext.h>
#include "./opencl_common.h"
#include "./opencl_common.h"
namespace
tvm
{
namespace
tvm
{
...
@@ -30,6 +33,7 @@ void OpenCLWorkspace::GetAttr(
...
@@ -30,6 +33,7 @@ void OpenCLWorkspace::GetAttr(
CHECK_LT
(
index
,
devices
.
size
())
CHECK_LT
(
index
,
devices
.
size
())
<<
"Invalid device id "
<<
index
;
<<
"Invalid device id "
<<
index
;
switch
(
kind
)
{
switch
(
kind
)
{
case
kExist
:
break
;
case
kMaxThreadsPerBlock
:
{
case
kMaxThreadsPerBlock
:
{
size_t
value
;
size_t
value
;
OPENCL_CALL
(
clGetDeviceInfo
(
OPENCL_CALL
(
clGetDeviceInfo
(
...
@@ -80,7 +84,16 @@ void OpenCLWorkspace::GetAttr(
...
@@ -80,7 +84,16 @@ void OpenCLWorkspace::GetAttr(
*
rv
=
static_cast
<
int32_t
>
(
value
);
*
rv
=
static_cast
<
int32_t
>
(
value
);
break
;
break
;
}
}
case
kExist
:
break
;
case
kMaxThreadDimensions
:
{
size_t
dims
[
3
];
OPENCL_CALL
(
clGetDeviceInfo
(
devices
[
index
],
CL_DEVICE_MAX_WORK_ITEM_SIZES
,
sizeof
(
dims
),
dims
,
nullptr
));
std
::
stringstream
ss
;
// use json string to return multiple int values;
ss
<<
"["
<<
dims
[
0
]
<<
", "
<<
dims
[
1
]
<<
", "
<<
dims
[
2
]
<<
"]"
;
*
rv
=
ss
.
str
();
break
;
}
}
}
}
}
...
...
src/runtime/opengl/opengl_device_api.cc
View file @
531bb7c4
...
@@ -97,6 +97,7 @@ void OpenGLWorkspace::GetAttr(
...
@@ -97,6 +97,7 @@ void OpenGLWorkspace::GetAttr(
case
kDeviceName
:
return
;
case
kDeviceName
:
return
;
case
kMaxClockRate
:
return
;
case
kMaxClockRate
:
return
;
case
kMultiProcessorCount
:
return
;
case
kMultiProcessorCount
:
return
;
case
kMaxThreadDimensions
:
return
;
}
}
}
}
...
...
src/runtime/rocm/rocm_device_api.cc
View file @
531bb7c4
...
@@ -52,6 +52,7 @@ class ROCMDeviceAPI final : public DeviceAPI {
...
@@ -52,6 +52,7 @@ class ROCMDeviceAPI final : public DeviceAPI {
case
kDeviceName
:
return
;
case
kDeviceName
:
return
;
case
kMaxClockRate
:
return
;
case
kMaxClockRate
:
return
;
case
kMultiProcessorCount
:
return
;
case
kMultiProcessorCount
:
return
;
case
kMaxThreadDimensions
:
return
;
}
}
*
rv
=
value
;
*
rv
=
value
;
}
}
...
...
src/runtime/vulkan/vulkan_device_api.cc
View file @
531bb7c4
...
@@ -73,6 +73,7 @@ void VulkanWorkspace::GetAttr(
...
@@ -73,6 +73,7 @@ void VulkanWorkspace::GetAttr(
case
kMaxClockRate
:
return
;
case
kMaxClockRate
:
return
;
case
kMultiProcessorCount
:
return
;
case
kMultiProcessorCount
:
return
;
case
kExist
:
break
;
case
kExist
:
break
;
case
kMaxThreadDimensions
:
break
;
}
}
}
}
...
...
tests/python/unittest/test_pass_verify_gpu_code.py
0 → 100644
View file @
531bb7c4
"""Test gpu code verifier"""
import
tvm
def
get_verify_pass
(
valid
,
**
kwargs
):
def
verify_pass
(
stmt
):
valid
[
0
]
=
tvm
.
ir_pass
.
VerifyGPUCode
(
stmt
,
kwargs
)
return
stmt
return
verify_pass
def
test_shared_memory
():
N
=
1024
M
=
128
A
=
tvm
.
placeholder
((
N
,),
name
=
'A'
,
dtype
=
'float32'
)
B
=
tvm
.
compute
((
N
,
),
lambda
i
:
A
[
i
],
name
=
'B'
)
s
=
tvm
.
create_schedule
([
B
.
op
])
AA
=
s
.
cache_read
(
A
,
"shared"
,
[
B
])
o
,
i
=
s
[
B
]
.
split
(
s
[
B
]
.
op
.
axis
[
0
],
M
)
s
[
AA
]
.
compute_at
(
s
[
B
],
o
)
s
[
B
]
.
bind
(
o
,
tvm
.
thread_axis
(
"blockIdx.x"
))
s
[
B
]
.
bind
(
i
,
tvm
.
thread_axis
(
"threadIdx.x"
))
# shared memory usage: M * 4B
# thread usage: M
for
target
in
[
'opencl'
,
'cuda'
]:
if
not
tvm
.
context
(
target
)
.
exist
:
continue
valid
=
[
None
]
with
tvm
.
build_config
(
**
{
"add_lower_pass"
:
[
(
2
,
get_verify_pass
(
valid
,
max_shared_memory_per_block
=
4
*
M
-
1
,
max_thread_per_block
=
M
))]}):
tvm
.
build
(
s
,
[
A
,
B
],
target
)
assert
not
valid
[
0
]
with
tvm
.
build_config
(
**
{
"add_lower_pass"
:
[
(
2
,
get_verify_pass
(
valid
,
max_shared_memory_per_block
=
4
*
M
,
max_thread_per_block
=
M
))]}):
tvm
.
build
(
s
,
[
A
,
B
],
target
)
assert
valid
[
0
]
def
test_local_memory
():
N
=
1024
M
=
128
A
=
tvm
.
placeholder
((
N
,),
name
=
'A'
,
dtype
=
'float32'
)
B
=
tvm
.
compute
((
N
,
),
lambda
i
:
A
[
i
],
name
=
'B'
)
s
=
tvm
.
create_schedule
([
B
.
op
])
AA
=
s
.
cache_read
(
A
,
"local"
,
[
B
])
o
,
i
=
s
[
B
]
.
split
(
s
[
B
]
.
op
.
axis
[
0
],
M
)
s
[
AA
]
.
compute_at
(
s
[
B
],
o
)
s
[
B
]
.
bind
(
o
,
tvm
.
thread_axis
(
"blockIdx.x"
))
# local memory usage: M * 4B
# thread usage: M
for
target
in
[
'opencl'
,
'cuda'
]:
if
not
tvm
.
context
(
target
)
.
exist
:
continue
valid
=
[
None
]
with
tvm
.
build_config
(
**
{
"add_lower_pass"
:
[
(
2
,
get_verify_pass
(
valid
,
max_local_memory_per_block
=
4
*
M
-
1
,
max_thread_per_block
=
1
))]}):
tvm
.
build
(
s
,
[
A
,
B
],
target
)
assert
not
valid
[
0
]
with
tvm
.
build_config
(
**
{
"add_lower_pass"
:
[
(
2
,
get_verify_pass
(
valid
,
max_local_memory_per_block
=
4
*
M
,
max_thread_per_block
=
1
))]}):
tvm
.
build
(
s
,
[
A
,
B
],
target
)
assert
valid
[
0
]
def
test_num_thread
():
N
=
1024
M
=
128
A
=
tvm
.
placeholder
((
N
,),
name
=
'A'
,
dtype
=
'float32'
)
B
=
tvm
.
compute
((
N
,
),
lambda
i
:
A
[
i
],
name
=
'B'
)
s
=
tvm
.
create_schedule
([
B
.
op
])
o
,
i
=
s
[
B
]
.
split
(
s
[
B
]
.
op
.
axis
[
0
],
M
)
s
[
B
]
.
bind
(
o
,
tvm
.
thread_axis
(
'threadIdx.x'
))
s
[
B
]
.
bind
(
i
,
tvm
.
thread_axis
(
"threadIdx.y"
))
# shared memory usage: 0
# thread usage: N
for
target
in
[
'opencl'
,
'cuda'
]:
if
not
tvm
.
context
(
target
)
.
exist
:
continue
valid
=
[
None
]
with
tvm
.
build_config
(
**
{
"add_lower_pass"
:
[
(
2
,
get_verify_pass
(
valid
,
max_shared_memory_per_block
=
0
,
max_thread_per_block
=
N
-
1
))]}):
tvm
.
build
(
s
,
[
A
,
B
],
target
)
assert
not
valid
[
0
]
with
tvm
.
build_config
(
**
{
"add_lower_pass"
:
[
(
2
,
get_verify_pass
(
valid
,
max_shared_memory_per_block
=
0
,
max_thread_per_block
=
N
))]}):
tvm
.
build
(
s
,
[
A
,
B
],
target
)
assert
valid
[
0
]
with
tvm
.
build_config
(
**
{
"add_lower_pass"
:
[
(
2
,
get_verify_pass
(
valid
,
max_shared_memory_per_block
=
0
,
max_thread_per_block
=
N
,
max_thread_y
=
M
-
1
))]}):
tvm
.
build
(
s
,
[
A
,
B
],
target
)
assert
not
valid
[
0
]
with
tvm
.
build_config
(
**
{
"add_lower_pass"
:
[
(
2
,
get_verify_pass
(
valid
,
max_shared_memory_per_block
=
0
,
max_thread_per_block
=
N
,
max_thread_y
=
M
))]}):
tvm
.
build
(
s
,
[
A
,
B
],
target
)
assert
valid
[
0
]
def
test_multiple_kernels
():
N
=
1024
A
=
tvm
.
placeholder
((
N
,
N
),
name
=
'A'
)
B
=
tvm
.
compute
((
N
,
N
),
lambda
i
,
j
:
A
[
i
,
j
])
C
=
tvm
.
compute
((
N
,
N
),
lambda
i
,
j
:
B
[
i
,
j
])
s
=
tvm
.
create_schedule
([
C
.
op
])
s
[
C
]
.
bind
(
s
[
C
]
.
op
.
axis
[
1
],
tvm
.
thread_axis
(
"threadIdx.x"
))
s
[
B
]
.
bind
(
s
[
B
]
.
op
.
axis
[
1
],
tvm
.
thread_axis
(
"threadIdx.x"
))
# shared memory usage: 0
# thread usage: N
for
target
in
[
'opencl'
,
'cuda'
]:
if
not
tvm
.
context
(
target
)
.
exist
:
continue
valid
=
[
None
]
with
tvm
.
build_config
(
**
{
"add_lower_pass"
:
[
(
2
,
get_verify_pass
(
valid
,
max_shared_memory_per_block
=
0
,
max_thread_per_block
=
N
-
1
))]}):
tvm
.
build
(
s
,
[
A
,
C
],
target
)
assert
not
valid
[
0
]
with
tvm
.
build_config
(
**
{
"add_lower_pass"
:
[
(
2
,
get_verify_pass
(
valid
,
max_shared_memory_per_block
=
0
,
max_thread_per_block
=
N
))]}):
tvm
.
build
(
s
,
[
A
,
C
],
target
)
assert
valid
[
0
]
if
__name__
==
"__main__"
:
test_local_memory
()
test_shared_memory
()
test_num_thread
()
test_multiple_kernels
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment