Unverified Commit f823c577 by Tianqi Chen Committed by GitHub

[RUNTIME][REFACTOR] Use object protocol to support runtime::Module (#4289)

Previously runtime::Module was supported using shared_ptr.
This PR refactors the codebase to use the Object protocol.

It will open doors to allow easier interpolation between
Object containers and module in the future.
parent aeb5f130
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2018 by Contributors
* \file tvm_runtime.h
* \brief Pack all tvm runtime source files
*/
......@@ -35,6 +34,7 @@
#include "../src/runtime/file_util.cc"
#include "../src/runtime/dso_module.cc"
#include "../src/runtime/thread_pool.cc"
#include "../src/runtime/object.cc"
#include "../src/runtime/threading_backend.cc"
#include "../src/runtime/ndarray.cc"
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -55,6 +55,7 @@
#include "../src/runtime/threading_backend.cc"
#include "../src/runtime/graph/graph_runtime.cc"
#include "../src/runtime/ndarray.cc"
#include "../src/runtime/object.cc"
#ifdef TVM_OPENCL_RUNTIME
#include "../src/runtime/opencl/opencl_device_api.cc"
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -32,5 +32,6 @@
#include "../../src/runtime/threading_backend.cc"
#include "../../src/runtime/thread_pool.cc"
#include "../../src/runtime/ndarray.cc"
#include "../../src/runtime/object.cc"
#include "../../src/runtime/system_lib_module.cc"
#include "../../src/runtime/graph/graph_runtime.cc"
......@@ -47,6 +47,7 @@
#include "../../src/runtime/threading_backend.cc"
#include "../../src/runtime/thread_pool.cc"
#include "../../src/runtime/ndarray.cc"
#include "../../src/runtime/object.cc"
// NOTE: all the files after this are optional modules
// that you can include remove, depending on how much feature you use.
......
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2017 by Contributors
* \file TVMRuntime.mm
*/
#include "TVMRuntime.h"
......@@ -35,6 +34,8 @@
#include "../../../src/runtime/file_util.cc"
#include "../../../src/runtime/dso_module.cc"
#include "../../../src/runtime/ndarray.cc"
#include "../../../src/runtime/object.cc"
// RPC server
#include "../../../src/runtime/rpc/rpc_session.cc"
#include "../../../src/runtime/rpc/rpc_server_env.cc"
......
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2018 by Contributors
* \brief This is an all in one TVM runtime file.
* \file tvm_runtime_pack.cc
*/
......@@ -32,6 +31,7 @@
#include "src/runtime/threading_backend.cc"
#include "src/runtime/thread_pool.cc"
#include "src/runtime/ndarray.cc"
#include "src/runtime/object.cc"
// NOTE: all the files after this are optional modules
// that you can include remove, depending on how much feature you use.
......
......@@ -27,28 +27,31 @@
#define TVM_RUNTIME_MODULE_H_
#include <dmlc/io.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/memory.h>
#include <memory>
#include <vector>
#include <string>
#include <unordered_map>
#include "c_runtime_api.h"
namespace tvm {
namespace runtime {
// The internal container of module.
class ModuleNode;
class PackedFunc;
/*!
* \brief Module container of TVM.
*/
class Module {
class Module : public ObjectRef {
public:
Module() {}
// constructor from container.
explicit Module(std::shared_ptr<ModuleNode> n)
: node_(n) {}
explicit Module(ObjectPtr<Object> n)
: ObjectRef(n) {}
/*!
* \brief Get packed function from current module by name.
*
......@@ -59,10 +62,6 @@ class Module {
* \note Implemented in packed_func.cc
*/
inline PackedFunc GetFunction(const std::string& name, bool query_imports = false);
/*! \return internal container */
inline ModuleNode* operator->();
/*! \return internal container */
inline const ModuleNode* operator->() const;
// The following functions requires link with runtime.
/*!
* \brief Import another module into this module.
......@@ -71,7 +70,11 @@ class Module {
* \note Cyclic dependency is not allowed among modules,
* An error will be thrown when cyclic dependency is detected.
*/
TVM_DLL void Import(Module other);
inline void Import(Module other);
/*! \return internal container */
inline ModuleNode* operator->();
/*! \return internal container */
inline const ModuleNode* operator->() const;
/*!
* \brief Load a module from file.
* \param file_name The name of the host function module.
......@@ -81,20 +84,41 @@ class Module {
*/
TVM_DLL static Module LoadFromFile(const std::string& file_name,
const std::string& format = "");
private:
std::shared_ptr<ModuleNode> node_;
// refer to the corresponding container.
using ContainerType = ModuleNode;
friend class ModuleNode;
};
/*!
* \brief Base node container of module.
* Do not create this directly, instead use Module.
* \brief Base container of module.
*
* Please subclass ModuleNode to create a specific runtime module.
*
* \code
*
* class MyModuleNode : public ModuleNode {
* public:
* // implement the interface
* };
*
* // use make_object to create a specific
* // instace of MyModuleNode.
* Module CreateMyModule() {
* ObjectPtr<MyModuleNode> n =
* tvm::runtime::make_object<MyModuleNode>();
* return Module(n);
* }
*
* \endcode
*/
class ModuleNode {
class ModuleNode : public Object {
public:
/*! \brief virtual destructor */
virtual ~ModuleNode() {}
/*! \return The module type key */
/*!
* \return The per module type key.
* \note This key is used to for serializing custom modules.
*/
virtual const char* type_key() const = 0;
/*!
* \brief Get a PackedFunc from module.
......@@ -105,7 +129,7 @@ class ModuleNode {
* For benchmarking, use prepare to eliminate
*
* \param name the name of the function.
* \param sptr_to_self The shared_ptr that points to this module node.
* \param sptr_to_self The ObjectPtr that points to this module node.
*
* \return PackedFunc(nullptr) when it is not available.
*
......@@ -115,7 +139,7 @@ class ModuleNode {
*/
virtual PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) = 0;
const ObjectPtr<Object>& sptr_to_self) = 0;
/*!
* \brief Save the module to file.
* \param file_name The file to be saved to.
......@@ -138,6 +162,24 @@ class ModuleNode {
*/
TVM_DLL virtual std::string GetSource(const std::string& format = "");
/*!
* \brief Get packed function from current module by name.
*
* \param name The name of the function.
* \param query_imports Whether also query dependency modules.
* \return The result function.
* This function will return PackedFunc(nullptr) if function do not exist.
* \note Implemented in packed_func.cc
*/
TVM_DLL PackedFunc GetFunction(const std::string& name, bool query_imports = false);
/*!
* \brief Import another module into this module.
* \param other The module to be imported.
*
* \note Cyclic dependency is not allowed among modules,
* An error will be thrown when cyclic dependency is detected.
*/
TVM_DLL void Import(Module other);
/*!
* \brief Get a function from current environment
* The environment includes all the imports as well as Global functions.
*
......@@ -150,6 +192,13 @@ class ModuleNode {
return imports_;
}
// integration with the existing components.
static constexpr const uint32_t _type_index = TypeIndex::kRuntimeModule;
static constexpr const char* _type_key = "runtime.Module";
// NOTE: ModuleNode can still be sub-classed
//
TVM_DECLARE_FINAL_OBJECT_INFO(ModuleNode, Object);
protected:
friend class Module;
/*! \brief The modules this module depend on */
......@@ -180,16 +229,21 @@ constexpr const char* tvm_module_main = "__tvm_main__";
} // namespace symbol
// implementations of inline functions.
inline void Module::Import(Module other) {
return (*this)->Import(other);
}
inline ModuleNode* Module::operator->() {
return node_.get();
return static_cast<ModuleNode*>(get_mutable());
}
inline const ModuleNode* Module::operator->() const {
return node_.get();
return static_cast<const ModuleNode*>(get());
}
} // namespace runtime
} // namespace tvm
#include "packed_func.h"
#include <tvm/runtime/packed_func.h> // NOLINT(*)
#endif // TVM_RUNTIME_MODULE_H_
......@@ -53,6 +53,7 @@ enum TypeIndex {
kVMTensor = 1,
kVMClosure = 2,
kVMADT = 3,
kRuntimeModule = 4,
kStaticIndexEnd,
/*! \brief Type index is allocated during runtime. */
kDynamic = kStaticIndexEnd
......@@ -302,7 +303,7 @@ class Object {
template<typename>
friend class ObjectPtr;
friend class TVMRetValue;
friend class TVMObjectCAPI;
friend class ObjectInternal;
};
/*!
......@@ -310,11 +311,11 @@ class Object {
*
* It is always important to get a reference type
* if we want to return a value as reference or keep
* the node alive beyond the scope of the function.
* the object alive beyond the scope of the function.
*
* \param ptr The node pointer
* \param ptr The object pointer
* \tparam RefType The reference type
* \tparam ObjectType The node type
* \tparam ObjectType The object type
* \return The corresponding RefType
*/
template <typename RefType, typename ObjectType>
......@@ -486,6 +487,8 @@ class ObjectPtr {
friend class TVMArgValue;
template <typename RefType, typename ObjType>
friend RefType GetRef(const ObjType* ptr);
template <typename BaseType, typename ObjType>
friend ObjectPtr<BaseType> GetObjectPtr(ObjType* ptr);
};
/*! \brief Base class of all object reference */
......@@ -513,7 +516,7 @@ class ObjectRef {
}
/*!
* \brief Comparator
* \param other Another node ref.
* \param other Another object ref.
* \return the compare result.
*/
bool operator!=(const ObjectRef& other) const {
......@@ -535,7 +538,7 @@ class ObjectRef {
const Object* get() const {
return data_.get();
}
/*! \return the internal node pointer */
/*! \return the internal object pointer */
const Object* operator->() const {
return get();
}
......@@ -595,6 +598,16 @@ class ObjectRef {
friend SubRef Downcast(BaseRef ref);
};
/*!
* \brief Get an object ptr type from a raw object ptr.
*
* \param ptr The object pointer
* \tparam BaseType The reference type
* \tparam ObjectType The object type
* \return The corresponding RefType
*/
template <typename BaseType, typename ObjectType>
inline ObjectPtr<BaseType> GetObjectPtr(ObjectType* ptr);
/*! \brief ObjectRef hash functor */
struct ObjectHash {
......@@ -781,6 +794,13 @@ inline RefType GetRef(const ObjType* ptr) {
return RefType(ObjectPtr<Object>(const_cast<Object*>(static_cast<const Object*>(ptr))));
}
template <typename BaseType, typename ObjType>
inline ObjectPtr<BaseType> GetObjectPtr(ObjType* ptr) {
static_assert(std::is_base_of<BaseType, ObjType>::value,
"Can only cast to the ref of same container type");
return ObjectPtr<BaseType>(static_cast<Object*>(ptr));
}
template <typename SubRef, typename BaseRef>
inline SubRef Downcast(BaseRef ref) {
CHECK(ref->template IsInstance<typename SubRef::ContainerType>())
......
......@@ -496,6 +496,14 @@ class TVMPODValue_ {
return ObjectRef(
ObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
}
operator Module() const {
if (type_code_ == kNull) {
return Module(ObjectPtr<Object>(nullptr));
}
TVM_CHECK_TYPE_CODE(type_code_, kModuleHandle);
return Module(
ObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
}
operator TVMContext() const {
TVM_CHECK_TYPE_CODE(type_code_, kTVMContext);
return value_.v_ctx;
......@@ -574,6 +582,7 @@ class TVMArgValue : public TVMPODValue_ {
using TVMPODValue_::operator NDArray;
using TVMPODValue_::operator TVMContext;
using TVMPODValue_::operator ObjectRef;
using TVMPODValue_::operator Module;
using TVMPODValue_::IsObjectRef;
// conversion operator.
......@@ -610,10 +619,6 @@ class TVMArgValue : public TVMPODValue_ {
operator TypedPackedFunc<FType>() const {
return TypedPackedFunc<FType>(operator PackedFunc());
}
operator Module() const {
TVM_CHECK_TYPE_CODE(type_code_, kModuleHandle);
return *ptr<Module>();
}
const TVMValue& value() const {
return value_;
}
......@@ -665,6 +670,7 @@ class TVMRetValue : public TVMPODValue_ {
using TVMPODValue_::operator TVMContext;
using TVMPODValue_::operator NDArray;
using TVMPODValue_::operator ObjectRef;
using TVMPODValue_::operator Module;
using TVMPODValue_::IsObjectRef;
TVMRetValue(const TVMRetValue& other) : TVMPODValue_() {
......@@ -696,10 +702,6 @@ class TVMRetValue : public TVMPODValue_ {
operator TypedPackedFunc<FType>() const {
return TypedPackedFunc<FType>(operator PackedFunc());
}
operator Module() const {
TVM_CHECK_TYPE_CODE(type_code_, kModuleHandle);
return *ptr<Module>();
}
// Assign operators
TVMRetValue& operator=(TVMRetValue&& other) {
this->Clear();
......@@ -766,17 +768,13 @@ class TVMRetValue : public TVMPODValue_ {
TVMRetValue& operator=(ObjectRef other) {
return operator=(std::move(other.data_));
}
TVMRetValue& operator=(Module m) {
SwitchToObject(kModuleHandle, std::move(m.data_));
return *this;
}
template<typename T>
TVMRetValue& operator=(ObjectPtr<T> other) {
if (other.data_ != nullptr) {
this->Clear();
type_code_ = kObjectHandle;
// move the handle out
value_.v_handle = other.data_;
other.data_ = nullptr;
} else {
SwitchToPOD(kNull);
}
SwitchToObject(kObjectHandle, std::move(other));
return *this;
}
TVMRetValue& operator=(PackedFunc f) {
......@@ -787,10 +785,6 @@ class TVMRetValue : public TVMPODValue_ {
TVMRetValue& operator=(const TypedPackedFunc<FType>& f) {
return operator=(f.packed());
}
TVMRetValue& operator=(Module m) {
this->SwitchToClass(kModuleHandle, m);
return *this;
}
TVMRetValue& operator=(const TVMRetValue& other) { // NOLINT(*0
this->Assign(other);
return *this;
......@@ -860,7 +854,7 @@ class TVMRetValue : public TVMPODValue_ {
break;
}
case kModuleHandle: {
SwitchToClass<Module>(kModuleHandle, other);
*this = other.operator Module();
break;
}
case kNDArrayContainer: {
......@@ -907,16 +901,30 @@ class TVMRetValue : public TVMPODValue_ {
*static_cast<T*>(value_.v_handle) = v;
}
}
void SwitchToObject(int type_code, ObjectPtr<Object> other) {
if (other.data_ != nullptr) {
this->Clear();
type_code_ = type_code;
// move the handle out
value_.v_handle = other.data_;
other.data_ = nullptr;
} else {
SwitchToPOD(kNull);
}
}
void Clear() {
if (type_code_ == kNull) return;
switch (type_code_) {
case kStr: delete ptr<std::string>(); break;
case kFuncHandle: delete ptr<PackedFunc>(); break;
case kModuleHandle: delete ptr<Module>(); break;
case kNDArrayContainer: {
static_cast<NDArray::Container*>(value_.v_handle)->DecRef();
break;
}
case kModuleHandle: {
static_cast<Object*>(value_.v_handle)->DecRef();
break;
}
case kObjectHandle: {
static_cast<Object*>(value_.v_handle)->DecRef();
break;
......@@ -1156,8 +1164,12 @@ class TVMArgsSetter {
operator()(i, value.packed());
}
void operator()(size_t i, const Module& value) const { // NOLINT(*)
values_[i].v_handle = const_cast<Module*>(&value);
type_codes_[i] = kModuleHandle;
if (value.defined()) {
values_[i].v_handle = value.data_.data_;
type_codes_[i] = kModuleHandle;
} else {
type_codes_[i] = kNull;
}
}
void operator()(size_t i, const NDArray& value) const { // NOLINT(*)
values_[i].v_handle = value.data_;
......@@ -1372,19 +1384,10 @@ inline ExtTypeVTable* ExtTypeVTable::Register_() {
return ExtTypeVTable::RegisterInternal(code, vt);
}
// Implement Module::GetFunction
// Put implementation in this file so we have seen the PackedFunc
inline PackedFunc Module::GetFunction(const std::string& name, bool query_imports) {
PackedFunc pf = node_->GetFunction(name, node_);
if (pf != nullptr) return pf;
if (query_imports) {
for (const Module& m : node_->imports_) {
pf = m.node_->GetFunction(name, m.node_);
if (pf != nullptr) return pf;
}
}
return pf;
return (*this)->GetFunction(name, query_imports);
}
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_PACKED_FUNC_H_
......@@ -480,7 +480,7 @@ class Executable : public ModuleNode {
* \return PackedFunc or nullptr when it is not available.
*/
PackedFunc GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final;
const ObjectPtr<Object>& sptr_to_self) final;
/*!
* \brief Serialize the executable into global section, constant section, and
......@@ -658,7 +658,7 @@ class VirtualMachine : public runtime::ModuleNode {
* it should capture sptr_to_self.
*/
virtual PackedFunc GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self);
const ObjectPtr<Object>& sptr_to_self);
/*!
* \brief Invoke a PackedFunction
......
......@@ -148,7 +148,7 @@ class Executable(object):
raise TypeError("bytecode is expected to be the type of bytearray " +
"or TVMByteArray, but received {}".format(type(code)))
if not isinstance(lib, tvm.module.Module):
if lib is not None and not isinstance(lib, tvm.module.Module):
raise TypeError("lib is expected to be the type of tvm.module.Module" +
", but received {}".format(type(lib)))
......
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2017 by Contributors
* \file llvm_module.cc
* \brief LLVM runtime module for TVM
*/
......@@ -54,7 +53,7 @@ class LLVMModuleNode final : public runtime::ModuleNode {
PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final {
const ObjectPtr<Object>& sptr_to_self) final {
if (name == "__tvm_is_system_module") {
bool flag =
(mptr_->getFunction("__tvm_module_startup") != nullptr);
......@@ -325,7 +324,7 @@ TVM_REGISTER_API("codegen.llvm_lookup_intrinsic_id")
TVM_REGISTER_API("codegen.build_llvm")
.set_body([](TVMArgs args, TVMRetValue* rv) {
std::shared_ptr<LLVMModuleNode> n = std::make_shared<LLVMModuleNode>();
auto n = make_object<LLVMModuleNode>();
n->Init(args[0], args[1]);
*rv = runtime::Module(n);
});
......@@ -339,7 +338,7 @@ TVM_REGISTER_API("codegen.llvm_version_major")
TVM_REGISTER_API("module.loadfile_ll")
.set_body([](TVMArgs args, TVMRetValue* rv) {
std::shared_ptr<LLVMModuleNode> n = std::make_shared<LLVMModuleNode>();
auto n = make_object<LLVMModuleNode>();
n->LoadIR(args[0]);
*rv = runtime::Module(n);
});
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2017 by Contributors
* \file source_module.cc
* \brief Source code module, only for viewing
*/
......@@ -51,7 +50,7 @@ class SourceModuleNode : public runtime::ModuleNode {
PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final {
const ObjectPtr<Object>& sptr_to_self) final {
LOG(FATAL) << "Source module cannot execute, to get executable module"
<< " build TVM with \'" << fmt_ << "\' runtime support";
return PackedFunc();
......@@ -67,8 +66,7 @@ class SourceModuleNode : public runtime::ModuleNode {
};
runtime::Module SourceModuleCreate(std::string code, std::string fmt) {
std::shared_ptr<SourceModuleNode> n =
std::make_shared<SourceModuleNode>(code, fmt);
auto n = make_object<SourceModuleNode>(code, fmt);
return runtime::Module(n);
}
......@@ -84,7 +82,7 @@ class CSourceModuleNode : public runtime::ModuleNode {
PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final {
const ObjectPtr<Object>& sptr_to_self) final {
LOG(FATAL) << "C Source module cannot execute, to get executable module"
<< " build TVM with \'" << fmt_ << "\' runtime support";
return PackedFunc();
......@@ -113,8 +111,7 @@ class CSourceModuleNode : public runtime::ModuleNode {
};
runtime::Module CSourceModuleCreate(std::string code, std::string fmt) {
std::shared_ptr<CSourceModuleNode> n =
std::make_shared<CSourceModuleNode>(code, fmt);
auto n = make_object<CSourceModuleNode>(code, fmt);
return runtime::Module(n);
}
......@@ -134,7 +131,7 @@ class DeviceSourceModuleNode final : public runtime::ModuleNode {
PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final {
const ObjectPtr<Object>& sptr_to_self) final {
LOG(FATAL) << "Source module cannot execute, to get executable module"
<< " build TVM with \'" << fmt_ << "\' runtime support";
return PackedFunc();
......@@ -182,8 +179,7 @@ runtime::Module DeviceSourceModuleCreate(
std::unordered_map<std::string, FunctionInfo> fmap,
std::string type_key,
std::function<std::string(const std::string&)> fget_source) {
std::shared_ptr<DeviceSourceModuleNode> n =
std::make_shared<DeviceSourceModuleNode>(data, fmt, fmap, type_key, fget_source);
auto n = make_object<DeviceSourceModuleNode>(data, fmt, fmap, type_key, fget_source);
return runtime::Module(n);
}
......
......@@ -115,7 +115,7 @@ class RelayBuildModule : public runtime::ModuleNode {
* \return The corresponding member function.
*/
PackedFunc GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final {
const ObjectPtr<Object>& sptr_to_self) final {
if (name == "get_graph_json") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->GetGraphJSON();
......@@ -489,7 +489,7 @@ class RelayBuildModule : public runtime::ModuleNode {
};
runtime::Module RelayBuildCreate() {
std::shared_ptr<RelayBuildModule> exec = std::make_shared<RelayBuildModule>();
auto exec = make_object<RelayBuildModule>();
return runtime::Module(exec);
}
......
......@@ -593,7 +593,7 @@ class GraphRuntimeCodegenModule : public runtime::ModuleNode {
public:
GraphRuntimeCodegenModule() {}
virtual PackedFunc GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) {
const ObjectPtr<Object>& sptr_to_self) {
if (name == "init") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.num_args, 2)
......@@ -654,8 +654,7 @@ class GraphRuntimeCodegenModule : public runtime::ModuleNode {
};
runtime::Module CreateGraphCodegenMod() {
std::shared_ptr<GraphRuntimeCodegenModule> ptr =
std::make_shared<GraphRuntimeCodegenModule>();
auto ptr = make_object<GraphRuntimeCodegenModule>();
return runtime::Module(ptr);
}
......
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2019 by Contributors
* \file src/relay/backend/vm/compiler.cc
* \brief A compiler from relay::Module to the VM byte code.
*/
......@@ -745,7 +744,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
PackedFunc VMCompiler::GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) {
const ObjectPtr<Object>& sptr_to_self) {
if (name == "compile") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.num_args, 3);
......@@ -974,7 +973,7 @@ void VMCompiler::LibraryCodegen() {
}
runtime::Module CreateVMCompiler() {
std::shared_ptr<VMCompiler> exec = std::make_shared<VMCompiler>();
auto exec = make_object<VMCompiler>();
return runtime::Module(exec);
}
......
......@@ -86,14 +86,14 @@ class VMCompiler : public runtime::ModuleNode {
virtual ~VMCompiler() {}
virtual PackedFunc GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self);
const ObjectPtr<Object>& sptr_to_self);
const char* type_key() const {
return "VMCompiler";
}
void InitVM() {
exec_ = std::make_shared<Executable>();
exec_ = make_object<Executable>();
}
/*!
......@@ -141,7 +141,7 @@ class VMCompiler : public runtime::ModuleNode {
/*! \brief Global shared meta data */
VMCompilerContext context_;
/*! \brief Compiled executable. */
std::shared_ptr<Executable> exec_;
ObjectPtr<Executable> exec_;
/*! \brief parameters */
std::unordered_map<std::string, runtime::NDArray> params_;
};
......
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2019 by Contributors
* \file src/relay/backend/vm/profiler/compiler.cc
* \brief A compiler from relay::Module to the VM byte code.
*/
......@@ -37,7 +36,7 @@ class VMCompilerDebug : public VMCompiler {
};
runtime::Module CreateVMCompilerDebug() {
std::shared_ptr<VMCompilerDebug> exec = std::make_shared<VMCompilerDebug>();
auto exec = make_object<VMCompilerDebug>();
return runtime::Module(exec);
}
......
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2016 by Contributors
* \file c_runtime_api.cc
* \brief Device specific implementations
*/
......@@ -41,6 +40,7 @@
#include <cstdlib>
#include <cctype>
#include "runtime_base.h"
#include "object_internal.h"
namespace tvm {
namespace runtime {
......@@ -370,16 +370,20 @@ int TVMModLoadFromFile(const char* file_name,
const char* format,
TVMModuleHandle* out) {
API_BEGIN();
Module m = Module::LoadFromFile(file_name, format);
*out = new Module(m);
TVMRetValue ret;
ret = Module::LoadFromFile(file_name, format);
TVMValue val;
int type_code;
ret.MoveToCHost(&val, &type_code);
*out = val.v_handle;
API_END();
}
int TVMModImport(TVMModuleHandle mod,
TVMModuleHandle dep) {
API_BEGIN();
static_cast<Module*>(mod)->Import(
*static_cast<Module*>(dep));
ObjectInternal::GetModuleNode(mod)->Import(
GetRef<Module>(ObjectInternal::GetModuleNode(dep)));
API_END();
}
......@@ -388,7 +392,7 @@ int TVMModGetFunction(TVMModuleHandle mod,
int query_imports,
TVMFunctionHandle *func) {
API_BEGIN();
PackedFunc pf = static_cast<Module*>(mod)->GetFunction(
PackedFunc pf = ObjectInternal::GetModuleNode(mod)->GetFunction(
func_name, query_imports != 0);
if (pf != nullptr) {
*func = new PackedFunc(pf);
......@@ -399,9 +403,7 @@ int TVMModGetFunction(TVMModuleHandle mod,
}
int TVMModFree(TVMModuleHandle mod) {
API_BEGIN();
delete static_cast<Module*>(mod);
API_END();
return TVMObjectFree(mod);
}
int TVMBackendGetFuncFromEnv(void* mod_node,
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -69,7 +69,7 @@ class CUDAModuleNode : public runtime::ModuleNode {
PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final;
const ObjectPtr<Object>& sptr_to_self) final;
void SaveToFile(const std::string& file_name,
const std::string& format) final {
......@@ -166,7 +166,7 @@ class CUDAWrappedFunc {
public:
// initialize the CUDA function.
void Init(CUDAModuleNode* m,
std::shared_ptr<ModuleNode> sptr,
ObjectPtr<Object> sptr,
const std::string& func_name,
size_t num_void_args,
const std::vector<std::string>& thread_axis_tags) {
......@@ -220,7 +220,7 @@ class CUDAWrappedFunc {
// internal module
CUDAModuleNode* m_;
// the resource holder
std::shared_ptr<ModuleNode> sptr_;
ObjectPtr<Object> sptr_;
// The name of the function.
std::string func_name_;
// Device function cache per device.
......@@ -233,7 +233,7 @@ class CUDAWrappedFunc {
class CUDAPrepGlobalBarrier {
public:
CUDAPrepGlobalBarrier(CUDAModuleNode* m,
std::shared_ptr<ModuleNode> sptr)
ObjectPtr<Object> sptr)
: m_(m), sptr_(sptr) {
std::fill(pcache_.begin(), pcache_.end(), 0);
}
......@@ -252,14 +252,14 @@ class CUDAPrepGlobalBarrier {
// internal module
CUDAModuleNode* m_;
// the resource holder
std::shared_ptr<ModuleNode> sptr_;
ObjectPtr<Object> sptr_;
// mark as mutable, to enable lazy initialization
mutable std::array<CUdeviceptr, kMaxNumGPUs> pcache_;
};
PackedFunc CUDAModuleNode::GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) {
const ObjectPtr<Object>& sptr_to_self) {
CHECK_EQ(sptr_to_self.get(), this);
CHECK_NE(name, symbol::tvm_module_main)
<< "Device function do not have main";
......@@ -279,8 +279,7 @@ Module CUDAModuleCreate(
std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap,
std::string cuda_source) {
std::shared_ptr<CUDAModuleNode> n =
std::make_shared<CUDAModuleNode>(data, fmt, fmap, cuda_source);
auto n = make_object<CUDAModuleNode>(data, fmt, fmap, cuda_source);
return Module(n);
}
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2017 by Contributors
* \file cuda_module.h
* \brief Execution handling of CUDA kernels
*/
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -18,11 +18,11 @@
*/
/*!
* Copyright (c) 2017 by Contributors
* \file dso_dll_module.cc
* \brief Module to load from dynamic shared library.
*/
#include <tvm/runtime/module.h>
#include <tvm/runtime/memory.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/packed_func.h>
#include "module_util.h"
......@@ -50,7 +50,7 @@ class DSOModuleNode final : public ModuleNode {
PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final {
const ObjectPtr<Object>& sptr_to_self) final {
BackendPackedCFunc faddr;
if (name == runtime::symbol::tvm_module_main) {
const char* entry_name = reinterpret_cast<const char*>(
......@@ -124,7 +124,7 @@ class DSOModuleNode final : public ModuleNode {
TVM_REGISTER_GLOBAL("module.loadfile_so")
.set_body([](TVMArgs args, TVMRetValue* rv) {
std::shared_ptr<DSOModuleNode> n = std::make_shared<DSOModuleNode>();
auto n = make_object<DSOModuleNode>();
n->Init(args[0]);
*rv = runtime::Module(n);
});
......
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2018 by Contributors
* \file graph_runtime_debug.cc
*/
#include <tvm/runtime/packed_func.h>
......@@ -28,6 +27,7 @@
#include <chrono>
#include <sstream>
#include "../graph_runtime.h"
#include "../../object_internal.h"
namespace tvm {
namespace runtime {
......@@ -121,7 +121,7 @@ class GraphRuntimeDebug : public GraphRuntime {
* \param sptr_to_self Packed function pointer.
*/
PackedFunc GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self);
const ObjectPtr<Object>& sptr_to_self);
/*!
* \brief Get the node index given the name of node.
......@@ -169,7 +169,7 @@ void DebugGetNodeOutput(int index, DLTensor* data_out) {
*/
PackedFunc GraphRuntimeDebug::GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) {
const ObjectPtr<Object>& sptr_to_self) {
// return member functions during query.
if (name == "get_output_by_layer") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
......@@ -207,7 +207,7 @@ PackedFunc GraphRuntimeDebug::GetFunction(
Module GraphRuntimeDebugCreate(const std::string& sym_json,
const tvm::runtime::Module& m,
const std::vector<TVMContext>& ctxs) {
std::shared_ptr<GraphRuntimeDebug> exec = std::make_shared<GraphRuntimeDebug>();
auto exec = make_object<GraphRuntimeDebug>();
exec->Init(sym_json, m, ctxs);
return Module(exec);
}
......@@ -222,15 +222,16 @@ TVM_REGISTER_GLOBAL("tvm.graph_runtime_debug.create")
});
TVM_REGISTER_GLOBAL("tvm.graph_runtime_debug.remote_create")
.set_body([](TVMArgs args, TVMRetValue* rv) {
.set_body([](TVMArgs args, TVMRetValue* rv) {
CHECK_GE(args.num_args, 4) << "The expected number of arguments for "
"graph_runtime.remote_create is "
"at least 4, but it has "
<< args.num_args;
void* mhandle = args[1];
ModuleNode* mnode = ObjectInternal::GetModuleNode(mhandle);
const auto& contexts = GetAllContext(args);
*rv = GraphRuntimeDebugCreate(
args[0], *static_cast<tvm::runtime::Module*>(mhandle), contexts);
args[0], GetRef<Module>(mnode), contexts);
});
} // namespace runtime
......
......@@ -18,11 +18,8 @@
*/
/*!
* Copyright (c) 2017 by Contributors
* \file graph_runtime.cc
*/
#include "graph_runtime.h"
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/packed_func.h>
......@@ -38,6 +35,9 @@
#include <utility>
#include <vector>
#include "graph_runtime.h"
#include "../object_internal.h"
namespace tvm {
namespace runtime {
namespace details {
......@@ -411,7 +411,7 @@ std::pair<std::function<void()>, std::shared_ptr<GraphRuntime::OpArgs> > GraphRu
PackedFunc GraphRuntime::GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) {
const ObjectPtr<Object>& sptr_to_self) {
// Return member functions during query.
if (name == "set_input") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
......@@ -478,7 +478,7 @@ PackedFunc GraphRuntime::GetFunction(
Module GraphRuntimeCreate(const std::string& sym_json,
const tvm::runtime::Module& m,
const std::vector<TVMContext>& ctxs) {
std::shared_ptr<GraphRuntime> exec = std::make_shared<GraphRuntime>();
auto exec = make_object<GraphRuntime>();
exec->Init(sym_json, m, ctxs);
return Module(exec);
}
......@@ -513,15 +513,17 @@ TVM_REGISTER_GLOBAL("tvm.graph_runtime.create")
});
TVM_REGISTER_GLOBAL("tvm.graph_runtime.remote_create")
.set_body([](TVMArgs args, TVMRetValue* rv) {
.set_body([](TVMArgs args, TVMRetValue* rv) {
CHECK_GE(args.num_args, 4) << "The expected number of arguments for "
"graph_runtime.remote_create is "
"at least 4, but it has "
<< args.num_args;
void* mhandle = args[1];
ModuleNode* mnode = ObjectInternal::GetModuleNode(mhandle);
const auto& contexts = GetAllContext(args);
*rv = GraphRuntimeCreate(
args[0], *static_cast<tvm::runtime::Module*>(mhandle), contexts);
args[0], GetRef<Module>(mnode), contexts);
});
} // namespace runtime
} // namespace tvm
......@@ -18,8 +18,6 @@
*/
/*!
* Copyright (c) 2017 by Contributors
*
* \brief Tiny graph runtime that can run graph
* containing only tvm PackedFunc.
* \file graph_runtime.h
......@@ -83,7 +81,7 @@ class GraphRuntime : public ModuleNode {
* \return The corresponding member function.
*/
virtual PackedFunc GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self);
const ObjectPtr<Object>& sptr_to_self);
/*!
* \return The type key of the executor.
......
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2017 by Contributors
* \file metal_module.cc
*/
#include <dmlc/memory_io.h>
......@@ -54,7 +53,7 @@ class MetalModuleNode final :public runtime::ModuleNode {
PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final;
const ObjectPtr<Object>& sptr_to_self) final;
void SaveToFile(const std::string& file_name,
const std::string& format) final {
......@@ -187,7 +186,7 @@ class MetalWrappedFunc {
public:
// initialize the METAL function.
void Init(MetalModuleNode* m,
std::shared_ptr<ModuleNode> sptr,
ObjectPtr<Object> sptr,
const std::string& func_name,
size_t num_buffer_args,
size_t num_pack_args,
......@@ -244,7 +243,7 @@ class MetalWrappedFunc {
// internal module
MetalModuleNode* m_;
// the resource holder
std::shared_ptr<ModuleNode> sptr_;
ObjectPtr<Object> sptr_;
// The name of the function.
std::string func_name_;
// Number of buffer arguments
......@@ -260,7 +259,7 @@ class MetalWrappedFunc {
PackedFunc MetalModuleNode::GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) {
const ObjectPtr<Object>& sptr_to_self) {
CHECK_EQ(sptr_to_self.get(), this);
CHECK_NE(name, symbol::tvm_module_main)
<< "Device function do not have main";
......@@ -281,8 +280,7 @@ Module MetalModuleCreate(
std::unordered_map<std::string, FunctionInfo> fmap,
std::string source) {
metal::MetalWorkspace::Global()->Init();
std::shared_ptr<MetalModuleNode> n =
std::make_shared<MetalModuleNode>(data, fmt, fmap, source);
auto n = make_object<MetalModuleNode>(data, fmt, fmap, source);
return Module(n);
}
......
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2019 by Contributors
* \file micro_device_api.cc
*/
......@@ -50,7 +49,7 @@ class MicroDeviceAPI final : public DeviceAPI {
size_t nbytes,
size_t alignment,
TVMType type_hint) final {
std::shared_ptr<MicroSession>& session = MicroSession::Current();
ObjectPtr<MicroSession>& session = MicroSession::Current();
void* data = session->AllocateInSection(SectionKind::kHeap, nbytes).cast_to<void*>();
CHECK(data != nullptr) << "unable to allocate " << nbytes << " bytes on device heap";
MicroDevSpace* dev_space = new MicroDevSpace();
......@@ -82,11 +81,12 @@ class MicroDeviceAPI final : public DeviceAPI {
MicroDevSpace* from_space = static_cast<MicroDevSpace*>(const_cast<void*>(from));
MicroDevSpace* to_space = static_cast<MicroDevSpace*>(const_cast<void*>(to));
CHECK(from_space->session == to_space->session)
<< "attempt to copy data between different micro sessions (" << from_space->session
<< " != " << to_space->session << ")";
<< "attempt to copy data between different micro sessions ("
<< from_space->session.get()
<< " != " << to_space->session.get() << ")";
CHECK(ctx_from.device_id == ctx_to.device_id)
<< "can only copy between the same micro device";
std::shared_ptr<MicroSession>& session = from_space->session;
ObjectPtr<MicroSession>& session = from_space->session;
const std::shared_ptr<LowLevelDevice>& lld = session->low_level_device();
DevBaseOffset from_dev_offset = GetDevLoc(from_space, from_offset);
......@@ -99,7 +99,7 @@ class MicroDeviceAPI final : public DeviceAPI {
// Reading from the device.
MicroDevSpace* from_space = static_cast<MicroDevSpace*>(const_cast<void*>(from));
std::shared_ptr<MicroSession>& session = from_space->session;
ObjectPtr<MicroSession>& session = from_space->session;
const std::shared_ptr<LowLevelDevice>& lld = session->low_level_device();
DevBaseOffset from_dev_offset = GetDevLoc(from_space, from_offset);
......@@ -109,7 +109,7 @@ class MicroDeviceAPI final : public DeviceAPI {
// Writing to the device.
MicroDevSpace* to_space = static_cast<MicroDevSpace*>(const_cast<void*>(to));
std::shared_ptr<MicroSession>& session = to_space->session;
ObjectPtr<MicroSession>& session = to_space->session;
const std::shared_ptr<LowLevelDevice>& lld = session->low_level_device();
void* from_host_ptr = GetHostLoc(from, from_offset);
......@@ -124,7 +124,7 @@ class MicroDeviceAPI final : public DeviceAPI {
}
void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final {
std::shared_ptr<MicroSession>& session = MicroSession::Current();
ObjectPtr<MicroSession>& session = MicroSession::Current();
void* data = session->AllocateInSection(SectionKind::kWorkspace, size).cast_to<void*>();
CHECK(data != nullptr) << "unable to allocate " << size << " bytes on device workspace";
......@@ -136,7 +136,7 @@ class MicroDeviceAPI final : public DeviceAPI {
void FreeWorkspace(TVMContext ctx, void* data) final {
MicroDevSpace* dev_space = static_cast<MicroDevSpace*>(data);
std::shared_ptr<MicroSession>& session = dev_space->session;
ObjectPtr<MicroSession>& session = dev_space->session;
session->FreeInSection(SectionKind::kWorkspace,
DevBaseOffset(reinterpret_cast<std::uintptr_t>(dev_space->data)));
delete dev_space;
......
......@@ -18,9 +18,8 @@
*/
/*!
* Copyright (c) 2019 by Contributors
* \file micro_module.cc
*/
* \file micro_module.cc
*/
#include <tvm/runtime/registry.h>
#include <tvm/runtime/c_runtime_api.h>
......@@ -48,7 +47,7 @@ class MicroModuleNode final : public ModuleNode {
}
PackedFunc GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final;
const ObjectPtr<Object>& sptr_to_self) final;
/*!
* \brief initializes module by establishing device connection and loads binary
......@@ -76,13 +75,13 @@ class MicroModuleNode final : public ModuleNode {
/*! \brief path to module binary */
std::string binary_path_;
/*! \brief global session pointer */
std::shared_ptr<MicroSession> session_;
ObjectPtr<MicroSession> session_;
};
class MicroWrappedFunc {
public:
MicroWrappedFunc(MicroModuleNode* m,
std::shared_ptr<MicroSession> session,
ObjectPtr<MicroSession> session,
const std::string& func_name,
DevBaseOffset func_offset) {
m_ = m;
......@@ -99,7 +98,7 @@ class MicroWrappedFunc {
/*! \brief internal module */
MicroModuleNode* m_;
/*! \brief reference to the session for this function (to keep the session alive) */
std::shared_ptr<MicroSession> session_;
ObjectPtr<MicroSession> session_;
/*! \brief name of the function */
std::string func_name_;
/*! \brief offset of the function to be called */
......@@ -108,7 +107,7 @@ class MicroWrappedFunc {
PackedFunc MicroModuleNode::GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) {
const ObjectPtr<Object>& sptr_to_self) {
DevBaseOffset func_offset =
session_->low_level_device()->ToDevOffset(binary_info_.symbol_map[name]);
MicroWrappedFunc f(this, session_, name, func_offset);
......@@ -118,9 +117,9 @@ PackedFunc MicroModuleNode::GetFunction(
// register loadfile function to load module from Python frontend
TVM_REGISTER_GLOBAL("module.loadfile_micro_dev")
.set_body([](TVMArgs args, TVMRetValue* rv) {
std::shared_ptr<MicroModuleNode> n = std::make_shared<MicroModuleNode>();
auto n = make_object<MicroModuleNode>();
n->InitMicroModule(args[0]);
*rv = runtime::Module(n);
});
});
} // namespace runtime
} // namespace tvm
......@@ -18,13 +18,11 @@
*/
/*!
* Copyright (c) 2019 by Contributors
* \file micro_session.cc
*/
#include <dmlc/thread_local.h>
#include <tvm/runtime/registry.h>
#include <memory>
#include <stack>
#include <tuple>
#include <vector>
......@@ -36,18 +34,18 @@ namespace tvm {
namespace runtime {
struct TVMMicroSessionThreadLocalEntry {
std::stack<std::shared_ptr<MicroSession>> session_stack;
std::stack<ObjectPtr<MicroSession>> session_stack;
};
typedef dmlc::ThreadLocalStore<TVMMicroSessionThreadLocalEntry> TVMMicroSessionThreadLocalStore;
std::shared_ptr<MicroSession>& MicroSession::Current() {
ObjectPtr<MicroSession>& MicroSession::Current() {
TVMMicroSessionThreadLocalEntry *entry = TVMMicroSessionThreadLocalStore::Get();
CHECK_GT(entry->session_stack.size(), 0) << "No current session";
return entry->session_stack.top();
}
void MicroSession::EnterWithScope(std::shared_ptr<MicroSession> session) {
void MicroSession::EnterWithScope(ObjectPtr<MicroSession> session) {
TVMMicroSessionThreadLocalEntry *entry = TVMMicroSessionThreadLocalStore::Get();
entry->session_stack.push(session);
}
......@@ -121,7 +119,7 @@ void MicroSession::CreateSession(const std::string& device_type,
void MicroSession::PushToExecQueue(DevBaseOffset func, const TVMArgs& args) {
int32_t (*func_dev_addr)(void*, void*, int32_t) =
reinterpret_cast<int32_t (*)(void*, void*, int32_t)>(
low_level_device()->ToDevPtr(func).value());
low_level_device()->ToDevPtr(func).value());
// Create an allocator stream for the memory region after the most recent
// allocation in the args section.
......@@ -355,10 +353,10 @@ void MicroSession::DevSymbolWrite(const SymbolMap& symbol_map,
PackedFunc MicroSession::GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) {
const ObjectPtr<Object>& sptr_to_self) {
if (name == "enter") {
return PackedFunc([sptr_to_self](TVMArgs args, TVMRetValue* rv) {
MicroSession::EnterWithScope(std::dynamic_pointer_cast<MicroSession>(sptr_to_self));
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
MicroSession::EnterWithScope(GetObjectPtr<MicroSession>(this));
});
} else if (name == "exit") {
return PackedFunc([sptr_to_self](TVMArgs args, TVMRetValue* rv) {
......@@ -378,7 +376,7 @@ TVM_REGISTER_GLOBAL("micro._CreateSession")
uint64_t base_addr = args[3];
const std::string& server_addr = args[4];
int port = args[5];
std::shared_ptr<MicroSession> session = std::make_shared<MicroSession>();
ObjectPtr<MicroSession> session = make_object<MicroSession>();
session->CreateSession(
device_type, binary_path, toolchain_prefix, base_addr, server_addr, port);
*rv = Module(session);
......
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2019 by Contributors
* \file micro_session.h
* \brief session to manage multiple micro modules
*
......@@ -66,7 +65,7 @@ class MicroSession : public ModuleNode {
* \return The corresponding member function.
*/
virtual PackedFunc GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self);
const ObjectPtr<Object>& sptr_to_self);
/*!
* \return The type key of the executor.
......@@ -85,7 +84,7 @@ class MicroSession : public ModuleNode {
*/
~MicroSession();
static std::shared_ptr<MicroSession>& Current();
static ObjectPtr<MicroSession>& Current();
/*!
* \brief creates session by setting up a low-level device and initting allocators for it
......@@ -240,7 +239,7 @@ class MicroSession : public ModuleNode {
* \brief Push a new session context onto the thread-local stack.
* The session on top of the stack is used as the current global session.
*/
static void EnterWithScope(std::shared_ptr<MicroSession> session);
static void EnterWithScope(ObjectPtr<MicroSession> session);
/*!
* \brief Pop a session off the thread-local context stack,
* restoring the previous session as the current context.
......@@ -258,7 +257,7 @@ struct MicroDevSpace {
/*! \brief data being wrapped */
void* data;
/*! \brief shared ptr to session where this data is valid */
std::shared_ptr<MicroSession> session;
ObjectPtr<MicroSession> session;
};
} // namespace runtime
......
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2019 by Contributors
* \file tcl_socket.h
* \brief TCP socket wrapper for communicating using Tcl commands
*/
......
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2017 by Contributors
* \file module.cc
* \brief TVM module system
*/
......@@ -34,33 +33,46 @@
namespace tvm {
namespace runtime {
void Module::Import(Module other) {
void ModuleNode::Import(Module other) {
// specially handle rpc
if (!std::strcmp((*this)->type_key(), "rpc")) {
if (!std::strcmp(this->type_key(), "rpc")) {
static const PackedFunc* fimport_ = nullptr;
if (fimport_ == nullptr) {
fimport_ = runtime::Registry::Get("rpc._ImportRemoteModule");
CHECK(fimport_ != nullptr);
}
(*fimport_)(*this, other);
(*fimport_)(GetRef<Module>(this), other);
return;
}
// cyclic detection.
std::unordered_set<const ModuleNode*> visited{other.node_.get()};
std::vector<const ModuleNode*> stack{other.node_.get()};
std::unordered_set<const ModuleNode*> visited{other.operator->()};
std::vector<const ModuleNode*> stack{other.operator->()};
while (!stack.empty()) {
const ModuleNode* n = stack.back();
stack.pop_back();
for (const Module& m : n->imports_) {
const ModuleNode* next = m.node_.get();
const ModuleNode* next = m.operator->();
if (visited.count(next)) continue;
visited.insert(next);
stack.push_back(next);
}
}
CHECK(!visited.count(node_.get()))
CHECK(!visited.count(this))
<< "Cyclic dependency detected during import";
node_->imports_.emplace_back(std::move(other));
this->imports_.emplace_back(std::move(other));
}
PackedFunc ModuleNode::GetFunction(const std::string& name, bool query_imports) {
ModuleNode* self = this;
PackedFunc pf = self->GetFunction(name, GetObjectPtr<Object>(this));
if (pf != nullptr) return pf;
if (query_imports) {
for (Module& m : self->imports_) {
pf = m->GetFunction(name, m.data_);
if (pf != nullptr) return pf;
}
}
return pf;
}
Module Module::LoadFromFile(const std::string& file_name,
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2017 by Contributors
* \file module_util.cc
* \brief Utilities for module.
*/
......@@ -64,7 +63,7 @@ void ImportModuleBlob(const char* mblob, std::vector<Module>* mlist) {
}
PackedFunc WrapPackedFunc(BackendPackedCFunc faddr,
const std::shared_ptr<ModuleNode>& sptr_to_self) {
const ObjectPtr<Object>& sptr_to_self) {
return PackedFunc([faddr, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
int ret = (*faddr)(
const_cast<TVMValue*>(args.values),
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2017 by Contributors
* \file module_util.h
* \brief Helper utilities for module building
*/
......@@ -45,7 +44,7 @@ namespace runtime {
* \param faddr The function address
* \param mptr The module pointer node.
*/
PackedFunc WrapPackedFunc(BackendPackedCFunc faddr, const std::shared_ptr<ModuleNode>& mptr);
PackedFunc WrapPackedFunc(BackendPackedCFunc faddr, const ObjectPtr<Object>& mptr);
/*!
* \brief Load and append module blob to module list
* \param mblob The module blob.
......
......@@ -27,6 +27,7 @@
#include <vector>
#include <utility>
#include <unordered_map>
#include "object_internal.h"
#include "runtime_base.h"
namespace tvm {
......@@ -200,18 +201,6 @@ uint32_t Object::TypeKey2Index(const std::string& key) {
return TypeContext::Global()->TypeKey2Index(key);
}
class TVMObjectCAPI {
public:
static void Free(TVMObjectHandle obj) {
if (obj != nullptr) {
static_cast<Object*>(obj)->DecRef();
}
}
static uint32_t TypeKey2Index(const std::string& type_key) {
return Object::TypeKey2Index(type_key);
}
};
} // namespace runtime
} // namespace tvm
......@@ -224,13 +213,13 @@ int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex) {
int TVMObjectFree(TVMObjectHandle obj) {
API_BEGIN();
tvm::runtime::TVMObjectCAPI::Free(obj);
tvm::runtime::ObjectInternal::ObjectFree(obj);
API_END();
}
int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex) {
API_BEGIN();
out_tindex[0] = tvm::runtime::TVMObjectCAPI::TypeKey2Index(
out_tindex[0] = tvm::runtime::ObjectInternal::ObjectTypeKey2Index(
type_key);
API_END();
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*
* \file src/runtime/object_internal.h
* \brief Expose a few functions for CFFI purposes.
* This file is not intended to be used
*/
#ifndef TVM_RUNTIME_OBJECT_INTERNAL_H_
#define TVM_RUNTIME_OBJECT_INTERNAL_H_
#include <tvm/runtime/object.h>
#include <tvm/runtime/module.h>
#include <string>
namespace tvm {
namespace runtime {
/*!
* \brief Internal object namespace to expose
* certain util functions for FFI.
*/
class ObjectInternal {
public:
/*!
* \brief Free an object handle.
*/
static void ObjectFree(TVMObjectHandle obj) {
if (obj != nullptr) {
static_cast<Object*>(obj)->DecRef();
}
}
/*!
* \brief Expose TypeKey2Index
* \param type_key The original type key.
* \return the corresponding index.
*/
static uint32_t ObjectTypeKey2Index(const std::string& type_key) {
return Object::TypeKey2Index(type_key);
}
/*!
* \brief Convert ModuleHandle to module node pointer.
* \param handle The module handle.
* \return the corresponding module node pointer.
*/
static ModuleNode* GetModuleNode(TVMModuleHandle handle) {
// NOTE: we will need to convert to Object
// then to ModuleNode in order to get the correct
// address translation
return static_cast<ModuleNode*>(static_cast<Object*>(handle));
}
};
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_OBJECT_INTERNAL_H_
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2018 by Contributors
* \file aocl_common.h
* \brief AOCL common header
*/
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2018 by Contributors
* \file aocl_device_api.cc
*/
#include <tvm/runtime/registry.h>
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2018 by Contributors
* \file aocl_module.h
* \brief Execution handling of OpenCL kernels for AOCL
*/
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -278,7 +278,7 @@ class OpenCLModuleNode : public ModuleNode {
PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final;
const ObjectPtr<Object>& sptr_to_self) final;
void SaveToFile(const std::string& file_name,
const std::string& format) final;
void SaveToBinary(dmlc::Stream* stream) final;
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2017 by Contributors
* \file opencl_module.cc
*/
#include <dmlc/memory_io.h>
......@@ -36,7 +35,7 @@ class OpenCLWrappedFunc {
public:
// initialize the OpenCL function.
void Init(OpenCLModuleNode* m,
std::shared_ptr<ModuleNode> sptr,
ObjectPtr<Object> sptr,
OpenCLModuleNode::KTRefEntry entry,
std::string func_name,
std::vector<size_t> arg_size,
......@@ -88,7 +87,7 @@ class OpenCLWrappedFunc {
// The module
OpenCLModuleNode* m_;
// resource handle
std::shared_ptr<ModuleNode> sptr_;
ObjectPtr<Object> sptr_;
// global kernel id in the kernel table.
OpenCLModuleNode::KTRefEntry entry_;
// The name of the function.
......@@ -122,7 +121,7 @@ const std::shared_ptr<cl::OpenCLWorkspace>& OpenCLModuleNode::GetGlobalWorkspace
PackedFunc OpenCLModuleNode::GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) {
const ObjectPtr<Object>& sptr_to_self) {
CHECK_EQ(sptr_to_self.get(), this);
CHECK_NE(name, symbol::tvm_module_main)
<< "Device function do not have main";
......@@ -251,8 +250,7 @@ Module OpenCLModuleCreate(
std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap,
std::string source) {
std::shared_ptr<OpenCLModuleNode> n =
std::make_shared<OpenCLModuleNode>(data, fmt, fmap, source);
auto n = make_object<OpenCLModuleNode>(data, fmt, fmap, source);
n->Init();
return Module(n);
}
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2017 by Contributors
* \file opencl_module.h
* \brief Execution handling of OPENCL kernels
*/
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -44,7 +44,7 @@ class OpenGLModuleNode final : public ModuleNode {
const char* type_key() const final { return "opengl"; }
PackedFunc GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final;
const ObjectPtr<Object>& sptr_to_self) final;
std::string GetSource(const std::string& format) final;
......@@ -74,7 +74,7 @@ class OpenGLModuleNode final : public ModuleNode {
class OpenGLWrappedFunc {
public:
OpenGLWrappedFunc(OpenGLModuleNode* m,
std::shared_ptr<ModuleNode> sptr,
ObjectPtr<Object> sptr,
std::string func_name,
std::vector<size_t> arg_size,
const std::vector<std::string>& thread_axis_tags);
......@@ -85,7 +85,7 @@ class OpenGLWrappedFunc {
// The module
OpenGLModuleNode* m_;
// resource handle
std::shared_ptr<ModuleNode> sptr_;
ObjectPtr<Object> sptr_;
// The name of the function.
std::string func_name_;
// convert code for void argument
......@@ -111,7 +111,7 @@ OpenGLModuleNode::OpenGLModuleNode(
PackedFunc OpenGLModuleNode::GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) {
const ObjectPtr<Object>& sptr_to_self) {
CHECK_EQ(sptr_to_self.get(), this);
CHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main";
......@@ -191,7 +191,7 @@ const FunctionInfo& OpenGLModuleNode::GetFunctionInfo(
OpenGLWrappedFunc::OpenGLWrappedFunc(
OpenGLModuleNode* m,
std::shared_ptr<ModuleNode> sptr,
ObjectPtr<Object> sptr,
std::string func_name,
std::vector<size_t> arg_size,
const std::vector<std::string>& thread_axis_tags)
......@@ -251,9 +251,9 @@ void OpenGLWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv,
Module OpenGLModuleCreate(std::unordered_map<std::string, OpenGLShader> shaders,
std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap) {
auto n = std::make_shared<OpenGLModuleNode>(std::move(shaders),
std::move(fmt),
std::move(fmap));
auto n = make_object<OpenGLModuleNode>(std::move(shaders),
std::move(fmt),
std::move(fmap));
return Module(n);
}
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2017 by Contributors
* \file rocm_module.cc
*/
#include <tvm/runtime/registry.h>
......@@ -68,7 +67,7 @@ class ROCMModuleNode : public runtime::ModuleNode {
PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final;
const ObjectPtr<Object>& sptr_to_self) final;
void SaveToFile(const std::string& file_name,
......@@ -158,7 +157,7 @@ class ROCMWrappedFunc {
public:
// initialize the ROCM function.
void Init(ROCMModuleNode* m,
std::shared_ptr<ModuleNode> sptr,
ObjectPtr<Object> sptr,
const std::string& func_name,
size_t num_void_args,
const std::vector<std::string>& thread_axis_tags) {
......@@ -204,7 +203,7 @@ class ROCMWrappedFunc {
// internal module
ROCMModuleNode* m_;
// the resource holder
std::shared_ptr<ModuleNode> sptr_;
ObjectPtr<Object> sptr_;
// The name of the function.
std::string func_name_;
// Device function cache per device.
......@@ -217,7 +216,7 @@ class ROCMWrappedFunc {
PackedFunc ROCMModuleNode::GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) {
const ObjectPtr<Object>& sptr_to_self) {
CHECK_EQ(sptr_to_self.get(), this);
CHECK_NE(name, symbol::tvm_module_main)
<< "Device function do not have main";
......@@ -235,8 +234,7 @@ Module ROCMModuleCreate(
std::unordered_map<std::string, FunctionInfo> fmap,
std::string hip_source,
std::string assembly) {
std::shared_ptr<ROCMModuleNode> n =
std::make_shared<ROCMModuleNode>(data, fmt, fmap, hip_source, assembly);
auto n = make_object<ROCMModuleNode>(data, fmt, fmap, hip_source, assembly);
return Module(n);
}
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -123,7 +123,7 @@ class RPCModuleNode final : public ModuleNode {
PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final {
const ObjectPtr<Object>& sptr_to_self) final {
RPCFuncHandle handle = GetFuncHandle(name);
return WrapRemote(handle);
}
......@@ -195,8 +195,7 @@ void RPCWrappedFunc::WrapRemote(std::shared_ptr<RPCSession> sess,
return wf->operator()(args, rv);
});
} else if (tcode == kModuleHandle) {
std::shared_ptr<RPCModuleNode> n =
std::make_shared<RPCModuleNode>(handle, sess);
auto n = make_object<RPCModuleNode>(handle, sess);
*rv = Module(n);
} else if (tcode == kArrayHandle || tcode == kNDArrayContainer) {
CHECK_EQ(args.size(), 2);
......@@ -209,8 +208,7 @@ void RPCWrappedFunc::WrapRemote(std::shared_ptr<RPCSession> sess,
}
Module CreateRPCModule(std::shared_ptr<RPCSession> sess) {
std::shared_ptr<RPCModuleNode> n =
std::make_shared<RPCModuleNode>(nullptr, sess);
auto n = make_object<RPCModuleNode>(nullptr, sess);
return Module(n);
}
......@@ -237,8 +235,7 @@ TVM_REGISTER_GLOBAL("rpc._LoadRemoteModule")
CHECK_EQ(tkey, "rpc");
auto& sess = static_cast<RPCModuleNode*>(m.operator->())->sess();
void* mhandle = sess->CallRemote(RPCCode::kModuleLoad, args[1]);
std::shared_ptr<RPCModuleNode> n =
std::make_shared<RPCModuleNode>(mhandle, sess);
auto n = make_object<RPCModuleNode>(mhandle, sess);
*rv = Module(n);
});
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -35,6 +35,7 @@
#include <cmath>
#include <algorithm>
#include "rpc_session.h"
#include "../object_internal.h"
#include "../../common/ring_buffer.h"
#include "../../common/socket.h"
......@@ -1119,25 +1120,29 @@ void RPCModuleLoad(TVMArgs args, TVMRetValue *rv) {
}
std::string file_name = args[0];
TVMRetValue ret = (*fsys_load_)(file_name);
Module m = ret;
*rv = static_cast<void*>(new Module(m));
// pass via void*
TVMValue value;
int rcode;
ret.MoveToCHost(&value, &rcode);
CHECK_EQ(rcode, kModuleHandle);
*rv = static_cast<void*>(value.v_handle);
}
void RPCModuleImport(TVMArgs args, TVMRetValue *rv) {
void* pmod = args[0];
void* cmod = args[1];
static_cast<Module*>(pmod)->Import(
*static_cast<Module*>(cmod));
ObjectInternal::GetModuleNode(pmod)->Import(
GetRef<Module>(ObjectInternal::GetModuleNode(cmod)));
}
void RPCModuleFree(TVMArgs args, TVMRetValue *rv) {
void* mhandle = args[0];
delete static_cast<Module*>(mhandle);
ObjectInternal::ObjectFree(mhandle);
}
void RPCModuleGetFunc(TVMArgs args, TVMRetValue *rv) {
void* mhandle = args[0];
PackedFunc pf = static_cast<Module*>(mhandle)->GetFunction(
PackedFunc pf = ObjectInternal::GetModuleNode(mhandle)->GetFunction(
args[1], false);
if (pf != nullptr) {
*rv = static_cast<void*>(new PackedFunc(pf));
......@@ -1149,7 +1154,7 @@ void RPCModuleGetFunc(TVMArgs args, TVMRetValue *rv) {
void RPCModuleGetSource(TVMArgs args, TVMRetValue *rv) {
void* mhandle = args[0];
std::string fmt = args[1];
*rv = (*static_cast<Module*>(mhandle))->GetSource(fmt);
*rv = ObjectInternal::GetModuleNode(mhandle)->GetSource(fmt);
}
void RPCNDArrayFree(TVMArgs args, TVMRetValue *rv) {
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2017 by Contributors
* \file rpc_session.h
* \brief Base RPC session interface.
*/
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2017 by Contributors
* Implementation stack VM.
* \file stackvm.cc
*/
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2017 by Contributors
* \file stackvm_module.cc
*/
#include <tvm/runtime/registry.h>
......@@ -42,7 +41,7 @@ class StackVMModuleNode : public runtime::ModuleNode {
PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final {
const ObjectPtr<Object>& sptr_to_self) final {
if (name == runtime::symbol::tvm_module_main) {
return GetFunction(entry_func_, sptr_to_self);
}
......@@ -89,8 +88,7 @@ class StackVMModuleNode : public runtime::ModuleNode {
static Module Create(std::unordered_map<std::string, StackVM> fmap,
std::string entry_func) {
std::shared_ptr<StackVMModuleNode> n =
std::make_shared<StackVMModuleNode>();
auto n = make_object<StackVMModuleNode>();
n->fmap_ = std::move(fmap);
n->entry_func_ = std::move(entry_func);
return Module(n);
......@@ -101,8 +99,7 @@ class StackVMModuleNode : public runtime::ModuleNode {
std::string entry_func, data;
strm->Read(&fmap);
strm->Read(&entry_func);
std::shared_ptr<StackVMModuleNode> n =
std::make_shared<StackVMModuleNode>();
auto n = make_object<StackVMModuleNode>();
n->fmap_ = std::move(fmap);
n->entry_func_ = std::move(entry_func);
uint64_t num_imports;
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -18,11 +18,11 @@
*/
/*!
* Copyright (c) 2017 by Contributors
* \file system_lib_module.cc
* \brief SystemLib module.
*/
#include <tvm/runtime/registry.h>
#include <tvm/runtime/memory.h>
#include <tvm/runtime/c_backend_api.h>
#include <mutex>
#include "module_util.h"
......@@ -40,7 +40,7 @@ class SystemLibModuleNode : public ModuleNode {
PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final {
const ObjectPtr<Object>& sptr_to_self) final {
std::lock_guard<std::mutex> lock(mutex_);
if (module_blob_ != nullptr) {
......@@ -83,9 +83,8 @@ class SystemLibModuleNode : public ModuleNode {
}
}
static const std::shared_ptr<SystemLibModuleNode>& Global() {
static std::shared_ptr<SystemLibModuleNode> inst =
std::make_shared<SystemLibModuleNode>();
static const ObjectPtr<SystemLibModuleNode>& Global() {
static auto inst = make_object<SystemLibModuleNode>();
return inst;
}
......
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2019 by Contributors
* \file tvm/runtime/vm/executable.cc
* \brief The implementation of a virtual machine executable APIs.
*/
......@@ -51,7 +50,7 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr);
Instruction DeserializeInstruction(const VMInstructionSerializer& instr);
PackedFunc Executable::GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) {
const ObjectPtr<Object>& sptr_to_self) {
if (name == "get_lib") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->GetLib();
......@@ -440,7 +439,7 @@ void LoadHeader(dmlc::Stream* strm) {
}
runtime::Module Executable::Load(const std::string& code, const runtime::Module lib) {
std::shared_ptr<Executable> exec = std::make_shared<Executable>();
auto exec = make_object<Executable>();
exec->lib = lib;
exec->code_ = code;
dmlc::MemoryStringStream strm(&exec->code_);
......
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2019 by Contributors
* \file src/runtime/vm/profiler/vm.cc
* \brief The Relay debug virtual machine.
*/
......@@ -41,7 +40,7 @@ namespace runtime {
namespace vm {
PackedFunc VirtualMachineDebug::GetFunction(
const std::string& name, const std::shared_ptr<ModuleNode>& sptr_to_self) {
const std::string& name, const ObjectPtr<Object>& sptr_to_self) {
if (name == "get_stat") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
double total_duration = 0.0;
......@@ -124,7 +123,7 @@ void VirtualMachineDebug::InvokePacked(Index packed_index,
}
runtime::Module CreateVirtualMachineDebug(const Executable* exec) {
std::shared_ptr<VirtualMachineDebug> vm = std::make_shared<VirtualMachineDebug>();
auto vm = make_object<VirtualMachineDebug>();
vm->LoadExecutable(exec);
return runtime::Module(vm);
}
......
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2019 by Contributors
* \file src/runtime/vm/profiler/vm.h
* \brief The Relay debug virtual machine.
*/
......@@ -42,7 +41,7 @@ class VirtualMachineDebug : public VirtualMachine {
VirtualMachineDebug() : VirtualMachine() {}
PackedFunc GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final;
const ObjectPtr<Object>& sptr_to_self) final;
void InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count,
Index output_size, const std::vector<ObjectRef>& args) final;
......
......@@ -627,7 +627,7 @@ ObjectRef CopyTo(ObjectRef src, const DLContext& ctx) {
}
PackedFunc VirtualMachine::GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) {
const ObjectPtr<Object>& sptr_to_self) {
if (name == "invoke") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK(exec) << "The executable is not created yet.";
......@@ -1052,7 +1052,7 @@ void VirtualMachine::RunLoop() {
}
runtime::Module CreateVirtualMachine(const Executable* exec) {
std::shared_ptr<VirtualMachine> vm = std::make_shared<VirtualMachine>();
auto vm = make_object<VirtualMachine>();
vm->LoadExecutable(exec);
return runtime::Module(vm);
}
......
......@@ -663,7 +663,9 @@ class VulkanModuleNode;
// a wrapped function class to get packed func.
class VulkanWrappedFunc {
public:
void Init(VulkanModuleNode* m, std::shared_ptr<ModuleNode> sptr, const std::string& func_name,
void Init(VulkanModuleNode* m,
ObjectPtr<Object> sptr,
const std::string& func_name,
size_t num_buffer_args, size_t num_pack_args,
const std::vector<std::string>& thread_axis_tags) {
m_ = m;
......@@ -680,7 +682,7 @@ class VulkanWrappedFunc {
// internal module
VulkanModuleNode* m_;
// the resource holder
std::shared_ptr<ModuleNode> sptr_;
ObjectPtr<Object> sptr_;
// v The name of the function.
std::string func_name_;
// Number of buffer arguments
......@@ -705,7 +707,7 @@ class VulkanModuleNode final : public runtime::ModuleNode {
const char* type_key() const final { return "vulkan"; }
PackedFunc GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final {
const ObjectPtr<Object>& sptr_to_self) final {
CHECK_EQ(sptr_to_self.get(), this);
CHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main";
auto it = fmap_.find(name);
......@@ -939,7 +941,7 @@ class VulkanModuleNode final : public runtime::ModuleNode {
Module VulkanModuleCreate(std::unordered_map<std::string, VulkanShader> smap,
std::unordered_map<std::string, FunctionInfo> fmap, std::string source) {
std::shared_ptr<VulkanModuleNode> n = std::make_shared<VulkanModuleNode>(smap, fmap, source);
auto n = make_object<VulkanModuleNode>(smap, fmap, source);
return Module(n);
}
......
......@@ -226,7 +226,7 @@ class DPIModule final : public DPIModuleNode {
PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final {
const ObjectPtr<Object>& sptr_to_self) final {
if (name == "WriteReg") {
return TypedPackedFunc<void(int, int)>(
[this](int addr, int value){
......@@ -413,8 +413,7 @@ class DPIModule final : public DPIModuleNode {
};
Module DPIModuleNode::Load(std::string dll_name) {
std::shared_ptr<DPIModule> n =
std::make_shared<DPIModule>();
auto n = make_object<DPIModule>();
n->Init(dll_name);
return Module(n);
}
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -31,6 +31,7 @@
#include "../src/runtime/system_lib_module.cc"
#include "../src/runtime/module.cc"
#include "../src/runtime/ndarray.cc"
#include "../src/runtime/object.cc"
#include "../src/runtime/registry.cc"
#include "../src/runtime/file_util.cc"
#include "../src/runtime/dso_module.cc"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment