Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
9082a93b
Commit
9082a93b
authored
Apr 07, 2018
by
Tianqi Chen
Committed by
GitHub
Apr 07, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Make target and build module more pythonic (#1089)
parent
819728db
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
25 additions
and
12 deletions
+25
-12
include/tvm/build_module.h
+3
-0
python/tvm/build_module.py
+11
-4
python/tvm/contrib/rpc/server.py
+3
-2
src/codegen/build_module.cc
+4
-6
tests/python/unittest/test_lang_target.py
+4
-0
No files found.
include/tvm/build_module.h
View file @
9082a93b
...
@@ -24,6 +24,8 @@ class TargetNode : public Node {
...
@@ -24,6 +24,8 @@ class TargetNode : public Node {
public
:
public
:
/*! \brief The name of the target device */
/*! \brief The name of the target device */
std
::
string
target_name
;
std
::
string
target_name
;
/*! \brief The name of the target device */
std
::
string
device_name
;
/*! \brief The type of the target device */
/*! \brief The type of the target device */
int
device_type
;
int
device_type
;
/*! \brief The maximum threads that a schedule should use for this device */
/*! \brief The maximum threads that a schedule should use for this device */
...
@@ -42,6 +44,7 @@ class TargetNode : public Node {
...
@@ -42,6 +44,7 @@ class TargetNode : public Node {
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"target_name"
,
&
target_name
);
v
->
Visit
(
"target_name"
,
&
target_name
);
v
->
Visit
(
"device_name"
,
&
device_name
);
v
->
Visit
(
"device_type"
,
&
device_type
);
v
->
Visit
(
"device_type"
,
&
device_type
);
v
->
Visit
(
"max_num_threads"
,
&
max_num_threads
);
v
->
Visit
(
"max_num_threads"
,
&
max_num_threads
);
v
->
Visit
(
"thread_warp_size"
,
&
thread_warp_size
);
v
->
Visit
(
"thread_warp_size"
,
&
thread_warp_size
);
...
...
python/tvm/build_module.py
View file @
9082a93b
...
@@ -150,6 +150,13 @@ class BuildConfig(NodeBase):
...
@@ -150,6 +150,13 @@ class BuildConfig(NodeBase):
result
+=
[(
phase
,
func
)]
result
+=
[(
phase
,
func
)]
return
result
return
result
@add_lower_pass.setter
def
add_lower_pass
(
self
,
value
):
add_lower_pass_args
=
[]
for
x
in
value
:
add_lower_pass_args
+=
[
x
[
0
],
x
[
1
]]
_api_internal
.
_BuildConfigSetAddLowerPass
(
self
,
*
add_lower_pass_args
)
def
__enter__
(
self
):
def
__enter__
(
self
):
# pylint: disable=protected-access
# pylint: disable=protected-access
_api_internal
.
_EnterBuildConfigScope
(
self
)
_api_internal
.
_EnterBuildConfigScope
(
self
)
...
@@ -168,9 +175,12 @@ class BuildConfig(NodeBase):
...
@@ -168,9 +175,12 @@ class BuildConfig(NodeBase):
"'
%
s' object cannot set attribute '
%
s'"
%
(
str
(
type
(
self
)),
name
))
"'
%
s' object cannot set attribute '
%
s'"
%
(
str
(
type
(
self
)),
name
))
return
super
(
BuildConfig
,
self
)
.
__setattr__
(
name
,
value
)
return
super
(
BuildConfig
,
self
)
.
__setattr__
(
name
,
value
)
def
current_build_config
():
def
current_build_config
():
"""Get the current build configuration."""
return
_api_internal
.
_GetCurrentBuildConfig
()
return
_api_internal
.
_GetCurrentBuildConfig
()
def
build_config
(
**
kwargs
):
def
build_config
(
**
kwargs
):
"""Configure the build behavior by setting config variables.
"""Configure the build behavior by setting config variables.
...
@@ -230,10 +240,7 @@ def build_config(**kwargs):
...
@@ -230,10 +240,7 @@ def build_config(**kwargs):
config
=
make
.
node
(
"BuildConfig"
,
**
node_args
)
config
=
make
.
node
(
"BuildConfig"
,
**
node_args
)
if
"add_lower_pass"
in
kwargs
:
if
"add_lower_pass"
in
kwargs
:
add_lower_pass_args
=
[]
config
.
add_lower_pass
=
kwargs
[
"add_lower_pass"
]
for
x
in
kwargs
[
"add_lower_pass"
]:
add_lower_pass_args
+=
[
x
[
0
],
x
[
1
]]
_api_internal
.
_BuildConfigSetAddLowerPass
(
config
,
*
add_lower_pass_args
)
return
config
return
config
...
...
python/tvm/contrib/rpc/server.py
View file @
9082a93b
...
@@ -166,8 +166,9 @@ def _listen_loop(sock, port, rpc_key, tracker_addr):
...
@@ -166,8 +166,9 @@ def _listen_loop(sock, port, rpc_key, tracker_addr):
conn
,
addr
,
opts
=
_accept_conn
(
sock
,
tracker_conn
)
conn
,
addr
,
opts
=
_accept_conn
(
sock
,
tracker_conn
)
except
(
socket
.
error
,
IOError
):
except
(
socket
.
error
,
IOError
):
# retry when tracker is dropped
# retry when tracker is dropped
tracker_conn
.
close
()
if
tracker_conn
:
tracker_conn
=
None
tracker_conn
.
close
()
tracker_conn
=
None
continue
continue
# step 3: serving
# step 3: serving
...
...
src/codegen/build_module.cc
View file @
9082a93b
...
@@ -37,8 +37,6 @@ Target CreateTarget(const std::string& target_name,
...
@@ -37,8 +37,6 @@ Target CreateTarget(const std::string& target_name,
t
->
target_name
=
target_name
;
t
->
target_name
=
target_name
;
std
::
string
device_name
=
""
;
std
::
string
libs_flag
=
"-libs="
;
std
::
string
libs_flag
=
"-libs="
;
std
::
string
device_flag
=
"-device="
;
std
::
string
device_flag
=
"-device="
;
for
(
auto
&
item
:
options
)
{
for
(
auto
&
item
:
options
)
{
...
@@ -51,12 +49,12 @@ Target CreateTarget(const std::string& target_name,
...
@@ -51,12 +49,12 @@ Target CreateTarget(const std::string& target_name,
t
->
libs_array
.
push_back
(
ir
::
StringImm
::
make
(
lib_item
));
t
->
libs_array
.
push_back
(
ir
::
StringImm
::
make
(
lib_item
));
}
}
}
else
if
(
item
.
find
(
device_flag
)
==
0
)
{
}
else
if
(
item
.
find
(
device_flag
)
==
0
)
{
device_name
=
item
.
substr
(
device_flag
.
length
());
t
->
device_name
=
item
.
substr
(
device_flag
.
length
());
}
}
}
}
if
(
device_name
.
length
()
>
0
)
{
if
(
t
->
device_name
.
length
()
>
0
)
{
t
->
keys_array
.
push_back
(
ir
::
StringImm
::
make
(
device_name
));
t
->
keys_array
.
push_back
(
ir
::
StringImm
::
make
(
t
->
device_name
));
}
}
t
->
device_type
=
kDLCPU
;
t
->
device_type
=
kDLCPU
;
t
->
thread_warp_size
=
1
;
t
->
thread_warp_size
=
1
;
...
@@ -78,7 +76,7 @@ Target CreateTarget(const std::string& target_name,
...
@@ -78,7 +76,7 @@ Target CreateTarget(const std::string& target_name,
t
->
keys_array
.
push_back
(
ir
::
StringImm
::
make
(
"rocm"
));
t
->
keys_array
.
push_back
(
ir
::
StringImm
::
make
(
"rocm"
));
t
->
keys_array
.
push_back
(
ir
::
StringImm
::
make
(
"gpu"
));
t
->
keys_array
.
push_back
(
ir
::
StringImm
::
make
(
"gpu"
));
t
->
max_num_threads
=
256
;
t
->
max_num_threads
=
256
;
if
(
device_name
==
"intel_gpu"
)
{
if
(
t
->
device_name
==
"intel_gpu"
)
{
t
->
thread_warp_size
=
16
;
t
->
thread_warp_size
=
16
;
}
}
}
else
if
(
target_name
==
"metal"
||
target_name
==
"vulkan"
)
{
}
else
if
(
target_name
==
"metal"
||
target_name
==
"vulkan"
)
{
...
...
tests/python/unittest/test_lang_target.py
View file @
9082a93b
...
@@ -36,6 +36,7 @@ def test_target_dispatch():
...
@@ -36,6 +36,7 @@ def test_target_dispatch():
assert
tvm
.
target
.
current_target
()
==
None
assert
tvm
.
target
.
current_target
()
==
None
def
test_target_string_parse
():
def
test_target_string_parse
():
target
=
tvm
.
target
.
create
(
"cuda -libs=cublas,cudnn"
)
target
=
tvm
.
target
.
create
(
"cuda -libs=cublas,cudnn"
)
...
@@ -45,6 +46,9 @@ def test_target_string_parse():
...
@@ -45,6 +46,9 @@ def test_target_string_parse():
assert
target
.
libs
==
[
'cublas'
,
'cudnn'
]
assert
target
.
libs
==
[
'cublas'
,
'cudnn'
]
assert
str
(
target
)
==
str
(
tvm
.
target
.
cuda
(
"-libs=cublas,cudnn"
))
assert
str
(
target
)
==
str
(
tvm
.
target
.
cuda
(
"-libs=cublas,cudnn"
))
assert
tvm
.
target
.
intel_gpu
()
.
device_name
==
"intel_gpu"
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_target_dispatch
()
test_target_dispatch
()
test_target_string_parse
()
test_target_string_parse
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment