Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
61dad72e
Commit
61dad72e
authored
Jun 06, 2018
by
Liangfu Chen
Committed by
Tianqi Chen
Jun 05, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add test to irbuilder for gpu execution (#1228)
parent
ce34ae16
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
108 additions
and
12 deletions
+108
-12
tests/python/unittest/test_codegen_extern.py
+31
-12
tests/python/unittest/test_ir_builder.py
+77
-0
No files found.
tests/python/unittest/test_codegen_extern.py
View file @
61dad72e
...
...
@@ -2,36 +2,55 @@ import tvm
import
numpy
as
np
def
test_add_pipeline
():
nn
=
1024
nn
=
64
max_threads
=
4
n
=
tvm
.
convert
(
nn
)
A
=
tvm
.
placeholder
((
n
,),
name
=
'A'
)
def
extern_generator
(
ins
,
outs
):
"""Manually write the IR for the extern function, add pipeline"""
ib
=
tvm
.
ir_builder
.
create
()
with
ib
.
for_range
(
0
,
n
/
2
)
as
i
:
with
ib
.
for_range
(
0
,
(
n
+
1
)
//
2
)
as
i
:
ib
.
emit
(
outs
[
0
]
.
vstore
(
i
*
2
,
ins
[
0
]
.
vload
(
i
*
2
,
"float32x2"
)
+
tvm
.
const
(
1
,
"float32x2"
)))
return
ib
.
get
()
C
=
tvm
.
extern
(
A
.
shape
,
[
A
],
extern_generator
,
name
=
'C'
)
s
=
tvm
.
create_schedule
(
C
.
op
)
print
(
tvm
.
lower
(
s
,
[
A
,
C
],
simple_mode
=
True
))
def
extern_generator_gpu
(
ins
,
outs
):
"""Manually write the IR for the extern function, add pipeline"""
ib
=
tvm
.
ir_builder
.
create
()
bx
=
tvm
.
thread_axis
(
"blockIdx.x"
)
tx
=
tvm
.
thread_axis
(
"threadIdx.x"
)
ib
.
scope_attr
(
bx
,
"thread_extent"
,
(
nn
+
max_threads
-
1
)
//
max_threads
)
ib
.
scope_attr
(
tx
,
"thread_extent"
,
max_threads
)
idx
=
bx
.
var
*
max_threads
+
tx
.
var
with
ib
.
if_scope
(
ib
.
likely
(
idx
<
n
)):
ib
.
emit
(
outs
[
0
]
.
vstore
(
idx
*
2
,
ins
[
0
]
.
vload
(
idx
*
2
,
"float32x2"
)
+
tvm
.
const
(
1
,
"float32x2"
)))
return
ib
.
get
()
C_cpu
=
tvm
.
extern
(
A
.
shape
,
[
A
],
extern_generator
,
name
=
'C'
)
C_gpu
=
tvm
.
extern
(
A
.
shape
,
[
A
],
extern_generator_gpu
,
name
=
'C'
)
s_cpu
=
tvm
.
create_schedule
(
C_cpu
.
op
)
s_gpu
=
tvm
.
create_schedule
(
C_gpu
.
op
)
print
(
tvm
.
lower
(
s_cpu
,
[
A
,
C_cpu
],
simple_mode
=
True
))
print
(
tvm
.
lower
(
s_gpu
,
[
A
,
C_gpu
],
simple_mode
=
True
))
def
check_
llvm
(
):
if
not
tvm
.
module
.
enabled
(
"llvm"
):
def
check_
target
(
target
):
if
not
tvm
.
module
.
enabled
(
target
):
return
s
=
s_gpu
if
target
in
[
'opencl'
,
'cuda'
]
else
s_cpu
C
=
C_gpu
if
target
in
[
'opencl'
,
'cuda'
]
else
C_cpu
# build and invoke the kernel.
f
=
tvm
.
build
(
s
,
[
A
,
C
],
"llvm"
)
ctx
=
tvm
.
c
pu
(
0
)
f
=
tvm
.
build
(
s
,
[
A
,
C
],
target
)
ctx
=
tvm
.
c
ontext
(
target
,
0
)
# launch the kernel.
n
=
nn
a
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
n
)
.
astype
(
A
.
dtype
),
ctx
)
c
=
tvm
.
nd
.
array
(
np
.
zeros
(
n
,
dtype
=
C
.
dtype
),
ctx
)
f
(
a
,
c
)
np
.
testing
.
assert_allclose
(
c
.
asnumpy
(),
a
.
asnumpy
()
+
1
)
check_llvm
()
np
.
testing
.
assert_allclose
(
c
.
asnumpy
(),
a
.
asnumpy
()
+
1
)
check_target
(
"llvm"
)
check_target
(
"opencl"
)
check_target
(
"cuda"
)
def
test_pack_buffer_simple
():
nn
=
1024
...
...
tests/python/unittest/test_ir_builder.py
View file @
61dad72e
import
tvm
import
numpy
as
np
def
test_for
():
ib
=
tvm
.
ir_builder
.
create
()
...
...
@@ -53,8 +54,84 @@ def test_prefetch():
body
=
ib
.
get
()
assert
body
.
body
.
bounds
[
0
]
.
extent
.
value
==
2
def
test_cpu
():
n
=
1024
dtype
=
"float32"
A
=
tvm
.
placeholder
((
n
,),
name
=
'A'
)
B
=
tvm
.
placeholder
((
n
,),
name
=
'B'
)
def
test_device_ir
(
A
,
B
,
C
):
n
=
A
.
shape
[
0
]
max_threads
=
8
ib
=
tvm
.
ir_builder
.
create
()
Aptr
=
ib
.
buffer_ptr
(
A
)
Bptr
=
ib
.
buffer_ptr
(
B
)
Cptr
=
ib
.
buffer_ptr
(
C
)
with
ib
.
for_range
(
0
,
n
,
name
=
"i"
)
as
i
:
Cptr
[
i
]
=
Aptr
[
i
]
+
Bptr
[
i
]
body
=
ib
.
get
()
return
body
C
=
tvm
.
extern
(
A
.
shape
,
[
A
,
B
],
lambda
ins
,
outs
:
test_device_ir
(
ins
[
0
],
ins
[
1
],
outs
[
0
]),
name
=
"vector_add"
,
dtype
=
dtype
)
s
=
tvm
.
create_schedule
(
C
.
op
)
def
check_target
(
target
):
if
not
tvm
.
module
.
enabled
(
target
):
return
# build and invoke the kernel.
fadd
=
tvm
.
build
(
s
,
[
A
,
B
,
C
],
target
)
ctx
=
tvm
.
context
(
target
,
0
)
# launch the kernel.
a
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
n
)
.
astype
(
A
.
dtype
),
ctx
)
b
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
n
)
.
astype
(
B
.
dtype
),
ctx
)
c
=
tvm
.
nd
.
array
(
np
.
zeros
(
n
,
dtype
=
C
.
dtype
),
ctx
)
fadd
(
a
,
b
,
c
)
np
.
testing
.
assert_allclose
(
c
.
asnumpy
(),
a
.
asnumpy
()
+
b
.
asnumpy
())
check_target
(
"llvm"
)
def
test_gpu
():
n
=
tvm
.
var
(
'n'
)
dtype
=
"float32"
A
=
tvm
.
placeholder
((
n
,),
name
=
'A'
)
B
=
tvm
.
placeholder
((
n
,),
name
=
'B'
)
def
test_device_ir
(
A
,
B
,
C
):
n
=
A
.
shape
[
0
]
max_threads
=
32
ib
=
tvm
.
ir_builder
.
create
()
bx
=
tvm
.
thread_axis
(
"blockIdx.x"
)
tx
=
tvm
.
thread_axis
(
"threadIdx.x"
)
ib
.
scope_attr
(
bx
,
"thread_extent"
,
(
n
+
max_threads
-
1
)
//
max_threads
)
ib
.
scope_attr
(
tx
,
"thread_extent"
,
max_threads
)
idx
=
bx
.
var
*
max_threads
+
tx
.
var
Aptr
=
ib
.
buffer_ptr
(
A
)
Bptr
=
ib
.
buffer_ptr
(
B
)
Cptr
=
ib
.
buffer_ptr
(
C
)
with
ib
.
if_scope
(
ib
.
likely
(
idx
<
n
)):
Cptr
[
idx
]
=
Aptr
[
idx
]
+
Bptr
[
idx
]
body
=
ib
.
get
()
return
body
C
=
tvm
.
extern
(
A
.
shape
,
[
A
,
B
],
lambda
ins
,
outs
:
test_device_ir
(
ins
[
0
],
ins
[
1
],
outs
[
0
]),
name
=
"vector_add"
,
dtype
=
dtype
)
s
=
tvm
.
create_schedule
(
C
.
op
)
bounds
=
tvm
.
schedule
.
InferBound
(
s
)
stmt
=
tvm
.
schedule
.
ScheduleOps
(
s
,
bounds
)
def
check_target
(
target
):
n
=
1024
if
not
tvm
.
module
.
enabled
(
target
):
return
# build and invoke the kernel.
fadd
=
tvm
.
build
(
s
,
[
A
,
B
,
C
],
target
)
ctx
=
tvm
.
context
(
target
,
0
)
# launch the kernel.
a
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
n
)
.
astype
(
A
.
dtype
),
ctx
)
b
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
n
)
.
astype
(
B
.
dtype
),
ctx
)
c
=
tvm
.
nd
.
array
(
np
.
zeros
(
n
,
dtype
=
C
.
dtype
),
ctx
)
fadd
(
a
,
b
,
c
)
np
.
testing
.
assert_allclose
(
c
.
asnumpy
(),
a
.
asnumpy
()
+
b
.
asnumpy
())
check_target
(
"opencl"
)
check_target
(
"cuda"
)
if
__name__
==
"__main__"
:
test_prefetch
()
test_if
()
test_for
()
test_cpu
()
test_gpu
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment