Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
10ae8ee1
Unverified
Commit
10ae8ee1
authored
Sep 14, 2018
by
Tianqi Chen
Committed by
GitHub
Sep 14, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RUNTIME] Support TVMContext (#1720)
parent
dd9589ec
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
50 additions
and
6 deletions
+50
-6
include/tvm/runtime/packed_func.h
+5
-0
python/tvm/_ffi/_ctypes/function.py
+2
-2
python/tvm/_ffi/_ctypes/types.py
+22
-4
src/api/api_test.cc
+10
-0
tests/python/unittest/test_runtime_packed_func.py
+11
-0
No files found.
include/tvm/runtime/packed_func.h
View file @
10ae8ee1
...
@@ -646,6 +646,11 @@ class TVMRetValue : public TVMPODValue_ {
...
@@ -646,6 +646,11 @@ class TVMRetValue : public TVMPODValue_ {
value_
.
v_int64
=
value
;
value_
.
v_int64
=
value
;
return
*
this
;
return
*
this
;
}
}
TVMRetValue
&
operator
=
(
TVMContext
value
)
{
this
->
SwitchToPOD
(
kTVMContext
);
value_
.
v_ctx
=
value
;
return
*
this
;
}
TVMRetValue
&
operator
=
(
TVMType
t
)
{
TVMRetValue
&
operator
=
(
TVMType
t
)
{
this
->
SwitchToPOD
(
kTVMType
);
this
->
SwitchToPOD
(
kTVMType
);
value_
.
v_type
=
t
;
value_
.
v_type
=
t
;
...
...
python/tvm/_ffi/_ctypes/function.py
View file @
10ae8ee1
...
@@ -15,7 +15,7 @@ from . import ndarray as _nd
...
@@ -15,7 +15,7 @@ from . import ndarray as _nd
from
.ndarray
import
NDArrayBase
,
_make_array
from
.ndarray
import
NDArrayBase
,
_make_array
from
.types
import
TVMValue
,
TypeCode
from
.types
import
TVMValue
,
TypeCode
from
.types
import
TVMPackedCFunc
,
TVMCFuncFinalizer
from
.types
import
TVMPackedCFunc
,
TVMCFuncFinalizer
from
.types
import
RETURN_SWITCH
,
C_TO_PY_ARG_SWITCH
,
_wrap_arg_func
from
.types
import
RETURN_SWITCH
,
C_TO_PY_ARG_SWITCH
,
_wrap_arg_func
,
_ctx_to_int64
from
.node
import
NodeBase
from
.node
import
NodeBase
from
.
import
node
as
_node
from
.
import
node
as
_node
...
@@ -110,7 +110,7 @@ def _make_tvm_args(args, temp_args):
...
@@ -110,7 +110,7 @@ def _make_tvm_args(args, temp_args):
values
[
i
]
.
v_str
=
c_str
(
str
(
arg
))
values
[
i
]
.
v_str
=
c_str
(
str
(
arg
))
type_codes
[
i
]
=
TypeCode
.
STR
type_codes
[
i
]
=
TypeCode
.
STR
elif
isinstance
(
arg
,
TVMContext
):
elif
isinstance
(
arg
,
TVMContext
):
values
[
i
]
.
v_
ctx
=
arg
values
[
i
]
.
v_
int64
=
_ctx_to_int64
(
arg
)
type_codes
[
i
]
=
TypeCode
.
TVM_CONTEXT
type_codes
[
i
]
=
TypeCode
.
TVM_CONTEXT
elif
isinstance
(
arg
,
bytearray
):
elif
isinstance
(
arg
,
bytearray
):
arr
=
TVMByteArray
()
arr
=
TVMByteArray
()
...
...
python/tvm/_ffi/_ctypes/types.py
View file @
10ae8ee1
...
@@ -3,8 +3,9 @@
...
@@ -3,8 +3,9 @@
from
__future__
import
absolute_import
as
_abs
from
__future__
import
absolute_import
as
_abs
import
ctypes
import
ctypes
import
struct
from
..base
import
py_str
,
check_call
,
_LIB
from
..base
import
py_str
,
check_call
,
_LIB
from
..runtime_ctypes
import
TVMByteArray
,
TypeCode
from
..runtime_ctypes
import
TVMByteArray
,
TypeCode
,
TVMContext
class
TVMValue
(
ctypes
.
Union
):
class
TVMValue
(
ctypes
.
Union
):
"""TVMValue in C API"""
"""TVMValue in C API"""
...
@@ -36,7 +37,7 @@ def _return_handle(x):
...
@@ -36,7 +37,7 @@ def _return_handle(x):
return
handle
return
handle
def
_return_bytes
(
x
):
def
_return_bytes
(
x
):
"""return
handle
"""
"""return
bytes
"""
handle
=
x
.
v_handle
handle
=
x
.
v_handle
if
not
isinstance
(
handle
,
ctypes
.
c_void_p
):
if
not
isinstance
(
handle
,
ctypes
.
c_void_p
):
handle
=
ctypes
.
c_void_p
(
handle
)
handle
=
ctypes
.
c_void_p
(
handle
)
...
@@ -48,6 +49,15 @@ def _return_bytes(x):
...
@@ -48,6 +49,15 @@ def _return_bytes(x):
raise
RuntimeError
(
'memmove failed'
)
raise
RuntimeError
(
'memmove failed'
)
return
res
return
res
def
_return_context
(
value
):
"""return TVMContext"""
# use bit unpacking from int64 view
# We use this to get around ctypes issue on Union of Structure
data
=
struct
.
pack
(
"=q"
,
value
.
v_int64
)
arr
=
struct
.
unpack
(
"=ii"
,
data
)
return
TVMContext
(
arr
[
0
],
arr
[
1
])
def
_wrap_arg_func
(
return_f
,
type_code
):
def
_wrap_arg_func
(
return_f
,
type_code
):
tcode
=
ctypes
.
c_int
(
type_code
)
tcode
=
ctypes
.
c_int
(
type_code
)
def
_wrap_func
(
x
):
def
_wrap_func
(
x
):
...
@@ -55,13 +65,20 @@ def _wrap_arg_func(return_f, type_code):
...
@@ -55,13 +65,20 @@ def _wrap_arg_func(return_f, type_code):
return
return_f
(
x
)
return
return_f
(
x
)
return
_wrap_func
return
_wrap_func
def
_ctx_to_int64
(
ctx
):
"""Pack context into int64 in native endian"""
data
=
struct
.
pack
(
"=ii"
,
ctx
.
device_type
,
ctx
.
device_id
)
return
struct
.
unpack
(
"=q"
,
data
)[
0
]
RETURN_SWITCH
=
{
RETURN_SWITCH
=
{
TypeCode
.
INT
:
lambda
x
:
x
.
v_int64
,
TypeCode
.
INT
:
lambda
x
:
x
.
v_int64
,
TypeCode
.
FLOAT
:
lambda
x
:
x
.
v_float64
,
TypeCode
.
FLOAT
:
lambda
x
:
x
.
v_float64
,
TypeCode
.
HANDLE
:
_return_handle
,
TypeCode
.
HANDLE
:
_return_handle
,
TypeCode
.
NULL
:
lambda
x
:
None
,
TypeCode
.
NULL
:
lambda
x
:
None
,
TypeCode
.
STR
:
lambda
x
:
py_str
(
x
.
v_str
),
TypeCode
.
STR
:
lambda
x
:
py_str
(
x
.
v_str
),
TypeCode
.
BYTES
:
_return_bytes
TypeCode
.
BYTES
:
_return_bytes
,
TypeCode
.
TVM_CONTEXT
:
_return_context
}
}
C_TO_PY_ARG_SWITCH
=
{
C_TO_PY_ARG_SWITCH
=
{
...
@@ -70,5 +87,6 @@ C_TO_PY_ARG_SWITCH = {
...
@@ -70,5 +87,6 @@ C_TO_PY_ARG_SWITCH = {
TypeCode
.
HANDLE
:
_return_handle
,
TypeCode
.
HANDLE
:
_return_handle
,
TypeCode
.
NULL
:
lambda
x
:
None
,
TypeCode
.
NULL
:
lambda
x
:
None
,
TypeCode
.
STR
:
lambda
x
:
py_str
(
x
.
v_str
),
TypeCode
.
STR
:
lambda
x
:
py_str
(
x
.
v_str
),
TypeCode
.
BYTES
:
_return_bytes
TypeCode
.
BYTES
:
_return_bytes
,
TypeCode
.
TVM_CONTEXT
:
_return_context
}
}
src/api/api_test.cc
View file @
10ae8ee1
...
@@ -35,6 +35,16 @@ TVM_REGISTER_API("_nop")
...
@@ -35,6 +35,16 @@ TVM_REGISTER_API("_nop")
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
});
});
TVM_REGISTER_API
(
"_context_test"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
DLContext
ctx
=
args
[
0
];
int
dtype
=
args
[
1
];
int
did
=
args
[
2
];
CHECK_EQ
(
static_cast
<
int
>
(
ctx
.
device_type
),
dtype
);
CHECK_EQ
(
static_cast
<
int
>
(
ctx
.
device_id
),
did
);
*
ret
=
ctx
;
});
// internal fucntion used for debug and testing purposes
// internal fucntion used for debug and testing purposes
TVM_REGISTER_API
(
"_ndarray_use_count"
)
TVM_REGISTER_API
(
"_ndarray_use_count"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
...
...
tests/python/unittest/test_runtime_packed_func.py
View file @
10ae8ee1
...
@@ -70,6 +70,16 @@ def test_empty_array():
...
@@ -70,6 +70,16 @@ def test_empty_array():
tvm
.
convert
(
myfunc
)(
x
)
tvm
.
convert
(
myfunc
)(
x
)
def
test_ctx
():
def
test_ctx_func
(
ctx
):
assert
tvm
.
gpu
(
7
)
==
ctx
return
tvm
.
cpu
(
0
)
x
=
test_ctx_func
(
tvm
.
gpu
(
7
))
assert
x
==
tvm
.
cpu
(
0
)
x
=
tvm
.
opencl
(
10
)
x
=
tvm
.
_api_internal
.
_context_test
(
x
,
x
.
device_type
,
x
.
device_id
)
assert
x
==
tvm
.
opencl
(
10
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_empty_array
()
test_empty_array
()
test_get_global
()
test_get_global
()
...
@@ -77,3 +87,4 @@ if __name__ == "__main__":
...
@@ -77,3 +87,4 @@ if __name__ == "__main__":
test_convert
()
test_convert
()
test_return_func
()
test_return_func
()
test_byte_array
()
test_byte_array
()
test_ctx
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment