Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
5061a6da
Commit
5061a6da
authored
Sep 12, 2017
by
Tianqi Chen
Committed by
GitHub
Sep 12, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RUNTIME] Add function to pack arguments (#452)
parent
769544ad
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
89 additions
and
5 deletions
+89
-5
src/runtime/pack_args.h
+79
-2
src/runtime/rocm/rocm_module.cc
+10
-3
No files found.
src/runtime/pack_args.h
View file @
5061a6da
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/c_runtime_api.h>
#include <vector>
#include <vector>
#include <cstring>
namespace
tvm
{
namespace
tvm
{
namespace
runtime
{
namespace
runtime
{
...
@@ -31,7 +32,7 @@ union ArgUnion {
...
@@ -31,7 +32,7 @@ union ArgUnion {
* \brief Create a packed function from void addr types.
* \brief Create a packed function from void addr types.
*
*
* \param f with signiture (TVMArgs args, TVMRetValue* rv, void* void_args)
* \param f with signiture (TVMArgs args, TVMRetValue* rv, void* void_args)
* \param arg_types The arguments t
hat wish to get from
* \param arg_types The arguments t
ype information.
* \tparam F the function type
* \tparam F the function type
*
*
* \return The wrapped packed function.
* \return The wrapped packed function.
...
@@ -42,7 +43,7 @@ inline PackedFunc PackFuncVoidAddr(F f, const std::vector<TVMType>& arg_types);
...
@@ -42,7 +43,7 @@ inline PackedFunc PackFuncVoidAddr(F f, const std::vector<TVMType>& arg_types);
* \brief Create a packed function that from function only packs buffer arguments.
* \brief Create a packed function that from function only packs buffer arguments.
*
*
* \param f with signiture (TVMArgs args, TVMRetValue* rv, ArgUnion* pack_args)
* \param f with signiture (TVMArgs args, TVMRetValue* rv, ArgUnion* pack_args)
* \param arg_types The arguments t
hat wish to get from
* \param arg_types The arguments t
ype information.
* \tparam F the function type
* \tparam F the function type
*
*
* \return The wrapped packed function.
* \return The wrapped packed function.
...
@@ -50,6 +51,17 @@ inline PackedFunc PackFuncVoidAddr(F f, const std::vector<TVMType>& arg_types);
...
@@ -50,6 +51,17 @@ inline PackedFunc PackFuncVoidAddr(F f, const std::vector<TVMType>& arg_types);
template
<
typename
F
>
template
<
typename
F
>
inline
PackedFunc
PackFuncNonBufferArg
(
F
f
,
const
std
::
vector
<
TVMType
>&
arg_types
);
inline
PackedFunc
PackFuncNonBufferArg
(
F
f
,
const
std
::
vector
<
TVMType
>&
arg_types
);
/*!
/*!
* \brief Create a packed function that from function that takes a packed arguments.
*
* \param f with signature (TVMArgs args, TVMRetValue* rv, void* pack_args, size_t nbytes)
* \param arg_types The arguments that wish to get from
* \tparam F the function type
*
* \return The wrapped packed function.
*/
template
<
typename
F
>
inline
PackedFunc
PackFuncPackedArg
(
F
f
,
const
std
::
vector
<
TVMType
>&
arg_types
);
/*!
* \brief Extract number of buffer argument from the argument types.
* \brief Extract number of buffer argument from the argument types.
* \param arg_types The argument types.
* \param arg_types The argument types.
* \return number of buffer arguments
* \return number of buffer arguments
...
@@ -179,6 +191,56 @@ inline PackedFunc PackFuncNonBufferArg_(
...
@@ -179,6 +191,56 @@ inline PackedFunc PackFuncNonBufferArg_(
};
};
return
PackedFunc
(
ret
);
return
PackedFunc
(
ret
);
}
}
template
<
int
N
,
typename
F
>
inline
PackedFunc
PackFuncPackedArg_
(
F
f
,
const
std
::
vector
<
ArgConvertCode
>&
codes
)
{
int
num_args
=
static_cast
<
int
>
(
codes
.
size
());
auto
ret
=
[
f
,
codes
,
num_args
](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
TempArray
<
uint64_t
,
N
>
pack_
(
num_args
);
int32_t
*
pack
=
reinterpret_cast
<
int32_t
*>
(
pack_
.
data
());
int32_t
*
ptr
=
pack
;
static_assert
(
sizeof
(
TVMValue
)
==
8
,
"invariant"
);
static_assert
(
sizeof
(
void
*
)
%
sizeof
(
int32_t
)
==
0
,
"invariant"
);
for
(
int
i
=
0
;
i
<
num_args
;
++
i
)
{
switch
(
codes
[
i
])
{
case
HANDLE_TO_HANDLE
:
{
std
::
memcpy
(
ptr
,
&
(
args
.
values
[
i
].
v_handle
),
sizeof
(
void
*
));
ptr
+=
sizeof
(
void
*
)
/
sizeof
(
int32_t
);
break
;
}
case
INT64_TO_INT64
:
case
FLOAT64_TO_FLOAT64
:
{
std
::
memcpy
(
ptr
,
&
args
.
values
[
i
],
sizeof
(
TVMValue
));
ptr
+=
2
;
break
;
}
case
INT64_TO_INT32
:
{
*
ptr
=
static_cast
<
int32_t
>
(
args
.
values
[
i
].
v_int64
);
++
ptr
;
break
;
}
case
INT64_TO_UINT32
:
{
*
reinterpret_cast
<
uint32_t
*>
(
ptr
)
=
static_cast
<
uint32_t
>
(
args
.
values
[
i
].
v_int64
);
++
ptr
;
break
;
}
case
FLOAT64_TO_FLOAT32
:
{
*
reinterpret_cast
<
float
*>
(
ptr
)
=
static_cast
<
float
>
(
args
.
values
[
i
].
v_float64
);
++
ptr
;
break
;
}
default
:
{
LOG
(
FATAL
)
<<
"not reached"
;
break
;
}
}
}
f
(
args
,
ret
,
pack
,
(
ptr
-
pack
)
*
sizeof
(
int32_t
));
};
return
PackedFunc
(
ret
);
}
}
// namespace detail
}
// namespace detail
template
<
typename
F
>
template
<
typename
F
>
...
@@ -228,6 +290,21 @@ inline PackedFunc PackFuncNonBufferArg(F f, const std::vector<TVMType>& arg_type
...
@@ -228,6 +290,21 @@ inline PackedFunc PackFuncNonBufferArg(F f, const std::vector<TVMType>& arg_type
return
detail
::
PackFuncNonBufferArg_
<
0
>
(
f
,
base
,
codes
);
return
detail
::
PackFuncNonBufferArg_
<
0
>
(
f
,
base
,
codes
);
}
}
}
}
template
<
typename
F
>
inline
PackedFunc
PackFuncPackedArg
(
F
f
,
const
std
::
vector
<
TVMType
>&
arg_types
)
{
std
::
vector
<
detail
::
ArgConvertCode
>
codes
;
for
(
size_t
i
=
0
;
i
<
arg_types
.
size
();
++
i
)
{
codes
.
push_back
(
detail
::
GetArgConvertCode
(
arg_types
[
i
]));
}
size_t
nargs
=
codes
.
size
();
// specialization
if
(
nargs
<=
4
)
{
return
detail
::
PackFuncPackedArg_
<
4
>
(
f
,
codes
);
}
else
{
return
detail
::
PackFuncPackedArg_
<
0
>
(
f
,
codes
);
}
}
}
// namespace runtime
}
// namespace runtime
}
// namespace tvm
}
// namespace tvm
#endif // TVM_RUNTIME_PACK_ARGS_H_
#endif // TVM_RUNTIME_PACK_ARGS_H_
src/runtime/rocm/rocm_module.cc
View file @
5061a6da
...
@@ -133,7 +133,8 @@ class ROCMWrappedFunc {
...
@@ -133,7 +133,8 @@ class ROCMWrappedFunc {
// invoke the function with void arguments
// invoke the function with void arguments
void
operator
()(
TVMArgs
args
,
void
operator
()(
TVMArgs
args
,
TVMRetValue
*
rv
,
TVMRetValue
*
rv
,
void
**
void_args
)
const
{
void
*
packed_args
,
size_t
packed_nbytes
)
const
{
int
device_id
;
int
device_id
;
ROCM_CALL
(
hipGetDevice
(
&
device_id
));
ROCM_CALL
(
hipGetDevice
(
&
device_id
));
if
(
fcache_
[
device_id
]
==
nullptr
)
{
if
(
fcache_
[
device_id
]
==
nullptr
)
{
...
@@ -141,6 +142,11 @@ class ROCMWrappedFunc {
...
@@ -141,6 +142,11 @@ class ROCMWrappedFunc {
}
}
hipStream_t
strm
=
static_cast
<
hipStream_t
>
(
ROCMThreadEntry
::
ThreadLocal
()
->
stream
);
hipStream_t
strm
=
static_cast
<
hipStream_t
>
(
ROCMThreadEntry
::
ThreadLocal
()
->
stream
);
ThreadWorkLoad
wl
=
thread_axis_cfg_
.
Extract
(
args
);
ThreadWorkLoad
wl
=
thread_axis_cfg_
.
Extract
(
args
);
void
*
config
[]
=
{
HIP_LAUNCH_PARAM_BUFFER_POINTER
,
&
packed_args
,
HIP_LAUNCH_PARAM_BUFFER_SIZE
,
&
packed_nbytes
,
HIP_LAUNCH_PARAM_END
};
// HIP supports only extra_args.
// HIP supports only extra_args.
ROCM_DRIVER_CALL
(
hipModuleLaunchKernel
(
ROCM_DRIVER_CALL
(
hipModuleLaunchKernel
(
fcache_
[
device_id
],
fcache_
[
device_id
],
...
@@ -150,7 +156,8 @@ class ROCMWrappedFunc {
...
@@ -150,7 +156,8 @@ class ROCMWrappedFunc {
wl
.
block_dim
(
0
),
wl
.
block_dim
(
0
),
wl
.
block_dim
(
1
),
wl
.
block_dim
(
1
),
wl
.
block_dim
(
2
),
wl
.
block_dim
(
2
),
0
,
strm
,
void_args
,
0
));
0
,
strm
,
nullptr
,
reinterpret_cast
<
void
**>
(
&
config
)));
}
}
private
:
private
:
...
@@ -180,7 +187,7 @@ PackedFunc ROCMModuleNode::GetFunction(
...
@@ -180,7 +187,7 @@ PackedFunc ROCMModuleNode::GetFunction(
const
FunctionInfo
&
info
=
it
->
second
;
const
FunctionInfo
&
info
=
it
->
second
;
ROCMWrappedFunc
f
;
ROCMWrappedFunc
f
;
f
.
Init
(
this
,
sptr_to_self
,
name
,
info
.
arg_types
.
size
(),
info
.
thread_axis_tags
);
f
.
Init
(
this
,
sptr_to_self
,
name
,
info
.
arg_types
.
size
(),
info
.
thread_axis_tags
);
return
PackFunc
VoidAddr
(
f
,
info
.
arg_types
);
return
PackFunc
PackedArg
(
f
,
info
.
arg_types
);
}
}
Module
ROCMModuleCreate
(
Module
ROCMModuleCreate
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment