Unverified Commit 745c8a06 by Tianqi Chen Committed by GitHub

[RUNTIME] Improved Packed FFI for optional. (#5478)

Allows Optional<NDArray> and module to be passed with the right type code.
parent 7ea834f9
......@@ -1346,16 +1346,16 @@ inline R TypedPackedFunc<R(Args...)>::operator()(Args... args) const {
// We use type traits to eliminate un-necessary checks.
template<typename T>
inline void TVMArgsSetter::SetObject(size_t i, T&& value) const {
using TObjectRef = typename std::remove_reference<T>::type;
using ContainerType = typename std::remove_reference<T>::type::ContainerType;
if (value.defined()) {
Object* ptr = value.data_.data_;
if (std::is_base_of<NDArray, TObjectRef>::value ||
(std::is_base_of<TObjectRef, NDArray>::value &&
if (std::is_base_of<NDArray::ContainerType, ContainerType>::value ||
(std::is_base_of<ContainerType, NDArray::ContainerType>::value &&
ptr->IsInstance<NDArray::ContainerType>())) {
values_[i].v_handle = NDArray::FFIGetHandle(value);
type_codes_[i] = kTVMNDArrayHandle;
} else if (std::is_base_of<Module, TObjectRef>::value ||
(std::is_base_of<TObjectRef, Module>::value &&
} else if (std::is_base_of<Module::ContainerType, ContainerType>::value ||
(std::is_base_of<ContainerType, Module::ContainerType>::value &&
ptr->IsInstance<Module::ContainerType>())) {
values_[i].v_handle = ptr;
type_codes_[i] = kTVMModuleHandle;
......@@ -1375,12 +1375,12 @@ template<typename TObjectRef, typename>
inline bool TVMPODValue_::IsObjectRef() const {
using ContainerType = typename TObjectRef::ContainerType;
// NOTE: the following code can be optimized by constant folding.
if (std::is_base_of<NDArray, TObjectRef>::value) {
if (std::is_base_of<NDArray::ContainerType, ContainerType>::value) {
return type_code_ == kTVMNDArrayHandle &&
TVMArrayHandleToObjectHandle(
static_cast<TVMArrayHandle>(value_.v_handle))->IsInstance<ContainerType>();
}
if (std::is_base_of<Module, TObjectRef>::value) {
if (std::is_base_of<Module::ContainerType, ContainerType>::value) {
return type_code_ == kTVMModuleHandle &&
static_cast<Object*>(value_.v_handle)->IsInstance<ContainerType>();
}
......@@ -1390,8 +1390,10 @@ inline bool TVMPODValue_::IsObjectRef() const {
*static_cast<Object**>(value_.v_handle));
}
return
(std::is_base_of<TObjectRef, NDArray>::value && type_code_ == kTVMNDArrayHandle) ||
(std::is_base_of<TObjectRef, Module>::value && type_code_ == kTVMModuleHandle) ||
(std::is_base_of<ContainerType, NDArray::ContainerType>::value &&
type_code_ == kTVMNDArrayHandle) ||
(std::is_base_of<ContainerType, Module::ContainerType>::value &&
type_code_ == kTVMModuleHandle) ||
(type_code_ == kTVMObjectHandle &&
ObjectTypeChecker<TObjectRef>::Check(static_cast<Object*>(value_.v_handle)));
}
......@@ -1402,13 +1404,14 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const {
std::is_base_of<ObjectRef, TObjectRef>::value,
"Conversion only works for ObjectRef");
using ContainerType = typename TObjectRef::ContainerType;
if (type_code_ == kTVMNullptr) {
CHECK(TObjectRef::_type_is_nullable)
<< "Expect a not null value of " << ContainerType::_type_key;
return TObjectRef(ObjectPtr<Object>(nullptr));
}
// NOTE: the following code can be optimized by constant folding.
if (std::is_base_of<NDArray, TObjectRef>::value) {
if (std::is_base_of<NDArray::ContainerType, ContainerType>::value) {
// Casting to a sub-class of NDArray
TVM_CHECK_TYPE_CODE(type_code_, kTVMNDArrayHandle);
ObjectPtr<Object> data = NDArray::FFIDataFromHandle(
......@@ -1417,7 +1420,7 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const {
<< "Expect " << ContainerType::_type_key << " but get " << data->GetTypeKey();
return TObjectRef(data);
}
if (std::is_base_of<Module, TObjectRef>::value) {
if (std::is_base_of<Module::ContainerType, ContainerType>::value) {
// Casting to a sub-class of Module
TVM_CHECK_TYPE_CODE(type_code_, kTVMModuleHandle);
ObjectPtr<Object> data = GetObjectPtr<Object>(static_cast<Object*>(value_.v_handle));
......@@ -1438,13 +1441,13 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const {
<< "Expect " << ObjectTypeChecker<TObjectRef>::TypeName()
<< " but get " << ptr->GetTypeKey();
return TObjectRef(GetObjectPtr<Object>(ptr));
} else if (std::is_base_of<TObjectRef, NDArray>::value &&
} else if (std::is_base_of<ContainerType, NDArray::ContainerType>::value &&
type_code_ == kTVMNDArrayHandle) {
// Casting to a base class that NDArray can sub-class
ObjectPtr<Object> data = NDArray::FFIDataFromHandle(
static_cast<TVMArrayHandle>(value_.v_handle));
return TObjectRef(data);
} else if (std::is_base_of<TObjectRef, Module>::value &&
} else if (std::is_base_of<ContainerType, Module::ContainerType>::value &&
type_code_ == kTVMModuleHandle) {
// Casting to a base class that Module can sub-class
return TObjectRef(GetObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
......@@ -1456,15 +1459,16 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const {
template<typename TObjectRef, typename>
inline TVMRetValue& TVMRetValue::operator=(TObjectRef other) {
using ContainerType = typename TObjectRef::ContainerType;
const Object* ptr = other.get();
if (ptr != nullptr) {
if (std::is_base_of<NDArray, TObjectRef>::value ||
(std::is_base_of<TObjectRef, NDArray>::value &&
if (std::is_base_of<NDArray::ContainerType, ContainerType>::value ||
(std::is_base_of<ContainerType, NDArray::ContainerType>::value &&
ptr->IsInstance<NDArray::ContainerType>())) {
return operator=(NDArray(std::move(other.data_)));
}
if (std::is_base_of<Module, TObjectRef>::value ||
(std::is_base_of<TObjectRef, Module>::value &&
if (std::is_base_of<Module::ContainerType, ContainerType>::value ||
(std::is_base_of<ContainerType, Module::ContainerType>::value &&
ptr->IsInstance<Module::ContainerType>())) {
return operator=(Module(std::move(other.data_)));
}
......
......@@ -177,6 +177,16 @@ TEST(BuildModule, Heterogeneous) {
runtime::Module mod = (*graph_runtime)(
json, module, cpu_dev_ty, cpu_dev_id, gpu_dev_ty, gpu_dev_id);
// test FFI for module.
auto test_ffi = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
int tcode = args[1];
CHECK_EQ(args[0].type_code(), tcode);
});
test_ffi(runtime::Module(mod), static_cast<int>(kTVMModuleHandle));
test_ffi(Optional<runtime::Module>(mod), static_cast<int>(kTVMModuleHandle));
PackedFunc set_input = mod.GetFunction("set_input", false);
PackedFunc run = mod.GetFunction("run", false);
PackedFunc get_output = mod.GetFunction("get_output", false);
......
......@@ -468,6 +468,18 @@ TEST(Optional, PackedCall) {
CHECK(packedfunc("xyz", false).operator String() == "xyz");
CHECK(packedfunc("xyz", false).operator Optional<String>() == "xyz");
CHECK(packedfunc(nullptr, true).operator Optional<String>() == nullptr);
// test FFI convention.
auto test_ffi = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
int tcode = args[1];
CHECK_EQ(args[0].type_code(), tcode);
});
String s = "xyz";
auto nd = NDArray::Empty({0, 1}, DataType::Float(32), DLContext{kDLCPU, 0});
test_ffi(Optional<NDArray>(nd), static_cast<int>(kTVMNDArrayHandle));
test_ffi(Optional<String>(s), static_cast<int>(kTVMObjectRValueRefArg));
test_ffi(s, static_cast<int>(kTVMObjectHandle));
test_ffi(String(s), static_cast<int>(kTVMObjectRValueRefArg));
}
int main(int argc, char** argv) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment