Commit 5061a6da by Tianqi Chen Committed by GitHub

[RUNTIME] Add function to pack arguments (#452)

parent 769544ad
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include <tvm/runtime/c_runtime_api.h> #include <tvm/runtime/c_runtime_api.h>
#include <vector> #include <vector>
#include <cstring>
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
...@@ -31,7 +32,7 @@ union ArgUnion { ...@@ -31,7 +32,7 @@ union ArgUnion {
* \brief Create a packed function from void addr types. * \brief Create a packed function from void addr types.
* *
* \param f with signiture (TVMArgs args, TVMRetValue* rv, void* void_args) * \param f with signiture (TVMArgs args, TVMRetValue* rv, void* void_args)
* \param arg_types The arguments that wish to get from * \param arg_types The arguments type information.
* \tparam F the function type * \tparam F the function type
* *
* \return The wrapped packed function. * \return The wrapped packed function.
...@@ -42,7 +43,7 @@ inline PackedFunc PackFuncVoidAddr(F f, const std::vector<TVMType>& arg_types); ...@@ -42,7 +43,7 @@ inline PackedFunc PackFuncVoidAddr(F f, const std::vector<TVMType>& arg_types);
* \brief Create a packed function that from function only packs buffer arguments. * \brief Create a packed function that from function only packs buffer arguments.
* *
* \param f with signiture (TVMArgs args, TVMRetValue* rv, ArgUnion* pack_args) * \param f with signiture (TVMArgs args, TVMRetValue* rv, ArgUnion* pack_args)
* \param arg_types The arguments that wish to get from * \param arg_types The arguments type information.
* \tparam F the function type * \tparam F the function type
* *
* \return The wrapped packed function. * \return The wrapped packed function.
...@@ -50,6 +51,17 @@ inline PackedFunc PackFuncVoidAddr(F f, const std::vector<TVMType>& arg_types); ...@@ -50,6 +51,17 @@ inline PackedFunc PackFuncVoidAddr(F f, const std::vector<TVMType>& arg_types);
template<typename F> template<typename F>
inline PackedFunc PackFuncNonBufferArg(F f, const std::vector<TVMType>& arg_types); inline PackedFunc PackFuncNonBufferArg(F f, const std::vector<TVMType>& arg_types);
/*! /*!
* \brief Create a packed function that from function that takes a packed arguments.
*
* \param f with signature (TVMArgs args, TVMRetValue* rv, void* pack_args, size_t nbytes)
* \param arg_types The arguments that wish to get from
* \tparam F the function type
*
* \return The wrapped packed function.
*/
template<typename F>
inline PackedFunc PackFuncPackedArg(F f, const std::vector<TVMType>& arg_types);
/*!
* \brief Extract number of buffer argument from the argument types. * \brief Extract number of buffer argument from the argument types.
* \param arg_types The argument types. * \param arg_types The argument types.
* \return number of buffer arguments * \return number of buffer arguments
...@@ -179,6 +191,56 @@ inline PackedFunc PackFuncNonBufferArg_( ...@@ -179,6 +191,56 @@ inline PackedFunc PackFuncNonBufferArg_(
}; };
return PackedFunc(ret); return PackedFunc(ret);
} }
template<int N, typename F>
inline PackedFunc PackFuncPackedArg_(
F f, const std::vector<ArgConvertCode>& codes) {
int num_args = static_cast<int>(codes.size());
auto ret = [f, codes, num_args](TVMArgs args, TVMRetValue* ret) {
TempArray<uint64_t, N> pack_(num_args);
int32_t* pack = reinterpret_cast<int32_t*>(pack_.data());
int32_t* ptr = pack;
static_assert(sizeof(TVMValue) == 8, "invariant");
static_assert(sizeof(void*) % sizeof(int32_t) == 0, "invariant");
for (int i = 0; i < num_args; ++i) {
switch (codes[i]) {
case HANDLE_TO_HANDLE: {
std::memcpy(ptr, &(args.values[i].v_handle), sizeof(void*));
ptr += sizeof(void*) / sizeof(int32_t);
break;
}
case INT64_TO_INT64:
case FLOAT64_TO_FLOAT64: {
std::memcpy(ptr, &args.values[i], sizeof(TVMValue));
ptr += 2;
break;
}
case INT64_TO_INT32: {
*ptr = static_cast<int32_t>(args.values[i].v_int64);
++ptr;
break;
}
case INT64_TO_UINT32 : {
*reinterpret_cast<uint32_t*>(ptr) =
static_cast<uint32_t>(args.values[i].v_int64);
++ptr;
break;
}
case FLOAT64_TO_FLOAT32: {
*reinterpret_cast<float*>(ptr) =
static_cast<float>(args.values[i].v_float64);
++ptr;
break;
}
default: {
LOG(FATAL) << "not reached"; break;
}
}
}
f(args, ret, pack, (ptr - pack) * sizeof(int32_t));
};
return PackedFunc(ret);
}
} // namespace detail } // namespace detail
template<typename F> template<typename F>
...@@ -228,6 +290,21 @@ inline PackedFunc PackFuncNonBufferArg(F f, const std::vector<TVMType>& arg_type ...@@ -228,6 +290,21 @@ inline PackedFunc PackFuncNonBufferArg(F f, const std::vector<TVMType>& arg_type
return detail::PackFuncNonBufferArg_<0>(f, base, codes); return detail::PackFuncNonBufferArg_<0>(f, base, codes);
} }
} }
template<typename F>
inline PackedFunc PackFuncPackedArg(F f, const std::vector<TVMType>& arg_types) {
std::vector<detail::ArgConvertCode> codes;
for (size_t i = 0; i < arg_types.size(); ++i) {
codes.push_back(detail::GetArgConvertCode(arg_types[i]));
}
size_t nargs = codes.size();
// specialization
if (nargs <= 4) {
return detail::PackFuncPackedArg_<4>(f, codes);
} else {
return detail::PackFuncPackedArg_<0>(f, codes);
}
}
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
#endif // TVM_RUNTIME_PACK_ARGS_H_ #endif // TVM_RUNTIME_PACK_ARGS_H_
...@@ -133,7 +133,8 @@ class ROCMWrappedFunc { ...@@ -133,7 +133,8 @@ class ROCMWrappedFunc {
// invoke the function with void arguments // invoke the function with void arguments
void operator()(TVMArgs args, void operator()(TVMArgs args,
TVMRetValue* rv, TVMRetValue* rv,
void** void_args) const { void* packed_args,
size_t packed_nbytes) const {
int device_id; int device_id;
ROCM_CALL(hipGetDevice(&device_id)); ROCM_CALL(hipGetDevice(&device_id));
if (fcache_[device_id] == nullptr) { if (fcache_[device_id] == nullptr) {
...@@ -141,6 +142,11 @@ class ROCMWrappedFunc { ...@@ -141,6 +142,11 @@ class ROCMWrappedFunc {
} }
hipStream_t strm = static_cast<hipStream_t>(ROCMThreadEntry::ThreadLocal()->stream); hipStream_t strm = static_cast<hipStream_t>(ROCMThreadEntry::ThreadLocal()->stream);
ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
void* config[] = {
HIP_LAUNCH_PARAM_BUFFER_POINTER, &packed_args,
HIP_LAUNCH_PARAM_BUFFER_SIZE, &packed_nbytes,
HIP_LAUNCH_PARAM_END
};
// HIP supports only extra_args. // HIP supports only extra_args.
ROCM_DRIVER_CALL(hipModuleLaunchKernel( ROCM_DRIVER_CALL(hipModuleLaunchKernel(
fcache_[device_id], fcache_[device_id],
...@@ -150,7 +156,8 @@ class ROCMWrappedFunc { ...@@ -150,7 +156,8 @@ class ROCMWrappedFunc {
wl.block_dim(0), wl.block_dim(0),
wl.block_dim(1), wl.block_dim(1),
wl.block_dim(2), wl.block_dim(2),
0, strm, void_args, 0)); 0, strm, nullptr,
reinterpret_cast<void**>(&config)));
} }
private: private:
...@@ -180,7 +187,7 @@ PackedFunc ROCMModuleNode::GetFunction( ...@@ -180,7 +187,7 @@ PackedFunc ROCMModuleNode::GetFunction(
const FunctionInfo& info = it->second; const FunctionInfo& info = it->second;
ROCMWrappedFunc f; ROCMWrappedFunc f;
f.Init(this, sptr_to_self, name, info.arg_types.size(), info.thread_axis_tags); f.Init(this, sptr_to_self, name, info.arg_types.size(), info.thread_axis_tags);
return PackFuncVoidAddr(f, info.arg_types); return PackFuncPackedArg(f, info.arg_types);
} }
Module ROCMModuleCreate( Module ROCMModuleCreate(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment