Unverified Commit b72dd9d9 by Tianqi Chen Committed by GitHub

[RUNTIME] Introduce RValue reference(move) support to TypedPackedFunc (#5271)

* [RUNTIME] Introduce RValue reference(move) support to TypedPackedFunc

This PR introduces RValue reference support the PackedFunc calling convention to address the above issue.
Specifically, when an argument is a r-value reference, we will use a assign a different type code(`kObjectRValueRefArg`),
and pass `Object**`  (the address to the Object pointer) instead through the values array.
The callee can choose to move out this Object pointer and set the original Object pointer from the caller side to be nullptr.

We also add an experimental move support to the python side(marked as _move so to indicate the dev nature).
This enhancement will enable copy on write optimizations through out the TVM stack.

* Address review comments

* fix compilation
parent 575d5369
...@@ -123,7 +123,7 @@ class PrimExpr : public BaseExpr { ...@@ -123,7 +123,7 @@ class PrimExpr : public BaseExpr {
private: private:
// Internal function for conversion. // Internal function for conversion.
friend class runtime::TVMPODValue_; friend struct runtime::PackedFuncValueConverter<PrimExpr>;
TVM_DLL static PrimExpr FromObject_(ObjectPtr<Object> ptr); TVM_DLL static PrimExpr FromObject_(ObjectPtr<Object> ptr);
}; };
...@@ -451,22 +451,24 @@ inline const TTypeNode* RelayExprNode::type_as() const { ...@@ -451,22 +451,24 @@ inline const TTypeNode* RelayExprNode::type_as() const {
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
// Additional implementattion overloads for PackedFunc. template<>
inline TVMPODValue_::operator tvm::PrimExpr() const { struct PackedFuncValueConverter<PrimExpr> {
if (type_code_ == kTVMNullptr) return PrimExpr(); // common rule for both RetValue and ArgValue.
if (type_code_ == kDLInt) { static PrimExpr From(const TVMPODValue_& val) {
CHECK_LE(value_.v_int64, std::numeric_limits<int>::max()); if (val.type_code() == kTVMNullptr) {
CHECK_GE(value_.v_int64, std::numeric_limits<int>::min()); return PrimExpr(ObjectPtr<Object>(nullptr));
return PrimExpr(static_cast<int>(value_.v_int64)); }
if (val.type_code() == kDLInt) {
return PrimExpr(val.operator int());
}
if (val.type_code() == kDLFloat) {
return PrimExpr(static_cast<float>(val.operator double()));
}
TVM_CHECK_TYPE_CODE(val.type_code(), kTVMObjectHandle);
Object* ptr = val.ptr<Object>();
return PrimExpr::FromObject_(GetObjectPtr<Object>(ptr));
} }
if (type_code_ == kDLFloat) { };
return PrimExpr(static_cast<float>(value_.v_float64));
}
TVM_CHECK_TYPE_CODE(type_code_, kTVMObjectHandle);
Object* ptr = static_cast<Object*>(value_.v_handle);
return PrimExpr::FromObject_(ObjectPtr<Object>(ptr));
}
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
#endif // TVM_IR_EXPR_H_ #endif // TVM_IR_EXPR_H_
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <tvm/runtime/object.h> #include <tvm/runtime/object.h>
#include <tvm/runtime/memory.h> #include <tvm/runtime/memory.h>
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
#include <tvm/runtime/container.h>
#include <type_traits> #include <type_traits>
#include <vector> #include <vector>
......
...@@ -104,6 +104,7 @@ typedef enum { ...@@ -104,6 +104,7 @@ typedef enum {
kTVMStr = 11U, kTVMStr = 11U,
kTVMBytes = 12U, kTVMBytes = 12U,
kTVMNDArrayHandle = 13U, kTVMNDArrayHandle = 13U,
kTVMObjectRValueRefArg = 14U,
// Extension codes for other frameworks to integrate TVM PackedFunc. // Extension codes for other frameworks to integrate TVM PackedFunc.
// To make sure each framework's id do not conflict, use first and // To make sure each framework's id do not conflict, use first and
// last sections to mark ranges. // last sections to mark ranges.
...@@ -290,7 +291,7 @@ TVM_DLL int TVMCFuncSetReturn(TVMRetValueHandle ret, ...@@ -290,7 +291,7 @@ TVM_DLL int TVMCFuncSetReturn(TVMRetValueHandle ret,
* *
* \return 0 when success, -1 when failure happens. * \return 0 when success, -1 when failure happens.
*/ */
TVM_DLL int TVMCbArgToReturn(TVMValue* value, int code); TVM_DLL int TVMCbArgToReturn(TVMValue* value, int* code);
/*! /*!
* \brief C type of packed function. * \brief C type of packed function.
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <tvm/runtime/memory.h> #include <tvm/runtime/memory.h>
#include <tvm/runtime/object.h> #include <tvm/runtime/object.h>
#include <tvm/runtime/packed_func.h>
#include <cstring> #include <cstring>
#include <initializer_list> #include <initializer_list>
...@@ -590,6 +591,25 @@ inline int String::memncmp(const char* lhs, const char* rhs, size_t lhs_count, ...@@ -590,6 +591,25 @@ inline int String::memncmp(const char* lhs, const char* rhs, size_t lhs_count,
} }
} }
template<>
struct PackedFuncValueConverter<::tvm::runtime::String> {
static String From(const TVMArgValue& val) {
if (val.IsObjectRef<tvm::runtime::String>()) {
return val.AsObjectRef<tvm::runtime::String>();
} else {
return tvm::runtime::String(val.operator std::string());
}
}
static String From(const TVMRetValue& val) {
if (val.IsObjectRef<tvm::runtime::String>()) {
return val.AsObjectRef<tvm::runtime::String>();
} else {
return tvm::runtime::String(val.operator std::string());
}
}
};
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
......
...@@ -477,6 +477,17 @@ class ObjectPtr { ...@@ -477,6 +477,17 @@ class ObjectPtr {
data_->IncRef(); data_->IncRef();
} }
} }
/*!
* \brief Move an ObjectPtr from an RValueRef argument.
* \param ref The rvalue reference.
* \return the moved result.
*/
static ObjectPtr<T> MoveFromRValueRefArg(Object** ref) {
ObjectPtr<T> ptr;
ptr.data_ = *ref;
*ref = nullptr;
return ptr;
}
// friend classes // friend classes
friend class Object; friend class Object;
friend class ObjectRef; friend class ObjectRef;
...@@ -489,6 +500,7 @@ class ObjectPtr { ...@@ -489,6 +500,7 @@ class ObjectPtr {
friend class TVMArgsSetter; friend class TVMArgsSetter;
friend class TVMRetValue; friend class TVMRetValue;
friend class TVMArgValue; friend class TVMArgValue;
friend class TVMMovableArgValue_;
template <typename RelayRefType, typename ObjType> template <typename RelayRefType, typename ObjType>
friend RelayRefType GetRef(const ObjType* ptr); friend RelayRefType GetRef(const ObjType* ptr);
template <typename BaseType, typename ObjType> template <typename BaseType, typename ObjType>
...@@ -550,6 +562,10 @@ class ObjectRef { ...@@ -550,6 +562,10 @@ class ObjectRef {
bool unique() const { bool unique() const {
return data_.unique(); return data_.unique();
} }
/*! \return The use count of the ptr, for debug purposes */
int use_count() const {
return data_.use_count();
}
/*! /*!
* \brief Try to downcast the internal Object to a * \brief Try to downcast the internal Object to a
* raw pointer of a corresponding type. * raw pointer of a corresponding type.
......
...@@ -30,7 +30,6 @@ ...@@ -30,7 +30,6 @@
#include <tvm/runtime/ndarray.h> #include <tvm/runtime/ndarray.h>
#include <tvm/runtime/data_type.h> #include <tvm/runtime/data_type.h>
#include <tvm/runtime/object.h> #include <tvm/runtime/object.h>
#include <tvm/runtime/container.h>
#include <functional> #include <functional>
#include <tuple> #include <tuple>
#include <vector> #include <vector>
...@@ -47,15 +46,12 @@ ...@@ -47,15 +46,12 @@
#endif #endif
namespace tvm { namespace tvm {
// forward declarations
class Integer;
class PrimExpr;
namespace runtime { namespace runtime {
// forward declarations // forward declarations
class TVMArgs; class TVMArgs;
class TVMArgValue; class TVMArgValue;
class TVMMovableArgValue_;
class TVMRetValue; class TVMRetValue;
class TVMArgsSetter; class TVMArgsSetter;
...@@ -211,6 +207,11 @@ class TypedPackedFunc<R(Args...)> { ...@@ -211,6 +207,11 @@ class TypedPackedFunc<R(Args...)> {
*/ */
inline TypedPackedFunc(const TVMArgValue& value); // NOLINT(*) inline TypedPackedFunc(const TVMArgValue& value); // NOLINT(*)
/*! /*!
* \brief constructor from TVMMovableArgValue_
* \param value The TVMMovableArgValue_
*/
inline TypedPackedFunc(TVMMovableArgValue_&& value); // NOLINT(*)
/*!
* \brief construct from a lambda function with the same signature. * \brief construct from a lambda function with the same signature.
* *
* Example usage: * Example usage:
...@@ -386,8 +387,8 @@ class TVMPODValue_ { ...@@ -386,8 +387,8 @@ class TVMPODValue_ {
} }
operator int() const { operator int() const {
TVM_CHECK_TYPE_CODE(type_code_, kDLInt); TVM_CHECK_TYPE_CODE(type_code_, kDLInt);
CHECK_LE(value_.v_int64, CHECK_LE(value_.v_int64, std::numeric_limits<int>::max());
std::numeric_limits<int>::max()); CHECK_GE(value_.v_int64, std::numeric_limits<int>::min());
return static_cast<int>(value_.v_int64); return static_cast<int>(value_.v_int64);
} }
operator bool() const { operator bool() const {
...@@ -449,9 +450,6 @@ class TVMPODValue_ { ...@@ -449,9 +450,6 @@ class TVMPODValue_ {
inline bool IsObjectRef() const; inline bool IsObjectRef() const;
template<typename TObjectRef> template<typename TObjectRef>
inline TObjectRef AsObjectRef() const; inline TObjectRef AsObjectRef() const;
// ObjectRef Specializations
inline operator tvm::PrimExpr() const;
inline operator tvm::Integer() const;
protected: protected:
friend class TVMArgsSetter; friend class TVMArgsSetter;
...@@ -497,8 +495,6 @@ class TVMArgValue : public TVMPODValue_ { ...@@ -497,8 +495,6 @@ class TVMArgValue : public TVMPODValue_ {
using TVMPODValue_::operator Module; using TVMPODValue_::operator Module;
using TVMPODValue_::IsObjectRef; using TVMPODValue_::IsObjectRef;
using TVMPODValue_::AsObjectRef; using TVMPODValue_::AsObjectRef;
using TVMPODValue_::operator tvm::PrimExpr;
using TVMPODValue_::operator tvm::Integer;
// conversion operator. // conversion operator.
operator std::string() const { operator std::string() const {
...@@ -512,13 +508,6 @@ class TVMArgValue : public TVMPODValue_ { ...@@ -512,13 +508,6 @@ class TVMArgValue : public TVMPODValue_ {
return std::string(value_.v_str); return std::string(value_.v_str);
} }
} }
operator tvm::runtime::String() const {
if (IsObjectRef<tvm::runtime::String>()) {
return AsObjectRef<tvm::runtime::String>();
} else {
return tvm::runtime::String(operator std::string());
}
}
operator DLDataType() const { operator DLDataType() const {
if (type_code_ == kTVMStr) { if (type_code_ == kTVMStr) {
return String2DLDataType(operator std::string()); return String2DLDataType(operator std::string());
...@@ -547,6 +536,7 @@ class TVMArgValue : public TVMPODValue_ { ...@@ -547,6 +536,7 @@ class TVMArgValue : public TVMPODValue_ {
const TVMValue& value() const { const TVMValue& value() const {
return value_; return value_;
} }
template<typename T, template<typename T,
typename = typename std::enable_if< typename = typename std::enable_if<
std::is_class<T>::value>::type> std::is_class<T>::value>::type>
...@@ -554,6 +544,45 @@ class TVMArgValue : public TVMPODValue_ { ...@@ -554,6 +544,45 @@ class TVMArgValue : public TVMPODValue_ {
}; };
/*! /*!
* \brief Internal auxiliary struct for TypedPackedFunc to indicate a movable argument.
*
* We can only construct a movable argument once from a single argument position.
* If the argument is passed as RValue reference, the result will be moved.
* We should only construct a MovableArg from an argument once,
* as the result will can moved.
*
* \note For internal development purpose only.
*/
class TVMMovableArgValue_ : public TVMArgValue {
public:
TVMMovableArgValue_(TVMValue value, int type_code)
: TVMArgValue(value, type_code) {
}
// reuse converter from parent
using TVMArgValue::operator double;
using TVMArgValue::operator int64_t;
using TVMArgValue::operator uint64_t;
using TVMArgValue::operator int;
using TVMArgValue::operator bool;
using TVMArgValue::operator void*;
using TVMArgValue::operator DLTensor*;
using TVMArgValue::operator TVMContext;
using TVMArgValue::operator std::string;
using TVMArgValue::operator DLDataType;
using TVMArgValue::operator DataType;
using TVMArgValue::operator PackedFunc;
/*!
* \brief Helper converter function.
* Try to move out an argument if possible,
* fall back to normal argument conversion rule otherwise.
*/
template<typename T,
typename = typename std::enable_if<
std::is_base_of<ObjectRef, T>::value>::type>
inline operator T() const;
};
/*!
* \brief Return Value container, * \brief Return Value container,
* Unlike TVMArgValue, which only holds reference and do not delete * Unlike TVMArgValue, which only holds reference and do not delete
* the underlying container during destruction. * the underlying container during destruction.
...@@ -591,8 +620,6 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -591,8 +620,6 @@ class TVMRetValue : public TVMPODValue_ {
using TVMPODValue_::operator Module; using TVMPODValue_::operator Module;
using TVMPODValue_::IsObjectRef; using TVMPODValue_::IsObjectRef;
using TVMPODValue_::AsObjectRef; using TVMPODValue_::AsObjectRef;
using TVMPODValue_::operator tvm::PrimExpr;
using TVMPODValue_::operator tvm::Integer;
TVMRetValue(const TVMRetValue& other) : TVMPODValue_() { TVMRetValue(const TVMRetValue& other) : TVMPODValue_() {
this->Assign(other); this->Assign(other);
...@@ -607,13 +634,6 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -607,13 +634,6 @@ class TVMRetValue : public TVMPODValue_ {
TVM_CHECK_TYPE_CODE(type_code_, kTVMStr); TVM_CHECK_TYPE_CODE(type_code_, kTVMStr);
return *ptr<std::string>(); return *ptr<std::string>();
} }
operator tvm::runtime::String() const {
if (IsObjectRef<tvm::runtime::String>()) {
return AsObjectRef<tvm::runtime::String>();
} else {
return tvm::runtime::String(operator std::string());
}
}
operator DLDataType() const { operator DLDataType() const {
if (type_code_ == kTVMStr) { if (type_code_ == kTVMStr) {
return String2DLDataType(operator std::string()); return String2DLDataType(operator std::string());
...@@ -723,6 +743,10 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -723,6 +743,10 @@ class TVMRetValue : public TVMPODValue_ {
this->Assign(other); this->Assign(other);
return *this; return *this;
} }
TVMRetValue& operator=(TVMMovableArgValue_&& other) {
this->Assign(other);
return *this;
}
/*! /*!
* \brief Move the value back to front-end via C API. * \brief Move the value back to front-end via C API.
* This marks the current container as null. * This marks the current container as null.
...@@ -806,6 +830,10 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -806,6 +830,10 @@ class TVMRetValue : public TVMPODValue_ {
static_cast<Object*>(other.value_.v_handle))); static_cast<Object*>(other.value_.v_handle)));
break; break;
} }
case kTVMObjectRValueRefArg: {
operator=(other.operator ObjectRef());
break;
}
default: { default: {
SwitchToPOD(other.type_code()); SwitchToPOD(other.type_code());
value_ = other.value_; value_ = other.value_;
...@@ -864,6 +892,35 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -864,6 +892,35 @@ class TVMRetValue : public TVMPODValue_ {
}; };
/*! /*!
* \brief Type trait to specify special value conversion rules from
* TVMArgValue and TVMRetValue.
*
* The trait can be specialized to add type specific conversion logic
* from the TVMArgvalue and TVMRetValue.
*
* \tparam TObjectRef the specific ObjectRefType.
*/
template<typename TObjectRef>
struct PackedFuncValueConverter {
/*!
* \brief Convert a TObjectRef from an argument value.
* \param val The argument value.
* \return the converted result.
*/
static TObjectRef From(const TVMArgValue& val) {
return val.AsObjectRef<TObjectRef>();
}
/*!
* \brief Convert a TObjectRef from a return value.
* \param val The argument value.
* \return the converted result.
*/
static TObjectRef From(const TVMRetValue& val) {
return val.AsObjectRef<TObjectRef>();
}
};
/*!
* \brief Export a function with the PackedFunc signature * \brief Export a function with the PackedFunc signature
* as a PackedFunc that can be loaded by LibraryModule. * as a PackedFunc that can be loaded by LibraryModule.
* *
...@@ -1132,10 +1189,24 @@ class TVMArgsSetter { ...@@ -1132,10 +1189,24 @@ class TVMArgsSetter {
// ObjectRef handling // ObjectRef handling
template<typename TObjectRef, template<typename TObjectRef,
typename = typename std::enable_if< typename = typename std::enable_if<
std::is_base_of<ObjectRef, TObjectRef>::value>::type> std::is_base_of<ObjectRef, TObjectRef>::value>
inline void operator()(size_t i, const TObjectRef& value) const; ::type>
void operator()(size_t i, const TObjectRef& value) const {
this->SetObject(i, value);
}
template<typename TObjectRef,
typename = typename std::enable_if<
std::is_base_of<ObjectRef,
typename std::remove_reference<TObjectRef>::type>::value>
::type>
void operator()(size_t i, TObjectRef&& value) const {
this->SetObject(i, std::forward<TObjectRef>(value));
}
private: private:
template<typename TObjectRef>
inline void SetObject(size_t i, TObjectRef&& value) const;
/*! \brief The values fields */ /*! \brief The values fields */
TVMValue* values_; TVMValue* values_;
/*! \brief The type code fields */ /*! \brief The type code fields */
...@@ -1163,10 +1234,13 @@ struct unpack_call_dispatcher { ...@@ -1163,10 +1234,13 @@ struct unpack_call_dispatcher {
const TVMArgs& args_pack, const TVMArgs& args_pack,
TVMRetValue* rv, TVMRetValue* rv,
Args&&... unpacked_args) { Args&&... unpacked_args) {
// construct a movable argument value
// which allows potential move of argument to the input of F.
unpack_call_dispatcher<R, nleft - 1, index + 1, F> unpack_call_dispatcher<R, nleft - 1, index + 1, F>
::run(f, args_pack, rv, ::run(f, args_pack, rv,
std::forward<Args>(unpacked_args)..., std::forward<Args>(unpacked_args)...,
args_pack[index]); TVMMovableArgValue_(args_pack.values[index],
args_pack.type_codes[index]));
} }
}; };
...@@ -1246,6 +1320,10 @@ TypedPackedFunc<R(Args...)>::TypedPackedFunc(const TVMArgValue& value) ...@@ -1246,6 +1320,10 @@ TypedPackedFunc<R(Args...)>::TypedPackedFunc(const TVMArgValue& value)
: packed_(value.operator PackedFunc()) {} : packed_(value.operator PackedFunc()) {}
template<typename R, typename ...Args> template<typename R, typename ...Args>
TypedPackedFunc<R(Args...)>::TypedPackedFunc(TVMMovableArgValue_&& value)
: packed_(value.operator PackedFunc()) {}
template<typename R, typename ...Args>
template<typename FType> template<typename FType>
inline void TypedPackedFunc<R(Args...)>::AssignTypedLambda(FType flambda) { inline void TypedPackedFunc<R(Args...)>::AssignTypedLambda(FType flambda) {
packed_ = PackedFunc([flambda](const TVMArgs& args, TVMRetValue* rv) { packed_ = PackedFunc([flambda](const TVMArgs& args, TVMRetValue* rv) {
...@@ -1264,8 +1342,9 @@ inline R TypedPackedFunc<R(Args...)>::operator()(Args... args) const { ...@@ -1264,8 +1342,9 @@ inline R TypedPackedFunc<R(Args...)>::operator()(Args... args) const {
// kTVMNDArrayHandle, kTVMModuleHandle, kTVMObjectHandle // kTVMNDArrayHandle, kTVMModuleHandle, kTVMObjectHandle
// //
// We use type traits to eliminate un-necessary checks. // We use type traits to eliminate un-necessary checks.
template<typename TObjectRef, typename> template<typename T>
inline void TVMArgsSetter::operator()(size_t i, const TObjectRef& value) const { inline void TVMArgsSetter::SetObject(size_t i, T&& value) const {
using TObjectRef = typename std::remove_reference<T>::type;
if (value.defined()) { if (value.defined()) {
Object* ptr = value.data_.data_; Object* ptr = value.data_.data_;
if (std::is_base_of<NDArray, TObjectRef>::value || if (std::is_base_of<NDArray, TObjectRef>::value ||
...@@ -1278,8 +1357,11 @@ inline void TVMArgsSetter::operator()(size_t i, const TObjectRef& value) const { ...@@ -1278,8 +1357,11 @@ inline void TVMArgsSetter::operator()(size_t i, const TObjectRef& value) const {
ptr->IsInstance<Module::ContainerType>())) { ptr->IsInstance<Module::ContainerType>())) {
values_[i].v_handle = ptr; values_[i].v_handle = ptr;
type_codes_[i] = kTVMModuleHandle; type_codes_[i] = kTVMModuleHandle;
} else if (std::is_rvalue_reference<T>::value) {
values_[i].v_handle = const_cast<Object**>(&(value.data_.data_));
type_codes_[i] = kTVMObjectRValueRefArg;
} else { } else {
values_[i].v_handle = ptr; values_[i].v_handle = value.data_.data_;
type_codes_[i] = kTVMObjectHandle; type_codes_[i] = kTVMObjectHandle;
} }
} else { } else {
...@@ -1300,6 +1382,11 @@ inline bool TVMPODValue_::IsObjectRef() const { ...@@ -1300,6 +1382,11 @@ inline bool TVMPODValue_::IsObjectRef() const {
return type_code_ == kTVMModuleHandle && return type_code_ == kTVMModuleHandle &&
static_cast<Object*>(value_.v_handle)->IsInstance<ContainerType>(); static_cast<Object*>(value_.v_handle)->IsInstance<ContainerType>();
} }
// NOTE: we don't pass NDArray and runtime::Module as RValue ref.
if (type_code_ == kTVMObjectRValueRefArg) {
return ObjectTypeChecker<TObjectRef>::Check(
*static_cast<Object**>(value_.v_handle));
}
return return
(std::is_base_of<TObjectRef, NDArray>::value && type_code_ == kTVMNDArrayHandle) || (std::is_base_of<TObjectRef, NDArray>::value && type_code_ == kTVMNDArrayHandle) ||
(std::is_base_of<TObjectRef, Module>::value && type_code_ == kTVMModuleHandle) || (std::is_base_of<TObjectRef, Module>::value && type_code_ == kTVMModuleHandle) ||
...@@ -1339,6 +1426,12 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const { ...@@ -1339,6 +1426,12 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const {
<< "Expect " << ObjectTypeChecker<TObjectRef>::TypeName() << "Expect " << ObjectTypeChecker<TObjectRef>::TypeName()
<< " but get " << ptr->GetTypeKey(); << " but get " << ptr->GetTypeKey();
return TObjectRef(GetObjectPtr<Object>(ptr)); return TObjectRef(GetObjectPtr<Object>(ptr));
} else if (type_code_ == kTVMObjectRValueRefArg) {
Object* ptr = *static_cast<Object**>(value_.v_handle);
CHECK(ObjectTypeChecker<TObjectRef>::Check(ptr))
<< "Expect " << ObjectTypeChecker<TObjectRef>::TypeName()
<< " but get " << ptr->GetTypeKey();
return TObjectRef(GetObjectPtr<Object>(ptr));
} else if (std::is_base_of<TObjectRef, NDArray>::value && } else if (std::is_base_of<TObjectRef, NDArray>::value &&
type_code_ == kTVMNDArrayHandle) { type_code_ == kTVMNDArrayHandle) {
// Casting to a base class that NDArray can sub-class // Casting to a base class that NDArray can sub-class
...@@ -1376,14 +1469,27 @@ inline TVMRetValue& TVMRetValue::operator=(TObjectRef other) { ...@@ -1376,14 +1469,27 @@ inline TVMRetValue& TVMRetValue::operator=(TObjectRef other) {
return *this; return *this;
} }
template<typename T, typename> template<typename T, typename>
inline TVMArgValue::operator T() const { inline TVMArgValue::operator T() const {
return AsObjectRef<T>(); return PackedFuncValueConverter<T>::From(*this);
}
template<typename T, typename>
inline TVMMovableArgValue_::operator T() const {
if (type_code_ == kTVMObjectRValueRefArg) {
auto** ref = static_cast<Object**>(value_.v_handle);
if (ObjectTypeChecker<T>::Check(*ref)) {
return T(ObjectPtr<Object>::MoveFromRValueRefArg(ref));
}
}
// fallback
return PackedFuncValueConverter<T>::From(*this);
} }
template<typename T, typename> template<typename T, typename>
inline TVMRetValue::operator T() const { inline TVMRetValue::operator T() const {
return AsObjectRef<T>(); return PackedFuncValueConverter<T>::From(*this);
} }
inline PackedFunc Module::GetFunction(const std::string& name, bool query_imports) { inline PackedFunc Module::GetFunction(const std::string& name, bool query_imports) {
......
...@@ -1338,20 +1338,20 @@ enum TVMStructFieldKind : int { ...@@ -1338,20 +1338,20 @@ enum TVMStructFieldKind : int {
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
// Additional implementattion overloads for PackedFunc. // Additional implementattion overloads for PackedFunc.
inline TVMPODValue_::operator tvm::Integer() const {
if (type_code_ == kTVMNullptr) return Integer(); template<>
if (type_code_ == kDLInt) { struct PackedFuncValueConverter<tvm::Integer> {
CHECK_LE(value_.v_int64, std::numeric_limits<int>::max()); // common rule for RetValue and ArgValue
CHECK_GE(value_.v_int64, std::numeric_limits<int>::min()); static tvm::Integer From(const TVMPODValue_& val) {
return Integer(static_cast<int>(value_.v_int64)); if (val.type_code() == kTVMNullptr) {
return Integer(ObjectPtr<Object>(nullptr));
}
if (val.type_code() == kDLInt) {
return Integer(val.operator int());
}
return val.AsObjectRef<tvm::Integer>();
} }
TVM_CHECK_TYPE_CODE(type_code_, kTVMObjectHandle); };
Object* ptr = static_cast<Object*>(value_.v_handle);
CHECK(ObjectTypeChecker<Integer>::Check(ptr))
<< "Expect type " << ObjectTypeChecker<PrimExpr>::TypeName()
<< " but get " << ptr->GetTypeKey();
return Integer(ObjectPtr<Object>(ptr));
}
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
......
...@@ -244,8 +244,9 @@ extern "C" int funcInvokeCallback(TVMValue *args, ...@@ -244,8 +244,9 @@ extern "C" int funcInvokeCallback(TVMValue *args,
int tcode = typeCodes[i]; int tcode = typeCodes[i];
if (tcode == kTVMObjectHandle || if (tcode == kTVMObjectHandle ||
tcode == kTVMPackedFuncHandle || tcode == kTVMPackedFuncHandle ||
tcode == kTVMObjectRValueRefArg ||
tcode == kTVMModuleHandle) { tcode == kTVMModuleHandle) {
TVMCbArgToReturn(&arg, tcode); TVMCbArgToReturn(&arg, &tcode);
} }
jobject jarg = tvmRetValueToJava(env, arg, tcode); jobject jarg = tvmRetValueToJava(env, arg, tcode);
env->SetObjectArrayElement(jargs, i, jarg); env->SetObjectArrayElement(jargs, i, jarg);
......
...@@ -60,6 +60,9 @@ RETURN_SWITCH[TypeCode.OBJECT_HANDLE] = _return_object ...@@ -60,6 +60,9 @@ RETURN_SWITCH[TypeCode.OBJECT_HANDLE] = _return_object
C_TO_PY_ARG_SWITCH[TypeCode.OBJECT_HANDLE] = _wrap_arg_func( C_TO_PY_ARG_SWITCH[TypeCode.OBJECT_HANDLE] = _wrap_arg_func(
_return_object, TypeCode.OBJECT_HANDLE) _return_object, TypeCode.OBJECT_HANDLE)
C_TO_PY_ARG_SWITCH[TypeCode.OBJECT_RVALUE_REF_ARG] = _wrap_arg_func(
_return_object, TypeCode.OBJECT_RVALUE_REF_ARG)
class ObjectBase(object): class ObjectBase(object):
"""Base object for all object types""" """Base object for all object types"""
......
...@@ -23,7 +23,7 @@ from numbers import Number, Integral ...@@ -23,7 +23,7 @@ from numbers import Number, Integral
from ..base import _LIB, get_last_ffi_error, py2cerror, check_call from ..base import _LIB, get_last_ffi_error, py2cerror, check_call
from ..base import c_str, string_types from ..base import c_str, string_types
from ..runtime_ctypes import DataType, TVMByteArray, TVMContext from ..runtime_ctypes import DataType, TVMByteArray, TVMContext, ObjectRValueRef
from . import ndarray as _nd from . import ndarray as _nd
from .ndarray import NDArrayBase, _make_array from .ndarray import NDArrayBase, _make_array
from .types import TVMValue, TypeCode from .types import TVMValue, TypeCode
...@@ -164,6 +164,9 @@ def _make_tvm_args(args, temp_args): ...@@ -164,6 +164,9 @@ def _make_tvm_args(args, temp_args):
elif isinstance(arg, ctypes.c_void_p): elif isinstance(arg, ctypes.c_void_p):
values[i].v_handle = arg values[i].v_handle = arg
type_codes[i] = TypeCode.HANDLE type_codes[i] = TypeCode.HANDLE
elif isinstance(arg, ObjectRValueRef):
values[i].v_handle = ctypes.cast(ctypes.byref(arg.obj.handle), ctypes.c_void_p)
type_codes[i] = TypeCode.OBJECT_RVALUE_REF_ARG
elif callable(arg): elif callable(arg):
arg = convert_to_tvm_func(arg) arg = convert_to_tvm_func(arg)
values[i].v_handle = arg.handle values[i].v_handle = arg.handle
......
...@@ -73,9 +73,9 @@ def _return_context(value): ...@@ -73,9 +73,9 @@ def _return_context(value):
def _wrap_arg_func(return_f, type_code): def _wrap_arg_func(return_f, type_code):
tcode = ctypes.c_int(type_code)
def _wrap_func(x): def _wrap_func(x):
check_call(_LIB.TVMCbArgToReturn(ctypes.byref(x), tcode)) tcode = ctypes.c_int(type_code)
check_call(_LIB.TVMCbArgToReturn(ctypes.byref(x), ctypes.byref(tcode)))
return return_f(x) return return_f(x)
return _wrap_func return _wrap_func
......
...@@ -37,6 +37,7 @@ cdef enum TVMTypeCode: ...@@ -37,6 +37,7 @@ cdef enum TVMTypeCode:
kTVMStr = 11 kTVMStr = 11
kTVMBytes = 12 kTVMBytes = 12
kTVMNDArrayHandle = 13 kTVMNDArrayHandle = 13
kTVMObjectRefArg = 14
kTVMExtBegin = 15 kTVMExtBegin = 15
cdef extern from "tvm/runtime/c_runtime_api.h": cdef extern from "tvm/runtime/c_runtime_api.h":
...@@ -113,7 +114,7 @@ cdef extern from "tvm/runtime/c_runtime_api.h": ...@@ -113,7 +114,7 @@ cdef extern from "tvm/runtime/c_runtime_api.h":
void* resource_handle, void* resource_handle,
TVMPackedCFuncFinalizer fin, TVMPackedCFuncFinalizer fin,
TVMPackedFuncHandle *out) TVMPackedFuncHandle *out)
int TVMCbArgToReturn(TVMValue* value, int code) int TVMCbArgToReturn(TVMValue* value, int* code)
int TVMArrayAlloc(tvm_index_t* shape, int TVMArrayAlloc(tvm_index_t* shape,
tvm_index_t ndim, tvm_index_t ndim,
DLDataType dtype, DLDataType dtype,
......
...@@ -64,10 +64,7 @@ cdef class ObjectBase: ...@@ -64,10 +64,7 @@ cdef class ObjectBase:
property handle: property handle:
def __get__(self): def __get__(self):
if self.chandle == NULL: return ctypes_handle(self.chandle)
return None
else:
return ctypes_handle(self.chandle)
def __set__(self, value): def __set__(self, value):
self._set_handle(value) self._set_handle(value)
......
...@@ -20,7 +20,7 @@ import traceback ...@@ -20,7 +20,7 @@ import traceback
from cpython cimport Py_INCREF, Py_DECREF from cpython cimport Py_INCREF, Py_DECREF
from numbers import Number, Integral from numbers import Number, Integral
from ..base import string_types, py2cerror from ..base import string_types, py2cerror
from ..runtime_ctypes import DataType, TVMContext, TVMByteArray from ..runtime_ctypes import DataType, TVMContext, TVMByteArray, ObjectRValueRef
cdef void tvm_callback_finalize(void* fhandle): cdef void tvm_callback_finalize(void* fhandle):
...@@ -43,8 +43,9 @@ cdef int tvm_callback(TVMValue* args, ...@@ -43,8 +43,9 @@ cdef int tvm_callback(TVMValue* args,
if (tcode == kTVMObjectHandle or if (tcode == kTVMObjectHandle or
tcode == kTVMPackedFuncHandle or tcode == kTVMPackedFuncHandle or
tcode == kTVMModuleHandle or tcode == kTVMModuleHandle or
tcode == kTVMObjectRefArg or
tcode > kTVMExtBegin): tcode > kTVMExtBegin):
CALL(TVMCbArgToReturn(&value, tcode)) CALL(TVMCbArgToReturn(&value, &tcode))
if tcode != kTVMDLTensorHandle: if tcode != kTVMDLTensorHandle:
pyargs.append(make_ret(value, tcode)) pyargs.append(make_ret(value, tcode))
...@@ -167,6 +168,9 @@ cdef inline int make_arg(object arg, ...@@ -167,6 +168,9 @@ cdef inline int make_arg(object arg,
elif isinstance(arg, ctypes.c_void_p): elif isinstance(arg, ctypes.c_void_p):
value[0].v_handle = c_handle(arg) value[0].v_handle = c_handle(arg)
tcode[0] = kTVMOpaqueHandle tcode[0] = kTVMOpaqueHandle
elif isinstance(arg, ObjectRValueRef):
value[0].v_handle = &((<ObjectBase>(arg.obj)).chandle)
tcode[0] = kTVMObjectRefArg
elif callable(arg): elif callable(arg):
arg = convert_to_tvm_func(arg) arg = convert_to_tvm_func(arg)
value[0].v_handle = (<PackedFuncBase>arg).chandle value[0].v_handle = (<PackedFuncBase>arg).chandle
......
...@@ -39,6 +39,7 @@ class TypeCode(object): ...@@ -39,6 +39,7 @@ class TypeCode(object):
STR = 11 STR = 11
BYTES = 12 BYTES = 12
NDARRAY_HANDLE = 13 NDARRAY_HANDLE = 13
OBJECT_RVALUE_REF_ARG = 14
EXT_BEGIN = 15 EXT_BEGIN = 15
...@@ -281,4 +282,18 @@ class TVMArray(ctypes.Structure): ...@@ -281,4 +282,18 @@ class TVMArray(ctypes.Structure):
("strides", ctypes.POINTER(tvm_shape_index_t)), ("strides", ctypes.POINTER(tvm_shape_index_t)),
("byte_offset", ctypes.c_uint64)] ("byte_offset", ctypes.c_uint64)]
class ObjectRValueRef:
"""Represent an RValue ref to an object that can be moved.
Parameters
----------
obj : tvm.runtime.Object
The object that this value refers to
"""
__slots__ = ["obj"]
def __init__(self, obj):
self.obj = obj
TVMArrayHandle = ctypes.POINTER(TVMArray) TVMArrayHandle = ctypes.POINTER(TVMArray)
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
import ctypes import ctypes
from tvm._ffi.base import _FFI_MODE, _RUNTIME_ONLY, check_call, _LIB, c_str from tvm._ffi.base import _FFI_MODE, _RUNTIME_ONLY, check_call, _LIB, c_str
from tvm._ffi.runtime_ctypes import ObjectRValueRef
from . import _ffi_api, _ffi_node_api from . import _ffi_api, _ffi_node_api
try: try:
...@@ -85,5 +86,35 @@ class Object(ObjectBase): ...@@ -85,5 +86,35 @@ class Object(ObjectBase):
else: else:
self.handle = None self.handle = None
def _move(self):
"""Create an RValue reference to the object and mark the object as moved.
This is a advanced developer API that can be useful when passing an
unique reference to an Object that you no longer needed to a function.
A unique reference can trigger copy on write optimization that avoids
copy when we transform an object.
Note
----
All the reference of the object becomes invalid after it is moved.
Be very careful when using this feature.
Examples
--------
.. code-block:: python
x = tvm.tir.Var("x", "int32")
x0 = x
some_packed_func(x._move())
# both x0 and x will points to None after the function call.
Returns
-------
rvalue : The rvalue reference.
"""
return ObjectRValueRef(self)
_set_class_object(Object) _set_class_object(Object)
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
# pylint: disable=unused-import, invalid-name # pylint: disable=unused-import, invalid-name
from numbers import Number, Integral from numbers import Number, Integral
from tvm._ffi.base import string_types from tvm._ffi.base import string_types
from tvm._ffi.runtime_ctypes import ObjectRValueRef
from . import _ffi_node_api, _ffi_api from . import _ffi_node_api, _ffi_api
from .object import ObjectBase, _set_class_object_generic from .object import ObjectBase, _set_class_object_generic
...@@ -33,7 +34,7 @@ class ObjectGeneric(object): ...@@ -33,7 +34,7 @@ class ObjectGeneric(object):
raise NotImplementedError() raise NotImplementedError()
ObjectTypes = (ObjectBase, NDArrayBase, Module) ObjectTypes = (ObjectBase, NDArrayBase, Module, ObjectRValueRef)
def convert_to_object(value): def convert_to_object(value):
......
...@@ -261,7 +261,7 @@ unsafe extern "C" fn tvm_callback( ...@@ -261,7 +261,7 @@ unsafe extern "C" fn tvm_callback(
|| tcode == ffi::TVMTypeCode_kTVMPackedFuncHandle as c_int || tcode == ffi::TVMTypeCode_kTVMPackedFuncHandle as c_int
|| tcode == ffi::TVMTypeCode_kTVMModuleHandle as c_int || tcode == ffi::TVMTypeCode_kTVMModuleHandle as c_int
{ {
check_call!(ffi::TVMCbArgToReturn(&mut value as *mut _, tcode)); check_call!(ffi::TVMCbArgToReturn(&mut value as *mut _, &mut tcode as *mut _));
} }
local_args.push(TVMArgValue::from_tvm_value(value, tcode as u32)); local_args.push(TVMArgValue::from_tvm_value(value, tcode as u32));
} }
......
...@@ -371,7 +371,7 @@ TVM_REGISTER_GLOBAL("transform.MakeModulePass") ...@@ -371,7 +371,7 @@ TVM_REGISTER_GLOBAL("transform.MakeModulePass")
.set_body_typed( .set_body_typed(
[](runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func, [](runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func,
PassInfo pass_info) { PassInfo pass_info) {
return ModulePass(pass_func, pass_info); return ModulePass(pass_func, pass_info);
}); });
TVM_REGISTER_GLOBAL("transform.RunPass") TVM_REGISTER_GLOBAL("transform.RunPass")
......
...@@ -370,7 +370,6 @@ TVM_REGISTER_GLOBAL("node.MapGetItem") ...@@ -370,7 +370,6 @@ TVM_REGISTER_GLOBAL("node.MapGetItem")
Object* ptr = static_cast<Object*>(args[0].value().v_handle); Object* ptr = static_cast<Object*>(args[0].value().v_handle);
if (ptr->IsInstance<MapNode>()) { if (ptr->IsInstance<MapNode>()) {
CHECK(args[1].type_code() == kTVMObjectHandle);
auto* n = static_cast<const MapNode*>(ptr); auto* n = static_cast<const MapNode*>(ptr);
auto it = n->data.find(args[1].operator ObjectRef()); auto it = n->data.find(args[1].operator ObjectRef());
CHECK(it != n->data.end()) CHECK(it != n->data.end())
......
...@@ -577,13 +577,11 @@ int TVMStreamStreamSynchronize(int device_type, ...@@ -577,13 +577,11 @@ int TVMStreamStreamSynchronize(int device_type,
API_END(); API_END();
} }
int TVMCbArgToReturn(TVMValue* value, int code) { int TVMCbArgToReturn(TVMValue* value, int* code) {
API_BEGIN(); API_BEGIN();
tvm::runtime::TVMRetValue rv; tvm::runtime::TVMRetValue rv;
rv = tvm::runtime::TVMArgValue(*value, code); rv = tvm::runtime::TVMMovableArgValue_(*value, *code);
int tcode; rv.MoveToCHost(value, code);
rv.MoveToCHost(value, &tcode);
CHECK_EQ(tcode, code);
API_END(); API_END();
} }
......
...@@ -107,11 +107,11 @@ TVM_REGISTER_GLOBAL("testing.ErrorTest") ...@@ -107,11 +107,11 @@ TVM_REGISTER_GLOBAL("testing.ErrorTest")
.set_body_typed(ErrorTest); .set_body_typed(ErrorTest);
// internal function used for debug and testing purposes // internal function used for debug and testing purposes
TVM_REGISTER_GLOBAL("testing.ndarray_use_count") TVM_REGISTER_GLOBAL("testing.object_use_count")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
runtime::NDArray nd = args[0]; runtime::ObjectRef obj = args[0];
// substract the current one // substract the current one because we always copy
*ret = (nd.use_count() - 1); // and get another value.
*ret = (obj.use_count() - 1);
}); });
} // namespace tvm } // namespace tvm
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
#include <tvm/runtime/container.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h> #include <tvm/tir/expr.h>
...@@ -51,7 +52,7 @@ TEST(PackedFunc, Node) { ...@@ -51,7 +52,7 @@ TEST(PackedFunc, Node) {
Var x; Var x;
Var t = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { Var t = PackedFunc([&](TVMArgs args, TVMRetValue* rv) {
CHECK(args.num_args == 1); CHECK(args.num_args == 1);
CHECK(args.type_codes[0] == kTVMObjectHandle); CHECK(args[0].IsObjectRef<ObjectRef>());
Var b = args[0]; Var b = args[0];
CHECK(x.same_as(b)); CHECK(x.same_as(b));
*rv = b; *rv = b;
...@@ -269,6 +270,50 @@ TEST(PackedFunc, ObjectConversion) { ...@@ -269,6 +270,50 @@ TEST(PackedFunc, ObjectConversion) {
pf2(ObjectRef(m), Module()); pf2(ObjectRef(m), Module());
} }
TEST(TypedPackedFunc, RValue) {
using namespace tvm;
using namespace tvm::runtime;
{
auto f = [](tir::Var x, bool move) {
if (move) {
CHECK(x.unique());
} else {
CHECK(!x.unique());
}
CHECK(x->name_hint == "x");
return x;
};
TypedPackedFunc<tir::Var(tir::Var, bool)> tf(f);
tir::Var var("x");
CHECK(var.unique());
f(var, false);
// move the result to the function.
tir::Var ret = f(std::move(var), true);
CHECK(!var.defined());
}
{
// pass child class.
auto f = [](PrimExpr x, bool move) {
if (move) {
CHECK(x.unique());
} else {
CHECK(!x.unique());
}
return x;
};
TypedPackedFunc<PrimExpr(PrimExpr, bool)> tf(f);
tir::Var var("x");
CHECK(var.unique());
f(var, false);
f(std::move(var), true);
// auto conversion.
f(1, true);
}
}
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe"; testing::FLAGS_gtest_death_test_style = "threadsafe";
......
...@@ -98,6 +98,33 @@ def test_ctx(): ...@@ -98,6 +98,33 @@ def test_ctx():
x = tvm.testing.context_test(x, x.device_type, x.device_id) x = tvm.testing.context_test(x, x.device_type, x.device_id)
assert x == tvm.opencl(10) assert x == tvm.opencl(10)
def test_rvalue_ref():
def callback(x, expected_count):
assert expected_count == tvm.testing.object_use_count(x)
return x
f = tvm.runtime.convert(callback)
def check0():
x = tvm.tir.Var("x", "int32")
assert tvm.testing.object_use_count(x) == 1
f(x, 2)
y = f(x._move(), 1)
assert x.handle.value == None
def check1():
x = tvm.tir.Var("x", "int32")
assert tvm.testing.object_use_count(x) == 1
y = f(x, 2)
z = f(x._move(), 2)
assert x.handle.value == None
assert y.handle.value is not None
check0()
check1()
def test_trace_default_action(): def test_trace_default_action():
n = 2 n = 2
x = te.placeholder((n,n,n), name="X", dtype="float32") x = te.placeholder((n,n,n), name="X", dtype="float32")
...@@ -269,7 +296,11 @@ def test_trace_can_change_traced_value_float(): ...@@ -269,7 +296,11 @@ def test_trace_can_change_traced_value_float():
for t in ["float64", "float32"]: for t in ["float64", "float32"]:
check_assign(t) check_assign(t)
if __name__ == "__main__": if __name__ == "__main__":
test_rvalue_ref()
exit(0)
test_empty_array() test_empty_array()
test_get_global() test_get_global()
test_get_callback_with_node() test_get_callback_with_node()
......
...@@ -212,7 +212,7 @@ def test_rpc_return_ndarray(): ...@@ -212,7 +212,7 @@ def test_rpc_return_ndarray():
if name == "get_arr": if name == "get_arr":
return lambda : nd return lambda : nd
elif name == "ref_count": elif name == "ref_count":
return lambda : tvm.testing.ndarray_use_count(nd) return lambda : tvm.testing.object_use_count(nd)
elif name == "get_elem": elif name == "get_elem":
return lambda idx: nd.asnumpy()[idx] return lambda idx: nd.asnumpy()[idx]
elif name == "get_arr_elem": elif name == "get_arr_elem":
......
...@@ -105,6 +105,7 @@ var tvm_runtime = tvm_runtime || {}; ...@@ -105,6 +105,7 @@ var tvm_runtime = tvm_runtime || {};
var kTVMPackedFuncHandle = 10; var kTVMPackedFuncHandle = 10;
var kTVMStr = 11; var kTVMStr = 11;
var kTVMBytes = 12; var kTVMBytes = 12;
var kTVMObjectRValueRefArg = 14;
//----------------------------------------- //-----------------------------------------
// TVM CWrap library // TVM CWrap library
// ---------------------------------------- // ----------------------------------------
...@@ -171,7 +172,7 @@ var tvm_runtime = tvm_runtime || {}; ...@@ -171,7 +172,7 @@ var tvm_runtime = tvm_runtime || {};
("TVMCbArgToReturn", ("TVMCbArgToReturn",
"number", "number",
["number", // TVMValue* value ["number", // TVMValue* value
"number" // int code "number" // int* code
]); ]);
var TVMFuncCreateFromCFunc = Module.cwrap var TVMFuncCreateFromCFunc = Module.cwrap
...@@ -496,12 +497,15 @@ var tvm_runtime = tvm_runtime || {}; ...@@ -496,12 +497,15 @@ var tvm_runtime = tvm_runtime || {};
var args = []; var args = [];
for (var i = 0; i < nargs; ++i) { for (var i = 0; i < nargs; ++i) {
var vptr = arg_value + i * SIZEOF_TVMVALUE; var vptr = arg_value + i * SIZEOF_TVMVALUE;
var tcode = Module.getValue(arg_tcode + i * SIZEOF_INT, "i32"); var tcodeptr = arg_tcode + i * SIZEOF_INT;
var tcode = Module.getValue(tcodeptr, "i32");
if (tcode == kTVMObjectHandle || if (tcode == kTVMObjectHandle ||
tcode == kTVMObjectRValueRefArg ||
tcode == kTVMPackedFuncHandle || tcode == kTVMPackedFuncHandle ||
tcode == kTVMModuleHandle) { tcode == kTVMModuleHandle) {
TVM_CALL(TVMCbArgToReturn(vptr, tcode)); TVM_CALL(TVMCbArgToReturn(vptr, tcodeptr));
} }
tcode = Module.getValue(tcodeptr, "i32");
args.push(TVMRetValueToJS(vptr, tcode)); args.push(TVMRetValueToJS(vptr, tcode));
} }
var rv = funcTable[handle].apply(null, args); var rv = funcTable[handle].apply(null, args);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment