Unverified Commit 77c47748 by Tianqi Chen Committed by GitHub

[RUNTIME][DSO] Improve TVMBackendPackedCFunc to allow return val (#4637)

* [RUNTIME][DSO] Improve TVMBackendPackedCFunc to allow return value.

Previously the signature of LibraryModule's PackedFunc does not support return value.
This wasn't a limitation for our current usecase but could become one
as we start to generate more interesting functions.

This feature also start to get interesting as we move towards unified
object protocol and start to pass object around.
This PR enhances the function signature to allow return values.

We also created two macros TVM_DLL_EXPORT_PACKED_FUNC and TVM_DLL_EXPORT_TYPED_FUNC
to allow manual creation of functions that can be loaded by a LibraryModule.

Examples are added in apps/dso_plugin_module.
The change to TVMBackendPackedCFunc is backward compatible,
as previous function will simply ignore the return value field.

* address review comments
parent 76efece3
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
TVM_ROOT=$(shell cd ../..; pwd)
PKG_CFLAGS = -std=c++11 -O2 -fPIC\
-I${TVM_ROOT}/include\
-I${TVM_ROOT}/3rdparty/dmlc-core/include\
-I${TVM_ROOT}/3rdparty/dlpack/include
PKG_LDFLAGS =-L${TVM_ROOT}/build
UNAME_S := $(shell uname -s)
ifeq ($(UNAME_S), Darwin)
PKG_LDFLAGS += -undefined dynamic_lookup
endif
lib/plugin_module.so: plugin_module.cc
@mkdir -p $(@D)
$(CXX) $(PKG_CFLAGS) -shared -o $@ $^ $(PKG_LDFLAGS)
<!--- Licensed to the Apache Software Foundation (ASF) under one -->
<!--- or more contributor license agreements. See the NOTICE file -->
<!--- distributed with this work for additional information -->
<!--- regarding copyright ownership. The ASF licenses this file -->
<!--- to you under the Apache License, Version 2.0 (the -->
<!--- "License"); you may not use this file except in compliance -->
<!--- with the License. You may obtain a copy of the License at -->
<!--- http://www.apache.org/licenses/LICENSE-2.0 -->
<!--- Unless required by applicable law or agreed to in writing, -->
<!--- software distributed under the License is distributed on an -->
<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
<!--- KIND, either express or implied. See the License for the -->
<!--- specific language governing permissions and limitations -->
<!--- under the License. -->
Example Plugin Module
=====================
This folder contains an example that implements a C++ module
that can be directly loaded as TVM's DSOModule (via tvm.module.load)
## Guideline
When possible, we always recommend exposing
functions that modifies memory passed by the caller,
and calls into the runtime API for memory allocations.
## Advanced Usecases
In advanced usecases, we do allow the plugin module to
create and return managed objects.
However, there are several restrictions to keep in mind:
- If the module returns an object, we need to make sure
that the object get destructed before the module get unloaded.
Otherwise segfault can happen because of calling into an unloaded destructor.
- If the module returns a PackedFunc, then
we need to ensure that the libc of the DLL and tvm runtime matches.
Otherwise segfault can happen due to incompatibility of std::function.
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \brief Example code that can be compiled and loaded by TVM runtime.
* \file plugin_module.cc
*/
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/ndarray.h>
namespace tvm_dso_plugin {
using namespace tvm::runtime;
class MyModuleNode : public ModuleNode {
public:
explicit MyModuleNode(int value)
: value_(value) {}
virtual const char* type_key() const final {
return "MyModule";
}
virtual PackedFunc GetFunction(
const std::string& name,
const ObjectPtr<Object>& sptr_to_self) final {
if (name == "add") {
return TypedPackedFunc<int(int)>([sptr_to_self, this](int value) {
return value_ + value;
});
} else if (name == "mul") {
return TypedPackedFunc<int(int)>([sptr_to_self, this](int value) {
return value_ * value;
});
} else {
LOG(FATAL) << "unknown function " << name;
return PackedFunc();
}
}
private:
int value_;
};
void CreateMyModule_(TVMArgs args, TVMRetValue* rv) {
int value = args[0];
*rv = Module(make_object<MyModuleNode>(value));
}
int SubOne_(int x) {
return x - 1;
}
// USE TVM_DLL_EXPORT_TYPED_PACKED_FUNC to export a
// typed function as packed function.
TVM_DLL_EXPORT_TYPED_FUNC(SubOne, SubOne_);
// TVM_DLL_EXPORT_TYPED_PACKED_FUNC also works for lambda.
TVM_DLL_EXPORT_TYPED_FUNC(AddOne, [](int x) -> int {
return x + 1;
});
// Use TVM_EXPORT_PACKED_FUNC to export a function with
TVM_DLL_EXPORT_PACKED_FUNC(CreateMyModule, tvm_dso_plugin::CreateMyModule_);
} // namespace tvm_dso_plugin
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
import os
def test_plugin_module():
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
mod = tvm.module.load(os.path.join(curr_path, "lib", "plugin_module.so"))
# NOTE: we need to make sure all managed resources returned
# from mod get destructed before mod get unloaded.
#
# Failure mode we want to prevent from:
# We retain an object X whose destructor is within mod.
# The program will segfault if X get destructed after mod,
# because the destructor function has already been unloaded.
#
# The easiest way to achieve this is to wrap the
# logics related to mod inside a function.
def run_module(mod):
# normal functions
assert mod["AddOne"](10) == 11
assert mod["SubOne"](10) == 9
# advanced usecase: return a module
mymod = mod["CreateMyModule"](10);
fadd = mymod["add"]
assert fadd(10) == 20
assert mymod["mul"](10) == 100
run_module(mod)
if __name__ == "__main__":
test_plugin_module()
...@@ -22,6 +22,7 @@ PKG_CFLAGS = -std=c++11 -O2 -fPIC\ ...@@ -22,6 +22,7 @@ PKG_CFLAGS = -std=c++11 -O2 -fPIC\
-I${TVM_ROOT}/3rdparty/dmlc-core/include\ -I${TVM_ROOT}/3rdparty/dmlc-core/include\
-I${TVM_ROOT}/3rdparty/dlpack/include -I${TVM_ROOT}/3rdparty/dlpack/include
PKG_LDFLAGS =-L${TVM_ROOT}/build PKG_LDFLAGS =-L${TVM_ROOT}/build
UNAME_S := $(shell uname -s) UNAME_S := $(shell uname -s)
......
...@@ -410,7 +410,8 @@ Stmt HoistIfThenElse(Stmt stmt); ...@@ -410,7 +410,8 @@ Stmt HoistIfThenElse(Stmt stmt);
* *
* if num_packed_args is not zero: * if num_packed_args is not zero:
* f(TVMArg* packed_args, int* packed_arg_type_ids, int num_packed_args, * f(TVMArg* packed_args, int* packed_arg_type_ids, int num_packed_args,
* api_arg_k, api_arg_k+1, ... api_arg_n) * api_arg_k, api_arg_k+1, ... api_arg_n,
* TVMValue* out_ret_val, int* out_ret_tcode)
* *
* where n == len(api_args), k == num_packed_args * where n == len(api_args), k == num_packed_args
* *
......
...@@ -34,7 +34,23 @@ ...@@ -34,7 +34,23 @@
extern "C" { extern "C" {
#endif #endif
// Backend related functions. /*!
* \brief Signature for backend functions exported as DLL.
*
* \param args The arguments
* \param type_codes The type codes of the arguments
* \param num_args Number of arguments.
* \param out_ret_value The output value of the the return value.
* \param out_ret_tcode The output type code of the return value.
*
* \return 0 if success, -1 if failure happens, set error via TVMAPISetLastError.
*/
typedef int (*TVMBackendPackedCFunc)(TVMValue* args,
int* type_codes,
int num_args,
TVMValue* out_ret_value,
int* out_ret_tcode);
/*! /*!
* \brief Backend function for modules to get function * \brief Backend function for modules to get function
* from its environment mod_node (its imports and global function). * from its environment mod_node (its imports and global function).
......
...@@ -111,7 +111,7 @@ class Module : public ObjectRef { ...@@ -111,7 +111,7 @@ class Module : public ObjectRef {
* *
* \endcode * \endcode
*/ */
class ModuleNode : public Object { class TVM_DLL ModuleNode : public Object {
public: public:
/*! \brief virtual destructor */ /*! \brief virtual destructor */
TVM_DLL virtual ~ModuleNode() {} TVM_DLL virtual ~ModuleNode() {}
......
...@@ -771,6 +771,23 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -771,6 +771,23 @@ class TVMRetValue : public TVMPODValue_ {
*ret_type_code = type_code_; *ret_type_code = type_code_;
type_code_ = kNull; type_code_ = kNull;
} }
/*!
* \brief Construct a new TVMRetValue by
* moving from return value stored via C API.
* \param value the value.
* \param type_code The type code.
* \return The created TVMRetValue.
*/
static TVMRetValue MoveFromCHost(TVMValue value,
int type_code) {
// Can move POD and everything under the object system.
CHECK(type_code <= kFuncHandle ||
type_code == kNDArrayContainer);
TVMRetValue ret;
ret.value_ = value;
ret.type_code_ = type_code;
return ret;
}
/*! \return The value field, if the data is POD */ /*! \return The value field, if the data is POD */
const TVMValue& value() const { const TVMValue& value() const {
CHECK(type_code_ != kObjectHandle && CHECK(type_code_ != kObjectHandle &&
...@@ -877,6 +894,104 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -877,6 +894,104 @@ class TVMRetValue : public TVMPODValue_ {
} }
}; };
/*!
* \brief Export a function with the PackedFunc signature
* as a PackedFunc that can be loaded by LibraryModule.
*
* \param ExportName The symbol name to be exported.
* \param Function The function with PackedFunc signature.
* \sa PackedFunc
*
* \code
*
* void AddOne_(TVMArgs args, TVMRetValue* rv) {
* int value = args[0];
* *rv = value + 1;
* }
* // Expose the function as "AddOne"
* TVM_DLL_EXPORT_PACKED_FUNC(AddOne, AddOne_);
*
* \endcode
*/
#define TVM_DLL_EXPORT_PACKED_FUNC(ExportName, Function) \
extern "C" { \
TVM_DLL int ExportName(TVMValue* args, \
int* type_code, \
int num_args, \
TVMValue* out_value, \
int* out_type_code) { \
try { \
::tvm::runtime::TVMRetValue rv; \
Function(::tvm::runtime::TVMArgs( \
args, type_code, num_args), &rv); \
rv.MoveToCHost(out_value, out_type_code); \
return 0; \
} catch (const ::std::runtime_error& _except_) { \
TVMAPISetLastError(_except_.what()); \
return -1; \
} \
} \
}
/*!
* \brief Export typed function as a PackedFunc
* that can be loaded by LibraryModule.
*
* \param ExportName The symbol name to be exported.
* \param Function The typed function.
* \note ExportName and Function must be different,
* see code examples below.
*
* \sa TypedPackedFunc
*
* \code
*
* int AddOne_(int x) {
* return x + 1;
* }
*
* // Expose the function as "AddOne"
* TVM_DLL_EXPORT_TYPED_FUNC(AddOne, AddOne_);
*
* // Expose the function as "SubOne"
* TVM_DLL_EXPORT_TYPED_FUNC(SubOne, [](int x) {
* return x - 1;
* });
*
* // The following code will cause compilation error.
* // Because the same Function and ExortName
* // TVM_DLL_EXPORT_TYPED_FUNC(AddOne_, AddOne_);
*
* // The following code is OK, assuming the macro
* // is in a different namespace from xyz
* // TVM_DLL_EXPORT_TYPED_FUNC(AddOne_, xyz::AddOne_);
*
* \endcode
*/
#define TVM_DLL_EXPORT_TYPED_FUNC(ExportName, Function) \
extern "C" { \
TVM_DLL int ExportName(TVMValue* args, \
int* type_code, \
int num_args, \
TVMValue* out_value, \
int* out_type_code) { \
try { \
auto f = Function; \
using FType = ::tvm::runtime::detail:: \
function_signature<decltype(f)>::FType; \
::tvm::runtime::TVMRetValue rv; \
::tvm::runtime::detail::unpack_call_by_signature<FType>::run( \
f, \
::tvm::runtime::TVMArgs(args, type_code, num_args), &rv); \
rv.MoveToCHost(out_value, out_type_code); \
return 0; \
} catch (const ::std::runtime_error& _except_) { \
TVMAPISetLastError(_except_.what()); \
return -1; \
} \
} \
}
// implementation details // implementation details
inline const char* TypeCode2Str(int type_code) { inline const char* TypeCode2Str(int type_code) {
switch (type_code) { switch (type_code) {
...@@ -1218,6 +1333,20 @@ inline void unpack_call(const F& f, const TVMArgs& args, TVMRetValue* rv) { ...@@ -1218,6 +1333,20 @@ inline void unpack_call(const F& f, const TVMArgs& args, TVMRetValue* rv) {
unpack_call_dispatcher<R, nargs, 0, F>::run(f, args, rv); unpack_call_dispatcher<R, nargs, 0, F>::run(f, args, rv);
} }
template<typename FType>
struct unpack_call_by_signature {
};
template<typename R, typename ...Args>
struct unpack_call_by_signature<R(Args...)> {
template<typename F>
static void run(const F& f,
const TVMArgs& args,
TVMRetValue* rv) {
unpack_call<R, sizeof...(Args)>(f, args, rv);
}
};
template<typename R, typename ...Args> template<typename R, typename ...Args>
inline R call_packed(const PackedFunc& pf, Args&& ...args) { inline R call_packed(const PackedFunc& pf, Args&& ...args) {
return R(pf(std::forward<Args>(args)...)); return R(pf(std::forward<Args>(args)...));
......
...@@ -68,8 +68,8 @@ class LLVMModuleNode final : public runtime::ModuleNode { ...@@ -68,8 +68,8 @@ class LLVMModuleNode final : public runtime::ModuleNode {
const std::string& fname = (name == runtime::symbol::tvm_module_main ? const std::string& fname = (name == runtime::symbol::tvm_module_main ?
entry_func_ : name); entry_func_ : name);
BackendPackedCFunc faddr = TVMBackendPackedCFunc faddr =
reinterpret_cast<BackendPackedCFunc>(GetFunctionAddr(fname)); reinterpret_cast<TVMBackendPackedCFunc>(GetFunctionAddr(fname));
if (faddr == nullptr) return PackedFunc(); if (faddr == nullptr) return PackedFunc();
return WrapPackedFunc(faddr, sptr_to_self); return WrapPackedFunc(faddr, sptr_to_self);
} }
......
...@@ -53,6 +53,8 @@ LoweredFunc MakeAPI(Stmt body, ...@@ -53,6 +53,8 @@ LoweredFunc MakeAPI(Stmt body,
Var v_packed_args("args", DataType::Handle()); Var v_packed_args("args", DataType::Handle());
Var v_packed_arg_type_ids("arg_type_ids", DataType::Handle()); Var v_packed_arg_type_ids("arg_type_ids", DataType::Handle());
Var v_num_packed_args("num_args", DataType::Int(32)); Var v_num_packed_args("num_args", DataType::Int(32));
Var v_out_ret_value("out_ret_value", DataType::Handle());
Var v_out_ret_tcode("out_ret_tcode", DataType::Handle());
// The arguments of the function. // The arguments of the function.
Array<Var> args; Array<Var> args;
// The device context // The device context
...@@ -151,6 +153,15 @@ LoweredFunc MakeAPI(Stmt body, ...@@ -151,6 +153,15 @@ LoweredFunc MakeAPI(Stmt body,
} }
} }
// allow return value if the function is packed.
if (num_packed_args != 0) {
args.push_back(v_out_ret_value);
args.push_back(v_out_ret_tcode);
}
size_t expected_nargs = num_unpacked_args + (num_packed_args != 0 ? 5 : 0);
CHECK_EQ(args.size(), expected_nargs);
// Arg definitions are defined before buffer binding to avoid the use before // Arg definitions are defined before buffer binding to avoid the use before
// def errors. // def errors.
// //
......
...@@ -28,7 +28,7 @@ ...@@ -28,7 +28,7 @@
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <string> #include <string>
#include <vector> #include <vector>
#include <cstdint> #include <utility>
#include "library_module.h" #include "library_module.h"
namespace tvm { namespace tvm {
...@@ -48,15 +48,15 @@ class LibraryModuleNode final : public ModuleNode { ...@@ -48,15 +48,15 @@ class LibraryModuleNode final : public ModuleNode {
PackedFunc GetFunction( PackedFunc GetFunction(
const std::string& name, const std::string& name,
const ObjectPtr<Object>& sptr_to_self) final { const ObjectPtr<Object>& sptr_to_self) final {
BackendPackedCFunc faddr; TVMBackendPackedCFunc faddr;
if (name == runtime::symbol::tvm_module_main) { if (name == runtime::symbol::tvm_module_main) {
const char* entry_name = reinterpret_cast<const char*>( const char* entry_name = reinterpret_cast<const char*>(
lib_->GetSymbol(runtime::symbol::tvm_module_main)); lib_->GetSymbol(runtime::symbol::tvm_module_main));
CHECK(entry_name!= nullptr) CHECK(entry_name!= nullptr)
<< "Symbol " << runtime::symbol::tvm_module_main << " is not presented"; << "Symbol " << runtime::symbol::tvm_module_main << " is not presented";
faddr = reinterpret_cast<BackendPackedCFunc>(lib_->GetSymbol(entry_name)); faddr = reinterpret_cast<TVMBackendPackedCFunc>(lib_->GetSymbol(entry_name));
} else { } else {
faddr = reinterpret_cast<BackendPackedCFunc>(lib_->GetSymbol(name.c_str())); faddr = reinterpret_cast<TVMBackendPackedCFunc>(lib_->GetSymbol(name.c_str()));
} }
if (faddr == nullptr) return PackedFunc(); if (faddr == nullptr) return PackedFunc();
return WrapPackedFunc(faddr, sptr_to_self); return WrapPackedFunc(faddr, sptr_to_self);
...@@ -77,14 +77,21 @@ class ModuleInternal { ...@@ -77,14 +77,21 @@ class ModuleInternal {
} }
}; };
PackedFunc WrapPackedFunc(BackendPackedCFunc faddr, PackedFunc WrapPackedFunc(TVMBackendPackedCFunc faddr,
const ObjectPtr<Object>& sptr_to_self) { const ObjectPtr<Object>& sptr_to_self) {
return PackedFunc([faddr, sptr_to_self](TVMArgs args, TVMRetValue* rv) { return PackedFunc([faddr, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
TVMValue ret_value;
int ret_type_code = kNull;
int ret = (*faddr)( int ret = (*faddr)(
const_cast<TVMValue*>(args.values), const_cast<TVMValue*>(args.values),
const_cast<int*>(args.type_codes), const_cast<int*>(args.type_codes),
args.num_args); args.num_args,
&ret_value,
&ret_type_code);
CHECK_EQ(ret, 0) << TVMGetLastError(); CHECK_EQ(ret, 0) << TVMGetLastError();
if (ret_type_code != kNull) {
*rv = TVMRetValue::MoveFromCHost(ret_value, ret_type_code);
}
}); });
} }
......
...@@ -29,13 +29,6 @@ ...@@ -29,13 +29,6 @@
#include <tvm/runtime/c_backend_api.h> #include <tvm/runtime/c_backend_api.h>
#include <functional> #include <functional>
extern "C" {
// Function signature for generated packed function in shared library
typedef int (*BackendPackedCFunc)(void* args,
int* type_codes,
int num_args);
} // extern "C"
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
/*! /*!
...@@ -58,11 +51,11 @@ class Library : public Object { ...@@ -58,11 +51,11 @@ class Library : public Object {
}; };
/*! /*!
* \brief Wrap a BackendPackedCFunc to packed function. * \brief Wrap a TVMBackendPackedCFunc to packed function.
* \param faddr The function address * \param faddr The function address
* \param mptr The module pointer node. * \param mptr The module pointer node.
*/ */
PackedFunc WrapPackedFunc(BackendPackedCFunc faddr, const ObjectPtr<Object>& mptr); PackedFunc WrapPackedFunc(TVMBackendPackedCFunc faddr, const ObjectPtr<Object>& mptr);
/*! /*!
* \brief Utility to initialize conext function symbols during startup * \brief Utility to initialize conext function symbols during startup
......
...@@ -37,7 +37,7 @@ def test_makeapi(): ...@@ -37,7 +37,7 @@ def test_makeapi():
f = tvm.ir_pass.MakeAPI( f = tvm.ir_pass.MakeAPI(
stmt, "myadd", [n, Ab, Bb, Cb], num_unpacked_args, True) stmt, "myadd", [n, Ab, Bb, Cb], num_unpacked_args, True)
assert(f.handle_data_type[Ab.data].dtype == Ab.dtype) assert(f.handle_data_type[Ab.data].dtype == Ab.dtype)
assert(len(f.args) == 5) assert(len(f.args) == 7)
output_ssa = False output_ssa = False
......
...@@ -39,6 +39,15 @@ cd ../.. ...@@ -39,6 +39,15 @@ cd ../..
TVM_FFI=cython python3 -m pytest -v apps/extension/tests TVM_FFI=cython python3 -m pytest -v apps/extension/tests
TVM_FFI=ctypes python3 -m pytest -v apps/extension/tests TVM_FFI=ctypes python3 -m pytest -v apps/extension/tests
# Test dso plugin
cd apps/dso_plugin_module
rm -rf lib
make
cd ../..
TVM_FFI=cython python3 -m pytest -v apps/dso_plugin_module
TVM_FFI=ctypes python3 -m pytest -v apps/dso_plugin_module
TVM_FFI=ctypes python3 -m pytest -v tests/python/integration TVM_FFI=ctypes python3 -m pytest -v tests/python/integration
TVM_FFI=ctypes python3 -m pytest -v tests/python/contrib TVM_FFI=ctypes python3 -m pytest -v tests/python/contrib
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment