Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
ec3a4251
Commit
ec3a4251
authored
Feb 12, 2019
by
Marina Kolpakova
Committed by
Tianqi Chen
Feb 12, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
A couple of fixes for GEN (#2593)
parent
77718f8e
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
22 additions
and
14 deletions
+22
-14
python/tvm/build_module.py
+1
-1
python/tvm/contrib/debugger/debug_runtime.py
+4
-3
src/runtime/thread_storage_scope.h
+1
-1
tutorials/dev/low_level_custom_pass.py
+1
-1
tutorials/get_started.py
+15
-8
No files found.
python/tvm/build_module.py
View file @
ec3a4251
...
@@ -302,7 +302,7 @@ def lower(sch,
...
@@ -302,7 +302,7 @@ def lower(sch,
Parameters
Parameters
----------
----------
sch : tvm.schedule.Schedule
sch : tvm.schedule.Schedule
The schedule to be buil
ded
The schedule to be buil
t
args : list of Buffer or Tensor or Var
args : list of Buffer or Tensor or Var
The argument lists to the function.
The argument lists to the function.
...
...
python/tvm/contrib/debugger/debug_runtime.py
View file @
ec3a4251
...
@@ -159,11 +159,12 @@ class GraphModuleDebug(graph_runtime.GraphModule):
...
@@ -159,11 +159,12 @@ class GraphModuleDebug(graph_runtime.GraphModule):
self
.
debug_datum
=
debug_result
.
DebugResult
(
graph_json
,
self
.
_dump_path
)
self
.
debug_datum
=
debug_result
.
DebugResult
(
graph_json
,
self
.
_dump_path
)
def
_run_debug
(
self
):
def
_run_debug
(
self
):
"""Execute the node spcified with index will be executed.
"""Execute the node sp
e
cified with index will be executed.
Each debug output will be copied to the buffer
Each debug output will be copied to the buffer
Time consumed for each execuion will be set as debug output.
Time consumed for each execu
t
ion will be set as debug output.
"""
"""
self
.
debug_datum
.
_time_list
=
[]
for
i
,
node
in
enumerate
(
self
.
debug_datum
.
get_graph_nodes
()):
for
i
,
node
in
enumerate
(
self
.
debug_datum
.
get_graph_nodes
()):
start_time
=
datetime
.
now
()
.
time
()
start_time
=
datetime
.
now
()
.
time
()
...
@@ -177,7 +178,7 @@ class GraphModuleDebug(graph_runtime.GraphModule):
...
@@ -177,7 +178,7 @@ class GraphModuleDebug(graph_runtime.GraphModule):
self
.
debug_datum
.
_output_tensor_list
.
append
(
out_tensor
)
self
.
debug_datum
.
_output_tensor_list
.
append
(
out_tensor
)
def
debug_get_output
(
self
,
node
,
out
):
def
debug_get_output
(
self
,
node
,
out
):
"""Run graph upto node and get the output to out
"""Run graph up
to node and get the output to out
Parameters
Parameters
----------
----------
...
...
src/runtime/thread_storage_scope.h
View file @
ec3a4251
...
@@ -130,7 +130,7 @@ struct ThreadScope {
...
@@ -130,7 +130,7 @@ struct ThreadScope {
};
};
/*! \brief workload spec
c
ification */
/*! \brief workload specification */
struct
ThreadWorkLoad
{
struct
ThreadWorkLoad
{
// array, first three are thread configuration.
// array, first three are thread configuration.
size_t
work_size
[
6
];
size_t
work_size
[
6
];
...
...
tutorials/dev/low_level_custom_pass.py
View file @
ec3a4251
...
@@ -31,7 +31,7 @@ import numpy as np
...
@@ -31,7 +31,7 @@ import numpy as np
######################################################################
######################################################################
# We first write a very simple vector add and build it with the default schedule. Then, we use
# We first write a very simple vector add and build it with the default schedule. Then, we use
# our customized lowering pass to manipulate the IR directly instead of using schedule pr
e
mitives.
# our customized lowering pass to manipulate the IR directly instead of using schedule pr
i
mitives.
#
#
n
=
tvm
.
const
(
128
,
"int32"
)
n
=
tvm
.
const
(
128
,
"int32"
)
...
...
tutorials/get_started.py
View file @
ec3a4251
...
@@ -94,7 +94,7 @@ bx, tx = s[C].split(C.op.axis[0], factor=64)
...
@@ -94,7 +94,7 @@ bx, tx = s[C].split(C.op.axis[0], factor=64)
# compute grid. These are GPU specific constructs that allows us
# compute grid. These are GPU specific constructs that allows us
# to generate code that runs on GPU.
# to generate code that runs on GPU.
#
#
if
tgt
==
"cuda"
:
if
tgt
==
"cuda"
or
tgt
.
startswith
(
'opencl'
)
:
s
[
C
]
.
bind
(
bx
,
tvm
.
thread_axis
(
"blockIdx.x"
))
s
[
C
]
.
bind
(
bx
,
tvm
.
thread_axis
(
"blockIdx.x"
))
s
[
C
]
.
bind
(
tx
,
tvm
.
thread_axis
(
"threadIdx.x"
))
s
[
C
]
.
bind
(
tx
,
tvm
.
thread_axis
(
"threadIdx.x"
))
...
@@ -149,7 +149,7 @@ tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
...
@@ -149,7 +149,7 @@ tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
#
#
# The following code fetches the device module and prints the content code.
# The following code fetches the device module and prints the content code.
#
#
if
tgt
==
"cuda"
:
if
tgt
==
"cuda"
or
tgt
.
startswith
(
'opencl'
)
:
dev_module
=
fadd
.
imported_modules
[
0
]
dev_module
=
fadd
.
imported_modules
[
0
]
print
(
"-----GPU code-----"
)
print
(
"-----GPU code-----"
)
print
(
dev_module
.
get_source
())
print
(
dev_module
.
get_source
())
...
@@ -193,6 +193,8 @@ temp = util.tempdir()
...
@@ -193,6 +193,8 @@ temp = util.tempdir()
fadd
.
save
(
temp
.
relpath
(
"myadd.o"
))
fadd
.
save
(
temp
.
relpath
(
"myadd.o"
))
if
tgt
==
"cuda"
:
if
tgt
==
"cuda"
:
fadd
.
imported_modules
[
0
]
.
save
(
temp
.
relpath
(
"myadd.ptx"
))
fadd
.
imported_modules
[
0
]
.
save
(
temp
.
relpath
(
"myadd.ptx"
))
if
tgt
.
startswith
(
'opencl'
):
fadd
.
imported_modules
[
0
]
.
save
(
temp
.
relpath
(
"myadd.cl"
))
cc
.
create_shared
(
temp
.
relpath
(
"myadd.so"
),
[
temp
.
relpath
(
"myadd.o"
)])
cc
.
create_shared
(
temp
.
relpath
(
"myadd.so"
),
[
temp
.
relpath
(
"myadd.o"
)])
print
(
temp
.
listdir
())
print
(
temp
.
listdir
())
...
@@ -200,29 +202,34 @@ print(temp.listdir())
...
@@ -200,29 +202,34 @@ print(temp.listdir())
# .. note:: Module Storage Format
# .. note:: Module Storage Format
#
#
# The CPU(host) module is directly saved as a shared library(so).
# The CPU(host) module is directly saved as a shared library(so).
# There can be multiple customed format on the device code.
# There can be multiple custom
iz
ed format on the device code.
# In our example, device code is stored in ptx, as well as a meta
# In our example, device code is stored in ptx, as well as a meta
# data json file. They can be loaded and linked sep
erated
ly via import.
# data json file. They can be loaded and linked sep
arate
ly via import.
#
#
######################################################################
######################################################################
# Load Compiled Module
# Load Compiled Module
# --------------------
# --------------------
# We can load the compiled module from the file system and run the code.
# We can load the compiled module from the file system and run the code.
# The following code load the host and device module sep
erated
ly and
# The following code load the host and device module sep
arate
ly and
# re-link them together. We can verify that the newly loaded function works.
# re-link them together. We can verify that the newly loaded function works.
#
#
fadd1
=
tvm
.
module
.
load
(
temp
.
relpath
(
"myadd.so"
))
fadd1
=
tvm
.
module
.
load
(
temp
.
relpath
(
"myadd.so"
))
if
tgt
==
"cuda"
:
if
tgt
==
"cuda"
:
fadd1_dev
=
tvm
.
module
.
load
(
temp
.
relpath
(
"myadd.ptx"
))
fadd1_dev
=
tvm
.
module
.
load
(
temp
.
relpath
(
"myadd.ptx"
))
fadd1
.
import_module
(
fadd1_dev
)
fadd1
.
import_module
(
fadd1_dev
)
if
tgt
.
startswith
(
'opencl'
):
fadd1_dev
=
tvm
.
module
.
load
(
temp
.
relpath
(
"myadd.cl"
))
fadd1
.
import_module
(
fadd1_dev
)
fadd1
(
a
,
b
,
c
)
fadd1
(
a
,
b
,
c
)
tvm
.
testing
.
assert_allclose
(
c
.
asnumpy
(),
a
.
asnumpy
()
+
b
.
asnumpy
())
tvm
.
testing
.
assert_allclose
(
c
.
asnumpy
(),
a
.
asnumpy
()
+
b
.
asnumpy
())
######################################################################
######################################################################
# Pack Everything into One Library
# Pack Everything into One Library
# --------------------------------
# --------------------------------
# In the above example, we store the device and host code sep
erated
ly.
# In the above example, we store the device and host code sep
arate
ly.
# TVM also supports export everything as one shared library.
# TVM also supports export everything as one shared library.
# Under the hood, we pack the device modules into binary blobs and link
# Under the hood, we pack the device modules into binary blobs and link
# them together with the host code.
# them together with the host code.
...
@@ -254,8 +261,8 @@ tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
...
@@ -254,8 +261,8 @@ tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
# The following codeblocks generate opencl code, creates array on opencl
# The following codeblocks generate opencl code, creates array on opencl
# device, and verifies the correctness of the code.
# device, and verifies the correctness of the code.
#
#
if
tgt
==
"opencl"
:
if
tgt
.
startswith
(
'opencl'
)
:
fadd_cl
=
tvm
.
build
(
s
,
[
A
,
B
,
C
],
"opencl"
,
name
=
"myadd"
)
fadd_cl
=
tvm
.
build
(
s
,
[
A
,
B
,
C
],
tgt
,
name
=
"myadd"
)
print
(
"------opencl code------"
)
print
(
"------opencl code------"
)
print
(
fadd_cl
.
imported_modules
[
0
]
.
get_source
())
print
(
fadd_cl
.
imported_modules
[
0
]
.
get_source
())
ctx
=
tvm
.
cl
(
0
)
ctx
=
tvm
.
cl
(
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment