Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
5061a6da
Commit
5061a6da
authored
Sep 12, 2017
by
Tianqi Chen
Committed by
GitHub
Sep 12, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RUNTIME] Add function to pack arguments (#452)
parent
769544ad
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
89 additions
and
5 deletions
+89
-5
src/runtime/pack_args.h
+79
-2
src/runtime/rocm/rocm_module.cc
+10
-3
No files found.
src/runtime/pack_args.h
View file @
5061a6da
...
...
@@ -15,6 +15,7 @@
#include <tvm/runtime/c_runtime_api.h>
#include <vector>
#include <cstring>
namespace
tvm
{
namespace
runtime
{
...
...
@@ -31,7 +32,7 @@ union ArgUnion {
* \brief Create a packed function from void addr types.
*
* \param f with signiture (TVMArgs args, TVMRetValue* rv, void* void_args)
* \param arg_types The arguments t
hat wish to get from
* \param arg_types The arguments t
ype information.
* \tparam F the function type
*
* \return The wrapped packed function.
...
...
@@ -42,7 +43,7 @@ inline PackedFunc PackFuncVoidAddr(F f, const std::vector<TVMType>& arg_types);
* \brief Create a packed function that from function only packs buffer arguments.
*
* \param f with signiture (TVMArgs args, TVMRetValue* rv, ArgUnion* pack_args)
* \param arg_types The arguments t
hat wish to get from
* \param arg_types The arguments t
ype information.
* \tparam F the function type
*
* \return The wrapped packed function.
...
...
@@ -50,6 +51,17 @@ inline PackedFunc PackFuncVoidAddr(F f, const std::vector<TVMType>& arg_types);
template
<
typename
F
>
inline
PackedFunc
PackFuncNonBufferArg
(
F
f
,
const
std
::
vector
<
TVMType
>&
arg_types
);
/*!
* \brief Create a packed function that from function that takes a packed arguments.
*
* \param f with signature (TVMArgs args, TVMRetValue* rv, void* pack_args, size_t nbytes)
* \param arg_types The arguments that wish to get from
* \tparam F the function type
*
* \return The wrapped packed function.
*/
template
<
typename
F
>
inline
PackedFunc
PackFuncPackedArg
(
F
f
,
const
std
::
vector
<
TVMType
>&
arg_types
);
/*!
* \brief Extract number of buffer argument from the argument types.
* \param arg_types The argument types.
* \return number of buffer arguments
...
...
@@ -179,6 +191,56 @@ inline PackedFunc PackFuncNonBufferArg_(
};
return
PackedFunc
(
ret
);
}
template
<
int
N
,
typename
F
>
inline
PackedFunc
PackFuncPackedArg_
(
F
f
,
const
std
::
vector
<
ArgConvertCode
>&
codes
)
{
int
num_args
=
static_cast
<
int
>
(
codes
.
size
());
auto
ret
=
[
f
,
codes
,
num_args
](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
TempArray
<
uint64_t
,
N
>
pack_
(
num_args
);
int32_t
*
pack
=
reinterpret_cast
<
int32_t
*>
(
pack_
.
data
());
int32_t
*
ptr
=
pack
;
static_assert
(
sizeof
(
TVMValue
)
==
8
,
"invariant"
);
static_assert
(
sizeof
(
void
*
)
%
sizeof
(
int32_t
)
==
0
,
"invariant"
);
for
(
int
i
=
0
;
i
<
num_args
;
++
i
)
{
switch
(
codes
[
i
])
{
case
HANDLE_TO_HANDLE
:
{
std
::
memcpy
(
ptr
,
&
(
args
.
values
[
i
].
v_handle
),
sizeof
(
void
*
));
ptr
+=
sizeof
(
void
*
)
/
sizeof
(
int32_t
);
break
;
}
case
INT64_TO_INT64
:
case
FLOAT64_TO_FLOAT64
:
{
std
::
memcpy
(
ptr
,
&
args
.
values
[
i
],
sizeof
(
TVMValue
));
ptr
+=
2
;
break
;
}
case
INT64_TO_INT32
:
{
*
ptr
=
static_cast
<
int32_t
>
(
args
.
values
[
i
].
v_int64
);
++
ptr
;
break
;
}
case
INT64_TO_UINT32
:
{
*
reinterpret_cast
<
uint32_t
*>
(
ptr
)
=
static_cast
<
uint32_t
>
(
args
.
values
[
i
].
v_int64
);
++
ptr
;
break
;
}
case
FLOAT64_TO_FLOAT32
:
{
*
reinterpret_cast
<
float
*>
(
ptr
)
=
static_cast
<
float
>
(
args
.
values
[
i
].
v_float64
);
++
ptr
;
break
;
}
default
:
{
LOG
(
FATAL
)
<<
"not reached"
;
break
;
}
}
}
f
(
args
,
ret
,
pack
,
(
ptr
-
pack
)
*
sizeof
(
int32_t
));
};
return
PackedFunc
(
ret
);
}
}
// namespace detail
template
<
typename
F
>
...
...
@@ -228,6 +290,21 @@ inline PackedFunc PackFuncNonBufferArg(F f, const std::vector<TVMType>& arg_type
return
detail
::
PackFuncNonBufferArg_
<
0
>
(
f
,
base
,
codes
);
}
}
template
<
typename
F
>
inline
PackedFunc
PackFuncPackedArg
(
F
f
,
const
std
::
vector
<
TVMType
>&
arg_types
)
{
std
::
vector
<
detail
::
ArgConvertCode
>
codes
;
for
(
size_t
i
=
0
;
i
<
arg_types
.
size
();
++
i
)
{
codes
.
push_back
(
detail
::
GetArgConvertCode
(
arg_types
[
i
]));
}
size_t
nargs
=
codes
.
size
();
// specialization
if
(
nargs
<=
4
)
{
return
detail
::
PackFuncPackedArg_
<
4
>
(
f
,
codes
);
}
else
{
return
detail
::
PackFuncPackedArg_
<
0
>
(
f
,
codes
);
}
}
}
// namespace runtime
}
// namespace tvm
#endif // TVM_RUNTIME_PACK_ARGS_H_
src/runtime/rocm/rocm_module.cc
View file @
5061a6da
...
...
@@ -133,7 +133,8 @@ class ROCMWrappedFunc {
// invoke the function with void arguments
void
operator
()(
TVMArgs
args
,
TVMRetValue
*
rv
,
void
**
void_args
)
const
{
void
*
packed_args
,
size_t
packed_nbytes
)
const
{
int
device_id
;
ROCM_CALL
(
hipGetDevice
(
&
device_id
));
if
(
fcache_
[
device_id
]
==
nullptr
)
{
...
...
@@ -141,6 +142,11 @@ class ROCMWrappedFunc {
}
hipStream_t
strm
=
static_cast
<
hipStream_t
>
(
ROCMThreadEntry
::
ThreadLocal
()
->
stream
);
ThreadWorkLoad
wl
=
thread_axis_cfg_
.
Extract
(
args
);
void
*
config
[]
=
{
HIP_LAUNCH_PARAM_BUFFER_POINTER
,
&
packed_args
,
HIP_LAUNCH_PARAM_BUFFER_SIZE
,
&
packed_nbytes
,
HIP_LAUNCH_PARAM_END
};
// HIP supports only extra_args.
ROCM_DRIVER_CALL
(
hipModuleLaunchKernel
(
fcache_
[
device_id
],
...
...
@@ -150,7 +156,8 @@ class ROCMWrappedFunc {
wl
.
block_dim
(
0
),
wl
.
block_dim
(
1
),
wl
.
block_dim
(
2
),
0
,
strm
,
void_args
,
0
));
0
,
strm
,
nullptr
,
reinterpret_cast
<
void
**>
(
&
config
)));
}
private
:
...
...
@@ -180,7 +187,7 @@ PackedFunc ROCMModuleNode::GetFunction(
const
FunctionInfo
&
info
=
it
->
second
;
ROCMWrappedFunc
f
;
f
.
Init
(
this
,
sptr_to_self
,
name
,
info
.
arg_types
.
size
(),
info
.
thread_axis_tags
);
return
PackFunc
VoidAddr
(
f
,
info
.
arg_types
);
return
PackFunc
PackedArg
(
f
,
info
.
arg_types
);
}
Module
ROCMModuleCreate
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment