Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
9082a93b
Commit
9082a93b
authored
Apr 07, 2018
by
Tianqi Chen
Committed by
GitHub
Apr 07, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Make target and build module more pythonic (#1089)
parent
819728db
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
25 additions
and
12 deletions
+25
-12
include/tvm/build_module.h
+3
-0
python/tvm/build_module.py
+11
-4
python/tvm/contrib/rpc/server.py
+3
-2
src/codegen/build_module.cc
+4
-6
tests/python/unittest/test_lang_target.py
+4
-0
No files found.
include/tvm/build_module.h
View file @
9082a93b
...
...
@@ -24,6 +24,8 @@ class TargetNode : public Node {
public
:
/*! \brief The name of the target device */
std
::
string
target_name
;
/*! \brief The name of the target device */
std
::
string
device_name
;
/*! \brief The type of the target device */
int
device_type
;
/*! \brief The maximum threads that a schedule should use for this device */
...
...
@@ -42,6 +44,7 @@ class TargetNode : public Node {
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"target_name"
,
&
target_name
);
v
->
Visit
(
"device_name"
,
&
device_name
);
v
->
Visit
(
"device_type"
,
&
device_type
);
v
->
Visit
(
"max_num_threads"
,
&
max_num_threads
);
v
->
Visit
(
"thread_warp_size"
,
&
thread_warp_size
);
...
...
python/tvm/build_module.py
View file @
9082a93b
...
...
@@ -150,6 +150,13 @@ class BuildConfig(NodeBase):
result
+=
[(
phase
,
func
)]
return
result
@add_lower_pass.setter
def
add_lower_pass
(
self
,
value
):
add_lower_pass_args
=
[]
for
x
in
value
:
add_lower_pass_args
+=
[
x
[
0
],
x
[
1
]]
_api_internal
.
_BuildConfigSetAddLowerPass
(
self
,
*
add_lower_pass_args
)
def
__enter__
(
self
):
# pylint: disable=protected-access
_api_internal
.
_EnterBuildConfigScope
(
self
)
...
...
@@ -168,9 +175,12 @@ class BuildConfig(NodeBase):
"'
%
s' object cannot set attribute '
%
s'"
%
(
str
(
type
(
self
)),
name
))
return
super
(
BuildConfig
,
self
)
.
__setattr__
(
name
,
value
)
def
current_build_config
():
"""Get the current build configuration."""
return
_api_internal
.
_GetCurrentBuildConfig
()
def
build_config
(
**
kwargs
):
"""Configure the build behavior by setting config variables.
...
...
@@ -230,10 +240,7 @@ def build_config(**kwargs):
config
=
make
.
node
(
"BuildConfig"
,
**
node_args
)
if
"add_lower_pass"
in
kwargs
:
add_lower_pass_args
=
[]
for
x
in
kwargs
[
"add_lower_pass"
]:
add_lower_pass_args
+=
[
x
[
0
],
x
[
1
]]
_api_internal
.
_BuildConfigSetAddLowerPass
(
config
,
*
add_lower_pass_args
)
config
.
add_lower_pass
=
kwargs
[
"add_lower_pass"
]
return
config
...
...
python/tvm/contrib/rpc/server.py
View file @
9082a93b
...
...
@@ -166,8 +166,9 @@ def _listen_loop(sock, port, rpc_key, tracker_addr):
conn
,
addr
,
opts
=
_accept_conn
(
sock
,
tracker_conn
)
except
(
socket
.
error
,
IOError
):
# retry when tracker is dropped
tracker_conn
.
close
()
tracker_conn
=
None
if
tracker_conn
:
tracker_conn
.
close
()
tracker_conn
=
None
continue
# step 3: serving
...
...
src/codegen/build_module.cc
View file @
9082a93b
...
...
@@ -37,8 +37,6 @@ Target CreateTarget(const std::string& target_name,
t
->
target_name
=
target_name
;
std
::
string
device_name
=
""
;
std
::
string
libs_flag
=
"-libs="
;
std
::
string
device_flag
=
"-device="
;
for
(
auto
&
item
:
options
)
{
...
...
@@ -51,12 +49,12 @@ Target CreateTarget(const std::string& target_name,
t
->
libs_array
.
push_back
(
ir
::
StringImm
::
make
(
lib_item
));
}
}
else
if
(
item
.
find
(
device_flag
)
==
0
)
{
device_name
=
item
.
substr
(
device_flag
.
length
());
t
->
device_name
=
item
.
substr
(
device_flag
.
length
());
}
}
if
(
device_name
.
length
()
>
0
)
{
t
->
keys_array
.
push_back
(
ir
::
StringImm
::
make
(
device_name
));
if
(
t
->
device_name
.
length
()
>
0
)
{
t
->
keys_array
.
push_back
(
ir
::
StringImm
::
make
(
t
->
device_name
));
}
t
->
device_type
=
kDLCPU
;
t
->
thread_warp_size
=
1
;
...
...
@@ -78,7 +76,7 @@ Target CreateTarget(const std::string& target_name,
t
->
keys_array
.
push_back
(
ir
::
StringImm
::
make
(
"rocm"
));
t
->
keys_array
.
push_back
(
ir
::
StringImm
::
make
(
"gpu"
));
t
->
max_num_threads
=
256
;
if
(
device_name
==
"intel_gpu"
)
{
if
(
t
->
device_name
==
"intel_gpu"
)
{
t
->
thread_warp_size
=
16
;
}
}
else
if
(
target_name
==
"metal"
||
target_name
==
"vulkan"
)
{
...
...
tests/python/unittest/test_lang_target.py
View file @
9082a93b
...
...
@@ -36,6 +36,7 @@ def test_target_dispatch():
assert
tvm
.
target
.
current_target
()
==
None
def
test_target_string_parse
():
target
=
tvm
.
target
.
create
(
"cuda -libs=cublas,cudnn"
)
...
...
@@ -45,6 +46,9 @@ def test_target_string_parse():
assert
target
.
libs
==
[
'cublas'
,
'cudnn'
]
assert
str
(
target
)
==
str
(
tvm
.
target
.
cuda
(
"-libs=cublas,cudnn"
))
assert
tvm
.
target
.
intel_gpu
()
.
device_name
==
"intel_gpu"
if
__name__
==
"__main__"
:
test_target_dispatch
()
test_target_string_parse
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment