Unverified Commit f823c577 by Tianqi Chen Committed by GitHub

[RUNTIME][REFACTOR] Use object protocol to support runtime::Module (#4289)

Previously runtime::Module was supported using shared_ptr.
This PR refactors the codebase to use the Object protocol.

It will open doors to allow easier interpolation between
Object containers and module in the future.
parent aeb5f130
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
*/ */
/*! /*!
* Copyright (c) 2018 by Contributors
* \file tvm_runtime.h * \file tvm_runtime.h
* \brief Pack all tvm runtime source files * \brief Pack all tvm runtime source files
*/ */
...@@ -35,6 +34,7 @@ ...@@ -35,6 +34,7 @@
#include "../src/runtime/file_util.cc" #include "../src/runtime/file_util.cc"
#include "../src/runtime/dso_module.cc" #include "../src/runtime/dso_module.cc"
#include "../src/runtime/thread_pool.cc" #include "../src/runtime/thread_pool.cc"
#include "../src/runtime/object.cc"
#include "../src/runtime/threading_backend.cc" #include "../src/runtime/threading_backend.cc"
#include "../src/runtime/ndarray.cc" #include "../src/runtime/ndarray.cc"
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -55,6 +55,7 @@ ...@@ -55,6 +55,7 @@
#include "../src/runtime/threading_backend.cc" #include "../src/runtime/threading_backend.cc"
#include "../src/runtime/graph/graph_runtime.cc" #include "../src/runtime/graph/graph_runtime.cc"
#include "../src/runtime/ndarray.cc" #include "../src/runtime/ndarray.cc"
#include "../src/runtime/object.cc"
#ifdef TVM_OPENCL_RUNTIME #ifdef TVM_OPENCL_RUNTIME
#include "../src/runtime/opencl/opencl_device_api.cc" #include "../src/runtime/opencl/opencl_device_api.cc"
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -32,5 +32,6 @@ ...@@ -32,5 +32,6 @@
#include "../../src/runtime/threading_backend.cc" #include "../../src/runtime/threading_backend.cc"
#include "../../src/runtime/thread_pool.cc" #include "../../src/runtime/thread_pool.cc"
#include "../../src/runtime/ndarray.cc" #include "../../src/runtime/ndarray.cc"
#include "../../src/runtime/object.cc"
#include "../../src/runtime/system_lib_module.cc" #include "../../src/runtime/system_lib_module.cc"
#include "../../src/runtime/graph/graph_runtime.cc" #include "../../src/runtime/graph/graph_runtime.cc"
...@@ -47,6 +47,7 @@ ...@@ -47,6 +47,7 @@
#include "../../src/runtime/threading_backend.cc" #include "../../src/runtime/threading_backend.cc"
#include "../../src/runtime/thread_pool.cc" #include "../../src/runtime/thread_pool.cc"
#include "../../src/runtime/ndarray.cc" #include "../../src/runtime/ndarray.cc"
#include "../../src/runtime/object.cc"
// NOTE: all the files after this are optional modules // NOTE: all the files after this are optional modules
// that you can include remove, depending on how much feature you use. // that you can include remove, depending on how much feature you use.
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
*/ */
/*! /*!
* Copyright (c) 2017 by Contributors
* \file TVMRuntime.mm * \file TVMRuntime.mm
*/ */
#include "TVMRuntime.h" #include "TVMRuntime.h"
...@@ -35,6 +34,8 @@ ...@@ -35,6 +34,8 @@
#include "../../../src/runtime/file_util.cc" #include "../../../src/runtime/file_util.cc"
#include "../../../src/runtime/dso_module.cc" #include "../../../src/runtime/dso_module.cc"
#include "../../../src/runtime/ndarray.cc" #include "../../../src/runtime/ndarray.cc"
#include "../../../src/runtime/object.cc"
// RPC server // RPC server
#include "../../../src/runtime/rpc/rpc_session.cc" #include "../../../src/runtime/rpc/rpc_session.cc"
#include "../../../src/runtime/rpc/rpc_server_env.cc" #include "../../../src/runtime/rpc/rpc_server_env.cc"
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
*/ */
/*! /*!
* Copyright (c) 2018 by Contributors
* \brief This is an all in one TVM runtime file. * \brief This is an all in one TVM runtime file.
* \file tvm_runtime_pack.cc * \file tvm_runtime_pack.cc
*/ */
...@@ -32,6 +31,7 @@ ...@@ -32,6 +31,7 @@
#include "src/runtime/threading_backend.cc" #include "src/runtime/threading_backend.cc"
#include "src/runtime/thread_pool.cc" #include "src/runtime/thread_pool.cc"
#include "src/runtime/ndarray.cc" #include "src/runtime/ndarray.cc"
#include "src/runtime/object.cc"
// NOTE: all the files after this are optional modules // NOTE: all the files after this are optional modules
// that you can include remove, depending on how much feature you use. // that you can include remove, depending on how much feature you use.
......
...@@ -27,28 +27,31 @@ ...@@ -27,28 +27,31 @@
#define TVM_RUNTIME_MODULE_H_ #define TVM_RUNTIME_MODULE_H_
#include <dmlc/io.h> #include <dmlc/io.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/memory.h>
#include <memory> #include <memory>
#include <vector> #include <vector>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include "c_runtime_api.h"
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
// The internal container of module.
class ModuleNode; class ModuleNode;
class PackedFunc; class PackedFunc;
/*! /*!
* \brief Module container of TVM. * \brief Module container of TVM.
*/ */
class Module { class Module : public ObjectRef {
public: public:
Module() {} Module() {}
// constructor from container. // constructor from container.
explicit Module(std::shared_ptr<ModuleNode> n) explicit Module(ObjectPtr<Object> n)
: node_(n) {} : ObjectRef(n) {}
/*! /*!
* \brief Get packed function from current module by name. * \brief Get packed function from current module by name.
* *
...@@ -59,10 +62,6 @@ class Module { ...@@ -59,10 +62,6 @@ class Module {
* \note Implemented in packed_func.cc * \note Implemented in packed_func.cc
*/ */
inline PackedFunc GetFunction(const std::string& name, bool query_imports = false); inline PackedFunc GetFunction(const std::string& name, bool query_imports = false);
/*! \return internal container */
inline ModuleNode* operator->();
/*! \return internal container */
inline const ModuleNode* operator->() const;
// The following functions requires link with runtime. // The following functions requires link with runtime.
/*! /*!
* \brief Import another module into this module. * \brief Import another module into this module.
...@@ -71,7 +70,11 @@ class Module { ...@@ -71,7 +70,11 @@ class Module {
* \note Cyclic dependency is not allowed among modules, * \note Cyclic dependency is not allowed among modules,
* An error will be thrown when cyclic dependency is detected. * An error will be thrown when cyclic dependency is detected.
*/ */
TVM_DLL void Import(Module other); inline void Import(Module other);
/*! \return internal container */
inline ModuleNode* operator->();
/*! \return internal container */
inline const ModuleNode* operator->() const;
/*! /*!
* \brief Load a module from file. * \brief Load a module from file.
* \param file_name The name of the host function module. * \param file_name The name of the host function module.
...@@ -81,20 +84,41 @@ class Module { ...@@ -81,20 +84,41 @@ class Module {
*/ */
TVM_DLL static Module LoadFromFile(const std::string& file_name, TVM_DLL static Module LoadFromFile(const std::string& file_name,
const std::string& format = ""); const std::string& format = "");
// refer to the corresponding container.
private: using ContainerType = ModuleNode;
std::shared_ptr<ModuleNode> node_; friend class ModuleNode;
}; };
/*! /*!
* \brief Base node container of module. * \brief Base container of module.
* Do not create this directly, instead use Module. *
* Please subclass ModuleNode to create a specific runtime module.
*
* \code
*
* class MyModuleNode : public ModuleNode {
* public:
* // implement the interface
* };
*
* // use make_object to create a specific
* // instace of MyModuleNode.
* Module CreateMyModule() {
* ObjectPtr<MyModuleNode> n =
* tvm::runtime::make_object<MyModuleNode>();
* return Module(n);
* }
*
* \endcode
*/ */
class ModuleNode { class ModuleNode : public Object {
public: public:
/*! \brief virtual destructor */ /*! \brief virtual destructor */
virtual ~ModuleNode() {} virtual ~ModuleNode() {}
/*! \return The module type key */ /*!
* \return The per module type key.
* \note This key is used to for serializing custom modules.
*/
virtual const char* type_key() const = 0; virtual const char* type_key() const = 0;
/*! /*!
* \brief Get a PackedFunc from module. * \brief Get a PackedFunc from module.
...@@ -105,7 +129,7 @@ class ModuleNode { ...@@ -105,7 +129,7 @@ class ModuleNode {
* For benchmarking, use prepare to eliminate * For benchmarking, use prepare to eliminate
* *
* \param name the name of the function. * \param name the name of the function.
* \param sptr_to_self The shared_ptr that points to this module node. * \param sptr_to_self The ObjectPtr that points to this module node.
* *
* \return PackedFunc(nullptr) when it is not available. * \return PackedFunc(nullptr) when it is not available.
* *
...@@ -115,7 +139,7 @@ class ModuleNode { ...@@ -115,7 +139,7 @@ class ModuleNode {
*/ */
virtual PackedFunc GetFunction( virtual PackedFunc GetFunction(
const std::string& name, const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) = 0; const ObjectPtr<Object>& sptr_to_self) = 0;
/*! /*!
* \brief Save the module to file. * \brief Save the module to file.
* \param file_name The file to be saved to. * \param file_name The file to be saved to.
...@@ -138,6 +162,24 @@ class ModuleNode { ...@@ -138,6 +162,24 @@ class ModuleNode {
*/ */
TVM_DLL virtual std::string GetSource(const std::string& format = ""); TVM_DLL virtual std::string GetSource(const std::string& format = "");
/*! /*!
* \brief Get packed function from current module by name.
*
* \param name The name of the function.
* \param query_imports Whether also query dependency modules.
* \return The result function.
* This function will return PackedFunc(nullptr) if function do not exist.
* \note Implemented in packed_func.cc
*/
TVM_DLL PackedFunc GetFunction(const std::string& name, bool query_imports = false);
/*!
* \brief Import another module into this module.
* \param other The module to be imported.
*
* \note Cyclic dependency is not allowed among modules,
* An error will be thrown when cyclic dependency is detected.
*/
TVM_DLL void Import(Module other);
/*!
* \brief Get a function from current environment * \brief Get a function from current environment
* The environment includes all the imports as well as Global functions. * The environment includes all the imports as well as Global functions.
* *
...@@ -150,6 +192,13 @@ class ModuleNode { ...@@ -150,6 +192,13 @@ class ModuleNode {
return imports_; return imports_;
} }
// integration with the existing components.
static constexpr const uint32_t _type_index = TypeIndex::kRuntimeModule;
static constexpr const char* _type_key = "runtime.Module";
// NOTE: ModuleNode can still be sub-classed
//
TVM_DECLARE_FINAL_OBJECT_INFO(ModuleNode, Object);
protected: protected:
friend class Module; friend class Module;
/*! \brief The modules this module depend on */ /*! \brief The modules this module depend on */
...@@ -180,16 +229,21 @@ constexpr const char* tvm_module_main = "__tvm_main__"; ...@@ -180,16 +229,21 @@ constexpr const char* tvm_module_main = "__tvm_main__";
} // namespace symbol } // namespace symbol
// implementations of inline functions. // implementations of inline functions.
inline void Module::Import(Module other) {
return (*this)->Import(other);
}
inline ModuleNode* Module::operator->() { inline ModuleNode* Module::operator->() {
return node_.get(); return static_cast<ModuleNode*>(get_mutable());
} }
inline const ModuleNode* Module::operator->() const { inline const ModuleNode* Module::operator->() const {
return node_.get(); return static_cast<const ModuleNode*>(get());
} }
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
#include "packed_func.h" #include <tvm/runtime/packed_func.h> // NOLINT(*)
#endif // TVM_RUNTIME_MODULE_H_ #endif // TVM_RUNTIME_MODULE_H_
...@@ -53,6 +53,7 @@ enum TypeIndex { ...@@ -53,6 +53,7 @@ enum TypeIndex {
kVMTensor = 1, kVMTensor = 1,
kVMClosure = 2, kVMClosure = 2,
kVMADT = 3, kVMADT = 3,
kRuntimeModule = 4,
kStaticIndexEnd, kStaticIndexEnd,
/*! \brief Type index is allocated during runtime. */ /*! \brief Type index is allocated during runtime. */
kDynamic = kStaticIndexEnd kDynamic = kStaticIndexEnd
...@@ -302,7 +303,7 @@ class Object { ...@@ -302,7 +303,7 @@ class Object {
template<typename> template<typename>
friend class ObjectPtr; friend class ObjectPtr;
friend class TVMRetValue; friend class TVMRetValue;
friend class TVMObjectCAPI; friend class ObjectInternal;
}; };
/*! /*!
...@@ -310,11 +311,11 @@ class Object { ...@@ -310,11 +311,11 @@ class Object {
* *
* It is always important to get a reference type * It is always important to get a reference type
* if we want to return a value as reference or keep * if we want to return a value as reference or keep
* the node alive beyond the scope of the function. * the object alive beyond the scope of the function.
* *
* \param ptr The node pointer * \param ptr The object pointer
* \tparam RefType The reference type * \tparam RefType The reference type
* \tparam ObjectType The node type * \tparam ObjectType The object type
* \return The corresponding RefType * \return The corresponding RefType
*/ */
template <typename RefType, typename ObjectType> template <typename RefType, typename ObjectType>
...@@ -486,6 +487,8 @@ class ObjectPtr { ...@@ -486,6 +487,8 @@ class ObjectPtr {
friend class TVMArgValue; friend class TVMArgValue;
template <typename RefType, typename ObjType> template <typename RefType, typename ObjType>
friend RefType GetRef(const ObjType* ptr); friend RefType GetRef(const ObjType* ptr);
template <typename BaseType, typename ObjType>
friend ObjectPtr<BaseType> GetObjectPtr(ObjType* ptr);
}; };
/*! \brief Base class of all object reference */ /*! \brief Base class of all object reference */
...@@ -513,7 +516,7 @@ class ObjectRef { ...@@ -513,7 +516,7 @@ class ObjectRef {
} }
/*! /*!
* \brief Comparator * \brief Comparator
* \param other Another node ref. * \param other Another object ref.
* \return the compare result. * \return the compare result.
*/ */
bool operator!=(const ObjectRef& other) const { bool operator!=(const ObjectRef& other) const {
...@@ -535,7 +538,7 @@ class ObjectRef { ...@@ -535,7 +538,7 @@ class ObjectRef {
const Object* get() const { const Object* get() const {
return data_.get(); return data_.get();
} }
/*! \return the internal node pointer */ /*! \return the internal object pointer */
const Object* operator->() const { const Object* operator->() const {
return get(); return get();
} }
...@@ -595,6 +598,16 @@ class ObjectRef { ...@@ -595,6 +598,16 @@ class ObjectRef {
friend SubRef Downcast(BaseRef ref); friend SubRef Downcast(BaseRef ref);
}; };
/*!
* \brief Get an object ptr type from a raw object ptr.
*
* \param ptr The object pointer
* \tparam BaseType The reference type
* \tparam ObjectType The object type
* \return The corresponding RefType
*/
template <typename BaseType, typename ObjectType>
inline ObjectPtr<BaseType> GetObjectPtr(ObjectType* ptr);
/*! \brief ObjectRef hash functor */ /*! \brief ObjectRef hash functor */
struct ObjectHash { struct ObjectHash {
...@@ -781,6 +794,13 @@ inline RefType GetRef(const ObjType* ptr) { ...@@ -781,6 +794,13 @@ inline RefType GetRef(const ObjType* ptr) {
return RefType(ObjectPtr<Object>(const_cast<Object*>(static_cast<const Object*>(ptr)))); return RefType(ObjectPtr<Object>(const_cast<Object*>(static_cast<const Object*>(ptr))));
} }
template <typename BaseType, typename ObjType>
inline ObjectPtr<BaseType> GetObjectPtr(ObjType* ptr) {
static_assert(std::is_base_of<BaseType, ObjType>::value,
"Can only cast to the ref of same container type");
return ObjectPtr<BaseType>(static_cast<Object*>(ptr));
}
template <typename SubRef, typename BaseRef> template <typename SubRef, typename BaseRef>
inline SubRef Downcast(BaseRef ref) { inline SubRef Downcast(BaseRef ref) {
CHECK(ref->template IsInstance<typename SubRef::ContainerType>()) CHECK(ref->template IsInstance<typename SubRef::ContainerType>())
......
...@@ -496,6 +496,14 @@ class TVMPODValue_ { ...@@ -496,6 +496,14 @@ class TVMPODValue_ {
return ObjectRef( return ObjectRef(
ObjectPtr<Object>(static_cast<Object*>(value_.v_handle))); ObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
} }
operator Module() const {
if (type_code_ == kNull) {
return Module(ObjectPtr<Object>(nullptr));
}
TVM_CHECK_TYPE_CODE(type_code_, kModuleHandle);
return Module(
ObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
}
operator TVMContext() const { operator TVMContext() const {
TVM_CHECK_TYPE_CODE(type_code_, kTVMContext); TVM_CHECK_TYPE_CODE(type_code_, kTVMContext);
return value_.v_ctx; return value_.v_ctx;
...@@ -574,6 +582,7 @@ class TVMArgValue : public TVMPODValue_ { ...@@ -574,6 +582,7 @@ class TVMArgValue : public TVMPODValue_ {
using TVMPODValue_::operator NDArray; using TVMPODValue_::operator NDArray;
using TVMPODValue_::operator TVMContext; using TVMPODValue_::operator TVMContext;
using TVMPODValue_::operator ObjectRef; using TVMPODValue_::operator ObjectRef;
using TVMPODValue_::operator Module;
using TVMPODValue_::IsObjectRef; using TVMPODValue_::IsObjectRef;
// conversion operator. // conversion operator.
...@@ -610,10 +619,6 @@ class TVMArgValue : public TVMPODValue_ { ...@@ -610,10 +619,6 @@ class TVMArgValue : public TVMPODValue_ {
operator TypedPackedFunc<FType>() const { operator TypedPackedFunc<FType>() const {
return TypedPackedFunc<FType>(operator PackedFunc()); return TypedPackedFunc<FType>(operator PackedFunc());
} }
operator Module() const {
TVM_CHECK_TYPE_CODE(type_code_, kModuleHandle);
return *ptr<Module>();
}
const TVMValue& value() const { const TVMValue& value() const {
return value_; return value_;
} }
...@@ -665,6 +670,7 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -665,6 +670,7 @@ class TVMRetValue : public TVMPODValue_ {
using TVMPODValue_::operator TVMContext; using TVMPODValue_::operator TVMContext;
using TVMPODValue_::operator NDArray; using TVMPODValue_::operator NDArray;
using TVMPODValue_::operator ObjectRef; using TVMPODValue_::operator ObjectRef;
using TVMPODValue_::operator Module;
using TVMPODValue_::IsObjectRef; using TVMPODValue_::IsObjectRef;
TVMRetValue(const TVMRetValue& other) : TVMPODValue_() { TVMRetValue(const TVMRetValue& other) : TVMPODValue_() {
...@@ -696,10 +702,6 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -696,10 +702,6 @@ class TVMRetValue : public TVMPODValue_ {
operator TypedPackedFunc<FType>() const { operator TypedPackedFunc<FType>() const {
return TypedPackedFunc<FType>(operator PackedFunc()); return TypedPackedFunc<FType>(operator PackedFunc());
} }
operator Module() const {
TVM_CHECK_TYPE_CODE(type_code_, kModuleHandle);
return *ptr<Module>();
}
// Assign operators // Assign operators
TVMRetValue& operator=(TVMRetValue&& other) { TVMRetValue& operator=(TVMRetValue&& other) {
this->Clear(); this->Clear();
...@@ -766,17 +768,13 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -766,17 +768,13 @@ class TVMRetValue : public TVMPODValue_ {
TVMRetValue& operator=(ObjectRef other) { TVMRetValue& operator=(ObjectRef other) {
return operator=(std::move(other.data_)); return operator=(std::move(other.data_));
} }
TVMRetValue& operator=(Module m) {
SwitchToObject(kModuleHandle, std::move(m.data_));
return *this;
}
template<typename T> template<typename T>
TVMRetValue& operator=(ObjectPtr<T> other) { TVMRetValue& operator=(ObjectPtr<T> other) {
if (other.data_ != nullptr) { SwitchToObject(kObjectHandle, std::move(other));
this->Clear();
type_code_ = kObjectHandle;
// move the handle out
value_.v_handle = other.data_;
other.data_ = nullptr;
} else {
SwitchToPOD(kNull);
}
return *this; return *this;
} }
TVMRetValue& operator=(PackedFunc f) { TVMRetValue& operator=(PackedFunc f) {
...@@ -787,10 +785,6 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -787,10 +785,6 @@ class TVMRetValue : public TVMPODValue_ {
TVMRetValue& operator=(const TypedPackedFunc<FType>& f) { TVMRetValue& operator=(const TypedPackedFunc<FType>& f) {
return operator=(f.packed()); return operator=(f.packed());
} }
TVMRetValue& operator=(Module m) {
this->SwitchToClass(kModuleHandle, m);
return *this;
}
TVMRetValue& operator=(const TVMRetValue& other) { // NOLINT(*0 TVMRetValue& operator=(const TVMRetValue& other) { // NOLINT(*0
this->Assign(other); this->Assign(other);
return *this; return *this;
...@@ -860,7 +854,7 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -860,7 +854,7 @@ class TVMRetValue : public TVMPODValue_ {
break; break;
} }
case kModuleHandle: { case kModuleHandle: {
SwitchToClass<Module>(kModuleHandle, other); *this = other.operator Module();
break; break;
} }
case kNDArrayContainer: { case kNDArrayContainer: {
...@@ -907,16 +901,30 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -907,16 +901,30 @@ class TVMRetValue : public TVMPODValue_ {
*static_cast<T*>(value_.v_handle) = v; *static_cast<T*>(value_.v_handle) = v;
} }
} }
void SwitchToObject(int type_code, ObjectPtr<Object> other) {
if (other.data_ != nullptr) {
this->Clear();
type_code_ = type_code;
// move the handle out
value_.v_handle = other.data_;
other.data_ = nullptr;
} else {
SwitchToPOD(kNull);
}
}
void Clear() { void Clear() {
if (type_code_ == kNull) return; if (type_code_ == kNull) return;
switch (type_code_) { switch (type_code_) {
case kStr: delete ptr<std::string>(); break; case kStr: delete ptr<std::string>(); break;
case kFuncHandle: delete ptr<PackedFunc>(); break; case kFuncHandle: delete ptr<PackedFunc>(); break;
case kModuleHandle: delete ptr<Module>(); break;
case kNDArrayContainer: { case kNDArrayContainer: {
static_cast<NDArray::Container*>(value_.v_handle)->DecRef(); static_cast<NDArray::Container*>(value_.v_handle)->DecRef();
break; break;
} }
case kModuleHandle: {
static_cast<Object*>(value_.v_handle)->DecRef();
break;
}
case kObjectHandle: { case kObjectHandle: {
static_cast<Object*>(value_.v_handle)->DecRef(); static_cast<Object*>(value_.v_handle)->DecRef();
break; break;
...@@ -1156,8 +1164,12 @@ class TVMArgsSetter { ...@@ -1156,8 +1164,12 @@ class TVMArgsSetter {
operator()(i, value.packed()); operator()(i, value.packed());
} }
void operator()(size_t i, const Module& value) const { // NOLINT(*) void operator()(size_t i, const Module& value) const { // NOLINT(*)
values_[i].v_handle = const_cast<Module*>(&value); if (value.defined()) {
type_codes_[i] = kModuleHandle; values_[i].v_handle = value.data_.data_;
type_codes_[i] = kModuleHandle;
} else {
type_codes_[i] = kNull;
}
} }
void operator()(size_t i, const NDArray& value) const { // NOLINT(*) void operator()(size_t i, const NDArray& value) const { // NOLINT(*)
values_[i].v_handle = value.data_; values_[i].v_handle = value.data_;
...@@ -1372,19 +1384,10 @@ inline ExtTypeVTable* ExtTypeVTable::Register_() { ...@@ -1372,19 +1384,10 @@ inline ExtTypeVTable* ExtTypeVTable::Register_() {
return ExtTypeVTable::RegisterInternal(code, vt); return ExtTypeVTable::RegisterInternal(code, vt);
} }
// Implement Module::GetFunction
// Put implementation in this file so we have seen the PackedFunc
inline PackedFunc Module::GetFunction(const std::string& name, bool query_imports) { inline PackedFunc Module::GetFunction(const std::string& name, bool query_imports) {
PackedFunc pf = node_->GetFunction(name, node_); return (*this)->GetFunction(name, query_imports);
if (pf != nullptr) return pf;
if (query_imports) {
for (const Module& m : node_->imports_) {
pf = m.node_->GetFunction(name, m.node_);
if (pf != nullptr) return pf;
}
}
return pf;
} }
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
#endif // TVM_RUNTIME_PACKED_FUNC_H_ #endif // TVM_RUNTIME_PACKED_FUNC_H_
...@@ -480,7 +480,7 @@ class Executable : public ModuleNode { ...@@ -480,7 +480,7 @@ class Executable : public ModuleNode {
* \return PackedFunc or nullptr when it is not available. * \return PackedFunc or nullptr when it is not available.
*/ */
PackedFunc GetFunction(const std::string& name, PackedFunc GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final; const ObjectPtr<Object>& sptr_to_self) final;
/*! /*!
* \brief Serialize the executable into global section, constant section, and * \brief Serialize the executable into global section, constant section, and
...@@ -658,7 +658,7 @@ class VirtualMachine : public runtime::ModuleNode { ...@@ -658,7 +658,7 @@ class VirtualMachine : public runtime::ModuleNode {
* it should capture sptr_to_self. * it should capture sptr_to_self.
*/ */
virtual PackedFunc GetFunction(const std::string& name, virtual PackedFunc GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self); const ObjectPtr<Object>& sptr_to_self);
/*! /*!
* \brief Invoke a PackedFunction * \brief Invoke a PackedFunction
......
...@@ -148,7 +148,7 @@ class Executable(object): ...@@ -148,7 +148,7 @@ class Executable(object):
raise TypeError("bytecode is expected to be the type of bytearray " + raise TypeError("bytecode is expected to be the type of bytearray " +
"or TVMByteArray, but received {}".format(type(code))) "or TVMByteArray, but received {}".format(type(code)))
if not isinstance(lib, tvm.module.Module): if lib is not None and not isinstance(lib, tvm.module.Module):
raise TypeError("lib is expected to be the type of tvm.module.Module" + raise TypeError("lib is expected to be the type of tvm.module.Module" +
", but received {}".format(type(lib))) ", but received {}".format(type(lib)))
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
*/ */
/*! /*!
* Copyright (c) 2017 by Contributors
* \file llvm_module.cc * \file llvm_module.cc
* \brief LLVM runtime module for TVM * \brief LLVM runtime module for TVM
*/ */
...@@ -54,7 +53,7 @@ class LLVMModuleNode final : public runtime::ModuleNode { ...@@ -54,7 +53,7 @@ class LLVMModuleNode final : public runtime::ModuleNode {
PackedFunc GetFunction( PackedFunc GetFunction(
const std::string& name, const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final { const ObjectPtr<Object>& sptr_to_self) final {
if (name == "__tvm_is_system_module") { if (name == "__tvm_is_system_module") {
bool flag = bool flag =
(mptr_->getFunction("__tvm_module_startup") != nullptr); (mptr_->getFunction("__tvm_module_startup") != nullptr);
...@@ -325,7 +324,7 @@ TVM_REGISTER_API("codegen.llvm_lookup_intrinsic_id") ...@@ -325,7 +324,7 @@ TVM_REGISTER_API("codegen.llvm_lookup_intrinsic_id")
TVM_REGISTER_API("codegen.build_llvm") TVM_REGISTER_API("codegen.build_llvm")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
std::shared_ptr<LLVMModuleNode> n = std::make_shared<LLVMModuleNode>(); auto n = make_object<LLVMModuleNode>();
n->Init(args[0], args[1]); n->Init(args[0], args[1]);
*rv = runtime::Module(n); *rv = runtime::Module(n);
}); });
...@@ -339,7 +338,7 @@ TVM_REGISTER_API("codegen.llvm_version_major") ...@@ -339,7 +338,7 @@ TVM_REGISTER_API("codegen.llvm_version_major")
TVM_REGISTER_API("module.loadfile_ll") TVM_REGISTER_API("module.loadfile_ll")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
std::shared_ptr<LLVMModuleNode> n = std::make_shared<LLVMModuleNode>(); auto n = make_object<LLVMModuleNode>();
n->LoadIR(args[0]); n->LoadIR(args[0]);
*rv = runtime::Module(n); *rv = runtime::Module(n);
}); });
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
*/ */
/*! /*!
* Copyright (c) 2017 by Contributors
* \file source_module.cc * \file source_module.cc
* \brief Source code module, only for viewing * \brief Source code module, only for viewing
*/ */
...@@ -51,7 +50,7 @@ class SourceModuleNode : public runtime::ModuleNode { ...@@ -51,7 +50,7 @@ class SourceModuleNode : public runtime::ModuleNode {
PackedFunc GetFunction( PackedFunc GetFunction(
const std::string& name, const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final { const ObjectPtr<Object>& sptr_to_self) final {
LOG(FATAL) << "Source module cannot execute, to get executable module" LOG(FATAL) << "Source module cannot execute, to get executable module"
<< " build TVM with \'" << fmt_ << "\' runtime support"; << " build TVM with \'" << fmt_ << "\' runtime support";
return PackedFunc(); return PackedFunc();
...@@ -67,8 +66,7 @@ class SourceModuleNode : public runtime::ModuleNode { ...@@ -67,8 +66,7 @@ class SourceModuleNode : public runtime::ModuleNode {
}; };
runtime::Module SourceModuleCreate(std::string code, std::string fmt) { runtime::Module SourceModuleCreate(std::string code, std::string fmt) {
std::shared_ptr<SourceModuleNode> n = auto n = make_object<SourceModuleNode>(code, fmt);
std::make_shared<SourceModuleNode>(code, fmt);
return runtime::Module(n); return runtime::Module(n);
} }
...@@ -84,7 +82,7 @@ class CSourceModuleNode : public runtime::ModuleNode { ...@@ -84,7 +82,7 @@ class CSourceModuleNode : public runtime::ModuleNode {
PackedFunc GetFunction( PackedFunc GetFunction(
const std::string& name, const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final { const ObjectPtr<Object>& sptr_to_self) final {
LOG(FATAL) << "C Source module cannot execute, to get executable module" LOG(FATAL) << "C Source module cannot execute, to get executable module"
<< " build TVM with \'" << fmt_ << "\' runtime support"; << " build TVM with \'" << fmt_ << "\' runtime support";
return PackedFunc(); return PackedFunc();
...@@ -113,8 +111,7 @@ class CSourceModuleNode : public runtime::ModuleNode { ...@@ -113,8 +111,7 @@ class CSourceModuleNode : public runtime::ModuleNode {
}; };
runtime::Module CSourceModuleCreate(std::string code, std::string fmt) { runtime::Module CSourceModuleCreate(std::string code, std::string fmt) {
std::shared_ptr<CSourceModuleNode> n = auto n = make_object<CSourceModuleNode>(code, fmt);
std::make_shared<CSourceModuleNode>(code, fmt);
return runtime::Module(n); return runtime::Module(n);
} }
...@@ -134,7 +131,7 @@ class DeviceSourceModuleNode final : public runtime::ModuleNode { ...@@ -134,7 +131,7 @@ class DeviceSourceModuleNode final : public runtime::ModuleNode {
PackedFunc GetFunction( PackedFunc GetFunction(
const std::string& name, const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final { const ObjectPtr<Object>& sptr_to_self) final {
LOG(FATAL) << "Source module cannot execute, to get executable module" LOG(FATAL) << "Source module cannot execute, to get executable module"
<< " build TVM with \'" << fmt_ << "\' runtime support"; << " build TVM with \'" << fmt_ << "\' runtime support";
return PackedFunc(); return PackedFunc();
...@@ -182,8 +179,7 @@ runtime::Module DeviceSourceModuleCreate( ...@@ -182,8 +179,7 @@ runtime::Module DeviceSourceModuleCreate(
std::unordered_map<std::string, FunctionInfo> fmap, std::unordered_map<std::string, FunctionInfo> fmap,
std::string type_key, std::string type_key,
std::function<std::string(const std::string&)> fget_source) { std::function<std::string(const std::string&)> fget_source) {
std::shared_ptr<DeviceSourceModuleNode> n = auto n = make_object<DeviceSourceModuleNode>(data, fmt, fmap, type_key, fget_source);
std::make_shared<DeviceSourceModuleNode>(data, fmt, fmap, type_key, fget_source);
return runtime::Module(n); return runtime::Module(n);
} }
......
...@@ -115,7 +115,7 @@ class RelayBuildModule : public runtime::ModuleNode { ...@@ -115,7 +115,7 @@ class RelayBuildModule : public runtime::ModuleNode {
* \return The corresponding member function. * \return The corresponding member function.
*/ */
PackedFunc GetFunction(const std::string& name, PackedFunc GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final { const ObjectPtr<Object>& sptr_to_self) final {
if (name == "get_graph_json") { if (name == "get_graph_json") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->GetGraphJSON(); *rv = this->GetGraphJSON();
...@@ -489,7 +489,7 @@ class RelayBuildModule : public runtime::ModuleNode { ...@@ -489,7 +489,7 @@ class RelayBuildModule : public runtime::ModuleNode {
}; };
runtime::Module RelayBuildCreate() { runtime::Module RelayBuildCreate() {
std::shared_ptr<RelayBuildModule> exec = std::make_shared<RelayBuildModule>(); auto exec = make_object<RelayBuildModule>();
return runtime::Module(exec); return runtime::Module(exec);
} }
......
...@@ -593,7 +593,7 @@ class GraphRuntimeCodegenModule : public runtime::ModuleNode { ...@@ -593,7 +593,7 @@ class GraphRuntimeCodegenModule : public runtime::ModuleNode {
public: public:
GraphRuntimeCodegenModule() {} GraphRuntimeCodegenModule() {}
virtual PackedFunc GetFunction(const std::string& name, virtual PackedFunc GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) { const ObjectPtr<Object>& sptr_to_self) {
if (name == "init") { if (name == "init") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.num_args, 2) CHECK_EQ(args.num_args, 2)
...@@ -654,8 +654,7 @@ class GraphRuntimeCodegenModule : public runtime::ModuleNode { ...@@ -654,8 +654,7 @@ class GraphRuntimeCodegenModule : public runtime::ModuleNode {
}; };
runtime::Module CreateGraphCodegenMod() { runtime::Module CreateGraphCodegenMod() {
std::shared_ptr<GraphRuntimeCodegenModule> ptr = auto ptr = make_object<GraphRuntimeCodegenModule>();
std::make_shared<GraphRuntimeCodegenModule>();
return runtime::Module(ptr); return runtime::Module(ptr);
} }
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
*/ */
/*! /*!
* Copyright (c) 2019 by Contributors
* \file src/relay/backend/vm/compiler.cc * \file src/relay/backend/vm/compiler.cc
* \brief A compiler from relay::Module to the VM byte code. * \brief A compiler from relay::Module to the VM byte code.
*/ */
...@@ -745,7 +744,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> { ...@@ -745,7 +744,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
PackedFunc VMCompiler::GetFunction(const std::string& name, PackedFunc VMCompiler::GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) { const ObjectPtr<Object>& sptr_to_self) {
if (name == "compile") { if (name == "compile") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.num_args, 3); CHECK_EQ(args.num_args, 3);
...@@ -974,7 +973,7 @@ void VMCompiler::LibraryCodegen() { ...@@ -974,7 +973,7 @@ void VMCompiler::LibraryCodegen() {
} }
runtime::Module CreateVMCompiler() { runtime::Module CreateVMCompiler() {
std::shared_ptr<VMCompiler> exec = std::make_shared<VMCompiler>(); auto exec = make_object<VMCompiler>();
return runtime::Module(exec); return runtime::Module(exec);
} }
......
...@@ -86,14 +86,14 @@ class VMCompiler : public runtime::ModuleNode { ...@@ -86,14 +86,14 @@ class VMCompiler : public runtime::ModuleNode {
virtual ~VMCompiler() {} virtual ~VMCompiler() {}
virtual PackedFunc GetFunction(const std::string& name, virtual PackedFunc GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self); const ObjectPtr<Object>& sptr_to_self);
const char* type_key() const { const char* type_key() const {
return "VMCompiler"; return "VMCompiler";
} }
void InitVM() { void InitVM() {
exec_ = std::make_shared<Executable>(); exec_ = make_object<Executable>();
} }
/*! /*!
...@@ -141,7 +141,7 @@ class VMCompiler : public runtime::ModuleNode { ...@@ -141,7 +141,7 @@ class VMCompiler : public runtime::ModuleNode {
/*! \brief Global shared meta data */ /*! \brief Global shared meta data */
VMCompilerContext context_; VMCompilerContext context_;
/*! \brief Compiled executable. */ /*! \brief Compiled executable. */
std::shared_ptr<Executable> exec_; ObjectPtr<Executable> exec_;
/*! \brief parameters */ /*! \brief parameters */
std::unordered_map<std::string, runtime::NDArray> params_; std::unordered_map<std::string, runtime::NDArray> params_;
}; };
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
*/ */
/*! /*!
* Copyright (c) 2019 by Contributors
* \file src/relay/backend/vm/profiler/compiler.cc * \file src/relay/backend/vm/profiler/compiler.cc
* \brief A compiler from relay::Module to the VM byte code. * \brief A compiler from relay::Module to the VM byte code.
*/ */
...@@ -37,7 +36,7 @@ class VMCompilerDebug : public VMCompiler { ...@@ -37,7 +36,7 @@ class VMCompilerDebug : public VMCompiler {
}; };
runtime::Module CreateVMCompilerDebug() { runtime::Module CreateVMCompilerDebug() {
std::shared_ptr<VMCompilerDebug> exec = std::make_shared<VMCompilerDebug>(); auto exec = make_object<VMCompilerDebug>();
return runtime::Module(exec); return runtime::Module(exec);
} }
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
*/ */
/*! /*!
* Copyright (c) 2016 by Contributors
* \file c_runtime_api.cc * \file c_runtime_api.cc
* \brief Device specific implementations * \brief Device specific implementations
*/ */
...@@ -41,6 +40,7 @@ ...@@ -41,6 +40,7 @@
#include <cstdlib> #include <cstdlib>
#include <cctype> #include <cctype>
#include "runtime_base.h" #include "runtime_base.h"
#include "object_internal.h"
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
...@@ -370,16 +370,20 @@ int TVMModLoadFromFile(const char* file_name, ...@@ -370,16 +370,20 @@ int TVMModLoadFromFile(const char* file_name,
const char* format, const char* format,
TVMModuleHandle* out) { TVMModuleHandle* out) {
API_BEGIN(); API_BEGIN();
Module m = Module::LoadFromFile(file_name, format); TVMRetValue ret;
*out = new Module(m); ret = Module::LoadFromFile(file_name, format);
TVMValue val;
int type_code;
ret.MoveToCHost(&val, &type_code);
*out = val.v_handle;
API_END(); API_END();
} }
int TVMModImport(TVMModuleHandle mod, int TVMModImport(TVMModuleHandle mod,
TVMModuleHandle dep) { TVMModuleHandle dep) {
API_BEGIN(); API_BEGIN();
static_cast<Module*>(mod)->Import( ObjectInternal::GetModuleNode(mod)->Import(
*static_cast<Module*>(dep)); GetRef<Module>(ObjectInternal::GetModuleNode(dep)));
API_END(); API_END();
} }
...@@ -388,7 +392,7 @@ int TVMModGetFunction(TVMModuleHandle mod, ...@@ -388,7 +392,7 @@ int TVMModGetFunction(TVMModuleHandle mod,
int query_imports, int query_imports,
TVMFunctionHandle *func) { TVMFunctionHandle *func) {
API_BEGIN(); API_BEGIN();
PackedFunc pf = static_cast<Module*>(mod)->GetFunction( PackedFunc pf = ObjectInternal::GetModuleNode(mod)->GetFunction(
func_name, query_imports != 0); func_name, query_imports != 0);
if (pf != nullptr) { if (pf != nullptr) {
*func = new PackedFunc(pf); *func = new PackedFunc(pf);
...@@ -399,9 +403,7 @@ int TVMModGetFunction(TVMModuleHandle mod, ...@@ -399,9 +403,7 @@ int TVMModGetFunction(TVMModuleHandle mod,
} }
int TVMModFree(TVMModuleHandle mod) { int TVMModFree(TVMModuleHandle mod) {
API_BEGIN(); return TVMObjectFree(mod);
delete static_cast<Module*>(mod);
API_END();
} }
int TVMBackendGetFuncFromEnv(void* mod_node, int TVMBackendGetFuncFromEnv(void* mod_node,
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -69,7 +69,7 @@ class CUDAModuleNode : public runtime::ModuleNode { ...@@ -69,7 +69,7 @@ class CUDAModuleNode : public runtime::ModuleNode {
PackedFunc GetFunction( PackedFunc GetFunction(
const std::string& name, const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final; const ObjectPtr<Object>& sptr_to_self) final;
void SaveToFile(const std::string& file_name, void SaveToFile(const std::string& file_name,
const std::string& format) final { const std::string& format) final {
...@@ -166,7 +166,7 @@ class CUDAWrappedFunc { ...@@ -166,7 +166,7 @@ class CUDAWrappedFunc {
public: public:
// initialize the CUDA function. // initialize the CUDA function.
void Init(CUDAModuleNode* m, void Init(CUDAModuleNode* m,
std::shared_ptr<ModuleNode> sptr, ObjectPtr<Object> sptr,
const std::string& func_name, const std::string& func_name,
size_t num_void_args, size_t num_void_args,
const std::vector<std::string>& thread_axis_tags) { const std::vector<std::string>& thread_axis_tags) {
...@@ -220,7 +220,7 @@ class CUDAWrappedFunc { ...@@ -220,7 +220,7 @@ class CUDAWrappedFunc {
// internal module // internal module
CUDAModuleNode* m_; CUDAModuleNode* m_;
// the resource holder // the resource holder
std::shared_ptr<ModuleNode> sptr_; ObjectPtr<Object> sptr_;
// The name of the function. // The name of the function.
std::string func_name_; std::string func_name_;
// Device function cache per device. // Device function cache per device.
...@@ -233,7 +233,7 @@ class CUDAWrappedFunc { ...@@ -233,7 +233,7 @@ class CUDAWrappedFunc {
class CUDAPrepGlobalBarrier { class CUDAPrepGlobalBarrier {
public: public:
CUDAPrepGlobalBarrier(CUDAModuleNode* m, CUDAPrepGlobalBarrier(CUDAModuleNode* m,
std::shared_ptr<ModuleNode> sptr) ObjectPtr<Object> sptr)
: m_(m), sptr_(sptr) { : m_(m), sptr_(sptr) {
std::fill(pcache_.begin(), pcache_.end(), 0); std::fill(pcache_.begin(), pcache_.end(), 0);
} }
...@@ -252,14 +252,14 @@ class CUDAPrepGlobalBarrier { ...@@ -252,14 +252,14 @@ class CUDAPrepGlobalBarrier {
// internal module // internal module
CUDAModuleNode* m_; CUDAModuleNode* m_;
// the resource holder // the resource holder
std::shared_ptr<ModuleNode> sptr_; ObjectPtr<Object> sptr_;
// mark as mutable, to enable lazy initialization // mark as mutable, to enable lazy initialization
mutable std::array<CUdeviceptr, kMaxNumGPUs> pcache_; mutable std::array<CUdeviceptr, kMaxNumGPUs> pcache_;
}; };
PackedFunc CUDAModuleNode::GetFunction( PackedFunc CUDAModuleNode::GetFunction(
const std::string& name, const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) { const ObjectPtr<Object>& sptr_to_self) {
CHECK_EQ(sptr_to_self.get(), this); CHECK_EQ(sptr_to_self.get(), this);
CHECK_NE(name, symbol::tvm_module_main) CHECK_NE(name, symbol::tvm_module_main)
<< "Device function do not have main"; << "Device function do not have main";
...@@ -279,8 +279,7 @@ Module CUDAModuleCreate( ...@@ -279,8 +279,7 @@ Module CUDAModuleCreate(
std::string fmt, std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap, std::unordered_map<std::string, FunctionInfo> fmap,
std::string cuda_source) { std::string cuda_source) {
std::shared_ptr<CUDAModuleNode> n = auto n = make_object<CUDAModuleNode>(data, fmt, fmap, cuda_source);
std::make_shared<CUDAModuleNode>(data, fmt, fmap, cuda_source);
return Module(n); return Module(n);
} }
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
*/ */
/*! /*!
* Copyright (c) 2017 by Contributors
* \file cuda_module.h * \file cuda_module.h
* \brief Execution handling of CUDA kernels * \brief Execution handling of CUDA kernels
*/ */
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -18,11 +18,11 @@ ...@@ -18,11 +18,11 @@
*/ */
/*! /*!
* Copyright (c) 2017 by Contributors
* \file dso_dll_module.cc * \file dso_dll_module.cc
* \brief Module to load from dynamic shared library. * \brief Module to load from dynamic shared library.
*/ */
#include <tvm/runtime/module.h> #include <tvm/runtime/module.h>
#include <tvm/runtime/memory.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
#include "module_util.h" #include "module_util.h"
...@@ -50,7 +50,7 @@ class DSOModuleNode final : public ModuleNode { ...@@ -50,7 +50,7 @@ class DSOModuleNode final : public ModuleNode {
PackedFunc GetFunction( PackedFunc GetFunction(
const std::string& name, const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final { const ObjectPtr<Object>& sptr_to_self) final {
BackendPackedCFunc faddr; BackendPackedCFunc faddr;
if (name == runtime::symbol::tvm_module_main) { if (name == runtime::symbol::tvm_module_main) {
const char* entry_name = reinterpret_cast<const char*>( const char* entry_name = reinterpret_cast<const char*>(
...@@ -124,7 +124,7 @@ class DSOModuleNode final : public ModuleNode { ...@@ -124,7 +124,7 @@ class DSOModuleNode final : public ModuleNode {
TVM_REGISTER_GLOBAL("module.loadfile_so") TVM_REGISTER_GLOBAL("module.loadfile_so")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
std::shared_ptr<DSOModuleNode> n = std::make_shared<DSOModuleNode>(); auto n = make_object<DSOModuleNode>();
n->Init(args[0]); n->Init(args[0]);
*rv = runtime::Module(n); *rv = runtime::Module(n);
}); });
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
*/ */
/*! /*!
* Copyright (c) 2018 by Contributors
* \file graph_runtime_debug.cc * \file graph_runtime_debug.cc
*/ */
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
...@@ -28,6 +27,7 @@ ...@@ -28,6 +27,7 @@
#include <chrono> #include <chrono>
#include <sstream> #include <sstream>
#include "../graph_runtime.h" #include "../graph_runtime.h"
#include "../../object_internal.h"
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
...@@ -121,7 +121,7 @@ class GraphRuntimeDebug : public GraphRuntime { ...@@ -121,7 +121,7 @@ class GraphRuntimeDebug : public GraphRuntime {
* \param sptr_to_self Packed function pointer. * \param sptr_to_self Packed function pointer.
*/ */
PackedFunc GetFunction(const std::string& name, PackedFunc GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self); const ObjectPtr<Object>& sptr_to_self);
/*! /*!
* \brief Get the node index given the name of node. * \brief Get the node index given the name of node.
...@@ -169,7 +169,7 @@ void DebugGetNodeOutput(int index, DLTensor* data_out) { ...@@ -169,7 +169,7 @@ void DebugGetNodeOutput(int index, DLTensor* data_out) {
*/ */
PackedFunc GraphRuntimeDebug::GetFunction( PackedFunc GraphRuntimeDebug::GetFunction(
const std::string& name, const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) { const ObjectPtr<Object>& sptr_to_self) {
// return member functions during query. // return member functions during query.
if (name == "get_output_by_layer") { if (name == "get_output_by_layer") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
...@@ -207,7 +207,7 @@ PackedFunc GraphRuntimeDebug::GetFunction( ...@@ -207,7 +207,7 @@ PackedFunc GraphRuntimeDebug::GetFunction(
Module GraphRuntimeDebugCreate(const std::string& sym_json, Module GraphRuntimeDebugCreate(const std::string& sym_json,
const tvm::runtime::Module& m, const tvm::runtime::Module& m,
const std::vector<TVMContext>& ctxs) { const std::vector<TVMContext>& ctxs) {
std::shared_ptr<GraphRuntimeDebug> exec = std::make_shared<GraphRuntimeDebug>(); auto exec = make_object<GraphRuntimeDebug>();
exec->Init(sym_json, m, ctxs); exec->Init(sym_json, m, ctxs);
return Module(exec); return Module(exec);
} }
...@@ -222,15 +222,16 @@ TVM_REGISTER_GLOBAL("tvm.graph_runtime_debug.create") ...@@ -222,15 +222,16 @@ TVM_REGISTER_GLOBAL("tvm.graph_runtime_debug.create")
}); });
TVM_REGISTER_GLOBAL("tvm.graph_runtime_debug.remote_create") TVM_REGISTER_GLOBAL("tvm.graph_runtime_debug.remote_create")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
CHECK_GE(args.num_args, 4) << "The expected number of arguments for " CHECK_GE(args.num_args, 4) << "The expected number of arguments for "
"graph_runtime.remote_create is " "graph_runtime.remote_create is "
"at least 4, but it has " "at least 4, but it has "
<< args.num_args; << args.num_args;
void* mhandle = args[1]; void* mhandle = args[1];
ModuleNode* mnode = ObjectInternal::GetModuleNode(mhandle);
const auto& contexts = GetAllContext(args); const auto& contexts = GetAllContext(args);
*rv = GraphRuntimeDebugCreate( *rv = GraphRuntimeDebugCreate(
args[0], *static_cast<tvm::runtime::Module*>(mhandle), contexts); args[0], GetRef<Module>(mnode), contexts);
}); });
} // namespace runtime } // namespace runtime
......
...@@ -18,11 +18,8 @@ ...@@ -18,11 +18,8 @@
*/ */
/*! /*!
* Copyright (c) 2017 by Contributors
* \file graph_runtime.cc * \file graph_runtime.cc
*/ */
#include "graph_runtime.h"
#include <tvm/runtime/device_api.h> #include <tvm/runtime/device_api.h>
#include <tvm/runtime/ndarray.h> #include <tvm/runtime/ndarray.h>
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
...@@ -38,6 +35,9 @@ ...@@ -38,6 +35,9 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "graph_runtime.h"
#include "../object_internal.h"
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
namespace details { namespace details {
...@@ -411,7 +411,7 @@ std::pair<std::function<void()>, std::shared_ptr<GraphRuntime::OpArgs> > GraphRu ...@@ -411,7 +411,7 @@ std::pair<std::function<void()>, std::shared_ptr<GraphRuntime::OpArgs> > GraphRu
PackedFunc GraphRuntime::GetFunction( PackedFunc GraphRuntime::GetFunction(
const std::string& name, const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) { const ObjectPtr<Object>& sptr_to_self) {
// Return member functions during query. // Return member functions during query.
if (name == "set_input") { if (name == "set_input") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
...@@ -478,7 +478,7 @@ PackedFunc GraphRuntime::GetFunction( ...@@ -478,7 +478,7 @@ PackedFunc GraphRuntime::GetFunction(
Module GraphRuntimeCreate(const std::string& sym_json, Module GraphRuntimeCreate(const std::string& sym_json,
const tvm::runtime::Module& m, const tvm::runtime::Module& m,
const std::vector<TVMContext>& ctxs) { const std::vector<TVMContext>& ctxs) {
std::shared_ptr<GraphRuntime> exec = std::make_shared<GraphRuntime>(); auto exec = make_object<GraphRuntime>();
exec->Init(sym_json, m, ctxs); exec->Init(sym_json, m, ctxs);
return Module(exec); return Module(exec);
} }
...@@ -513,15 +513,17 @@ TVM_REGISTER_GLOBAL("tvm.graph_runtime.create") ...@@ -513,15 +513,17 @@ TVM_REGISTER_GLOBAL("tvm.graph_runtime.create")
}); });
TVM_REGISTER_GLOBAL("tvm.graph_runtime.remote_create") TVM_REGISTER_GLOBAL("tvm.graph_runtime.remote_create")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
CHECK_GE(args.num_args, 4) << "The expected number of arguments for " CHECK_GE(args.num_args, 4) << "The expected number of arguments for "
"graph_runtime.remote_create is " "graph_runtime.remote_create is "
"at least 4, but it has " "at least 4, but it has "
<< args.num_args; << args.num_args;
void* mhandle = args[1]; void* mhandle = args[1];
ModuleNode* mnode = ObjectInternal::GetModuleNode(mhandle);
const auto& contexts = GetAllContext(args); const auto& contexts = GetAllContext(args);
*rv = GraphRuntimeCreate( *rv = GraphRuntimeCreate(
args[0], *static_cast<tvm::runtime::Module*>(mhandle), contexts); args[0], GetRef<Module>(mnode), contexts);
}); });
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
...@@ -18,8 +18,6 @@ ...@@ -18,8 +18,6 @@
*/ */
/*! /*!
* Copyright (c) 2017 by Contributors
*
* \brief Tiny graph runtime that can run graph * \brief Tiny graph runtime that can run graph
* containing only tvm PackedFunc. * containing only tvm PackedFunc.
* \file graph_runtime.h * \file graph_runtime.h
...@@ -83,7 +81,7 @@ class GraphRuntime : public ModuleNode { ...@@ -83,7 +81,7 @@ class GraphRuntime : public ModuleNode {
* \return The corresponding member function. * \return The corresponding member function.
*/ */
virtual PackedFunc GetFunction(const std::string& name, virtual PackedFunc GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self); const ObjectPtr<Object>& sptr_to_self);
/*! /*!
* \return The type key of the executor. * \return The type key of the executor.
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
*/ */
/*! /*!
* Copyright (c) 2017 by Contributors
* \file metal_module.cc * \file metal_module.cc
*/ */
#include <dmlc/memory_io.h> #include <dmlc/memory_io.h>
...@@ -54,7 +53,7 @@ class MetalModuleNode final :public runtime::ModuleNode { ...@@ -54,7 +53,7 @@ class MetalModuleNode final :public runtime::ModuleNode {
PackedFunc GetFunction( PackedFunc GetFunction(
const std::string& name, const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final; const ObjectPtr<Object>& sptr_to_self) final;
void SaveToFile(const std::string& file_name, void SaveToFile(const std::string& file_name,
const std::string& format) final { const std::string& format) final {
...@@ -187,7 +186,7 @@ class MetalWrappedFunc { ...@@ -187,7 +186,7 @@ class MetalWrappedFunc {
public: public:
// initialize the METAL function. // initialize the METAL function.
void Init(MetalModuleNode* m, void Init(MetalModuleNode* m,
std::shared_ptr<ModuleNode> sptr, ObjectPtr<Object> sptr,
const std::string& func_name, const std::string& func_name,
size_t num_buffer_args, size_t num_buffer_args,
size_t num_pack_args, size_t num_pack_args,
...@@ -244,7 +243,7 @@ class MetalWrappedFunc { ...@@ -244,7 +243,7 @@ class MetalWrappedFunc {
// internal module // internal module
MetalModuleNode* m_; MetalModuleNode* m_;
// the resource holder // the resource holder
std::shared_ptr<ModuleNode> sptr_; ObjectPtr<Object> sptr_;
// The name of the function. // The name of the function.
std::string func_name_; std::string func_name_;
// Number of buffer arguments // Number of buffer arguments
...@@ -260,7 +259,7 @@ class MetalWrappedFunc { ...@@ -260,7 +259,7 @@ class MetalWrappedFunc {
PackedFunc MetalModuleNode::GetFunction( PackedFunc MetalModuleNode::GetFunction(
const std::string& name, const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) { const ObjectPtr<Object>& sptr_to_self) {
CHECK_EQ(sptr_to_self.get(), this); CHECK_EQ(sptr_to_self.get(), this);
CHECK_NE(name, symbol::tvm_module_main) CHECK_NE(name, symbol::tvm_module_main)
<< "Device function do not have main"; << "Device function do not have main";
...@@ -281,8 +280,7 @@ Module MetalModuleCreate( ...@@ -281,8 +280,7 @@ Module MetalModuleCreate(
std::unordered_map<std::string, FunctionInfo> fmap, std::unordered_map<std::string, FunctionInfo> fmap,
std::string source) { std::string source) {
metal::MetalWorkspace::Global()->Init(); metal::MetalWorkspace::Global()->Init();
std::shared_ptr<MetalModuleNode> n = auto n = make_object<MetalModuleNode>(data, fmt, fmap, source);
std::make_shared<MetalModuleNode>(data, fmt, fmap, source);
return Module(n); return Module(n);
} }
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
*/ */
/*! /*!
* Copyright (c) 2019 by Contributors
* \file micro_device_api.cc * \file micro_device_api.cc
*/ */
...@@ -50,7 +49,7 @@ class MicroDeviceAPI final : public DeviceAPI { ...@@ -50,7 +49,7 @@ class MicroDeviceAPI final : public DeviceAPI {
size_t nbytes, size_t nbytes,
size_t alignment, size_t alignment,
TVMType type_hint) final { TVMType type_hint) final {
std::shared_ptr<MicroSession>& session = MicroSession::Current(); ObjectPtr<MicroSession>& session = MicroSession::Current();
void* data = session->AllocateInSection(SectionKind::kHeap, nbytes).cast_to<void*>(); void* data = session->AllocateInSection(SectionKind::kHeap, nbytes).cast_to<void*>();
CHECK(data != nullptr) << "unable to allocate " << nbytes << " bytes on device heap"; CHECK(data != nullptr) << "unable to allocate " << nbytes << " bytes on device heap";
MicroDevSpace* dev_space = new MicroDevSpace(); MicroDevSpace* dev_space = new MicroDevSpace();
...@@ -82,11 +81,12 @@ class MicroDeviceAPI final : public DeviceAPI { ...@@ -82,11 +81,12 @@ class MicroDeviceAPI final : public DeviceAPI {
MicroDevSpace* from_space = static_cast<MicroDevSpace*>(const_cast<void*>(from)); MicroDevSpace* from_space = static_cast<MicroDevSpace*>(const_cast<void*>(from));
MicroDevSpace* to_space = static_cast<MicroDevSpace*>(const_cast<void*>(to)); MicroDevSpace* to_space = static_cast<MicroDevSpace*>(const_cast<void*>(to));
CHECK(from_space->session == to_space->session) CHECK(from_space->session == to_space->session)
<< "attempt to copy data between different micro sessions (" << from_space->session << "attempt to copy data between different micro sessions ("
<< " != " << to_space->session << ")"; << from_space->session.get()
<< " != " << to_space->session.get() << ")";
CHECK(ctx_from.device_id == ctx_to.device_id) CHECK(ctx_from.device_id == ctx_to.device_id)
<< "can only copy between the same micro device"; << "can only copy between the same micro device";
std::shared_ptr<MicroSession>& session = from_space->session; ObjectPtr<MicroSession>& session = from_space->session;
const std::shared_ptr<LowLevelDevice>& lld = session->low_level_device(); const std::shared_ptr<LowLevelDevice>& lld = session->low_level_device();
DevBaseOffset from_dev_offset = GetDevLoc(from_space, from_offset); DevBaseOffset from_dev_offset = GetDevLoc(from_space, from_offset);
...@@ -99,7 +99,7 @@ class MicroDeviceAPI final : public DeviceAPI { ...@@ -99,7 +99,7 @@ class MicroDeviceAPI final : public DeviceAPI {
// Reading from the device. // Reading from the device.
MicroDevSpace* from_space = static_cast<MicroDevSpace*>(const_cast<void*>(from)); MicroDevSpace* from_space = static_cast<MicroDevSpace*>(const_cast<void*>(from));
std::shared_ptr<MicroSession>& session = from_space->session; ObjectPtr<MicroSession>& session = from_space->session;
const std::shared_ptr<LowLevelDevice>& lld = session->low_level_device(); const std::shared_ptr<LowLevelDevice>& lld = session->low_level_device();
DevBaseOffset from_dev_offset = GetDevLoc(from_space, from_offset); DevBaseOffset from_dev_offset = GetDevLoc(from_space, from_offset);
...@@ -109,7 +109,7 @@ class MicroDeviceAPI final : public DeviceAPI { ...@@ -109,7 +109,7 @@ class MicroDeviceAPI final : public DeviceAPI {
// Writing to the device. // Writing to the device.
MicroDevSpace* to_space = static_cast<MicroDevSpace*>(const_cast<void*>(to)); MicroDevSpace* to_space = static_cast<MicroDevSpace*>(const_cast<void*>(to));
std::shared_ptr<MicroSession>& session = to_space->session; ObjectPtr<MicroSession>& session = to_space->session;
const std::shared_ptr<LowLevelDevice>& lld = session->low_level_device(); const std::shared_ptr<LowLevelDevice>& lld = session->low_level_device();
void* from_host_ptr = GetHostLoc(from, from_offset); void* from_host_ptr = GetHostLoc(from, from_offset);
...@@ -124,7 +124,7 @@ class MicroDeviceAPI final : public DeviceAPI { ...@@ -124,7 +124,7 @@ class MicroDeviceAPI final : public DeviceAPI {
} }
void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final { void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final {
std::shared_ptr<MicroSession>& session = MicroSession::Current(); ObjectPtr<MicroSession>& session = MicroSession::Current();
void* data = session->AllocateInSection(SectionKind::kWorkspace, size).cast_to<void*>(); void* data = session->AllocateInSection(SectionKind::kWorkspace, size).cast_to<void*>();
CHECK(data != nullptr) << "unable to allocate " << size << " bytes on device workspace"; CHECK(data != nullptr) << "unable to allocate " << size << " bytes on device workspace";
...@@ -136,7 +136,7 @@ class MicroDeviceAPI final : public DeviceAPI { ...@@ -136,7 +136,7 @@ class MicroDeviceAPI final : public DeviceAPI {
void FreeWorkspace(TVMContext ctx, void* data) final { void FreeWorkspace(TVMContext ctx, void* data) final {
MicroDevSpace* dev_space = static_cast<MicroDevSpace*>(data); MicroDevSpace* dev_space = static_cast<MicroDevSpace*>(data);
std::shared_ptr<MicroSession>& session = dev_space->session; ObjectPtr<MicroSession>& session = dev_space->session;
session->FreeInSection(SectionKind::kWorkspace, session->FreeInSection(SectionKind::kWorkspace,
DevBaseOffset(reinterpret_cast<std::uintptr_t>(dev_space->data))); DevBaseOffset(reinterpret_cast<std::uintptr_t>(dev_space->data)));
delete dev_space; delete dev_space;
......
...@@ -18,9 +18,8 @@ ...@@ -18,9 +18,8 @@
*/ */
/*! /*!
* Copyright (c) 2019 by Contributors * \file micro_module.cc
* \file micro_module.cc */
*/
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/runtime/c_runtime_api.h> #include <tvm/runtime/c_runtime_api.h>
...@@ -48,7 +47,7 @@ class MicroModuleNode final : public ModuleNode { ...@@ -48,7 +47,7 @@ class MicroModuleNode final : public ModuleNode {
} }
PackedFunc GetFunction(const std::string& name, PackedFunc GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final; const ObjectPtr<Object>& sptr_to_self) final;
/*! /*!
* \brief initializes module by establishing device connection and loads binary * \brief initializes module by establishing device connection and loads binary
...@@ -76,13 +75,13 @@ class MicroModuleNode final : public ModuleNode { ...@@ -76,13 +75,13 @@ class MicroModuleNode final : public ModuleNode {
/*! \brief path to module binary */ /*! \brief path to module binary */
std::string binary_path_; std::string binary_path_;
/*! \brief global session pointer */ /*! \brief global session pointer */
std::shared_ptr<MicroSession> session_; ObjectPtr<MicroSession> session_;
}; };
class MicroWrappedFunc { class MicroWrappedFunc {
public: public:
MicroWrappedFunc(MicroModuleNode* m, MicroWrappedFunc(MicroModuleNode* m,
std::shared_ptr<MicroSession> session, ObjectPtr<MicroSession> session,
const std::string& func_name, const std::string& func_name,
DevBaseOffset func_offset) { DevBaseOffset func_offset) {
m_ = m; m_ = m;
...@@ -99,7 +98,7 @@ class MicroWrappedFunc { ...@@ -99,7 +98,7 @@ class MicroWrappedFunc {
/*! \brief internal module */ /*! \brief internal module */
MicroModuleNode* m_; MicroModuleNode* m_;
/*! \brief reference to the session for this function (to keep the session alive) */ /*! \brief reference to the session for this function (to keep the session alive) */
std::shared_ptr<MicroSession> session_; ObjectPtr<MicroSession> session_;
/*! \brief name of the function */ /*! \brief name of the function */
std::string func_name_; std::string func_name_;
/*! \brief offset of the function to be called */ /*! \brief offset of the function to be called */
...@@ -108,7 +107,7 @@ class MicroWrappedFunc { ...@@ -108,7 +107,7 @@ class MicroWrappedFunc {
PackedFunc MicroModuleNode::GetFunction( PackedFunc MicroModuleNode::GetFunction(
const std::string& name, const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) { const ObjectPtr<Object>& sptr_to_self) {
DevBaseOffset func_offset = DevBaseOffset func_offset =
session_->low_level_device()->ToDevOffset(binary_info_.symbol_map[name]); session_->low_level_device()->ToDevOffset(binary_info_.symbol_map[name]);
MicroWrappedFunc f(this, session_, name, func_offset); MicroWrappedFunc f(this, session_, name, func_offset);
...@@ -118,9 +117,9 @@ PackedFunc MicroModuleNode::GetFunction( ...@@ -118,9 +117,9 @@ PackedFunc MicroModuleNode::GetFunction(
// register loadfile function to load module from Python frontend // register loadfile function to load module from Python frontend
TVM_REGISTER_GLOBAL("module.loadfile_micro_dev") TVM_REGISTER_GLOBAL("module.loadfile_micro_dev")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
std::shared_ptr<MicroModuleNode> n = std::make_shared<MicroModuleNode>(); auto n = make_object<MicroModuleNode>();
n->InitMicroModule(args[0]); n->InitMicroModule(args[0]);
*rv = runtime::Module(n); *rv = runtime::Module(n);
}); });
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
...@@ -18,13 +18,11 @@ ...@@ -18,13 +18,11 @@
*/ */
/*! /*!
* Copyright (c) 2019 by Contributors
* \file micro_session.cc * \file micro_session.cc
*/ */
#include <dmlc/thread_local.h> #include <dmlc/thread_local.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <memory>
#include <stack> #include <stack>
#include <tuple> #include <tuple>
#include <vector> #include <vector>
...@@ -36,18 +34,18 @@ namespace tvm { ...@@ -36,18 +34,18 @@ namespace tvm {
namespace runtime { namespace runtime {
struct TVMMicroSessionThreadLocalEntry { struct TVMMicroSessionThreadLocalEntry {
std::stack<std::shared_ptr<MicroSession>> session_stack; std::stack<ObjectPtr<MicroSession>> session_stack;
}; };
typedef dmlc::ThreadLocalStore<TVMMicroSessionThreadLocalEntry> TVMMicroSessionThreadLocalStore; typedef dmlc::ThreadLocalStore<TVMMicroSessionThreadLocalEntry> TVMMicroSessionThreadLocalStore;
std::shared_ptr<MicroSession>& MicroSession::Current() { ObjectPtr<MicroSession>& MicroSession::Current() {
TVMMicroSessionThreadLocalEntry *entry = TVMMicroSessionThreadLocalStore::Get(); TVMMicroSessionThreadLocalEntry *entry = TVMMicroSessionThreadLocalStore::Get();
CHECK_GT(entry->session_stack.size(), 0) << "No current session"; CHECK_GT(entry->session_stack.size(), 0) << "No current session";
return entry->session_stack.top(); return entry->session_stack.top();
} }
void MicroSession::EnterWithScope(std::shared_ptr<MicroSession> session) { void MicroSession::EnterWithScope(ObjectPtr<MicroSession> session) {
TVMMicroSessionThreadLocalEntry *entry = TVMMicroSessionThreadLocalStore::Get(); TVMMicroSessionThreadLocalEntry *entry = TVMMicroSessionThreadLocalStore::Get();
entry->session_stack.push(session); entry->session_stack.push(session);
} }
...@@ -121,7 +119,7 @@ void MicroSession::CreateSession(const std::string& device_type, ...@@ -121,7 +119,7 @@ void MicroSession::CreateSession(const std::string& device_type,
void MicroSession::PushToExecQueue(DevBaseOffset func, const TVMArgs& args) { void MicroSession::PushToExecQueue(DevBaseOffset func, const TVMArgs& args) {
int32_t (*func_dev_addr)(void*, void*, int32_t) = int32_t (*func_dev_addr)(void*, void*, int32_t) =
reinterpret_cast<int32_t (*)(void*, void*, int32_t)>( reinterpret_cast<int32_t (*)(void*, void*, int32_t)>(
low_level_device()->ToDevPtr(func).value()); low_level_device()->ToDevPtr(func).value());
// Create an allocator stream for the memory region after the most recent // Create an allocator stream for the memory region after the most recent
// allocation in the args section. // allocation in the args section.
...@@ -355,10 +353,10 @@ void MicroSession::DevSymbolWrite(const SymbolMap& symbol_map, ...@@ -355,10 +353,10 @@ void MicroSession::DevSymbolWrite(const SymbolMap& symbol_map,
PackedFunc MicroSession::GetFunction( PackedFunc MicroSession::GetFunction(
const std::string& name, const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) { const ObjectPtr<Object>& sptr_to_self) {
if (name == "enter") { if (name == "enter") {
return PackedFunc([sptr_to_self](TVMArgs args, TVMRetValue* rv) { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
MicroSession::EnterWithScope(std::dynamic_pointer_cast<MicroSession>(sptr_to_self)); MicroSession::EnterWithScope(GetObjectPtr<MicroSession>(this));
}); });
} else if (name == "exit") { } else if (name == "exit") {
return PackedFunc([sptr_to_self](TVMArgs args, TVMRetValue* rv) { return PackedFunc([sptr_to_self](TVMArgs args, TVMRetValue* rv) {
...@@ -378,7 +376,7 @@ TVM_REGISTER_GLOBAL("micro._CreateSession") ...@@ -378,7 +376,7 @@ TVM_REGISTER_GLOBAL("micro._CreateSession")
uint64_t base_addr = args[3]; uint64_t base_addr = args[3];
const std::string& server_addr = args[4]; const std::string& server_addr = args[4];
int port = args[5]; int port = args[5];
std::shared_ptr<MicroSession> session = std::make_shared<MicroSession>(); ObjectPtr<MicroSession> session = make_object<MicroSession>();
session->CreateSession( session->CreateSession(
device_type, binary_path, toolchain_prefix, base_addr, server_addr, port); device_type, binary_path, toolchain_prefix, base_addr, server_addr, port);
*rv = Module(session); *rv = Module(session);
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
*/ */
/*! /*!
* Copyright (c) 2019 by Contributors
* \file micro_session.h * \file micro_session.h
* \brief session to manage multiple micro modules * \brief session to manage multiple micro modules
* *
...@@ -66,7 +65,7 @@ class MicroSession : public ModuleNode { ...@@ -66,7 +65,7 @@ class MicroSession : public ModuleNode {
* \return The corresponding member function. * \return The corresponding member function.
*/ */
virtual PackedFunc GetFunction(const std::string& name, virtual PackedFunc GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self); const ObjectPtr<Object>& sptr_to_self);
/*! /*!
* \return The type key of the executor. * \return The type key of the executor.
...@@ -85,7 +84,7 @@ class MicroSession : public ModuleNode { ...@@ -85,7 +84,7 @@ class MicroSession : public ModuleNode {
*/ */
~MicroSession(); ~MicroSession();
static std::shared_ptr<MicroSession>& Current(); static ObjectPtr<MicroSession>& Current();
/*! /*!
* \brief creates session by setting up a low-level device and initting allocators for it * \brief creates session by setting up a low-level device and initting allocators for it
...@@ -240,7 +239,7 @@ class MicroSession : public ModuleNode { ...@@ -240,7 +239,7 @@ class MicroSession : public ModuleNode {
* \brief Push a new session context onto the thread-local stack. * \brief Push a new session context onto the thread-local stack.
* The session on top of the stack is used as the current global session. * The session on top of the stack is used as the current global session.
*/ */
static void EnterWithScope(std::shared_ptr<MicroSession> session); static void EnterWithScope(ObjectPtr<MicroSession> session);
/*! /*!
* \brief Pop a session off the thread-local context stack, * \brief Pop a session off the thread-local context stack,
* restoring the previous session as the current context. * restoring the previous session as the current context.
...@@ -258,7 +257,7 @@ struct MicroDevSpace { ...@@ -258,7 +257,7 @@ struct MicroDevSpace {
/*! \brief data being wrapped */ /*! \brief data being wrapped */
void* data; void* data;
/*! \brief shared ptr to session where this data is valid */ /*! \brief shared ptr to session where this data is valid */
std::shared_ptr<MicroSession> session; ObjectPtr<MicroSession> session;
}; };
} // namespace runtime } // namespace runtime
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
*/ */
/*! /*!
* Copyright (c) 2019 by Contributors
* \file tcl_socket.h * \file tcl_socket.h
* \brief TCP socket wrapper for communicating using Tcl commands * \brief TCP socket wrapper for communicating using Tcl commands
*/ */
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
*/ */
/*! /*!
* Copyright (c) 2017 by Contributors
* \file module.cc * \file module.cc
* \brief TVM module system * \brief TVM module system
*/ */
...@@ -34,33 +33,46 @@ ...@@ -34,33 +33,46 @@
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
void Module::Import(Module other) { void ModuleNode::Import(Module other) {
// specially handle rpc // specially handle rpc
if (!std::strcmp((*this)->type_key(), "rpc")) { if (!std::strcmp(this->type_key(), "rpc")) {
static const PackedFunc* fimport_ = nullptr; static const PackedFunc* fimport_ = nullptr;
if (fimport_ == nullptr) { if (fimport_ == nullptr) {
fimport_ = runtime::Registry::Get("rpc._ImportRemoteModule"); fimport_ = runtime::Registry::Get("rpc._ImportRemoteModule");
CHECK(fimport_ != nullptr); CHECK(fimport_ != nullptr);
} }
(*fimport_)(*this, other); (*fimport_)(GetRef<Module>(this), other);
return; return;
} }
// cyclic detection. // cyclic detection.
std::unordered_set<const ModuleNode*> visited{other.node_.get()}; std::unordered_set<const ModuleNode*> visited{other.operator->()};
std::vector<const ModuleNode*> stack{other.node_.get()}; std::vector<const ModuleNode*> stack{other.operator->()};
while (!stack.empty()) { while (!stack.empty()) {
const ModuleNode* n = stack.back(); const ModuleNode* n = stack.back();
stack.pop_back(); stack.pop_back();
for (const Module& m : n->imports_) { for (const Module& m : n->imports_) {
const ModuleNode* next = m.node_.get(); const ModuleNode* next = m.operator->();
if (visited.count(next)) continue; if (visited.count(next)) continue;
visited.insert(next); visited.insert(next);
stack.push_back(next); stack.push_back(next);
} }
} }
CHECK(!visited.count(node_.get())) CHECK(!visited.count(this))
<< "Cyclic dependency detected during import"; << "Cyclic dependency detected during import";
node_->imports_.emplace_back(std::move(other)); this->imports_.emplace_back(std::move(other));
}
PackedFunc ModuleNode::GetFunction(const std::string& name, bool query_imports) {
ModuleNode* self = this;
PackedFunc pf = self->GetFunction(name, GetObjectPtr<Object>(this));
if (pf != nullptr) return pf;
if (query_imports) {
for (Module& m : self->imports_) {
pf = m->GetFunction(name, m.data_);
if (pf != nullptr) return pf;
}
}
return pf;
} }
Module Module::LoadFromFile(const std::string& file_name, Module Module::LoadFromFile(const std::string& file_name,
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
*/ */
/*! /*!
* Copyright (c) 2017 by Contributors
* \file module_util.cc * \file module_util.cc
* \brief Utilities for module. * \brief Utilities for module.
*/ */
...@@ -64,7 +63,7 @@ void ImportModuleBlob(const char* mblob, std::vector<Module>* mlist) { ...@@ -64,7 +63,7 @@ void ImportModuleBlob(const char* mblob, std::vector<Module>* mlist) {
} }
PackedFunc WrapPackedFunc(BackendPackedCFunc faddr, PackedFunc WrapPackedFunc(BackendPackedCFunc faddr,
const std::shared_ptr<ModuleNode>& sptr_to_self) { const ObjectPtr<Object>& sptr_to_self) {
return PackedFunc([faddr, sptr_to_self](TVMArgs args, TVMRetValue* rv) { return PackedFunc([faddr, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
int ret = (*faddr)( int ret = (*faddr)(
const_cast<TVMValue*>(args.values), const_cast<TVMValue*>(args.values),
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
*/ */
/*! /*!
* Copyright (c) 2017 by Contributors
* \file module_util.h * \file module_util.h
* \brief Helper utilities for module building * \brief Helper utilities for module building
*/ */
...@@ -45,7 +44,7 @@ namespace runtime { ...@@ -45,7 +44,7 @@ namespace runtime {
* \param faddr The function address * \param faddr The function address
* \param mptr The module pointer node. * \param mptr The module pointer node.
*/ */
PackedFunc WrapPackedFunc(BackendPackedCFunc faddr, const std::shared_ptr<ModuleNode>& mptr); PackedFunc WrapPackedFunc(BackendPackedCFunc faddr, const ObjectPtr<Object>& mptr);
/*! /*!
* \brief Load and append module blob to module list * \brief Load and append module blob to module list
* \param mblob The module blob. * \param mblob The module blob.
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <vector> #include <vector>
#include <utility> #include <utility>
#include <unordered_map> #include <unordered_map>
#include "object_internal.h"
#include "runtime_base.h" #include "runtime_base.h"
namespace tvm { namespace tvm {
...@@ -200,18 +201,6 @@ uint32_t Object::TypeKey2Index(const std::string& key) { ...@@ -200,18 +201,6 @@ uint32_t Object::TypeKey2Index(const std::string& key) {
return TypeContext::Global()->TypeKey2Index(key); return TypeContext::Global()->TypeKey2Index(key);
} }
class TVMObjectCAPI {
public:
static void Free(TVMObjectHandle obj) {
if (obj != nullptr) {
static_cast<Object*>(obj)->DecRef();
}
}
static uint32_t TypeKey2Index(const std::string& type_key) {
return Object::TypeKey2Index(type_key);
}
};
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
...@@ -224,13 +213,13 @@ int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex) { ...@@ -224,13 +213,13 @@ int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex) {
int TVMObjectFree(TVMObjectHandle obj) { int TVMObjectFree(TVMObjectHandle obj) {
API_BEGIN(); API_BEGIN();
tvm::runtime::TVMObjectCAPI::Free(obj); tvm::runtime::ObjectInternal::ObjectFree(obj);
API_END(); API_END();
} }
int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex) { int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex) {
API_BEGIN(); API_BEGIN();
out_tindex[0] = tvm::runtime::TVMObjectCAPI::TypeKey2Index( out_tindex[0] = tvm::runtime::ObjectInternal::ObjectTypeKey2Index(
type_key); type_key);
API_END(); API_END();
} }
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*
* \file src/runtime/object_internal.h
* \brief Expose a few functions for CFFI purposes.
* This file is not intended to be used
*/
#ifndef TVM_RUNTIME_OBJECT_INTERNAL_H_
#define TVM_RUNTIME_OBJECT_INTERNAL_H_
#include <tvm/runtime/object.h>
#include <tvm/runtime/module.h>
#include <string>
namespace tvm {
namespace runtime {
/*!
* \brief Internal object namespace to expose
* certain util functions for FFI.
*/
class ObjectInternal {
public:
/*!
* \brief Free an object handle.
*/
static void ObjectFree(TVMObjectHandle obj) {
if (obj != nullptr) {
static_cast<Object*>(obj)->DecRef();
}
}
/*!
* \brief Expose TypeKey2Index
* \param type_key The original type key.
* \return the corresponding index.
*/
static uint32_t ObjectTypeKey2Index(const std::string& type_key) {
return Object::TypeKey2Index(type_key);
}
/*!
* \brief Convert ModuleHandle to module node pointer.
* \param handle The module handle.
* \return the corresponding module node pointer.
*/
static ModuleNode* GetModuleNode(TVMModuleHandle handle) {
// NOTE: we will need to convert to Object
// then to ModuleNode in order to get the correct
// address translation
return static_cast<ModuleNode*>(static_cast<Object*>(handle));
}
};
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_OBJECT_INTERNAL_H_
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
*/ */
/*! /*!
* Copyright (c) 2018 by Contributors
* \file aocl_common.h * \file aocl_common.h
* \brief AOCL common header * \brief AOCL common header
*/ */
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
*/ */
/*! /*!
* Copyright (c) 2018 by Contributors
* \file aocl_device_api.cc * \file aocl_device_api.cc
*/ */
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
*/ */
/*! /*!
* Copyright (c) 2018 by Contributors
* \file aocl_module.h * \file aocl_module.h
* \brief Execution handling of OpenCL kernels for AOCL * \brief Execution handling of OpenCL kernels for AOCL
*/ */
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -278,7 +278,7 @@ class OpenCLModuleNode : public ModuleNode { ...@@ -278,7 +278,7 @@ class OpenCLModuleNode : public ModuleNode {
PackedFunc GetFunction( PackedFunc GetFunction(
const std::string& name, const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final; const ObjectPtr<Object>& sptr_to_self) final;
void SaveToFile(const std::string& file_name, void SaveToFile(const std::string& file_name,
const std::string& format) final; const std::string& format) final;
void SaveToBinary(dmlc::Stream* stream) final; void SaveToBinary(dmlc::Stream* stream) final;
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
*/ */
/*! /*!
* Copyright (c) 2017 by Contributors
* \file opencl_module.cc * \file opencl_module.cc
*/ */
#include <dmlc/memory_io.h> #include <dmlc/memory_io.h>
...@@ -36,7 +35,7 @@ class OpenCLWrappedFunc { ...@@ -36,7 +35,7 @@ class OpenCLWrappedFunc {
public: public:
// initialize the OpenCL function. // initialize the OpenCL function.
void Init(OpenCLModuleNode* m, void Init(OpenCLModuleNode* m,
std::shared_ptr<ModuleNode> sptr, ObjectPtr<Object> sptr,
OpenCLModuleNode::KTRefEntry entry, OpenCLModuleNode::KTRefEntry entry,
std::string func_name, std::string func_name,
std::vector<size_t> arg_size, std::vector<size_t> arg_size,
...@@ -88,7 +87,7 @@ class OpenCLWrappedFunc { ...@@ -88,7 +87,7 @@ class OpenCLWrappedFunc {
// The module // The module
OpenCLModuleNode* m_; OpenCLModuleNode* m_;
// resource handle // resource handle
std::shared_ptr<ModuleNode> sptr_; ObjectPtr<Object> sptr_;
// global kernel id in the kernel table. // global kernel id in the kernel table.
OpenCLModuleNode::KTRefEntry entry_; OpenCLModuleNode::KTRefEntry entry_;
// The name of the function. // The name of the function.
...@@ -122,7 +121,7 @@ const std::shared_ptr<cl::OpenCLWorkspace>& OpenCLModuleNode::GetGlobalWorkspace ...@@ -122,7 +121,7 @@ const std::shared_ptr<cl::OpenCLWorkspace>& OpenCLModuleNode::GetGlobalWorkspace
PackedFunc OpenCLModuleNode::GetFunction( PackedFunc OpenCLModuleNode::GetFunction(
const std::string& name, const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) { const ObjectPtr<Object>& sptr_to_self) {
CHECK_EQ(sptr_to_self.get(), this); CHECK_EQ(sptr_to_self.get(), this);
CHECK_NE(name, symbol::tvm_module_main) CHECK_NE(name, symbol::tvm_module_main)
<< "Device function do not have main"; << "Device function do not have main";
...@@ -251,8 +250,7 @@ Module OpenCLModuleCreate( ...@@ -251,8 +250,7 @@ Module OpenCLModuleCreate(
std::string fmt, std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap, std::unordered_map<std::string, FunctionInfo> fmap,
std::string source) { std::string source) {
std::shared_ptr<OpenCLModuleNode> n = auto n = make_object<OpenCLModuleNode>(data, fmt, fmap, source);
std::make_shared<OpenCLModuleNode>(data, fmt, fmap, source);
n->Init(); n->Init();
return Module(n); return Module(n);
} }
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
*/ */
/*! /*!
* Copyright (c) 2017 by Contributors
* \file opencl_module.h * \file opencl_module.h
* \brief Execution handling of OPENCL kernels * \brief Execution handling of OPENCL kernels
*/ */
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -44,7 +44,7 @@ class OpenGLModuleNode final : public ModuleNode { ...@@ -44,7 +44,7 @@ class OpenGLModuleNode final : public ModuleNode {
const char* type_key() const final { return "opengl"; } const char* type_key() const final { return "opengl"; }
PackedFunc GetFunction(const std::string& name, PackedFunc GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final; const ObjectPtr<Object>& sptr_to_self) final;
std::string GetSource(const std::string& format) final; std::string GetSource(const std::string& format) final;
...@@ -74,7 +74,7 @@ class OpenGLModuleNode final : public ModuleNode { ...@@ -74,7 +74,7 @@ class OpenGLModuleNode final : public ModuleNode {
class OpenGLWrappedFunc { class OpenGLWrappedFunc {
public: public:
OpenGLWrappedFunc(OpenGLModuleNode* m, OpenGLWrappedFunc(OpenGLModuleNode* m,
std::shared_ptr<ModuleNode> sptr, ObjectPtr<Object> sptr,
std::string func_name, std::string func_name,
std::vector<size_t> arg_size, std::vector<size_t> arg_size,
const std::vector<std::string>& thread_axis_tags); const std::vector<std::string>& thread_axis_tags);
...@@ -85,7 +85,7 @@ class OpenGLWrappedFunc { ...@@ -85,7 +85,7 @@ class OpenGLWrappedFunc {
// The module // The module
OpenGLModuleNode* m_; OpenGLModuleNode* m_;
// resource handle // resource handle
std::shared_ptr<ModuleNode> sptr_; ObjectPtr<Object> sptr_;
// The name of the function. // The name of the function.
std::string func_name_; std::string func_name_;
// convert code for void argument // convert code for void argument
...@@ -111,7 +111,7 @@ OpenGLModuleNode::OpenGLModuleNode( ...@@ -111,7 +111,7 @@ OpenGLModuleNode::OpenGLModuleNode(
PackedFunc OpenGLModuleNode::GetFunction( PackedFunc OpenGLModuleNode::GetFunction(
const std::string& name, const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) { const ObjectPtr<Object>& sptr_to_self) {
CHECK_EQ(sptr_to_self.get(), this); CHECK_EQ(sptr_to_self.get(), this);
CHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main"; CHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main";
...@@ -191,7 +191,7 @@ const FunctionInfo& OpenGLModuleNode::GetFunctionInfo( ...@@ -191,7 +191,7 @@ const FunctionInfo& OpenGLModuleNode::GetFunctionInfo(
OpenGLWrappedFunc::OpenGLWrappedFunc( OpenGLWrappedFunc::OpenGLWrappedFunc(
OpenGLModuleNode* m, OpenGLModuleNode* m,
std::shared_ptr<ModuleNode> sptr, ObjectPtr<Object> sptr,
std::string func_name, std::string func_name,
std::vector<size_t> arg_size, std::vector<size_t> arg_size,
const std::vector<std::string>& thread_axis_tags) const std::vector<std::string>& thread_axis_tags)
...@@ -251,9 +251,9 @@ void OpenGLWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, ...@@ -251,9 +251,9 @@ void OpenGLWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv,
Module OpenGLModuleCreate(std::unordered_map<std::string, OpenGLShader> shaders, Module OpenGLModuleCreate(std::unordered_map<std::string, OpenGLShader> shaders,
std::string fmt, std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap) { std::unordered_map<std::string, FunctionInfo> fmap) {
auto n = std::make_shared<OpenGLModuleNode>(std::move(shaders), auto n = make_object<OpenGLModuleNode>(std::move(shaders),
std::move(fmt), std::move(fmt),
std::move(fmap)); std::move(fmap));
return Module(n); return Module(n);
} }
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
*/ */
/*! /*!
* Copyright (c) 2017 by Contributors
* \file rocm_module.cc * \file rocm_module.cc
*/ */
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
...@@ -68,7 +67,7 @@ class ROCMModuleNode : public runtime::ModuleNode { ...@@ -68,7 +67,7 @@ class ROCMModuleNode : public runtime::ModuleNode {
PackedFunc GetFunction( PackedFunc GetFunction(
const std::string& name, const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final; const ObjectPtr<Object>& sptr_to_self) final;
void SaveToFile(const std::string& file_name, void SaveToFile(const std::string& file_name,
...@@ -158,7 +157,7 @@ class ROCMWrappedFunc { ...@@ -158,7 +157,7 @@ class ROCMWrappedFunc {
public: public:
// initialize the ROCM function. // initialize the ROCM function.
void Init(ROCMModuleNode* m, void Init(ROCMModuleNode* m,
std::shared_ptr<ModuleNode> sptr, ObjectPtr<Object> sptr,
const std::string& func_name, const std::string& func_name,
size_t num_void_args, size_t num_void_args,
const std::vector<std::string>& thread_axis_tags) { const std::vector<std::string>& thread_axis_tags) {
...@@ -204,7 +203,7 @@ class ROCMWrappedFunc { ...@@ -204,7 +203,7 @@ class ROCMWrappedFunc {
// internal module // internal module
ROCMModuleNode* m_; ROCMModuleNode* m_;
// the resource holder // the resource holder
std::shared_ptr<ModuleNode> sptr_; ObjectPtr<Object> sptr_;
// The name of the function. // The name of the function.
std::string func_name_; std::string func_name_;
// Device function cache per device. // Device function cache per device.
...@@ -217,7 +216,7 @@ class ROCMWrappedFunc { ...@@ -217,7 +216,7 @@ class ROCMWrappedFunc {
PackedFunc ROCMModuleNode::GetFunction( PackedFunc ROCMModuleNode::GetFunction(
const std::string& name, const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) { const ObjectPtr<Object>& sptr_to_self) {
CHECK_EQ(sptr_to_self.get(), this); CHECK_EQ(sptr_to_self.get(), this);
CHECK_NE(name, symbol::tvm_module_main) CHECK_NE(name, symbol::tvm_module_main)
<< "Device function do not have main"; << "Device function do not have main";
...@@ -235,8 +234,7 @@ Module ROCMModuleCreate( ...@@ -235,8 +234,7 @@ Module ROCMModuleCreate(
std::unordered_map<std::string, FunctionInfo> fmap, std::unordered_map<std::string, FunctionInfo> fmap,
std::string hip_source, std::string hip_source,
std::string assembly) { std::string assembly) {
std::shared_ptr<ROCMModuleNode> n = auto n = make_object<ROCMModuleNode>(data, fmt, fmap, hip_source, assembly);
std::make_shared<ROCMModuleNode>(data, fmt, fmap, hip_source, assembly);
return Module(n); return Module(n);
} }
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -123,7 +123,7 @@ class RPCModuleNode final : public ModuleNode { ...@@ -123,7 +123,7 @@ class RPCModuleNode final : public ModuleNode {
PackedFunc GetFunction( PackedFunc GetFunction(
const std::string& name, const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final { const ObjectPtr<Object>& sptr_to_self) final {
RPCFuncHandle handle = GetFuncHandle(name); RPCFuncHandle handle = GetFuncHandle(name);
return WrapRemote(handle); return WrapRemote(handle);
} }
...@@ -195,8 +195,7 @@ void RPCWrappedFunc::WrapRemote(std::shared_ptr<RPCSession> sess, ...@@ -195,8 +195,7 @@ void RPCWrappedFunc::WrapRemote(std::shared_ptr<RPCSession> sess,
return wf->operator()(args, rv); return wf->operator()(args, rv);
}); });
} else if (tcode == kModuleHandle) { } else if (tcode == kModuleHandle) {
std::shared_ptr<RPCModuleNode> n = auto n = make_object<RPCModuleNode>(handle, sess);
std::make_shared<RPCModuleNode>(handle, sess);
*rv = Module(n); *rv = Module(n);
} else if (tcode == kArrayHandle || tcode == kNDArrayContainer) { } else if (tcode == kArrayHandle || tcode == kNDArrayContainer) {
CHECK_EQ(args.size(), 2); CHECK_EQ(args.size(), 2);
...@@ -209,8 +208,7 @@ void RPCWrappedFunc::WrapRemote(std::shared_ptr<RPCSession> sess, ...@@ -209,8 +208,7 @@ void RPCWrappedFunc::WrapRemote(std::shared_ptr<RPCSession> sess,
} }
Module CreateRPCModule(std::shared_ptr<RPCSession> sess) { Module CreateRPCModule(std::shared_ptr<RPCSession> sess) {
std::shared_ptr<RPCModuleNode> n = auto n = make_object<RPCModuleNode>(nullptr, sess);
std::make_shared<RPCModuleNode>(nullptr, sess);
return Module(n); return Module(n);
} }
...@@ -237,8 +235,7 @@ TVM_REGISTER_GLOBAL("rpc._LoadRemoteModule") ...@@ -237,8 +235,7 @@ TVM_REGISTER_GLOBAL("rpc._LoadRemoteModule")
CHECK_EQ(tkey, "rpc"); CHECK_EQ(tkey, "rpc");
auto& sess = static_cast<RPCModuleNode*>(m.operator->())->sess(); auto& sess = static_cast<RPCModuleNode*>(m.operator->())->sess();
void* mhandle = sess->CallRemote(RPCCode::kModuleLoad, args[1]); void* mhandle = sess->CallRemote(RPCCode::kModuleLoad, args[1]);
std::shared_ptr<RPCModuleNode> n = auto n = make_object<RPCModuleNode>(mhandle, sess);
std::make_shared<RPCModuleNode>(mhandle, sess);
*rv = Module(n); *rv = Module(n);
}); });
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -35,6 +35,7 @@ ...@@ -35,6 +35,7 @@
#include <cmath> #include <cmath>
#include <algorithm> #include <algorithm>
#include "rpc_session.h" #include "rpc_session.h"
#include "../object_internal.h"
#include "../../common/ring_buffer.h" #include "../../common/ring_buffer.h"
#include "../../common/socket.h" #include "../../common/socket.h"
...@@ -1119,25 +1120,29 @@ void RPCModuleLoad(TVMArgs args, TVMRetValue *rv) { ...@@ -1119,25 +1120,29 @@ void RPCModuleLoad(TVMArgs args, TVMRetValue *rv) {
} }
std::string file_name = args[0]; std::string file_name = args[0];
TVMRetValue ret = (*fsys_load_)(file_name); TVMRetValue ret = (*fsys_load_)(file_name);
Module m = ret; // pass via void*
*rv = static_cast<void*>(new Module(m)); TVMValue value;
int rcode;
ret.MoveToCHost(&value, &rcode);
CHECK_EQ(rcode, kModuleHandle);
*rv = static_cast<void*>(value.v_handle);
} }
void RPCModuleImport(TVMArgs args, TVMRetValue *rv) { void RPCModuleImport(TVMArgs args, TVMRetValue *rv) {
void* pmod = args[0]; void* pmod = args[0];
void* cmod = args[1]; void* cmod = args[1];
static_cast<Module*>(pmod)->Import( ObjectInternal::GetModuleNode(pmod)->Import(
*static_cast<Module*>(cmod)); GetRef<Module>(ObjectInternal::GetModuleNode(cmod)));
} }
void RPCModuleFree(TVMArgs args, TVMRetValue *rv) { void RPCModuleFree(TVMArgs args, TVMRetValue *rv) {
void* mhandle = args[0]; void* mhandle = args[0];
delete static_cast<Module*>(mhandle); ObjectInternal::ObjectFree(mhandle);
} }
void RPCModuleGetFunc(TVMArgs args, TVMRetValue *rv) { void RPCModuleGetFunc(TVMArgs args, TVMRetValue *rv) {
void* mhandle = args[0]; void* mhandle = args[0];
PackedFunc pf = static_cast<Module*>(mhandle)->GetFunction( PackedFunc pf = ObjectInternal::GetModuleNode(mhandle)->GetFunction(
args[1], false); args[1], false);
if (pf != nullptr) { if (pf != nullptr) {
*rv = static_cast<void*>(new PackedFunc(pf)); *rv = static_cast<void*>(new PackedFunc(pf));
...@@ -1149,7 +1154,7 @@ void RPCModuleGetFunc(TVMArgs args, TVMRetValue *rv) { ...@@ -1149,7 +1154,7 @@ void RPCModuleGetFunc(TVMArgs args, TVMRetValue *rv) {
void RPCModuleGetSource(TVMArgs args, TVMRetValue *rv) { void RPCModuleGetSource(TVMArgs args, TVMRetValue *rv) {
void* mhandle = args[0]; void* mhandle = args[0];
std::string fmt = args[1]; std::string fmt = args[1];
*rv = (*static_cast<Module*>(mhandle))->GetSource(fmt); *rv = ObjectInternal::GetModuleNode(mhandle)->GetSource(fmt);
} }
void RPCNDArrayFree(TVMArgs args, TVMRetValue *rv) { void RPCNDArrayFree(TVMArgs args, TVMRetValue *rv) {
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
*/ */
/*! /*!
* Copyright (c) 2017 by Contributors
* \file rpc_session.h * \file rpc_session.h
* \brief Base RPC session interface. * \brief Base RPC session interface.
*/ */
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
*/ */
/*! /*!
* Copyright (c) 2017 by Contributors
* Implementation stack VM. * Implementation stack VM.
* \file stackvm.cc * \file stackvm.cc
*/ */
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
*/ */
/*! /*!
* Copyright (c) 2017 by Contributors
* \file stackvm_module.cc * \file stackvm_module.cc
*/ */
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
...@@ -42,7 +41,7 @@ class StackVMModuleNode : public runtime::ModuleNode { ...@@ -42,7 +41,7 @@ class StackVMModuleNode : public runtime::ModuleNode {
PackedFunc GetFunction( PackedFunc GetFunction(
const std::string& name, const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final { const ObjectPtr<Object>& sptr_to_self) final {
if (name == runtime::symbol::tvm_module_main) { if (name == runtime::symbol::tvm_module_main) {
return GetFunction(entry_func_, sptr_to_self); return GetFunction(entry_func_, sptr_to_self);
} }
...@@ -89,8 +88,7 @@ class StackVMModuleNode : public runtime::ModuleNode { ...@@ -89,8 +88,7 @@ class StackVMModuleNode : public runtime::ModuleNode {
static Module Create(std::unordered_map<std::string, StackVM> fmap, static Module Create(std::unordered_map<std::string, StackVM> fmap,
std::string entry_func) { std::string entry_func) {
std::shared_ptr<StackVMModuleNode> n = auto n = make_object<StackVMModuleNode>();
std::make_shared<StackVMModuleNode>();
n->fmap_ = std::move(fmap); n->fmap_ = std::move(fmap);
n->entry_func_ = std::move(entry_func); n->entry_func_ = std::move(entry_func);
return Module(n); return Module(n);
...@@ -101,8 +99,7 @@ class StackVMModuleNode : public runtime::ModuleNode { ...@@ -101,8 +99,7 @@ class StackVMModuleNode : public runtime::ModuleNode {
std::string entry_func, data; std::string entry_func, data;
strm->Read(&fmap); strm->Read(&fmap);
strm->Read(&entry_func); strm->Read(&entry_func);
std::shared_ptr<StackVMModuleNode> n = auto n = make_object<StackVMModuleNode>();
std::make_shared<StackVMModuleNode>();
n->fmap_ = std::move(fmap); n->fmap_ = std::move(fmap);
n->entry_func_ = std::move(entry_func); n->entry_func_ = std::move(entry_func);
uint64_t num_imports; uint64_t num_imports;
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -18,11 +18,11 @@ ...@@ -18,11 +18,11 @@
*/ */
/*! /*!
* Copyright (c) 2017 by Contributors
* \file system_lib_module.cc * \file system_lib_module.cc
* \brief SystemLib module. * \brief SystemLib module.
*/ */
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/runtime/memory.h>
#include <tvm/runtime/c_backend_api.h> #include <tvm/runtime/c_backend_api.h>
#include <mutex> #include <mutex>
#include "module_util.h" #include "module_util.h"
...@@ -40,7 +40,7 @@ class SystemLibModuleNode : public ModuleNode { ...@@ -40,7 +40,7 @@ class SystemLibModuleNode : public ModuleNode {
PackedFunc GetFunction( PackedFunc GetFunction(
const std::string& name, const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final { const ObjectPtr<Object>& sptr_to_self) final {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
if (module_blob_ != nullptr) { if (module_blob_ != nullptr) {
...@@ -83,9 +83,8 @@ class SystemLibModuleNode : public ModuleNode { ...@@ -83,9 +83,8 @@ class SystemLibModuleNode : public ModuleNode {
} }
} }
static const std::shared_ptr<SystemLibModuleNode>& Global() { static const ObjectPtr<SystemLibModuleNode>& Global() {
static std::shared_ptr<SystemLibModuleNode> inst = static auto inst = make_object<SystemLibModuleNode>();
std::make_shared<SystemLibModuleNode>();
return inst; return inst;
} }
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
*/ */
/*! /*!
* Copyright (c) 2019 by Contributors
* \file tvm/runtime/vm/executable.cc * \file tvm/runtime/vm/executable.cc
* \brief The implementation of a virtual machine executable APIs. * \brief The implementation of a virtual machine executable APIs.
*/ */
...@@ -51,7 +50,7 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr); ...@@ -51,7 +50,7 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr);
Instruction DeserializeInstruction(const VMInstructionSerializer& instr); Instruction DeserializeInstruction(const VMInstructionSerializer& instr);
PackedFunc Executable::GetFunction(const std::string& name, PackedFunc Executable::GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) { const ObjectPtr<Object>& sptr_to_self) {
if (name == "get_lib") { if (name == "get_lib") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->GetLib(); *rv = this->GetLib();
...@@ -440,7 +439,7 @@ void LoadHeader(dmlc::Stream* strm) { ...@@ -440,7 +439,7 @@ void LoadHeader(dmlc::Stream* strm) {
} }
runtime::Module Executable::Load(const std::string& code, const runtime::Module lib) { runtime::Module Executable::Load(const std::string& code, const runtime::Module lib) {
std::shared_ptr<Executable> exec = std::make_shared<Executable>(); auto exec = make_object<Executable>();
exec->lib = lib; exec->lib = lib;
exec->code_ = code; exec->code_ = code;
dmlc::MemoryStringStream strm(&exec->code_); dmlc::MemoryStringStream strm(&exec->code_);
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
*/ */
/*! /*!
* Copyright (c) 2019 by Contributors
* \file src/runtime/vm/profiler/vm.cc * \file src/runtime/vm/profiler/vm.cc
* \brief The Relay debug virtual machine. * \brief The Relay debug virtual machine.
*/ */
...@@ -41,7 +40,7 @@ namespace runtime { ...@@ -41,7 +40,7 @@ namespace runtime {
namespace vm { namespace vm {
PackedFunc VirtualMachineDebug::GetFunction( PackedFunc VirtualMachineDebug::GetFunction(
const std::string& name, const std::shared_ptr<ModuleNode>& sptr_to_self) { const std::string& name, const ObjectPtr<Object>& sptr_to_self) {
if (name == "get_stat") { if (name == "get_stat") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
double total_duration = 0.0; double total_duration = 0.0;
...@@ -124,7 +123,7 @@ void VirtualMachineDebug::InvokePacked(Index packed_index, ...@@ -124,7 +123,7 @@ void VirtualMachineDebug::InvokePacked(Index packed_index,
} }
runtime::Module CreateVirtualMachineDebug(const Executable* exec) { runtime::Module CreateVirtualMachineDebug(const Executable* exec) {
std::shared_ptr<VirtualMachineDebug> vm = std::make_shared<VirtualMachineDebug>(); auto vm = make_object<VirtualMachineDebug>();
vm->LoadExecutable(exec); vm->LoadExecutable(exec);
return runtime::Module(vm); return runtime::Module(vm);
} }
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
*/ */
/*! /*!
* Copyright (c) 2019 by Contributors
* \file src/runtime/vm/profiler/vm.h * \file src/runtime/vm/profiler/vm.h
* \brief The Relay debug virtual machine. * \brief The Relay debug virtual machine.
*/ */
...@@ -42,7 +41,7 @@ class VirtualMachineDebug : public VirtualMachine { ...@@ -42,7 +41,7 @@ class VirtualMachineDebug : public VirtualMachine {
VirtualMachineDebug() : VirtualMachine() {} VirtualMachineDebug() : VirtualMachine() {}
PackedFunc GetFunction(const std::string& name, PackedFunc GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final; const ObjectPtr<Object>& sptr_to_self) final;
void InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count, void InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count,
Index output_size, const std::vector<ObjectRef>& args) final; Index output_size, const std::vector<ObjectRef>& args) final;
......
...@@ -627,7 +627,7 @@ ObjectRef CopyTo(ObjectRef src, const DLContext& ctx) { ...@@ -627,7 +627,7 @@ ObjectRef CopyTo(ObjectRef src, const DLContext& ctx) {
} }
PackedFunc VirtualMachine::GetFunction(const std::string& name, PackedFunc VirtualMachine::GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) { const ObjectPtr<Object>& sptr_to_self) {
if (name == "invoke") { if (name == "invoke") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK(exec) << "The executable is not created yet."; CHECK(exec) << "The executable is not created yet.";
...@@ -1052,7 +1052,7 @@ void VirtualMachine::RunLoop() { ...@@ -1052,7 +1052,7 @@ void VirtualMachine::RunLoop() {
} }
runtime::Module CreateVirtualMachine(const Executable* exec) { runtime::Module CreateVirtualMachine(const Executable* exec) {
std::shared_ptr<VirtualMachine> vm = std::make_shared<VirtualMachine>(); auto vm = make_object<VirtualMachine>();
vm->LoadExecutable(exec); vm->LoadExecutable(exec);
return runtime::Module(vm); return runtime::Module(vm);
} }
......
...@@ -663,7 +663,9 @@ class VulkanModuleNode; ...@@ -663,7 +663,9 @@ class VulkanModuleNode;
// a wrapped function class to get packed func. // a wrapped function class to get packed func.
class VulkanWrappedFunc { class VulkanWrappedFunc {
public: public:
void Init(VulkanModuleNode* m, std::shared_ptr<ModuleNode> sptr, const std::string& func_name, void Init(VulkanModuleNode* m,
ObjectPtr<Object> sptr,
const std::string& func_name,
size_t num_buffer_args, size_t num_pack_args, size_t num_buffer_args, size_t num_pack_args,
const std::vector<std::string>& thread_axis_tags) { const std::vector<std::string>& thread_axis_tags) {
m_ = m; m_ = m;
...@@ -680,7 +682,7 @@ class VulkanWrappedFunc { ...@@ -680,7 +682,7 @@ class VulkanWrappedFunc {
// internal module // internal module
VulkanModuleNode* m_; VulkanModuleNode* m_;
// the resource holder // the resource holder
std::shared_ptr<ModuleNode> sptr_; ObjectPtr<Object> sptr_;
// v The name of the function. // v The name of the function.
std::string func_name_; std::string func_name_;
// Number of buffer arguments // Number of buffer arguments
...@@ -705,7 +707,7 @@ class VulkanModuleNode final : public runtime::ModuleNode { ...@@ -705,7 +707,7 @@ class VulkanModuleNode final : public runtime::ModuleNode {
const char* type_key() const final { return "vulkan"; } const char* type_key() const final { return "vulkan"; }
PackedFunc GetFunction(const std::string& name, PackedFunc GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final { const ObjectPtr<Object>& sptr_to_self) final {
CHECK_EQ(sptr_to_self.get(), this); CHECK_EQ(sptr_to_self.get(), this);
CHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main"; CHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main";
auto it = fmap_.find(name); auto it = fmap_.find(name);
...@@ -939,7 +941,7 @@ class VulkanModuleNode final : public runtime::ModuleNode { ...@@ -939,7 +941,7 @@ class VulkanModuleNode final : public runtime::ModuleNode {
Module VulkanModuleCreate(std::unordered_map<std::string, VulkanShader> smap, Module VulkanModuleCreate(std::unordered_map<std::string, VulkanShader> smap,
std::unordered_map<std::string, FunctionInfo> fmap, std::string source) { std::unordered_map<std::string, FunctionInfo> fmap, std::string source) {
std::shared_ptr<VulkanModuleNode> n = std::make_shared<VulkanModuleNode>(smap, fmap, source); auto n = make_object<VulkanModuleNode>(smap, fmap, source);
return Module(n); return Module(n);
} }
......
...@@ -226,7 +226,7 @@ class DPIModule final : public DPIModuleNode { ...@@ -226,7 +226,7 @@ class DPIModule final : public DPIModuleNode {
PackedFunc GetFunction( PackedFunc GetFunction(
const std::string& name, const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final { const ObjectPtr<Object>& sptr_to_self) final {
if (name == "WriteReg") { if (name == "WriteReg") {
return TypedPackedFunc<void(int, int)>( return TypedPackedFunc<void(int, int)>(
[this](int addr, int value){ [this](int addr, int value){
...@@ -413,8 +413,7 @@ class DPIModule final : public DPIModuleNode { ...@@ -413,8 +413,7 @@ class DPIModule final : public DPIModuleNode {
}; };
Module DPIModuleNode::Load(std::string dll_name) { Module DPIModuleNode::Load(std::string dll_name) {
std::shared_ptr<DPIModule> n = auto n = make_object<DPIModule>();
std::make_shared<DPIModule>();
n->Init(dll_name); n->Init(dll_name);
return Module(n); return Module(n);
} }
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include "../src/runtime/system_lib_module.cc" #include "../src/runtime/system_lib_module.cc"
#include "../src/runtime/module.cc" #include "../src/runtime/module.cc"
#include "../src/runtime/ndarray.cc" #include "../src/runtime/ndarray.cc"
#include "../src/runtime/object.cc"
#include "../src/runtime/registry.cc" #include "../src/runtime/registry.cc"
#include "../src/runtime/file_util.cc" #include "../src/runtime/file_util.cc"
#include "../src/runtime/dso_module.cc" #include "../src/runtime/dso_module.cc"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment