Unverified Commit 16d3c1f6 by Tianqi Chen Committed by GitHub

[RUNTIME] Add TypedPackedFunc (#1626)

parent 21e13010
......@@ -118,6 +118,163 @@ class PackedFunc {
FType body_;
};
/*!
* \brief Please refer to \ref TypedPackedFuncAnchor "TypedPackedFunc<R(Args..)>"
*/
template<typename FType>
class TypedPackedFunc;
/*!
* \anchor TypedPackedFuncAnchor
* \brief A PackedFunc wrapper to provide typed function signature.
* It is backed by a PackedFunc internally.
*
* TypedPackedFunc enables compile time type checking.
* TypedPackedFunc works with the runtime system:
* - It can be passed as an argument of PackedFunc.
* - It can be assigned to TVMRetValue.
* - It can be directly converted to a type-erased PackedFunc.
*
* Developers should prefer TypedPackedFunc over PackedFunc in C++ code
* as it enables compile time checking.
* We can construct a TypedPackedFunc from a lambda function
* with the same signature.
*
* \code
* // user defined lambda function.
* auto addone = [](int x)->int {
* return x + 1;
* };
* // We can directly convert
* // lambda function to TypedPackedFunc
* TypedPackedFunc<int(int)> ftyped(addone);
* // invoke the function.
* int y = ftyped(1);
* // Can be directly converted to PackedFunc
* PackedFunc packed = ftype;
* \endcode
* \tparam R The return value of the function.
* \tparam Args The argument signature of the function.
*/
template<typename R, typename ...Args>
class TypedPackedFunc<R(Args...)> {
public:
/*! \brief short hand for this function type */
using TSelf = TypedPackedFunc<R(Args...)>;
/*! \brief default constructor */
TypedPackedFunc() {}
/*!
* \brief construct by wrap a PackedFunc
*
* Example usage:
* \code
* PackedFunc packed([](TVMArgs args, TVMRetValue *rv) {
* int x = args[0];
* *rv = x + 1;
* });
* // construct from packed function
* TypedPackedFunc<int(int)> ftyped(packed);
* // call the typed version.
* CHECK_EQ(ftyped(1), 2);
* \endcode
*
* \param packed The packed function
*/
explicit TypedPackedFunc(PackedFunc packed)
: packed_(packed) {
}
/*!
* \brief construct from a lambda function with the same signature.
*
* Example usage:
* \code
* auto typed_lambda = [](int x)->int { return x + 1; }
* // construct from packed function
* TypedPackedFunc<int(int)> ftyped(typed_lambda);
* // call the typed version.
* CHECK_EQ(ftyped(1), 2);
* \endcode
*
* \param typed_lambda typed lambda function.
* \tparam FLambda the type of the lambda function.
*/
template<typename FLambda,
typename = typename std::enable_if<
std::is_convertible<FLambda,
std::function<R(Args...)>
>::value>::type>
explicit TypedPackedFunc(const FLambda& typed_lambda) {
this->AssignTypedLambda(typed_lambda);
}
/*!
* \brief copy assignment operator from typed lambda
*
* Example usage:
* \code
* // construct from packed function
* TypedPackedFunc<int(int)> ftyped;
* ftyped = [](int x) { return x + 1; }
* // call the typed version.
* CHECK_EQ(ftyped(1), 2);
* \endcode
*
* \param typed_lambda typed lambda function.
* \tparam FLambda the type of the lambda function.
* \returns reference to self.
*/
template<typename FLambda,
typename = typename std::enable_if<
std::is_convertible<FLambda,
std::function<R(Args...)>
>::value>::type>
TSelf& operator=(FLambda typed_lambda) { // NOLINT(*)
this->AssignTypedLambda(typed_lambda);
return *this;
}
/*!
* \brief copy assignment operator from PackedFunc.
* \param packed The packed function.
* \returns reference to self.
*/
TSelf& operator=(PackedFunc packed) {
packed_ = packed;
return *this;
}
/*!
* \brief Invoke the operator.
* \param args The arguments
* \returns The return value.
*/
inline R operator()(Args ...args) const;
/*!
* \brief convert to PackedFunc
* \return the internal PackedFunc
*/
operator PackedFunc() const {
return packed();
}
/*!
* \return reference the internal PackedFunc
*/
const PackedFunc& packed() const {
return packed_;
}
private:
friend class TVMRetValue;
/*! \brief The internal packed function */
PackedFunc packed_;
/*!
* \brief Assign the packed field using a typed lambda function.
*
* \param flambda The lambda function.
* \tparam FLambda The lambda function type.
* \note We capture the lambda when possible for maximum efficiency.
*/
template<typename FLambda>
inline void AssignTypedLambda(FLambda flambda);
};
/*! \brief Arguments into TVM functions. */
class TVMArgs {
public:
......@@ -361,6 +518,10 @@ class TVMArgValue : public TVMPODValue_ {
TVM_CHECK_TYPE_CODE(type_code_, kFuncHandle);
return *ptr<PackedFunc>();
}
template<typename FType>
operator TypedPackedFunc<FType>() const {
return TypedPackedFunc<FType>(operator PackedFunc());
}
operator Module() const {
TVM_CHECK_TYPE_CODE(type_code_, kModuleHandle);
return *ptr<Module>();
......@@ -446,6 +607,10 @@ class TVMRetValue : public TVMPODValue_ {
TVM_CHECK_TYPE_CODE(type_code_, kFuncHandle);
return *ptr<PackedFunc>();
}
template<typename FType>
operator TypedPackedFunc<FType>() const {
return TypedPackedFunc<FType>(operator PackedFunc());
}
operator Module() const {
TVM_CHECK_TYPE_CODE(type_code_, kModuleHandle);
return *ptr<Module>();
......@@ -512,6 +677,10 @@ class TVMRetValue : public TVMPODValue_ {
this->SwitchToClass(kFuncHandle, f);
return *this;
}
template<typename FType>
TVMRetValue& operator=(const TypedPackedFunc<FType>& f) {
return operator=(f.packed());
}
TVMRetValue& operator=(Module m) {
this->SwitchToClass(kModuleHandle, m);
return *this;
......@@ -847,6 +1016,10 @@ class TVMArgsSetter {
values_[i].v_handle = const_cast<PackedFunc*>(&value);
type_codes_[i] = kFuncHandle;
}
template<typename FType>
void operator()(size_t i, const TypedPackedFunc<FType>& value) const { // NOLINT(*)
operator()(i, value.packed());
}
void operator()(size_t i, const Module& value) const { // NOLINT(*)
values_[i].v_handle = const_cast<Module*>(&value);
type_codes_[i] = kModuleHandle;
......@@ -894,6 +1067,84 @@ inline TVMRetValue PackedFunc::operator()(Args&& ...args) const {
return rv;
}
namespace detail {
template<typename R, int nleft, int index, typename F>
struct unpack_call_dispatcher {
template<typename ...Args>
static void run(const F& f,
const TVMArgs& args_pack,
TVMRetValue* rv,
Args&&... unpacked_args) {
unpack_call_dispatcher<R, nleft - 1, index + 1, F>
::run(f, args_pack, rv,
std::forward<Args>(unpacked_args)...,
args_pack[index]);
}
};
template<typename R, int index, typename F>
struct unpack_call_dispatcher<R, 0, index, F> {
template<typename ...Args>
static void run(const F& f,
const TVMArgs& args_pack,
TVMRetValue* rv,
Args&&... unpacked_args) {
*rv = R(f(std::forward<Args>(unpacked_args)...));
}
};
template<int index, typename F>
struct unpack_call_dispatcher<void, 0, index, F> {
template<typename ...Args>
static void run(const F& f,
const TVMArgs& args_pack,
TVMRetValue* rv,
Args&&... unpacked_args) {
f(std::forward<Args>(unpacked_args)...);
}
};
template<typename R, int nargs, typename F>
inline void unpack_call(const F& f, const TVMArgs& args, TVMRetValue* rv) {
unpack_call_dispatcher<R, nargs, 0, F>::run(f, args, rv);
}
template<typename R, typename ...Args>
inline R call_packed(const PackedFunc& pf, Args&& ...args) {
return R(pf(std::forward<Args>(args)...));
}
template<typename R>
struct typed_packed_call_dispatcher {
template<typename ...Args>
static inline R run(const PackedFunc& pf, Args&& ...args) {
return pf(std::forward<Args>(args)...);
}
};
template<>
struct typed_packed_call_dispatcher<void> {
template<typename ...Args>
static inline void run(const PackedFunc& pf, Args&& ...args) {
pf(std::forward<Args>(args)...);
}
};
} // namespace detail
template<typename R, typename ...Args>
template<typename FType>
inline void TypedPackedFunc<R(Args...)>::AssignTypedLambda(FType flambda) {
packed_ = PackedFunc([flambda](const TVMArgs& args, TVMRetValue* rv) {
detail::unpack_call<R, sizeof...(Args)>(flambda, args, rv);
});
}
template<typename R, typename ...Args>
inline R TypedPackedFunc<R(Args...)>::operator()(Args... args) const {
return detail::typed_packed_call_dispatcher<R>
::run(packed_, std::forward<Args>(args)...);
}
// extension and node type handling
namespace detail {
template<typename T, typename TSrc, bool is_ext>
......
......@@ -135,6 +135,29 @@ TEST(PackedFunc, Type) {
CHECK(get_type2("float32x2").operator Type() == Float(32, 2));
}
TEST(TypedPackedFunc, HighOrder) {
using namespace tvm;
using namespace tvm::runtime;
using Int1Func = TypedPackedFunc<int(int)>;
using Int2Func = TypedPackedFunc<int(int, int)>;
using BindFunc = TypedPackedFunc<Int1Func(Int2Func, int value)>;
BindFunc ftyped;
ftyped = [](Int2Func f1, int value) -> Int1Func {
auto binded = [f1, value](int x) {
return f1(value, x);
};
Int1Func x(binded);
return x;
};
auto add = [](int x, int y) { return x + y; };
CHECK_EQ(ftyped(Int2Func(add), 1)(2), 3);
PackedFunc f = ftyped(Int2Func(add), 1);
CHECK_EQ(f(3).operator int(), 4);
// call the type erased version.
Int1Func f1 = ftyped.packed()(Int2Func(add), 1);
CHECK_EQ(f1(3), 4);
}
// new namespoace
namespace test {
// register int vector as extension type
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment