/*! * Copyright (c) 2017 by Contributors * \file tvm/runtime/packed_func.h * \brief Type-erased function used across TVM API. */ #ifndef TVM_RUNTIME_PACKED_FUNC_H_ #define TVM_RUNTIME_PACKED_FUNC_H_ #include <dmlc/logging.h> #include <functional> #include <tuple> #include <vector> #include <string> #include <limits> #include <memory> #include <type_traits> #include "./c_runtime_api.h" #include "./module.h" namespace HalideIR { // Forward declare type for extensions // The header works fine without depending on this. struct Type; struct Expr; } // Whether use TVM runtime in header only mode. #ifndef TVM_RUNTIME_HEADER_ONLY #define TVM_RUNTIME_HEADER_ONLY 0 #endif namespace tvm { // Forward declare NodeRef and Node for extensions. // This header works fine without depend on NodeRef // as long as it is not used. class Node; class NodeRef; namespace runtime { // forward declarations class TVMArgs; class TVMArgValue; class TVMRetValue; class TVMArgsSetter; /*! * \brief Packed function is a type-erased function. * The arguments are passed by packed format. * * This is an useful unified interface to call generated functions, * It is the unified function function type of TVM. * It corresponds to TVMFunctionHandle in C runtime API. */ class PackedFunc { public: /*! * \brief The internal std::function * \param args The arguments to the function. * \param rv The return value. * * \code * // Example code on how to implemented FType * void MyPackedFunc(TVMArgs args, TVMRetValue* rv) { * // automatically convert arguments to desired type. * int a0 = args[0]; * float a1 = args[1]; * ... * // automatically assign values to rv * std::string my_return_value = "x"; * *rv = my_return_value; * } * \endcode */ using FType = std::function<void (TVMArgs args, TVMRetValue* rv)>; /*! \brief default constructor */ PackedFunc() {} /*! * \brief constructing a packed function from a std::function. * \param body the internal container of packed function. */ explicit PackedFunc(FType body) : body_(body) {} /*! * \brief Call packed function by directly passing in unpacked format. * \param args Arguments to be passed. * \tparam Args arguments to be passed. * * \code * // Example code on how to call packed function * void CallPacked(PackedFunc f) { * // call like normal functions by pass in arguments * // return value is automatically converted back * int rvalue = f(1, 2.0); * } * \endcode */ template<typename... Args> inline TVMRetValue operator()(Args&& ...args) const; /*! * \brief Call the function in packed format. * \param args The arguments * \param rv The return value. */ inline void CallPacked(TVMArgs args, TVMRetValue* rv) const; /*! \return the internal body function */ inline FType body() const; /*! \return Whether the packed function is nullptr */ bool operator==(std::nullptr_t null) const { return body_ == nullptr; } /*! \return Whether the packed function is not nullptr */ bool operator!=(std::nullptr_t null) const { return body_ != nullptr; } private: /*! \brief internal container of packed function */ FType body_; }; /*! \brief Arguments into TVM functions. */ class TVMArgs { public: const TVMValue* values; const int* type_codes; int num_args; /*! * \brief constructor * \param values The argument values * \param type_codes The argument type codes * \param num_args number of arguments. */ TVMArgs(const TVMValue* values, const int* type_codes, int num_args) : values(values), type_codes(type_codes), num_args(num_args) { } /*! \return size of the arguments */ inline int size() const; /*! * \brief Get i-th argument * \param i the index. * \return the ith argument. */ inline TVMArgValue operator[](int i) const; }; /*! * \brief Convert type code to its name * \param type_code The type code . * \return The name of type code. */ inline const char* TypeCode2Str(int type_code); /*! * \brief convert a string to TVM type. * \param s The string to be converted. * \return The corresponding tvm type. */ inline TVMType String2TVMType(std::string s); /*! * \brief convert a TVM type to string. * \param t The type to be converted. * \return The corresponding tvm type in string. */ inline std::string TVMType2String(TVMType t); // macro to check type code. #define TVM_CHECK_TYPE_CODE(CODE, T) \ CHECK_EQ(CODE, T) << " expected " \ << TypeCode2Str(T) << " but get " << TypeCode2Str(CODE) \ /*! * \brief Type traits to mark if a class is tvm extension type. * * To enable extension type in C++ must be register () ed via marco. * TVM_REGISTER_EXT_TYPE(TypeName) after defining this with this traits. * * Extension class can be passed and returned via PackedFunc in all tvm runtime. * Internally extension class is stored as T*. * * \tparam T the typename */ template<typename T> struct extension_class_info { static const int code = 0; }; /*! * \brief Runtime function table about extension type. */ class ExtTypeVTable { public: /*! \brief function to be called to delete a handle */ void (*destroy)(void* handle); /*! \brief function to be called when clone a handle */ void* (*clone)(void* handle); /*! * \brief Register type * \tparam T The type to be register. * \return The registered vtable. */ template <typename T> static inline ExtTypeVTable* Register_(); /*! * \brief Get a vtable based on type code. * \param type_code The type code * \return The registered vtable. */ TVM_DLL static ExtTypeVTable* Get(int type_code); private: // Internal registration function. TVM_DLL static ExtTypeVTable* RegisterInternal(int type_code, const ExtTypeVTable& vt); }; /*! * \brief Internal base class to * handle conversion to POD values. */ class TVMPODValue_ { public: operator double() const { TVM_CHECK_TYPE_CODE(type_code_, kDLFloat); return value_.v_float64; } operator int64_t() const { TVM_CHECK_TYPE_CODE(type_code_, kDLInt); return value_.v_int64; } operator uint64_t() const { TVM_CHECK_TYPE_CODE(type_code_, kDLInt); return value_.v_int64; } operator int() const { TVM_CHECK_TYPE_CODE(type_code_, kDLInt); CHECK_LE(value_.v_int64, std::numeric_limits<int>::max()); return static_cast<int>(value_.v_int64); } operator bool() const { TVM_CHECK_TYPE_CODE(type_code_, kDLInt); return value_.v_int64 != 0; } operator void*() const { if (type_code_ == kNull) return nullptr; if (type_code_ == kArrayHandle) return value_.v_handle; TVM_CHECK_TYPE_CODE(type_code_, kHandle); return value_.v_handle; } operator TVMArray*() const { if (type_code_ == kNull) return nullptr; TVM_CHECK_TYPE_CODE(type_code_, kArrayHandle); return static_cast<TVMArray*>(value_.v_handle); } operator TVMContext() const { TVM_CHECK_TYPE_CODE(type_code_, kTVMContext); return value_.v_ctx; } template<typename TExtension> const TExtension& AsExtension() const { CHECK_LT(type_code_, kExtEnd); return static_cast<TExtension*>(value_.v_handle)[0]; } int type_code() const { return type_code_; } /*! * \brief return handle as specific pointer type. * \tparam T the data type. * \return The pointer type. */ template<typename T> T* ptr() const { return static_cast<T*>(value_.v_handle); } protected: friend class TVMArgsSetter; friend class TVMRetValue; TVMPODValue_() : type_code_(kNull) {} TVMPODValue_(TVMValue value, int type_code) : value_(value), type_code_(type_code) {} /*! \brief The value */ TVMValue value_; /*! \brief the type code */ int type_code_; }; /*! * \brief A single argument value to PackedFunc. * Containing both type_code and TVMValue * * Provides utilities to do type cast into other types. */ class TVMArgValue : public TVMPODValue_ { public: /*! * \brief constructor * \param value of the function * \param type_code The type code. */ TVMArgValue(TVMValue value, int type_code) : TVMPODValue_(value, type_code) { } // reuse converter from parent using TVMPODValue_::operator double; using TVMPODValue_::operator int64_t; using TVMPODValue_::operator uint64_t; using TVMPODValue_::operator int; using TVMPODValue_::operator bool; using TVMPODValue_::operator void*; using TVMPODValue_::operator TVMArray*; using TVMPODValue_::operator TVMContext; // conversion operator. operator std::string() const { if (type_code_ == kTVMType) { return TVMType2String(operator TVMType()); } else if (type_code_ == kBytes) { TVMByteArray* arr = static_cast<TVMByteArray*>(value_.v_handle); return std::string(arr->data, arr->size); } else { TVM_CHECK_TYPE_CODE(type_code_, kStr); return std::string(value_.v_str); } } operator TVMType() const { if (type_code_ == kStr) { return String2TVMType(operator std::string()); } TVM_CHECK_TYPE_CODE(type_code_, kTVMType); return value_.v_type; } operator PackedFunc() const { if (type_code_ == kNull) return PackedFunc(); TVM_CHECK_TYPE_CODE(type_code_, kFuncHandle); return *ptr<PackedFunc>(); } operator Module() const { TVM_CHECK_TYPE_CODE(type_code_, kModuleHandle); return *ptr<Module>(); } const TVMValue& value() const { return value_; } // Deferred extension handler. template<typename TNodeRef> inline TNodeRef AsNodeRef() const; template<typename T, typename = typename std::enable_if< std::is_class<T>::value>::type> inline operator T() const; template<typename TNodeRef, typename = typename std::enable_if< std::is_class<TNodeRef>::value>::type> inline bool IsNodeType() const; inline operator HalideIR::Type() const; inline operator HalideIR::Expr() const; // get internal node ptr, if it is node inline std::shared_ptr<Node>& node_sptr(); }; /*! * \brief Return Value container, * Unlike TVMArgValue, which only holds reference and do not delete * the underlying container during destruction. * * TVMRetValue holds value and will manage the underlying containers * when it stores a complicated data type. */ class TVMRetValue : public TVMPODValue_ { public: /*! \brief default constructor */ TVMRetValue() {} /*! * \brief move constructor from anoter return value. * \param other The other return value. */ TVMRetValue(TVMRetValue&& other) : TVMPODValue_(other.value_, other.type_code_) { other.value_.v_handle = nullptr; other.type_code_ = kNull; } /*! \brief destructor */ ~TVMRetValue() { this->Clear(); } // reuse converter from parent using TVMPODValue_::operator double; using TVMPODValue_::operator int64_t; using TVMPODValue_::operator uint64_t; using TVMPODValue_::operator int; using TVMPODValue_::operator bool; using TVMPODValue_::operator void*; using TVMPODValue_::operator TVMArray*; using TVMPODValue_::operator TVMContext; // Disable copy and assign from another value, but allow move. TVMRetValue(const TVMRetValue& other) { this->Assign(other); } // conversion operators operator std::string() const { if (type_code_ == kTVMType) { return TVMType2String(operator TVMType()); } else if (type_code_ == kBytes) { return *ptr<std::string>(); } TVM_CHECK_TYPE_CODE(type_code_, kStr); return *ptr<std::string>(); } operator TVMType() const { if (type_code_ == kStr) { return String2TVMType(operator std::string()); } TVM_CHECK_TYPE_CODE(type_code_, kTVMType); return value_.v_type; } operator PackedFunc() const { if (type_code_ == kNull) return PackedFunc(); TVM_CHECK_TYPE_CODE(type_code_, kFuncHandle); return *ptr<PackedFunc>(); } operator Module() const { TVM_CHECK_TYPE_CODE(type_code_, kModuleHandle); return *ptr<Module>(); } // Assign operators TVMRetValue& operator=(TVMRetValue&& other) { this->Clear(); value_ = other.value_; type_code_ = other.type_code_; other.type_code_ = kNull; return *this; } TVMRetValue& operator=(double value) { this->SwitchToPOD(kDLFloat); value_.v_float64 = value; return *this; } TVMRetValue& operator=(std::nullptr_t value) { this->SwitchToPOD(kNull); value_.v_handle = value; return *this; } TVMRetValue& operator=(void* value) { this->SwitchToPOD(kHandle); value_.v_handle = value; return *this; } TVMRetValue& operator=(int64_t value) { this->SwitchToPOD(kDLInt); value_.v_int64 = value; return *this; } TVMRetValue& operator=(int value) { this->SwitchToPOD(kDLInt); value_.v_int64 = value; return *this; } TVMRetValue& operator=(TVMType t) { this->SwitchToPOD(kTVMType); value_.v_type = t; return *this; } TVMRetValue& operator=(bool value) { this->SwitchToPOD(kDLInt); value_.v_int64 = value; return *this; } TVMRetValue& operator=(std::string value) { this->SwitchToClass(kStr, value); return *this; } TVMRetValue& operator=(TVMByteArray value) { this->SwitchToClass(kBytes, std::string(value.data, value.size)); return *this; } TVMRetValue& operator=(PackedFunc f) { this->SwitchToClass(kFuncHandle, f); return *this; } TVMRetValue& operator=(Module m) { this->SwitchToClass(kModuleHandle, m); return *this; } TVMRetValue& operator=(const TVMRetValue& other) { // NOLINT(*0 this->Assign(other); return *this; } TVMRetValue& operator=(const TVMArgValue& other) { this->Assign(other); return *this; } template<typename T, typename = typename std::enable_if< extension_class_info<T>::code != 0>::type> TVMRetValue& operator=(const T& other) { this->SwitchToClass<T>( extension_class_info<T>::code, other); return *this; } /*! * \brief Move the value back to front-end via C API. * This marks the current container as null. * The managed resources is moved to front-end and * the front end should take charge in managing them. * * \param ret_value The return value. * \param ret_type_code The return type code. */ void MoveToCHost(TVMValue* ret_value, int* ret_type_code) { // cannot move str; need specially handle. CHECK(type_code_ != kStr && type_code_ != kBytes); *ret_value = value_; *ret_type_code = type_code_; type_code_ = kNull; } /*! \return The value field, if the data is POD */ const TVMValue& value() const { CHECK(type_code_ != kNodeHandle && type_code_ != kFuncHandle && type_code_ != kModuleHandle && type_code_ != kStr) << "TVMRetValue.value can only be used for POD data"; return value_; } // NodeRef related extenstions: in tvm/packed_func_ext.h template<typename T, typename = typename std::enable_if< std::is_class<T>::value>::type> inline operator T() const; template<typename TNodeRef> inline TNodeRef AsNodeRef() const; inline TVMRetValue& operator=(const NodeRef& other); inline TVMRetValue& operator=(const std::shared_ptr<Node>& other); // type related inline operator HalideIR::Type() const; inline TVMRetValue& operator=(const HalideIR::Type& other); private: template<typename T> void Assign(const T& other) { switch (other.type_code()) { case kStr: { SwitchToClass<std::string>(kStr, other); break; } case kBytes: { SwitchToClass<std::string>(kBytes, other); break; } case kFuncHandle: { SwitchToClass<PackedFunc>(kFuncHandle, other); break; } case kModuleHandle: { SwitchToClass<Module>(kModuleHandle, other); break; } case kNodeHandle: { SwitchToClass<std::shared_ptr<Node> >( kNodeHandle, *other.template ptr<std::shared_ptr<Node> >()); break; } default: { if (other.type_code() < kExtBegin) { SwitchToPOD(other.type_code()); value_ = other.value_; } else { #if TVM_RUNTIME_HEADER_ONLY LOG(FATAL) << "Header only mode do not support ext type"; #else this->Clear(); type_code_ = other.type_code(); value_.v_handle = (*(ExtTypeVTable::Get(other.type_code())->clone))( other.value().v_handle); #endif } break; } } } // get the internal container. void SwitchToPOD(int type_code) { if (type_code_ != type_code) { this->Clear(); type_code_ = type_code; } } template<typename T> void SwitchToClass(int type_code, T v) { if (type_code_ != type_code) { this->Clear(); type_code_ = type_code; value_.v_handle = new T(v); } else { *static_cast<T*>(value_.v_handle) = v; } } void Clear() { if (type_code_ == kNull) return; switch (type_code_) { case kStr: delete ptr<std::string>(); break; case kFuncHandle: delete ptr<PackedFunc>(); break; case kModuleHandle: delete ptr<Module>(); break; case kNodeHandle: delete ptr<std::shared_ptr<Node> >(); break; } if (type_code_ > kExtBegin) { #if TVM_RUNTIME_HEADER_ONLY LOG(FATAL) << "Header only mode do not support ext type"; #else (*(ExtTypeVTable::Get(type_code_)->destroy))(value_.v_handle); #endif } type_code_ = kNull; } }; // implementation details inline const char* TypeCode2Str(int type_code) { switch (type_code) { case kDLInt: return "int"; case kDLUInt: return "uint"; case kDLFloat: return "float"; case kStr: return "str"; case kBytes: return "bytes"; case kHandle: return "handle"; case kNull: return "NULL"; case kNodeHandle: return "NodeHandle"; case kArrayHandle: return "ArrayHandle"; case kTVMType: return "TVMType"; case kTVMContext: return "TVMContext"; case kFuncHandle: return "FunctionHandle"; case kModuleHandle: return "ModuleHandle"; default: LOG(FATAL) << "unknown type_code=" << static_cast<int>(type_code); return ""; } } #ifndef _LIBCPP_SGX_NO_IOSTREAMS inline std::ostream& operator<<(std::ostream& os, TVMType t) { // NOLINT(*) os << TypeCode2Str(t.code); if (t.code == kHandle) return os; os << static_cast<int>(t.bits); if (t.lanes != 1) { os << 'x' << static_cast<int>(t.lanes); } return os; } #endif inline std::string TVMType2String(TVMType t) { #ifndef _LIBCPP_SGX_NO_IOSTREAMS std::ostringstream os; os << t; return os.str(); #else std::string repr = ""; repr += TypeCode2Str(t.code); if (t.code == kHandle) return repr; repr += std::to_string(static_cast<int>(t.bits)); if (t.lanes != 1) { repr += "x" + std::to_string(static_cast<int>(t.lanes)); } return repr; #endif } inline TVMType String2TVMType(std::string s) { TVMType t; t.bits = 32; t.lanes = 1; const char* scan; if (s.substr(0, 3) == "int") { t.code = kDLInt; scan = s.c_str() + 3; } else if (s.substr(0, 4) == "uint") { t.code = kDLUInt; scan = s.c_str() + 4; } else if (s.substr(0, 5) == "float") { t.code = kDLFloat; scan = s.c_str() + 5; } else if (s.substr(0, 6) == "handle") { t.code = kHandle; t.bits = 64; // handle uses 64 bit by default. scan = s.c_str() + 6; } else { scan = s.c_str(); LOG(FATAL) << "unknown type " << s; } char* xdelim; // emulate sscanf("%ux%u", bits, lanes) uint8_t bits = static_cast<uint8_t>(strtoul(scan, &xdelim, 10)); if (bits != 0) t.bits = bits; if (*xdelim == 'x') { t.lanes = static_cast<uint16_t>(strtoul(xdelim + 1, nullptr, 10)); } return t; } inline TVMArgValue TVMArgs::operator[](int i) const { CHECK_LT(i, num_args) << "not enough argument passed, " << num_args << " passed" << " but request arg[" << i << "]."; return TVMArgValue(values[i], type_codes[i]); } inline int TVMArgs::size() const { return num_args; } inline void PackedFunc::CallPacked(TVMArgs args, TVMRetValue* rv) const { body_(args, rv); } inline PackedFunc::FType PackedFunc::body() const { return body_; } // internal namespace namespace detail { template<bool stop, std::size_t I, typename F> struct for_each_dispatcher { template<typename T, typename ...Args> static void run(const F& f, T&& value, Args&&... args) { // NOLINT(*) f(I, std::forward<T>(value)); for_each_dispatcher<sizeof...(Args) == 0, (I+1), F> ::run(f, std::forward<Args>(args)...); } }; template<std::size_t I, typename F> struct for_each_dispatcher<true, I, F> { static void run(const F& f) {} // NOLINT(*) }; template<typename F, typename ...Args> inline void for_each(const F& f, Args&&... args) { // NOLINT(*) for_each_dispatcher<sizeof...(Args) == 0, 0, F> ::run(f, std::forward<Args>(args)...); } } // namespace detail /* \brief argument settter to PackedFunc */ class TVMArgsSetter { public: TVMArgsSetter(TVMValue* values, int* type_codes) : values_(values), type_codes_(type_codes) {} // setters for POD types template<typename T, typename = typename std::enable_if< std::is_integral<T>::value>::type> void operator()(size_t i, T value) const { values_[i].v_int64 = static_cast<int64_t>(value); type_codes_[i] = kDLInt; } void operator()(size_t i, uint64_t value) const { values_[i].v_int64 = static_cast<int64_t>(value); CHECK_LE(value, static_cast<uint64_t>(std::numeric_limits<int64_t>::max())); type_codes_[i] = kDLInt; } void operator()(size_t i, double value) const { values_[i].v_float64 = value; type_codes_[i] = kDLFloat; } void operator()(size_t i, std::nullptr_t value) const { values_[i].v_handle = value; type_codes_[i] = kNull; } void operator()(size_t i, const TVMArgValue& value) const { values_[i] = value.value_; type_codes_[i] = value.type_code_; } void operator()(size_t i, void* value) const { values_[i].v_handle = value; type_codes_[i] = kHandle; } void operator()(size_t i, TVMArray* value) const { values_[i].v_handle = value; type_codes_[i] = kArrayHandle; } void operator()(size_t i, TVMContext value) const { values_[i].v_ctx = value; type_codes_[i] = kTVMContext; } void operator()(size_t i, TVMType value) const { values_[i].v_type = value; type_codes_[i] = kTVMType; } void operator()(size_t i, const char* value) const { values_[i].v_str = value; type_codes_[i] = kStr; } // setters for container type // They must be reference(instead of const ref) // to make sure they are alive in the tuple(instead of getting converted) void operator()(size_t i, const std::string& value) const { // NOLINT(*) values_[i].v_str = value.c_str(); type_codes_[i] = kStr; } void operator()(size_t i, const TVMByteArray& value) const { // NOLINT(*) values_[i].v_handle = const_cast<TVMByteArray*>(&value); type_codes_[i] = kBytes; } void operator()(size_t i, const PackedFunc& value) const { // NOLINT(*) values_[i].v_handle = const_cast<PackedFunc*>(&value); type_codes_[i] = kFuncHandle; } void operator()(size_t i, const Module& value) const { // NOLINT(*) values_[i].v_handle = const_cast<Module*>(&value); type_codes_[i] = kModuleHandle; } void operator()(size_t i, const TVMRetValue& value) const { // NOLINT(*) if (value.type_code() == kStr) { values_[i].v_str = value.ptr<std::string>()->c_str(); type_codes_[i] = kStr; } else { CHECK_NE(value.type_code(), kBytes) << "not handled."; values_[i] = value.value_; type_codes_[i] = value.type_code(); } } // extension template<typename T, typename = typename std::enable_if< extension_class_info<T>::code != 0>::type> inline void operator()(size_t i, const T& value) const; // NodeRef related extenstions: in tvm/packed_func_ext.h inline void operator()(size_t i, const NodeRef& other) const; // NOLINT(*) inline void operator()(size_t i, const HalideIR::Type& t) const; private: /*! \brief The values fields */ TVMValue* values_; /*! \brief The type code fields */ int* type_codes_; }; template<typename... Args> inline TVMRetValue PackedFunc::operator()(Args&& ...args) const { const int kNumArgs = sizeof...(Args); const int kArraySize = kNumArgs > 0 ? kNumArgs : 1; TVMValue values[kArraySize]; int type_codes[kArraySize]; detail::for_each(TVMArgsSetter(values, type_codes), std::forward<Args>(args)...); TVMRetValue rv; body_(TVMArgs(values, type_codes, kNumArgs), &rv); return rv; } // extension and node type handling namespace detail { template<typename T, typename TSrc, bool is_ext> struct TVMValueCast { static T Apply(const TSrc* self) { return self->template AsNodeRef<T>(); } }; template<typename T, typename TSrc> struct TVMValueCast<T, TSrc, true> { static T Apply(const TSrc* self) { return self->template AsExtension<T>(); } }; } // namespace detail template<typename T, typename> inline TVMArgValue::operator T() const { return detail:: TVMValueCast<T, TVMArgValue, extension_class_info<T>::code != 0> ::Apply(this); } template<typename T, typename> inline TVMRetValue::operator T() const { return detail:: TVMValueCast<T, TVMRetValue, extension_class_info<T>::code != 0> ::Apply(this); } template<typename T, typename> inline void TVMArgsSetter::operator()(size_t i, const T& value) const { static_assert(extension_class_info<T>::code != 0, "Need to have extesion code"); type_codes_[i] = extension_class_info<T>::code; values_[i].v_handle = const_cast<T*>(&value); } // extension type handling template<typename T> struct ExtTypeInfo { static void destroy(void* handle) { delete static_cast<T*>(handle); } static void* clone(void* handle) { return new T(*static_cast<T*>(handle)); } }; template<typename T> inline ExtTypeVTable* ExtTypeVTable::Register_() { const int code = extension_class_info<T>::code; static_assert(code != 0, "require extension_class_info traits to be declared with non-zero code"); ExtTypeVTable vt; vt.clone = ExtTypeInfo<T>::clone; vt.destroy = ExtTypeInfo<T>::destroy; return ExtTypeVTable::RegisterInternal(code, vt); } // Implement Module::GetFunction // Put implementation in this file so we have seen the PackedFunc inline PackedFunc Module::GetFunction(const std::string& name, bool query_imports) { PackedFunc pf = node_->GetFunction(name, node_); if (pf != nullptr) return pf; if (query_imports) { for (const Module& m : node_->imports_) { pf = m.node_->GetFunction(name, m.node_); if (pf != nullptr) return pf; } } return pf; } } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_PACKED_FUNC_H_