Unverified Commit 745c8a06 by Tianqi Chen Committed by GitHub

[RUNTIME] Improved Packed FFI for optional. (#5478)

Allows Optional<NDArray> and module to be passed with the right type code.
parent 7ea834f9
...@@ -1346,16 +1346,16 @@ inline R TypedPackedFunc<R(Args...)>::operator()(Args... args) const { ...@@ -1346,16 +1346,16 @@ inline R TypedPackedFunc<R(Args...)>::operator()(Args... args) const {
// We use type traits to eliminate un-necessary checks. // We use type traits to eliminate un-necessary checks.
template<typename T> template<typename T>
inline void TVMArgsSetter::SetObject(size_t i, T&& value) const { inline void TVMArgsSetter::SetObject(size_t i, T&& value) const {
using TObjectRef = typename std::remove_reference<T>::type; using ContainerType = typename std::remove_reference<T>::type::ContainerType;
if (value.defined()) { if (value.defined()) {
Object* ptr = value.data_.data_; Object* ptr = value.data_.data_;
if (std::is_base_of<NDArray, TObjectRef>::value || if (std::is_base_of<NDArray::ContainerType, ContainerType>::value ||
(std::is_base_of<TObjectRef, NDArray>::value && (std::is_base_of<ContainerType, NDArray::ContainerType>::value &&
ptr->IsInstance<NDArray::ContainerType>())) { ptr->IsInstance<NDArray::ContainerType>())) {
values_[i].v_handle = NDArray::FFIGetHandle(value); values_[i].v_handle = NDArray::FFIGetHandle(value);
type_codes_[i] = kTVMNDArrayHandle; type_codes_[i] = kTVMNDArrayHandle;
} else if (std::is_base_of<Module, TObjectRef>::value || } else if (std::is_base_of<Module::ContainerType, ContainerType>::value ||
(std::is_base_of<TObjectRef, Module>::value && (std::is_base_of<ContainerType, Module::ContainerType>::value &&
ptr->IsInstance<Module::ContainerType>())) { ptr->IsInstance<Module::ContainerType>())) {
values_[i].v_handle = ptr; values_[i].v_handle = ptr;
type_codes_[i] = kTVMModuleHandle; type_codes_[i] = kTVMModuleHandle;
...@@ -1375,12 +1375,12 @@ template<typename TObjectRef, typename> ...@@ -1375,12 +1375,12 @@ template<typename TObjectRef, typename>
inline bool TVMPODValue_::IsObjectRef() const { inline bool TVMPODValue_::IsObjectRef() const {
using ContainerType = typename TObjectRef::ContainerType; using ContainerType = typename TObjectRef::ContainerType;
// NOTE: the following code can be optimized by constant folding. // NOTE: the following code can be optimized by constant folding.
if (std::is_base_of<NDArray, TObjectRef>::value) { if (std::is_base_of<NDArray::ContainerType, ContainerType>::value) {
return type_code_ == kTVMNDArrayHandle && return type_code_ == kTVMNDArrayHandle &&
TVMArrayHandleToObjectHandle( TVMArrayHandleToObjectHandle(
static_cast<TVMArrayHandle>(value_.v_handle))->IsInstance<ContainerType>(); static_cast<TVMArrayHandle>(value_.v_handle))->IsInstance<ContainerType>();
} }
if (std::is_base_of<Module, TObjectRef>::value) { if (std::is_base_of<Module::ContainerType, ContainerType>::value) {
return type_code_ == kTVMModuleHandle && return type_code_ == kTVMModuleHandle &&
static_cast<Object*>(value_.v_handle)->IsInstance<ContainerType>(); static_cast<Object*>(value_.v_handle)->IsInstance<ContainerType>();
} }
...@@ -1390,8 +1390,10 @@ inline bool TVMPODValue_::IsObjectRef() const { ...@@ -1390,8 +1390,10 @@ inline bool TVMPODValue_::IsObjectRef() const {
*static_cast<Object**>(value_.v_handle)); *static_cast<Object**>(value_.v_handle));
} }
return return
(std::is_base_of<TObjectRef, NDArray>::value && type_code_ == kTVMNDArrayHandle) || (std::is_base_of<ContainerType, NDArray::ContainerType>::value &&
(std::is_base_of<TObjectRef, Module>::value && type_code_ == kTVMModuleHandle) || type_code_ == kTVMNDArrayHandle) ||
(std::is_base_of<ContainerType, Module::ContainerType>::value &&
type_code_ == kTVMModuleHandle) ||
(type_code_ == kTVMObjectHandle && (type_code_ == kTVMObjectHandle &&
ObjectTypeChecker<TObjectRef>::Check(static_cast<Object*>(value_.v_handle))); ObjectTypeChecker<TObjectRef>::Check(static_cast<Object*>(value_.v_handle)));
} }
...@@ -1402,13 +1404,14 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const { ...@@ -1402,13 +1404,14 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const {
std::is_base_of<ObjectRef, TObjectRef>::value, std::is_base_of<ObjectRef, TObjectRef>::value,
"Conversion only works for ObjectRef"); "Conversion only works for ObjectRef");
using ContainerType = typename TObjectRef::ContainerType; using ContainerType = typename TObjectRef::ContainerType;
if (type_code_ == kTVMNullptr) { if (type_code_ == kTVMNullptr) {
CHECK(TObjectRef::_type_is_nullable) CHECK(TObjectRef::_type_is_nullable)
<< "Expect a not null value of " << ContainerType::_type_key; << "Expect a not null value of " << ContainerType::_type_key;
return TObjectRef(ObjectPtr<Object>(nullptr)); return TObjectRef(ObjectPtr<Object>(nullptr));
} }
// NOTE: the following code can be optimized by constant folding. // NOTE: the following code can be optimized by constant folding.
if (std::is_base_of<NDArray, TObjectRef>::value) { if (std::is_base_of<NDArray::ContainerType, ContainerType>::value) {
// Casting to a sub-class of NDArray // Casting to a sub-class of NDArray
TVM_CHECK_TYPE_CODE(type_code_, kTVMNDArrayHandle); TVM_CHECK_TYPE_CODE(type_code_, kTVMNDArrayHandle);
ObjectPtr<Object> data = NDArray::FFIDataFromHandle( ObjectPtr<Object> data = NDArray::FFIDataFromHandle(
...@@ -1417,7 +1420,7 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const { ...@@ -1417,7 +1420,7 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const {
<< "Expect " << ContainerType::_type_key << " but get " << data->GetTypeKey(); << "Expect " << ContainerType::_type_key << " but get " << data->GetTypeKey();
return TObjectRef(data); return TObjectRef(data);
} }
if (std::is_base_of<Module, TObjectRef>::value) { if (std::is_base_of<Module::ContainerType, ContainerType>::value) {
// Casting to a sub-class of Module // Casting to a sub-class of Module
TVM_CHECK_TYPE_CODE(type_code_, kTVMModuleHandle); TVM_CHECK_TYPE_CODE(type_code_, kTVMModuleHandle);
ObjectPtr<Object> data = GetObjectPtr<Object>(static_cast<Object*>(value_.v_handle)); ObjectPtr<Object> data = GetObjectPtr<Object>(static_cast<Object*>(value_.v_handle));
...@@ -1438,13 +1441,13 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const { ...@@ -1438,13 +1441,13 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const {
<< "Expect " << ObjectTypeChecker<TObjectRef>::TypeName() << "Expect " << ObjectTypeChecker<TObjectRef>::TypeName()
<< " but get " << ptr->GetTypeKey(); << " but get " << ptr->GetTypeKey();
return TObjectRef(GetObjectPtr<Object>(ptr)); return TObjectRef(GetObjectPtr<Object>(ptr));
} else if (std::is_base_of<TObjectRef, NDArray>::value && } else if (std::is_base_of<ContainerType, NDArray::ContainerType>::value &&
type_code_ == kTVMNDArrayHandle) { type_code_ == kTVMNDArrayHandle) {
// Casting to a base class that NDArray can sub-class // Casting to a base class that NDArray can sub-class
ObjectPtr<Object> data = NDArray::FFIDataFromHandle( ObjectPtr<Object> data = NDArray::FFIDataFromHandle(
static_cast<TVMArrayHandle>(value_.v_handle)); static_cast<TVMArrayHandle>(value_.v_handle));
return TObjectRef(data); return TObjectRef(data);
} else if (std::is_base_of<TObjectRef, Module>::value && } else if (std::is_base_of<ContainerType, Module::ContainerType>::value &&
type_code_ == kTVMModuleHandle) { type_code_ == kTVMModuleHandle) {
// Casting to a base class that Module can sub-class // Casting to a base class that Module can sub-class
return TObjectRef(GetObjectPtr<Object>(static_cast<Object*>(value_.v_handle))); return TObjectRef(GetObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
...@@ -1456,15 +1459,16 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const { ...@@ -1456,15 +1459,16 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const {
template<typename TObjectRef, typename> template<typename TObjectRef, typename>
inline TVMRetValue& TVMRetValue::operator=(TObjectRef other) { inline TVMRetValue& TVMRetValue::operator=(TObjectRef other) {
using ContainerType = typename TObjectRef::ContainerType;
const Object* ptr = other.get(); const Object* ptr = other.get();
if (ptr != nullptr) { if (ptr != nullptr) {
if (std::is_base_of<NDArray, TObjectRef>::value || if (std::is_base_of<NDArray::ContainerType, ContainerType>::value ||
(std::is_base_of<TObjectRef, NDArray>::value && (std::is_base_of<ContainerType, NDArray::ContainerType>::value &&
ptr->IsInstance<NDArray::ContainerType>())) { ptr->IsInstance<NDArray::ContainerType>())) {
return operator=(NDArray(std::move(other.data_))); return operator=(NDArray(std::move(other.data_)));
} }
if (std::is_base_of<Module, TObjectRef>::value || if (std::is_base_of<Module::ContainerType, ContainerType>::value ||
(std::is_base_of<TObjectRef, Module>::value && (std::is_base_of<ContainerType, Module::ContainerType>::value &&
ptr->IsInstance<Module::ContainerType>())) { ptr->IsInstance<Module::ContainerType>())) {
return operator=(Module(std::move(other.data_))); return operator=(Module(std::move(other.data_)));
} }
......
...@@ -177,6 +177,16 @@ TEST(BuildModule, Heterogeneous) { ...@@ -177,6 +177,16 @@ TEST(BuildModule, Heterogeneous) {
runtime::Module mod = (*graph_runtime)( runtime::Module mod = (*graph_runtime)(
json, module, cpu_dev_ty, cpu_dev_id, gpu_dev_ty, gpu_dev_id); json, module, cpu_dev_ty, cpu_dev_id, gpu_dev_ty, gpu_dev_id);
// test FFI for module.
auto test_ffi = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
int tcode = args[1];
CHECK_EQ(args[0].type_code(), tcode);
});
test_ffi(runtime::Module(mod), static_cast<int>(kTVMModuleHandle));
test_ffi(Optional<runtime::Module>(mod), static_cast<int>(kTVMModuleHandle));
PackedFunc set_input = mod.GetFunction("set_input", false); PackedFunc set_input = mod.GetFunction("set_input", false);
PackedFunc run = mod.GetFunction("run", false); PackedFunc run = mod.GetFunction("run", false);
PackedFunc get_output = mod.GetFunction("get_output", false); PackedFunc get_output = mod.GetFunction("get_output", false);
......
...@@ -468,6 +468,18 @@ TEST(Optional, PackedCall) { ...@@ -468,6 +468,18 @@ TEST(Optional, PackedCall) {
CHECK(packedfunc("xyz", false).operator String() == "xyz"); CHECK(packedfunc("xyz", false).operator String() == "xyz");
CHECK(packedfunc("xyz", false).operator Optional<String>() == "xyz"); CHECK(packedfunc("xyz", false).operator Optional<String>() == "xyz");
CHECK(packedfunc(nullptr, true).operator Optional<String>() == nullptr); CHECK(packedfunc(nullptr, true).operator Optional<String>() == nullptr);
// test FFI convention.
auto test_ffi = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
int tcode = args[1];
CHECK_EQ(args[0].type_code(), tcode);
});
String s = "xyz";
auto nd = NDArray::Empty({0, 1}, DataType::Float(32), DLContext{kDLCPU, 0});
test_ffi(Optional<NDArray>(nd), static_cast<int>(kTVMNDArrayHandle));
test_ffi(Optional<String>(s), static_cast<int>(kTVMObjectRValueRefArg));
test_ffi(s, static_cast<int>(kTVMObjectHandle));
test_ffi(String(s), static_cast<int>(kTVMObjectRValueRefArg));
} }
int main(int argc, char** argv) { int main(int argc, char** argv) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment