Commit a0bd3786 by Tianqi Chen Committed by Zhi

[RFC][RUNTIME] Introduce new object protocol. (#4115)

* [RUNTIME] Introduce new object protocol.

This PR introduces a new object protocol to unify the node and object.
We also updated the existing runtime::vm code to make use of the new system.

Update to the node will be done in a follow up PR.

Other changes:

- Remove object related code in json serializer as that code logic was not complete
  and we have a separate serializer for VM, can revisit later.

* address review  comment

* Fix the child slot logic
parent 68472596
...@@ -70,7 +70,7 @@ cpplint: ...@@ -70,7 +70,7 @@ cpplint:
python3 3rdparty/dmlc-core/scripts/lint.py vta cpp vta/include vta/src python3 3rdparty/dmlc-core/scripts/lint.py vta cpp vta/include vta/src
python3 3rdparty/dmlc-core/scripts/lint.py topi cpp topi/include; python3 3rdparty/dmlc-core/scripts/lint.py topi cpp topi/include;
python3 3rdparty/dmlc-core/scripts/lint.py nnvm cpp nnvm/include nnvm/src; python3 3rdparty/dmlc-core/scripts/lint.py nnvm cpp nnvm/include nnvm/src;
python3 3rdparty/dmlc-core/scripts/lint.py tvm cpp include src verilog\ python3 3rdparty/dmlc-core/scripts/lint.py tvm cpp include src \
examples/extension/src examples/graph_executor/src examples/extension/src examples/graph_executor/src
pylint: pylint:
......
...@@ -42,7 +42,7 @@ namespace runtime { ...@@ -42,7 +42,7 @@ namespace runtime {
// forward declaration // forward declaration
class NDArray; class NDArray;
// forward declaration // forward declaration
class Object; class ObjectRef;
} // namespace runtime } // namespace runtime
/*! /*!
...@@ -63,7 +63,7 @@ class TVM_DLL AttrVisitor { ...@@ -63,7 +63,7 @@ class TVM_DLL AttrVisitor {
virtual void Visit(const char* key, DataType* value) = 0; virtual void Visit(const char* key, DataType* value) = 0;
virtual void Visit(const char* key, NodeRef* value) = 0; virtual void Visit(const char* key, NodeRef* value) = 0;
virtual void Visit(const char* key, runtime::NDArray* value) = 0; virtual void Visit(const char* key, runtime::NDArray* value) = 0;
virtual void Visit(const char* key, runtime::Object* value) = 0; virtual void Visit(const char* key, runtime::ObjectRef* value) = 0;
template<typename ENum, template<typename ENum,
typename = typename std::enable_if<std::is_enum<ENum>::value>::type> typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
void Visit(const char* key, ENum* ptr) { void Visit(const char* key, ENum* ptr) {
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/runtime/memory.h
* \brief Runtime memory management.
*/
#ifndef TVM_RUNTIME_MEMORY_H_
#define TVM_RUNTIME_MEMORY_H_
#include <utility>
#include <type_traits>
#include "object.h"
namespace tvm {
namespace runtime {
/*!
* \brief Allocate an object using default allocator.
* \param args arguments to the constructor.
* \tparam T the node type.
* \return The NodePtr to the allocated object.
*/
template<typename T, typename... Args>
inline ObjectPtr<T> make_object(Args&&... args);
// Detail implementations after this
//
// The current design allows swapping the
// allocator pattern when necessary.
//
// Possible future allocator optimizations:
// - Arena allocator that gives ownership of memory to arena (deleter_= nullptr)
// - Thread-local object pools: one pool per size and alignment requirement.
// - Can specialize by type of object to give the specific allocator to each object.
/*!
* \brief Base class of object allocators that implements make.
* Use curiously recurring template pattern.
*
* \tparam Derived The derived class.
*/
template<typename Derived>
class ObjAllocatorBase {
public:
/*!
* \tparam T The type to be allocated.
* \tparam Args The constructor signature.
* \param args The arguments.
*/
template<typename T, typename... Args>
inline ObjectPtr<T> make(Args&&... args) {
using Handler = typename Derived::template Handler<T>;
static_assert(std::is_base_of<Object, T>::value,
"make_node can only be used to create NodeBase");
T* ptr = Handler::New(static_cast<Derived*>(this),
std::forward<Args>(args)...);
ptr->type_index_ = T::type_index();
ptr->deleter_ = Handler::Deleter();
return ObjectPtr<T>(ptr);
}
};
// Simple allocator that uses new/delete.
class SimpleObjAllocator :
public ObjAllocatorBase<SimpleObjAllocator> {
public:
template<typename T>
class Handler {
public:
template<typename... Args>
static T* New(SimpleObjAllocator*, Args&&... args) {
// NOTE: the first argument is not needed for SimpleObjAllocator
// It is reserved for special allocators that needs to recycle
// the object to itself (e.g. in the case of object pool).
//
// In the case of an object pool, an allocator needs to create
// a special chunk memory that hides reference to the allocator
// and call allocator's release function in the deleter.
return new T(std::forward<Args>(args)...);
}
static Object::FDeleter Deleter() {
return Deleter_;
}
private:
static void Deleter_(Object* ptr) {
delete static_cast<T*>(ptr);
}
};
};
template<typename T, typename... Args>
inline ObjectPtr<T> make_object(Args&&... args) {
return SimpleObjAllocator().make<T>(std::forward<Args>(args)...);
}
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_MEMORY_H_
...@@ -16,117 +16,249 @@ ...@@ -16,117 +16,249 @@
* specific language governing permissions and limitations * specific language governing permissions and limitations
* under the License. * under the License.
*/ */
/*! /*!
* Copyright (c) 2019 by Contributors
* \file tvm/runtime/object.h * \file tvm/runtime/object.h
* \brief A managed object in the TVM runtime. * \brief A managed object in the TVM runtime.
*/ */
#ifndef TVM_RUNTIME_OBJECT_H_ #ifndef TVM_RUNTIME_OBJECT_H_
#define TVM_RUNTIME_OBJECT_H_ #define TVM_RUNTIME_OBJECT_H_
#include <tvm/runtime/ndarray.h> #include <type_traits>
#include <memory> #include <string>
#include <utility> #include <utility>
#include <vector> #include "c_runtime_api.h"
/*!
* \brief Whether or not use atomic reference counter.
* If the reference counter is not atomic,
* an object cannot be owned by multiple threads.
* We can, however, move an object across threads
*/
#ifndef TVM_OBJECT_ATOMIC_REF_COUNTER
#define TVM_OBJECT_ATOMIC_REF_COUNTER 1
#endif
#if TVM_OBJECT_ATOMIC_REF_COUNTER
#include <atomic>
#endif // TVM_OBJECT_ATOMIC_REF_COUNTER
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
template <typename T> /*! \brief list of the type index. */
class ObjectPtr; enum TypeIndex {
class Object; /*! \brief Root object type. */
kRoot = 0,
enum struct ObjectTag { kVMTensor = 1,
/*! \brief The tag of a tensor. */ kVMClosure = 2,
kTensor = 0U, kVMDatatype = 3,
/*! \brief The tag of a closure. */ kStaticIndexEnd,
kClosure = 1U, /*! \brief Type index is allocated during runtime. */
/*! \brief The tag of a structure. */ kDynamic = kStaticIndexEnd
kDatatype = 2U,
}; };
std::ostream& operator<<(std::ostream& os, const ObjectTag&); /*!
* \brief base class of all object containers.
struct ObjectCell { *
* Sub-class of objects should declare the following static constexpr fields:
*
* - _type_index:
* Static type index of the object, if assigned to TypeIndex::kDynamic
* the type index will be assigned during runtime.
* Runtime type index can be accessed by ObjectType::type_index();
* - _type_key:
* The unique string identifier of tyep type.
* - _type_final:
* Whether the type is terminal type(there is no subclass of the type in the object system).
* This field is automatically set by marco TVM_DECLARE_FINAL_OBJECT_INFO
* It is still OK to sub-class a terminal object type T and construct it using make_object.
* But IsInstance check will only show that the object type is T(instead of the sub-class).
*
* The following two fields are necessary for base classes that can be sub-classed.
*
* - _type_child_slots:
* Number of reserved type index slots for child classes.
* Used for runtime optimization for type checking in IsInstance.
* If an object's type_index is within range of [type_index, type_index + _type_child_slots]
* Then the object can be quickly decided as sub-class of the current object class.
* If not, a fallback mechanism is used to check the global type table.
* Recommendation: set to estimate number of children needed.
* - _type_child_slots_can_overflow:
* Whether we can add additional child classes even if the number of child classes
* exceeds the _type_child_slots. A fallback mechanism to check global type table will be used.
* Recommendation: set to false for optimal runtime speed if we know exact number of children.
*
* Two macros are used to declare helper functions in the object:
* - Use TVM_DECLARE_BASE_OBJECT_INFO for object classes that can be sub-classed.
* - Use TVM_DECLARE_FINAL_OBJECT_INFO for object classes that cannot be sub-classed.
*
* New objects can be created using make_object function.
* Which will automatically populate the type_index and deleter of the object.
*
* \sa make_object
* \sa ObjectPtr
* \sa ObjectRef
*
* \code
*
* // Create a base object
* class BaseObj : public Object {
* public:
* // object fields
* int field0;
*
* // object properties
* static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
* static constexpr const char* _type_key = "test.BaseObj";
* TVM_DECLARE_BASE_OBJECT_INFO(BaseObj, Object);
* };
*
* class ObjLeaf : public ObjBase {
* public:
* // fields
* int child_field0;
* // object properties
* static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
* static constexpr const char* _type_key = "test.LeafObj";
* TVM_DECLARE_BASE_OBJECT_INFO(LeaffObj, Object);
* };
*
* // The following code should be put into a cc file.
* TVM_REGISTER_OBJECT_TYPE(ObjBase);
* TVM_REGISTER_OBJECT_TYPE(ObjLeaf);
*
* // Usage example.
* void TestObjects() {
* // create an object
* ObjectRef leaf_ref(make_object<LeafObj>());
* // cast to a specific instance
* const LeafObj* leaf_ptr = leaf_ref.as<LeafObj>();
* CHECK(leaf_ptr != nullptr);
* // can also cast to the base class.
* CHECK(leaf_ref.as<BaseObj>() != nullptr);
* }
*
* \endcode
*/
class Object {
public: public:
/*! /*!
* \brief The type of object deleter. * \brief Object deleter
* \param The self pointer to the ObjectCell. * \param self pointer to the Object.
*/ */
typedef void (*FDeleter)(ObjectCell* self); typedef void (*FDeleter)(Object* self);
/*! \return The internal type index of the object. */
/*! \brief The tag of the object. uint32_t type_index() const {
* return type_index_;
* Describes which type of value }
* is represented by this object. /*!
* Check if the object is an instance of TargetType.
* \tparam TargetType The target type to be checked.
* \return Whether the target type is true.
*/ */
ObjectTag tag; template<typename TargetType>
inline bool IsInstance() const;
#if TVM_OBJECT_ATOMIC_REF_COUNTER
using RefCounterType = std::atomic<int32_t>;
#else
using RefCounterType = int32_t;
#endif
// Object type properties
static constexpr const char* _type_key = "Object";
static constexpr bool _type_final = false;
static constexpr uint32_t _type_child_slots = 0;
static constexpr bool _type_child_slots_can_overflow = true;
static const uint32_t _GetOrAllocRuntimeTypeIndex() {
return 0;
}
protected:
// The fields of the base object cell.
/*! \brief Type index(tag) that indicates the type of the object. */
uint32_t type_index_{0};
/*! \brief The internal reference counter */
RefCounterType ref_counter_{0};
/*! /*!
* \brief Increment the reference count. * \brief deleter of this object to enable customized allocation.
* If the deleter is nullptr, no deletion will be performed.
* The creator of the object must always set the deleter field properly.
*/ */
void IncRef() { ref_counter_.fetch_add(1, std::memory_order_relaxed); } FDeleter deleter_ = nullptr;
// Invariant checks.
static_assert(sizeof(int32_t) == sizeof(RefCounterType) &&
alignof(int32_t) == sizeof(RefCounterType),
"RefCounter ABI check.");
/*! /*!
* \brief Decrement the reference count. * \brief Get the type index using type key.
*
* When the function is first time called for a type,
* it will register the type to the type table in the runtime.
* If the static_tindex is TypeIndex::kDynamic, the function will
* allocate a runtime type index.
* Otherwise, we will populate the type table and return the static index.
*
* \param key the type key.
* \param static_tindex The current _type_index field.
* can be TypeIndex::kDynamic.
* \param parent_tindex The index of the parent.
* \param type_child_slots Number of slots reserved for its children.
* \param type_child_slots_can_overflow Whether to allow child to overflow the slots.
* \return The allocated type index.
*/ */
void DecRef() { TVM_DLL static uint32_t GetOrAllocRuntimeTypeIndex(
if (ref_counter_.fetch_sub(1, std::memory_order_release) == 1) { const char* key,
std::atomic_thread_fence(std::memory_order_acquire); uint32_t static_tindex,
if (this->deleter_ != nullptr) { uint32_t parent_tindex,
(*this->deleter_)(this); uint32_t type_child_slots,
} bool type_child_slots_can_overflow);
}
}
protected: /*!
// default constructor and copy constructor * \brief Get the type key of the corresponding index from runtime.
ObjectCell() {} * \param tindex The type index.
*/
explicit ObjectCell(ObjectTag tag) : tag(tag) {} TVM_DLL static std::string TypeIndex2Key(uint32_t tindex);
// override the copy and assign constructors to do nothing.
// This is to make sure only contents, but not deleter and ref_counter
// are copied when a child class copies itself.
ObjectCell(const ObjectCell& other) { // NOLINT(*)
}
ObjectCell(ObjectCell&& other) { // NOLINT(*)
}
ObjectCell& operator=(const ObjectCell& other) { // NOLINT(*)
return *this;
}
ObjectCell& operator=(ObjectCell&& other) { // NOLINT(*) /*!
return *this; * \brief Get the type index of the corresponding key from runtime.
} * \param key The type key.
*/
TVM_DLL static uint32_t TypeKey2Index(const char* key);
private: private:
/*! \brief Internal reference counter */ // reference counter related operations
std::atomic<int> ref_counter_{0}; /*! \brief developer function, increases reference counter. */
inline void IncRef();
/*! /*!
* \brief deleter of this object to enable customized allocation. * \brief developer function, decrease reference counter.
* If the deleter is nullptr, no deletion will be performed. * \note The deleter will be called when ref_counter_ becomes zero.
* The creator of the Node must always set the deleter field properly.
*/ */
FDeleter deleter_ = nullptr; inline void DecRef();
/*!
int use_count() const { return ref_counter_.load(std::memory_order_relaxed); } * \return The usage count of the cell.
* \note We use stl style naming to be consistent with known API in shared_ptr.
// friend declaration */
template <typename> inline int use_count() const;
/*!
* \brief Check of this object is derived from the parent.
* \param parent_tindex The parent type index.
* \return The derivation results.
*/
TVM_DLL bool DerivedFrom(uint32_t parent_tindex) const;
// friend classes
template<typename>
friend class ObjAllocatorBase;
template<typename>
friend class ObjectPtr; friend class ObjectPtr;
friend class TVMRetValue;
template <typename Y, typename... Args>
friend ObjectPtr<Y> MakeObject(Args&&...);
}; };
/*! /*!
* \brief A custom smart pointer for Object. * \brief A custom smart pointer for Object.
* must be subclass of NodeBase
* \tparam T the content data type. * \tparam T the content data type.
* \sa make_object
*/ */
template <typename T> template <typename T>
class ObjectPtr { class ObjectPtr {
...@@ -159,7 +291,6 @@ class ObjectPtr { ...@@ -159,7 +291,6 @@ class ObjectPtr {
: data_(other.data_) { : data_(other.data_) {
other.data_ = nullptr; other.data_ = nullptr;
} }
/*! /*!
* \brief move constructor * \brief move constructor
* \param other The value to be moved * \param other The value to be moved
...@@ -171,10 +302,10 @@ class ObjectPtr { ...@@ -171,10 +302,10 @@ class ObjectPtr {
"can only assign of child class ObjectPtr to parent"); "can only assign of child class ObjectPtr to parent");
other.data_ = nullptr; other.data_ = nullptr;
} }
/*! \brief destructor */ /*! \brief destructor */
~ObjectPtr() { this->reset(); } ~ObjectPtr() {
this->reset();
}
/*! /*!
* \brief Swap this array with another Object * \brief Swap this array with another Object
* \param other The other Object * \param other The other Object
...@@ -182,24 +313,24 @@ class ObjectPtr { ...@@ -182,24 +313,24 @@ class ObjectPtr {
void swap(ObjectPtr<T>& other) { // NOLINT(*) void swap(ObjectPtr<T>& other) { // NOLINT(*)
std::swap(data_, other.data_); std::swap(data_, other.data_);
} }
/*! /*!
* \return Get the content of the pointer * \return Get the content of the pointer
*/ */
T* get() const { return static_cast<T*>(data_); } T* get() const {
return static_cast<T*>(data_);
}
/*! /*!
* \return The pointer * \return The pointer
*/ */
T* operator->() const { return get(); } T* operator->() const {
return get();
}
/*! /*!
* \return The reference * \return The reference
*/ */
T& operator*() const { // NOLINT(*) T& operator*() const { // NOLINT(*)
return *get(); return *get();
} }
/*! /*!
* \brief copy assignmemt * \brief copy assignmemt
* \param other The value to be assigned. * \param other The value to be assigned.
...@@ -211,7 +342,6 @@ class ObjectPtr { ...@@ -211,7 +342,6 @@ class ObjectPtr {
ObjectPtr(other).swap(*this); // NOLINT(*) ObjectPtr(other).swap(*this); // NOLINT(*)
return *this; return *this;
} }
/*! /*!
* \brief move assignmemt * \brief move assignmemt
* \param other The value to be assigned. * \param other The value to be assigned.
...@@ -222,7 +352,6 @@ class ObjectPtr { ...@@ -222,7 +352,6 @@ class ObjectPtr {
ObjectPtr(std::move(other)).swap(*this); // NOLINT(*) ObjectPtr(std::move(other)).swap(*this); // NOLINT(*)
return *this; return *this;
} }
/*! \brief reset the content of ptr to be nullptr */ /*! \brief reset the content of ptr to be nullptr */
void reset() { void reset() {
if (data_ != nullptr) { if (data_ != nullptr) {
...@@ -230,163 +359,238 @@ class ObjectPtr { ...@@ -230,163 +359,238 @@ class ObjectPtr {
data_ = nullptr; data_ = nullptr;
} }
} }
/*! \return The use count of the ptr, for debug purposes */ /*! \return The use count of the ptr, for debug purposes */
int use_count() const { return data_ != nullptr ? data_->use_count() : 0; } int use_count() const {
return data_ != nullptr ? data_->use_count() : 0;
}
/*! \return whether the reference is unique */ /*! \return whether the reference is unique */
bool unique() const { return data_ != nullptr && data_->use_count() == 1; } bool unique() const {
return data_ != nullptr && data_->use_count() == 1;
}
/*! \return Whether two ObjectPtr do not equal each other */ /*! \return Whether two ObjectPtr do not equal each other */
bool operator==(const ObjectPtr<T>& other) const { return data_ == other.data_; } bool operator==(const ObjectPtr<T>& other) const {
return data_ == other.data_;
}
/*! \return Whether two ObjectPtr equals each other */ /*! \return Whether two ObjectPtr equals each other */
bool operator!=(const ObjectPtr<T>& other) const { return data_ != other.data_; } bool operator!=(const ObjectPtr<T>& other) const {
return data_ != other.data_;
}
/*! \return Whether the pointer is nullptr */ /*! \return Whether the pointer is nullptr */
bool operator==(std::nullptr_t null) const { return data_ == nullptr; } bool operator==(std::nullptr_t null) const {
return data_ == nullptr;
}
/*! \return Whether the pointer is not nullptr */ /*! \return Whether the pointer is not nullptr */
bool operator!=(std::nullptr_t null) const { return data_ != nullptr; } bool operator!=(std::nullptr_t null) const {
return data_ != nullptr;
/* ObjectPtr's support custom allocators.
*
* The below allocator represents the simplest
* possible impl. It can be easily swapped
* for customized executor's, different allocation
* strategies, and so on.
*
* See memory.h for more discussion on NodePtr's
* allocator.
*/
class StdAllocator {
public:
template <typename... Args>
static T* New(Args&&... args) {
return new T(std::forward<Args>(args)...);
}
static ObjectCell::FDeleter Deleter() { return Deleter_; }
private:
static void Deleter_(ObjectCell* ptr) { delete static_cast<T*>(ptr); }
};
template <typename U>
ObjectPtr<U> As() const {
auto ptr = reinterpret_cast<U*>(get());
return ObjectPtr<U>(ptr);
} }
private: private:
/*! \brief internal pointer field */ /*! \brief internal pointer field */
ObjectCell* data_{nullptr}; Object* data_{nullptr};
/*! /*!
* \brief constructor from NodeBase * \brief constructor from NodeBase
* \param data The node base pointer * \param data The data pointer
*/ */
// TODO(jroesch): NodePtr design doesn't really work here due to the passing. explicit ObjectPtr(Object* data) : data_(data) {
public:
explicit ObjectPtr(ObjectCell* data) : data_(data) {
if (data != nullptr) { if (data != nullptr) {
data_->IncRef(); data_->IncRef();
} }
} }
// friend classes
private: friend class Object;
template <typename Y, typename... Args> friend class ObjectRef;
friend ObjectPtr<Y> MakeObject(Args&&...); template<typename>
template <typename>
friend class ObjectPtr; friend class ObjectPtr;
friend class NDArray; template<typename>
friend class ObjAllocatorBase;
friend class TVMPODValue_; friend class TVMPODValue_;
friend class TVMArgValue; friend class TVMArgsSetter;
friend class TVMRetValue; friend class TVMRetValue;
friend class RPCWrappedFunc;
}; };
struct TensorCell; /*! \brief Base class of all object reference */
struct DatatypeCell; class ObjectRef {
struct ClosureCell; public:
/*! \brief default constructor */
ObjectRef() = default;
/*! \brief Constructor from existing object ptr */
explicit ObjectRef(ObjectPtr<Object> data) : data_(data) {}
/*! \return the internal object pointer */
inline const Object* get() const;
/*! \return the internal node pointer */
inline const Object* operator->() const;
/*!
* \brief Try to downcast the internal Object to a
* raw pointer of a corresponding type.
*
* The function will return a nullptr if the cast failed.
*
* if (const Add *add = node_ref.As<Add>()) {
* // This is an add node
* }
* \tparam ObjectType the target type, must be a subtype of Object/
*/
template <typename ObjectType>
inline const ObjectType* as() const;
/*! \brief type indicate the container type */
using ContainerType = Object;
protected:
/*! \brief Internal pointer that backs the reference. */
ObjectPtr<Object> data_;
// friend classes.
friend class TVMRetValue;
friend class TVMArgsSetter;
};
/*! /*!
* \brief A managed object in the TVM runtime. * \brief helper macro to declare a base object type that can be inheritated.
* * \param TypeName The name of the current type.
* For example a tuple, list, closure, and so on. * \param ParentType The name of the ParentType
*/
#define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \
static const uint32_t type_index() { \
if (_type_index != TypeIndex::kDynamic) return _type_index; \
return _GetOrAllocRuntimeTypeIndex(); \
} \
static const uint32_t _GetOrAllocRuntimeTypeIndex() { \
static uint32_t tidx = GetOrAllocRuntimeTypeIndex( \
TypeName::_type_key, \
TypeName::_type_index, \
ParentType::_GetOrAllocRuntimeTypeIndex(), \
TypeName::_type_child_slots, \
TypeName::_type_child_slots_can_overflow); \
return tidx; \
} \
/*!
* \brief helper macro to declare type information in a final class.
* \param TypeName The name of the current type.
* \param ParentType The name of the ParentType
*/
#define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType) \
static const constexpr bool _type_final = true; \
static const constexpr int _type_child_slots = 0; \
TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \
/*!
* \brief Helper macro to register the object type to runtime.
* Makes sure that the runtime type table is correctly populated.
* *
* Maintains a reference count for the object. * Use this macro in the cc file for each terminal class.
*/ */
class Object { #define TVM_REGISTER_OBJECT_TYPE(TypeName) \
public: static DMLC_ATTRIBUTE_UNUSED uint32_t __make_Object_tidx ## _ ## TypeName ## __ = \
ObjectPtr<ObjectCell> ptr_; TypeName::_GetOrAllocRuntimeTypeIndex()
explicit Object(ObjectPtr<ObjectCell> ptr) : ptr_(ptr) {}
explicit Object(ObjectCell* ptr) : ptr_(ptr) {}
Object() : ptr_() {}
Object(const Object& obj) : ptr_(obj.ptr_) {}
ObjectCell* operator->() { return this->ptr_.operator->(); }
const ObjectCell* operator->() const { return this->ptr_.operator->(); }
/*! \brief Construct a tensor object. */
static Object Tensor(const NDArray& data);
/*! \brief Construct a datatype object. */
static Object Datatype(size_t tag, const std::vector<Object>& fields);
/*! \brief Construct a tuple object. */
static Object Tuple(const std::vector<Object>& fields);
/*! \brief Construct a closure object. */
static Object Closure(size_t func_index, const std::vector<Object>& free_vars);
ObjectPtr<TensorCell> AsTensor() const;
ObjectPtr<DatatypeCell> AsDatatype() const;
ObjectPtr<ClosureCell> AsClosure() const;
};
/*! \brief An object containing an NDArray. */
struct TensorCell : public ObjectCell {
/*! \brief The NDArray. */
NDArray data;
explicit TensorCell(const NDArray& data) : ObjectCell(ObjectTag::kTensor), data(data) {}
};
/*! \brief An object representing a structure or enumeration. */ #define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \
struct DatatypeCell : public ObjectCell { TypeName() {} \
/*! \brief The tag representing the constructor used. */ explicit TypeName( \
size_t tag; ::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) \
/*! \brief The fields of the structure. */ : ParentType(n) {} \
std::vector<Object> fields; const ObjectName* operator->() const { \
return static_cast<const ObjectName*>(data_.get()); \
} \
operator bool() const { return data_ != nullptr; } \
using ContainerType = ObjectName;
DatatypeCell(size_t tag, const std::vector<Object>& fields)
: ObjectCell(ObjectTag::kDatatype), tag(tag), fields(fields) {}
};
/*! \brief An object representing a closure. */ // Implementations details below
struct ClosureCell : public ObjectCell { // Object reference counting.
/*! \brief The index into the VM function table. */ #if TVM_OBJECT_ATOMIC_REF_COUNTER
size_t func_index;
/*! \brief The free variables of the closure. */
std::vector<Object> free_vars;
ClosureCell(size_t func_index, const std::vector<Object>& free_vars) inline void Object::IncRef() {
: ObjectCell(ObjectTag::kClosure), func_index(func_index), free_vars(free_vars) {} ref_counter_.fetch_add(1, std::memory_order_relaxed);
}; }
/*! \brief Extract the NDArray from a tensor object. */ inline void Object::DecRef() {
NDArray ToNDArray(const Object& obj); if (ref_counter_.fetch_sub(1, std::memory_order_release) == 1) {
std::atomic_thread_fence(std::memory_order_acquire);
if (this->deleter_ != nullptr) {
(*this->deleter_)(this);
}
}
}
/*! inline int Object::use_count() const {
* \brief Allocate a node object. return ref_counter_.load(std::memory_order_relaxed);
* \param args arguments to the constructor. }
* \tparam T the node type.
* \return The NodePtr to the allocated object. #else
*/
template <typename T, typename... Args> inline void Object::IncRef() {
inline ObjectPtr<T> MakeObject(Args&&... args) { ++ref_counter_;
using Allocator = typename ObjectPtr<T>::StdAllocator;
static_assert(std::is_base_of<ObjectCell, T>::value, "MakeObject can only be used to create ");
T* node = Allocator::New(std::forward<Args>(args)...);
node->deleter_ = Allocator::Deleter();
return ObjectPtr<T>(node);
} }
inline void Object::DecRef() {
if (--ref_counter == 0) {
if (this->deleter_ != nullptr) {
(*this->deleter_)(this);
}
}
}
inline int Object::use_count() const {
return ref_counter_;
}
#endif // TVM_OBJECT_ATOMIC_REF_COUNTER
template<typename TargetType>
inline bool Object::IsInstance() const {
const Object* self = this;
// NOTE: the following code can be optimized by
// compiler dead-code elimination for already known constants.
if (self != nullptr) {
// Everything is a subclass of object.
if (std::is_same<TargetType, Object>::value) return true;
if (TargetType::_type_final) {
// if the target type is a final type
// then we only need to check the equivalence.
return self->type_index_ == TargetType::type_index();
} else {
// if target type is a non-leaf type
// Check if type index falls into the range of reserved slots.
uint32_t begin = TargetType::type_index();
// The condition will be optimized by constant-folding.
if (TargetType::_type_child_slots != 0) {
uint32_t end = begin + TargetType::_type_child_slots;
if (self->type_index_ >= begin && self->type_index_ < end) return true;
} else {
if (self->type_index_ == begin) return true;
}
if (!TargetType::_type_child_slots_can_overflow) return false;
// Invariance: parent index is always smaller than the child.
if (self->type_index_ < TargetType::type_index()) return false;
// The rare slower-path, check type hierachy.
return self->DerivedFrom(TargetType::type_index());
}
} else {
return false;
}
}
inline const Object* ObjectRef::get() const {
return data_.data_;
}
inline const Object* ObjectRef::operator->() const {
return get();
}
template <typename ObjectType>
inline const ObjectType* ObjectRef::as() const {
if (data_ != nullptr &&
data_->IsInstance<ObjectType>()) {
return static_cast<ObjectType*>(data_.get());
} else {
return nullptr;
}
}
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
#endif // TVM_RUNTIME_OBJECT_H_ #endif // TVM_RUNTIME_OBJECT_H_
...@@ -489,10 +489,10 @@ class TVMPODValue_ { ...@@ -489,10 +489,10 @@ class TVMPODValue_ {
TVM_CHECK_TYPE_CODE(type_code_, kNDArrayContainer); TVM_CHECK_TYPE_CODE(type_code_, kNDArrayContainer);
return NDArray(static_cast<NDArray::Container*>(value_.v_handle)); return NDArray(static_cast<NDArray::Container*>(value_.v_handle));
} }
operator Object() const { operator ObjectRef() const {
if (type_code_ == kNull) return Object(); if (type_code_ == kNull) return ObjectRef(ObjectPtr<Object>(nullptr));
TVM_CHECK_TYPE_CODE(type_code_, kObjectCell); TVM_CHECK_TYPE_CODE(type_code_, kObjectCell);
return Object(static_cast<ObjectCell*>(value_.v_handle)); return ObjectRef(ObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
} }
operator TVMContext() const { operator TVMContext() const {
TVM_CHECK_TYPE_CODE(type_code_, kTVMContext); TVM_CHECK_TYPE_CODE(type_code_, kTVMContext);
...@@ -566,7 +566,7 @@ class TVMArgValue : public TVMPODValue_ { ...@@ -566,7 +566,7 @@ class TVMArgValue : public TVMPODValue_ {
using TVMPODValue_::operator DLTensor*; using TVMPODValue_::operator DLTensor*;
using TVMPODValue_::operator NDArray; using TVMPODValue_::operator NDArray;
using TVMPODValue_::operator TVMContext; using TVMPODValue_::operator TVMContext;
using TVMPODValue_::operator Object; using TVMPODValue_::operator ObjectRef;
// conversion operator. // conversion operator.
operator std::string() const { operator std::string() const {
...@@ -662,7 +662,7 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -662,7 +662,7 @@ class TVMRetValue : public TVMPODValue_ {
using TVMPODValue_::operator DLTensor*; using TVMPODValue_::operator DLTensor*;
using TVMPODValue_::operator TVMContext; using TVMPODValue_::operator TVMContext;
using TVMPODValue_::operator NDArray; using TVMPODValue_::operator NDArray;
using TVMPODValue_::operator Object; using TVMPODValue_::operator ObjectRef;
TVMRetValue(const TVMRetValue& other) : TVMPODValue_() { TVMRetValue(const TVMRetValue& other) : TVMPODValue_() {
this->Assign(other); this->Assign(other);
} }
...@@ -759,11 +759,12 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -759,11 +759,12 @@ class TVMRetValue : public TVMPODValue_ {
other.data_ = nullptr; other.data_ = nullptr;
return *this; return *this;
} }
TVMRetValue& operator=(Object other) { TVMRetValue& operator=(ObjectRef other) {
this->Clear(); this->Clear();
type_code_ = kObjectCell; type_code_ = kObjectCell;
value_.v_handle = other.ptr_.data_; // move the handle out
other.ptr_.data_ = nullptr; value_.v_handle = other.data_.data_;
other.data_.data_ = nullptr;
return *this; return *this;
} }
TVMRetValue& operator=(PackedFunc f) { TVMRetValue& operator=(PackedFunc f) {
...@@ -862,7 +863,7 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -862,7 +863,7 @@ class TVMRetValue : public TVMPODValue_ {
break; break;
} }
case kObjectCell: { case kObjectCell: {
*this = other.operator Object(); *this = other.operator ObjectRef();
break; break;
} }
default: { default: {
...@@ -913,7 +914,7 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -913,7 +914,7 @@ class TVMRetValue : public TVMPODValue_ {
break; break;
} }
case kObjectCell: { case kObjectCell: {
static_cast<ObjectCell*>(value_.v_handle)->DecRef(); static_cast<Object*>(value_.v_handle)->DecRef();
break; break;
} }
} }
...@@ -1161,6 +1162,10 @@ class TVMArgsSetter { ...@@ -1161,6 +1162,10 @@ class TVMArgsSetter {
values_[i].v_handle = value.data_; values_[i].v_handle = value.data_;
type_codes_[i] = kNDArrayContainer; type_codes_[i] = kNDArrayContainer;
} }
void operator()(size_t i, const ObjectRef& value) const { // NOLINT(*)
values_[i].v_handle = value.data_.data_;
type_codes_[i] = kObjectCell;
}
void operator()(size_t i, const TVMRetValue& value) const { // NOLINT(*) void operator()(size_t i, const TVMRetValue& value) const { // NOLINT(*)
if (value.type_code() == kStr) { if (value.type_code() == kStr) {
values_[i].v_str = value.ptr<std::string>()->c_str(); values_[i].v_str = value.ptr<std::string>()->c_str();
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
*/ */
/*! /*!
* Copyright (c) 2018 by Contributors
* \file tvm/runtime/vm.h * \file tvm/runtime/vm.h
* \brief A virtual machine for executing Relay programs. * \brief A virtual machine for executing Relay programs.
*/ */
...@@ -36,6 +35,75 @@ namespace tvm { ...@@ -36,6 +35,75 @@ namespace tvm {
namespace runtime { namespace runtime {
namespace vm { namespace vm {
/*! \brief An object containing an NDArray. */
class TensorObj : public Object {
public:
/*! \brief The NDArray. */
NDArray data;
static constexpr const uint32_t _type_index = TypeIndex::kVMTensor;
static constexpr const char* _type_key = "vm.Tensor";
TVM_DECLARE_FINAL_OBJECT_INFO(TensorObj, Object);
};
/*! \brief reference to tensor. */
class Tensor : public ObjectRef {
public:
explicit Tensor(NDArray data);
TVM_DEFINE_OBJECT_REF_METHODS(Tensor, ObjectRef, TensorObj);
};
/*! \brief An object representing a structure or enumeration. */
class DatatypeObj : public Object {
public:
/*! \brief The tag representing the constructor used. */
size_t tag;
/*! \brief The fields of the structure. */
std::vector<ObjectRef> fields;
static constexpr const uint32_t _type_index = TypeIndex::kVMDatatype;
static constexpr const char* _type_key = "vm.Datatype";
TVM_DECLARE_FINAL_OBJECT_INFO(DatatypeObj, Object);
};
/*! \brief reference to data type. */
class Datatype : public ObjectRef {
public:
Datatype(size_t tag, std::vector<ObjectRef> fields);
/*!
* \brief construct a tuple object.
* \param fields The fields of the tuple.
* \return The constructed tuple type.
*/
static Datatype Tuple(std::vector<ObjectRef> fields);
TVM_DEFINE_OBJECT_REF_METHODS(Datatype, ObjectRef, DatatypeObj);
};
/*! \brief An object representing a closure. */
class ClosureObj : public Object {
public:
/*! \brief The index into the VM function table. */
size_t func_index;
/*! \brief The free variables of the closure. */
std::vector<ObjectRef> free_vars;
static constexpr const uint32_t _type_index = TypeIndex::kVMClosure;
static constexpr const char* _type_key = "vm.Closure";
TVM_DECLARE_FINAL_OBJECT_INFO(ClosureObj, Object);
};
/*! \brief reference to closure. */
class Closure : public ObjectRef {
public:
Closure(size_t func_index, std::vector<ObjectRef> free_vars);
TVM_DEFINE_OBJECT_REF_METHODS(Closure, ObjectRef, ClosureObj);
};
/*! \brief Magic number for NDArray list file */ /*! \brief Magic number for NDArray list file */
constexpr uint64_t kTVMNDArrayListMagic = 0xF7E58D4F05049CB7; constexpr uint64_t kTVMNDArrayListMagic = 0xF7E58D4F05049CB7;
...@@ -193,7 +261,7 @@ struct Instruction { ...@@ -193,7 +261,7 @@ struct Instruction {
static Instruction Ret(RegName return_reg); static Instruction Ret(RegName return_reg);
/*! \brief Construct a fatal instruction. /*! \brief Construct a fatal instruction.
* \return The fatal instruction. * \return The fatal instruction.
* */ * */
static Instruction Fatal(); static Instruction Fatal();
/*! \brief Construct a invoke packed instruction. /*! \brief Construct a invoke packed instruction.
* \param packed_index The index of the packed function. * \param packed_index The index of the packed function.
...@@ -348,7 +416,7 @@ struct VMFrame { ...@@ -348,7 +416,7 @@ struct VMFrame {
const Instruction* code; const Instruction* code;
/*! \brief Statically allocated space for objects */ /*! \brief Statically allocated space for objects */
std::vector<Object> register_file; std::vector<ObjectRef> register_file;
/*! \brief Register in caller's frame to put return value */ /*! \brief Register in caller's frame to put return value */
RegName caller_return_register; RegName caller_return_register;
...@@ -406,8 +474,11 @@ class VirtualMachine : public runtime::ModuleNode { ...@@ -406,8 +474,11 @@ class VirtualMachine : public runtime::ModuleNode {
* *
* \note The return value will be stored in the last output_size slots of args. * \note The return value will be stored in the last output_size slots of args.
*/ */
virtual void InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count, virtual void InvokePacked(Index packed_index,
Index output_size, const std::vector<Object>& args); const PackedFunc& func,
Index arg_count,
Index output_size,
const std::vector<ObjectRef>& args);
virtual ~VirtualMachine() {} virtual ~VirtualMachine() {}
...@@ -424,7 +495,7 @@ class VirtualMachine : public runtime::ModuleNode { ...@@ -424,7 +495,7 @@ class VirtualMachine : public runtime::ModuleNode {
/*! \brief The current stack of call frames. */ /*! \brief The current stack of call frames. */
std::vector<VMFrame> frames; std::vector<VMFrame> frames;
/*! \brief The global constant pool. */ /*! \brief The global constant pool. */
std::vector<Object> constants; std::vector<ObjectRef> constants;
/*! \brief The fuction table index of the current function. */ /*! \brief The fuction table index of the current function. */
Index func_index; Index func_index;
/*! \brief The current pointer to the code section. */ /*! \brief The current pointer to the code section. */
...@@ -433,7 +504,7 @@ class VirtualMachine : public runtime::ModuleNode { ...@@ -433,7 +504,7 @@ class VirtualMachine : public runtime::ModuleNode {
Index pc; Index pc;
/*! \brief The special return register. */ /*! \brief The special return register. */
Object return_register; ObjectRef return_register;
/*! \brief The set of TVM contexts the VM is currently executing on. */ /*! \brief The set of TVM contexts the VM is currently executing on. */
std::vector<TVMContext> ctxs; std::vector<TVMContext> ctxs;
...@@ -449,13 +520,13 @@ class VirtualMachine : public runtime::ModuleNode { ...@@ -449,13 +520,13 @@ class VirtualMachine : public runtime::ModuleNode {
* \param reg The register to write to. * \param reg The register to write to.
* \param obj The object to write to. * \param obj The object to write to.
*/ */
inline void WriteRegister(RegName reg, const Object& obj); inline void WriteRegister(RegName reg, const ObjectRef& obj);
/*! \brief Read a VM register. /*! \brief Read a VM register.
* \param reg The register to read from. * \param reg The register to read from.
* \return The read object. * \return The read object.
*/ */
inline Object ReadRegister(RegName reg) const; inline ObjectRef ReadRegister(RegName reg) const;
/*! \brief Read a VM register and cast it to int32_t /*! \brief Read a VM register and cast it to int32_t
* \param reg The register to read from. * \param reg The register to read from.
...@@ -468,15 +539,16 @@ class VirtualMachine : public runtime::ModuleNode { ...@@ -468,15 +539,16 @@ class VirtualMachine : public runtime::ModuleNode {
* \param args The arguments to the function. * \param args The arguments to the function.
* \return The object representing the result. * \return The object representing the result.
*/ */
Object Invoke(const VMFunction& func, const std::vector<Object>& args); ObjectRef Invoke(const VMFunction& func, const std::vector<ObjectRef>& args);
// TODO(@jroesch): I really would like this to be a global variable. // TODO(@jroesch): I really would like this to be a global variable.
/*! \brief Invoke a VM function by name. /*!
* \brief Invoke a VM function by name.
* \param name The function's name. * \param name The function's name.
* \param args The arguments to the function. * \param args The arguments to the function.
* \return The object representing the result. * \return The object representing the result.
*/ */
Object Invoke(const std::string& name, const std::vector<Object>& args); ObjectRef Invoke(const std::string& name, const std::vector<ObjectRef>& args);
VirtualMachine() : functions(), frames(), func_index(0), code(nullptr), pc(0) {} VirtualMachine() : functions(), frames(), func_index(0), code(nullptr), pc(0) {}
...@@ -513,11 +585,10 @@ class VirtualMachine : public runtime::ModuleNode { ...@@ -513,11 +585,10 @@ class VirtualMachine : public runtime::ModuleNode {
* *
* This does not begin execution of the VM. * This does not begin execution of the VM.
*/ */
void InvokeGlobal(const VMFunction& func, const std::vector<Object>& args); void InvokeGlobal(const VMFunction& func, const std::vector<ObjectRef>& args);
/*! \brief The parameter name to data mapping. */ /*! \brief The parameter name to data mapping. */
std::unordered_map<std::string, Object> params_; std::unordered_map<std::string, ObjectRef> params_;
}; };
} // namespace vm } // namespace vm
......
...@@ -44,9 +44,9 @@ except IMPORT_EXCEPT: ...@@ -44,9 +44,9 @@ except IMPORT_EXCEPT:
class ObjectTag(object): class ObjectTag(object):
"""Type code used in API calls""" """Type code used in API calls"""
TENSOR = 0 TENSOR = 1
CLOSURE = 1 CLOSURE = 2
DATATYPE = 2 DATATYPE = 3
class Object(_ObjectBase): class Object(_ObjectBase):
......
...@@ -92,7 +92,7 @@ struct APIAttrGetter : public AttrVisitor { ...@@ -92,7 +92,7 @@ struct APIAttrGetter : public AttrVisitor {
found_ref_object = true; found_ref_object = true;
} }
} }
void Visit(const char* key, runtime::Object* value) final { void Visit(const char* key, runtime::ObjectRef* value) final {
if (skey == key) { if (skey == key) {
*ret = value[0]; *ret = value[0];
found_ref_object = true; found_ref_object = true;
...@@ -133,7 +133,7 @@ struct APIAttrDir : public AttrVisitor { ...@@ -133,7 +133,7 @@ struct APIAttrDir : public AttrVisitor {
void Visit(const char* key, runtime::NDArray* value) final { void Visit(const char* key, runtime::NDArray* value) final {
names->push_back(key); names->push_back(key);
} }
void Visit(const char* key, runtime::Object* value) final { void Visit(const char* key, runtime::ObjectRef* value) final {
names->push_back(key); names->push_back(key);
} }
}; };
......
...@@ -54,7 +54,7 @@ inline Type String2Type(std::string s) { ...@@ -54,7 +54,7 @@ inline Type String2Type(std::string s) {
} }
using runtime::Object; using runtime::Object;
using runtime::ObjectCell; using runtime::ObjectRef;
// indexer to index all the ndoes // indexer to index all the ndoes
class NodeIndexer : public AttrVisitor { class NodeIndexer : public AttrVisitor {
...@@ -63,8 +63,6 @@ class NodeIndexer : public AttrVisitor { ...@@ -63,8 +63,6 @@ class NodeIndexer : public AttrVisitor {
std::vector<Node*> node_list{nullptr}; std::vector<Node*> node_list{nullptr};
std::unordered_map<DLTensor*, size_t> tensor_index; std::unordered_map<DLTensor*, size_t> tensor_index;
std::vector<DLTensor*> tensor_list; std::vector<DLTensor*> tensor_list;
std::unordered_map<ObjectCell*, size_t> vm_obj_index;
std::vector<ObjectCell*> vm_obj_list;
void Visit(const char* key, double* value) final {} void Visit(const char* key, double* value) final {}
void Visit(const char* key, int64_t* value) final {} void Visit(const char* key, int64_t* value) final {}
...@@ -86,12 +84,8 @@ class NodeIndexer : public AttrVisitor { ...@@ -86,12 +84,8 @@ class NodeIndexer : public AttrVisitor {
tensor_list.push_back(ptr); tensor_list.push_back(ptr);
} }
void Visit(const char* key, Object* value) final { void Visit(const char* key, ObjectRef* value) final {
ObjectCell* ptr = value->ptr_.get(); LOG(FATAL) << "Do not support json serialize non-node object";
if (vm_obj_index.count(ptr)) return;
CHECK_EQ(vm_obj_index.size(), vm_obj_list.size());
vm_obj_index[ptr] = vm_obj_list.size();
vm_obj_list.push_back(ptr);
} }
// make index of all the children of node // make index of all the children of node
...@@ -177,7 +171,6 @@ class JSONAttrGetter : public AttrVisitor { ...@@ -177,7 +171,6 @@ class JSONAttrGetter : public AttrVisitor {
public: public:
const std::unordered_map<Node*, size_t>* node_index_; const std::unordered_map<Node*, size_t>* node_index_;
const std::unordered_map<DLTensor*, size_t>* tensor_index_; const std::unordered_map<DLTensor*, size_t>* tensor_index_;
const std::unordered_map<ObjectCell*, size_t>* vm_obj_index_;
JSONNode* node_; JSONNode* node_;
void Visit(const char* key, double* value) final { void Visit(const char* key, double* value) final {
...@@ -212,9 +205,8 @@ class JSONAttrGetter : public AttrVisitor { ...@@ -212,9 +205,8 @@ class JSONAttrGetter : public AttrVisitor {
node_->attrs[key] = std::to_string( node_->attrs[key] = std::to_string(
tensor_index_->at(const_cast<DLTensor*>((*value).operator->()))); tensor_index_->at(const_cast<DLTensor*>((*value).operator->())));
} }
void Visit(const char* key, Object* value) final { void Visit(const char* key, ObjectRef* value) final {
node_->attrs[key] = std::to_string( LOG(FATAL) << "Do not support json serialize non-node object";
vm_obj_index_->at(value->ptr_.get()));
} }
// Get the node // Get the node
void Get(Node* node) { void Get(Node* node) {
...@@ -269,7 +261,6 @@ class JSONAttrSetter : public AttrVisitor { ...@@ -269,7 +261,6 @@ class JSONAttrSetter : public AttrVisitor {
public: public:
const std::vector<NodePtr<Node> >* node_list_; const std::vector<NodePtr<Node> >* node_list_;
const std::vector<runtime::NDArray>* tensor_list_; const std::vector<runtime::NDArray>* tensor_list_;
const std::vector<Object>* vm_obj_list_;
JSONNode* node_; JSONNode* node_;
...@@ -325,11 +316,8 @@ class JSONAttrSetter : public AttrVisitor { ...@@ -325,11 +316,8 @@ class JSONAttrSetter : public AttrVisitor {
CHECK_LE(index, tensor_list_->size()); CHECK_LE(index, tensor_list_->size());
*value = tensor_list_->at(index); *value = tensor_list_->at(index);
} }
void Visit(const char* key, Object* value) final { void Visit(const char* key, ObjectRef* value) final {
size_t index; LOG(FATAL) << "Do not support json serialize non-node object";
ParseValue(key, &index);
CHECK_LE(index, vm_obj_list_->size());
*value = vm_obj_list_->at(index);
} }
// set node to be current JSONNode // set node to be current JSONNode
void Set(Node* node) { void Set(Node* node) {
...@@ -508,8 +496,8 @@ class NodeAttrSetter : public AttrVisitor { ...@@ -508,8 +496,8 @@ class NodeAttrSetter : public AttrVisitor {
void Visit(const char* key, runtime::NDArray* value) final { void Visit(const char* key, runtime::NDArray* value) final {
*value = GetAttr(key).operator runtime::NDArray(); *value = GetAttr(key).operator runtime::NDArray();
} }
void Visit(const char* key, Object* value) final { void Visit(const char* key, ObjectRef* value) final {
*value = GetAttr(key).operator Object(); *value = GetAttr(key).operator ObjectRef();
} }
private: private:
......
...@@ -885,7 +885,7 @@ void VMCompiler::Compile(Module mod, ...@@ -885,7 +885,7 @@ void VMCompiler::Compile(Module mod,
// populate constants // populate constants
for (auto data : context_.constants) { for (auto data : context_.constants) {
vm_->constants.push_back(Object::Tensor(data)); vm_->constants.push_back(runtime::vm::Tensor(data));
} }
LibraryCodegen(); LibraryCodegen();
......
...@@ -102,7 +102,7 @@ void Deserializer::DeserializeConstantSection() { ...@@ -102,7 +102,7 @@ void Deserializer::DeserializeConstantSection() {
for (size_t i = 0; i < size; i++) { for (size_t i = 0; i < size; i++) {
runtime::NDArray constant; runtime::NDArray constant;
STREAM_CHECK(constant.Load(strm_), "constant"); STREAM_CHECK(constant.Load(strm_), "constant");
runtime::Object obj = runtime::Object::Tensor(constant); runtime::ObjectRef obj = runtime::vm::Tensor(constant);
vm_->constants.push_back(obj); vm_->constants.push_back(obj);
} }
} }
......
...@@ -98,8 +98,8 @@ std::string Serializer::Stats() const { ...@@ -98,8 +98,8 @@ std::string Serializer::Stats() const {
// Get the number of constants and the shape of each of them. // Get the number of constants and the shape of each of them.
oss << " Constant shapes (# " << vm_->constants.size() << "): ["; oss << " Constant shapes (# " << vm_->constants.size() << "): [";
for (const auto& it : vm_->constants) { for (const auto& it : vm_->constants) {
auto cell = it.AsTensor(); auto* cell = it.as<runtime::vm::TensorObj>();
CHECK(cell.operator->()); CHECK(cell != nullptr);
runtime::NDArray data = cell->data; runtime::NDArray data = cell->data;
const auto& shape = data.Shape(); const auto& shape = data.Shape();
...@@ -175,7 +175,8 @@ void Serializer::SerializeGlobalSection() { ...@@ -175,7 +175,8 @@ void Serializer::SerializeGlobalSection() {
void Serializer::SerializeConstantSection() { void Serializer::SerializeConstantSection() {
std::vector<DLTensor*> arrays; std::vector<DLTensor*> arrays;
for (const auto& obj : vm_->constants) { for (const auto& obj : vm_->constants) {
auto cell = obj.AsTensor(); const auto* cell = obj.as<runtime::vm::TensorObj>();
CHECK(cell != nullptr);
runtime::NDArray data = cell->data; runtime::NDArray data = cell->data;
arrays.push_back(const_cast<DLTensor*>(data.operator->())); arrays.push_back(const_cast<DLTensor*>(data.operator->()));
} }
......
...@@ -930,7 +930,7 @@ class PrettyPrinter::AttrPrinter : public AttrVisitor { ...@@ -930,7 +930,7 @@ class PrettyPrinter::AttrPrinter : public AttrVisitor {
void Visit(const char* key, runtime::NDArray* value) final { void Visit(const char* key, runtime::NDArray* value) final {
LOG(FATAL) << "do not allow NDarray as argument"; LOG(FATAL) << "do not allow NDarray as argument";
} }
void Visit(const char* key, runtime::Object* obj) final { void Visit(const char* key, runtime::ObjectRef* obj) final {
LOG(FATAL) << "do not allow Object as argument"; LOG(FATAL) << "do not allow Object as argument";
} }
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*
* \file src/runtime/object.cc
* \brief Object type management system.
*/
#include <dmlc/logging.h>
#include <tvm/runtime/object.h>
#include <mutex>
#include <string>
#include <vector>
#include <unordered_map>
namespace tvm {
namespace runtime {
/*! \brief Type information */
struct TypeInfo {
/*! \brief The current index. */
uint32_t index{0};
/*! \brief Index of the parent in the type hierachy */
uint32_t parent_index{0};
// NOTE: the indices in [index, index + num_reserved_slots) are
// reserved for the child-class of this type.
/*! \brief Total number of slots reserved for the type and its children. */
uint32_t num_slots{0};
/*! \brief number of allocated child slots. */
uint32_t allocated_slots{0};
/*! \brief Whether child can overflow. */
bool child_slots_can_overflow{true};
/*! \brief name of the type. */
std::string name;
};
/*!
* \brief Type context that manages the type hierachy information.
*/
class TypeContext {
public:
// NOTE: this is a relatively slow path for child checking
// Most types are already checked by the fast-path via reserved slot checking.
bool DerivedFrom(uint32_t child_tindex, uint32_t parent_tindex) {
// invariance: child's type index is always bigger than its parent.
if (child_tindex < parent_tindex) return false;
if (child_tindex == parent_tindex) return true;
{
std::lock_guard<std::mutex> lock(mutex_);
CHECK_LT(child_tindex, type_table_.size());
while (child_tindex > parent_tindex) {
child_tindex = type_table_[child_tindex].parent_index;
}
}
return child_tindex == parent_tindex;
}
uint32_t GetOrAllocRuntimeTypeIndex(const char* key,
uint32_t static_tindex,
uint32_t parent_tindex,
uint32_t num_child_slots,
bool child_slots_can_overflow) {
std::lock_guard<std::mutex> lock(mutex_);
std::string skey = key;
auto it = type_key2index_.find(skey);
if (it != type_key2index_.end()) {
return it->second;
}
// try to allocate from parent's type table.
CHECK_LT(parent_tindex, type_table_.size());
TypeInfo& pinfo = type_table_[parent_tindex];
CHECK_EQ(pinfo.index, parent_tindex);
// if parent cannot overflow, then this class cannot.
if (!pinfo.child_slots_can_overflow) {
child_slots_can_overflow = false;
}
// total number of slots include the type itself.
uint32_t num_slots = num_child_slots + 1;
uint32_t allocated_tindex;
if (static_tindex != TypeIndex::kDynamic) {
// statically assigned type
allocated_tindex = static_tindex;
CHECK_LT(static_tindex, type_table_.size());
CHECK_EQ(type_table_[allocated_tindex].allocated_slots, 0U)
<< "Conflicting static index " << static_tindex
<< " between " << type_table_[allocated_tindex].name
<< " and "
<< key;
} else if (pinfo.allocated_slots + num_slots < pinfo.num_slots) {
// allocate the slot from parent's reserved pool
allocated_tindex = parent_tindex + pinfo.allocated_slots;
// update parent's state
pinfo.allocated_slots += num_slots;
} else {
CHECK(pinfo.child_slots_can_overflow)
<< "Reach maximum number of sub-classes for " << pinfo.name;
// allocate new entries.
allocated_tindex = type_counter_;
type_counter_ += num_slots;
CHECK_LE(type_table_.size(), allocated_tindex);
type_table_.resize(allocated_tindex + 1, TypeInfo());
}
CHECK_GT(allocated_tindex, parent_tindex);
// initialize the slot.
type_table_[allocated_tindex].index = allocated_tindex;
type_table_[allocated_tindex].parent_index = parent_tindex;
type_table_[allocated_tindex].num_slots = num_slots;
type_table_[allocated_tindex].allocated_slots = 1;
type_table_[allocated_tindex].child_slots_can_overflow =
child_slots_can_overflow;
type_table_[allocated_tindex].name = skey;
// update the key2index mapping.
type_key2index_[skey] = allocated_tindex;
return allocated_tindex;
}
std::string TypeIndex2Key(uint32_t tindex) {
std::lock_guard<std::mutex> lock(mutex_);
CHECK(tindex < type_table_.size() &&
type_table_[tindex].allocated_slots != 0)
<< "Unknown type index " << tindex;
return type_table_[tindex].name;
}
uint32_t TypeKey2Index(const char* key) {
std::string skey = key;
auto it = type_key2index_.find(skey);
CHECK(it != type_key2index_.end())
<< "Cannot find type " << key;
return it->second;
}
static TypeContext* Global() {
static TypeContext inst;
return &inst;
}
private:
TypeContext() {
type_table_.resize(TypeIndex::kStaticIndexEnd, TypeInfo());
}
// mutex to avoid registration from multiple threads.
std::mutex mutex_;
std::atomic<uint32_t> type_counter_{TypeIndex::kStaticIndexEnd};
std::vector<TypeInfo> type_table_;
std::unordered_map<std::string, uint32_t> type_key2index_;
};
uint32_t Object::GetOrAllocRuntimeTypeIndex(const char* key,
uint32_t static_tindex,
uint32_t parent_tindex,
uint32_t num_child_slots,
bool child_slots_can_overflow) {
return TypeContext::Global()->GetOrAllocRuntimeTypeIndex(
key, static_tindex, parent_tindex, num_child_slots, child_slots_can_overflow);
}
bool Object::DerivedFrom(uint32_t parent_tindex) const {
return TypeContext::Global()->DerivedFrom(
this->type_index_, parent_tindex);
}
std::string Object::TypeIndex2Key(uint32_t tindex) {
return TypeContext::Global()->TypeIndex2Key(tindex);
}
uint32_t Object::TypeKey2Index(const char* key) {
return TypeContext::Global()->TypeKey2Index(key);
}
} // namespace runtime
} // namespace tvm
...@@ -18,134 +18,110 @@ ...@@ -18,134 +18,110 @@
*/ */
/*! /*!
* Copyright (c) 2019 by Contributors * \file src/runtime/vm/object.cc
* \file object.cc * \brief VM related objects.
* \brief A managed object in the TVM runtime.
*/ */
#include <tvm/logging.h> #include <tvm/logging.h>
#include <tvm/runtime/object.h> #include <tvm/runtime/object.h>
#include <tvm/runtime/vm.h>
#include <tvm/runtime/memory.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/runtime/c_runtime_api.h> #include <tvm/runtime/c_runtime_api.h>
#include <iostream>
#include "../runtime_base.h" #include "../runtime_base.h"
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
namespace vm {
std::ostream& operator<<(std::ostream& os, const ObjectTag& tag) { Tensor::Tensor(NDArray data) {
switch (tag) { auto ptr = make_object<TensorObj>();
case ObjectTag::kClosure: ptr->data = std::move(data);
os << "Closure"; data_ = std::move(ptr);
break;
case ObjectTag::kDatatype:
os << "Datatype";
break;
case ObjectTag::kTensor:
os << "Tensor";
break;
default:
LOG(FATAL) << "Invalid object tag: found " << static_cast<int>(tag);
}
return os;
}
Object Object::Tensor(const NDArray& data) {
ObjectPtr<ObjectCell> ptr = MakeObject<TensorCell>(data);
return Object(ptr);
}
Object Object::Datatype(size_t tag, const std::vector<Object>& fields) {
ObjectPtr<ObjectCell> ptr = MakeObject<DatatypeCell>(tag, fields);
return Object(ptr);
} }
Object Object::Tuple(const std::vector<Object>& fields) { return Object::Datatype(0, fields); } Datatype::Datatype(size_t tag, std::vector<ObjectRef> fields) {
auto ptr = make_object<DatatypeObj>();
Object Object::Closure(size_t func_index, const std::vector<Object>& free_vars) { ptr->tag = tag;
ObjectPtr<ObjectCell> ptr = MakeObject<ClosureCell>(func_index, free_vars); ptr->fields = std::move(fields);
return Object(ptr); data_ = std::move(ptr);
} }
ObjectPtr<TensorCell> Object::AsTensor() const { Datatype Datatype::Tuple(std::vector<ObjectRef> fields) {
CHECK(ptr_.get()); return Datatype(0, fields);
CHECK(ptr_.get()->tag == ObjectTag::kTensor);
return ptr_.As<TensorCell>();
} }
ObjectPtr<DatatypeCell> Object::AsDatatype() const { Closure::Closure(size_t func_index, std::vector<ObjectRef> free_vars) {
CHECK(ptr_.get()); auto ptr = make_object<ClosureObj>();
CHECK(ptr_.get()->tag == ObjectTag::kDatatype); ptr->func_index = func_index;
return ptr_.As<DatatypeCell>(); ptr->free_vars = std::move(free_vars);
data_ = std::move(ptr);
} }
ObjectPtr<ClosureCell> Object::AsClosure() const {
CHECK(ptr_.get());
CHECK(ptr_.get()->tag == ObjectTag::kClosure);
return ptr_.As<ClosureCell>();
}
NDArray ToNDArray(const Object& obj) {
auto tensor = obj.AsTensor();
return tensor->data;
}
TVM_REGISTER_GLOBAL("_vmobj.GetTensorData") TVM_REGISTER_GLOBAL("_vmobj.GetTensorData")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
Object obj = args[0]; ObjectRef obj = args[0];
auto cell = obj.AsTensor(); const auto* cell = obj.as<TensorObj>();
CHECK(cell != nullptr);
*rv = cell->data; *rv = cell->data;
}); });
TVM_REGISTER_GLOBAL("_vmobj.GetDatatypeTag") TVM_REGISTER_GLOBAL("_vmobj.GetDatatypeTag")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
Object obj = args[0]; ObjectRef obj = args[0];
auto cell = obj.AsDatatype(); const auto* cell = obj.as<DatatypeObj>();
*rv = static_cast<int>(cell->tag); CHECK(cell != nullptr);
*rv = static_cast<int64_t>(cell->tag);
}); });
TVM_REGISTER_GLOBAL("_vmobj.GetDatatypeNumberOfFields") TVM_REGISTER_GLOBAL("_vmobj.GetDatatypeNumberOfFields")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
Object obj = args[0]; ObjectRef obj = args[0];
auto cell = obj.AsDatatype(); const auto* cell = obj.as<DatatypeObj>();
*rv = static_cast<int>(cell->fields.size()); CHECK(cell != nullptr);
*rv = static_cast<int64_t>(cell->fields.size());
}); });
TVM_REGISTER_GLOBAL("_vmobj.GetDatatypeFields") TVM_REGISTER_GLOBAL("_vmobj.GetDatatypeFields")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
Object obj = args[0]; ObjectRef obj = args[0];
int idx = args[1]; int idx = args[1];
auto cell = obj.AsDatatype(); const auto* cell = obj.as<DatatypeObj>();
CHECK(cell != nullptr);
CHECK_LT(idx, cell->fields.size()); CHECK_LT(idx, cell->fields.size());
*rv = cell->fields[idx]; *rv = cell->fields[idx];
}); });
TVM_REGISTER_GLOBAL("_vmobj.Tensor") TVM_REGISTER_GLOBAL("_vmobj.Tensor")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = Object::Tensor(args[0]); *rv = Tensor(args[0].operator NDArray());
}); });
TVM_REGISTER_GLOBAL("_vmobj.Tuple") TVM_REGISTER_GLOBAL("_vmobj.Tuple")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
std::vector<Object> fields; std::vector<ObjectRef> fields;
for (auto i = 0; i < args.size(); ++i) { for (auto i = 0; i < args.size(); ++i) {
fields.push_back(args[i]); fields.push_back(args[i]);
} }
*rv = Object::Tuple(fields); *rv = Datatype::Tuple(fields);
}); });
TVM_REGISTER_GLOBAL("_vmobj.Datatype") TVM_REGISTER_GLOBAL("_vmobj.Datatype")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
int itag = args[0]; int itag = args[0];
size_t tag = static_cast<size_t>(itag); size_t tag = static_cast<size_t>(itag);
std::vector<Object> fields; std::vector<ObjectRef> fields;
for (int i = 1; i < args.size(); i++) { for (int i = 1; i < args.size(); i++) {
fields.push_back(args[i]); fields.push_back(args[i]);
} }
*rv = Object::Datatype(tag, fields); *rv = Datatype(tag, fields);
}); });
TVM_REGISTER_OBJECT_TYPE(TensorObj);
TVM_REGISTER_OBJECT_TYPE(DatatypeObj);
TVM_REGISTER_OBJECT_TYPE(ClosureObj);
} // namespace vm
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
...@@ -153,6 +129,7 @@ using namespace tvm::runtime; ...@@ -153,6 +129,7 @@ using namespace tvm::runtime;
int TVMGetObjectTag(TVMObjectHandle handle, int* tag) { int TVMGetObjectTag(TVMObjectHandle handle, int* tag) {
API_BEGIN(); API_BEGIN();
*tag = static_cast<int>(static_cast<ObjectCell*>(handle)->tag); int res = static_cast<int>(static_cast<Object*>(handle)->type_index());
*tag = res;
API_END(); API_END();
} }
...@@ -96,7 +96,7 @@ void VirtualMachineDebug::Init(const std::vector<TVMContext>& ctxs) { ...@@ -96,7 +96,7 @@ void VirtualMachineDebug::Init(const std::vector<TVMContext>& ctxs) {
void VirtualMachineDebug::InvokePacked(Index packed_index, void VirtualMachineDebug::InvokePacked(Index packed_index,
const PackedFunc& func, Index arg_count, const PackedFunc& func, Index arg_count,
Index output_size, Index output_size,
const std::vector<Object>& args) { const std::vector<ObjectRef>& args) {
auto ctx = VirtualMachine::GetParamsContext(); auto ctx = VirtualMachine::GetParamsContext();
// warmup // warmup
VirtualMachine::InvokePacked(packed_index, func, arg_count, output_size, VirtualMachine::InvokePacked(packed_index, func, arg_count, output_size,
......
...@@ -45,7 +45,7 @@ class VirtualMachineDebug : public VirtualMachine { ...@@ -45,7 +45,7 @@ class VirtualMachineDebug : public VirtualMachine {
const std::shared_ptr<ModuleNode>& sptr_to_self) final; const std::shared_ptr<ModuleNode>& sptr_to_self) final;
void InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count, void InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count,
Index output_size, const std::vector<Object>& args) final; Index output_size, const std::vector<ObjectRef>& args) final;
~VirtualMachineDebug() {} ~VirtualMachineDebug() {}
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
*/ */
/*! /*!
* Copyright (c) 2019 by Contributors
* \file src/runtime/vm/vm.cc * \file src/runtime/vm/vm.cc
* \brief The Relay virtual machine. * \brief The Relay virtual machine.
*/ */
...@@ -558,12 +557,12 @@ std::ostream& operator<<(std::ostream& os, const VMFunction& vm_func) { ...@@ -558,12 +557,12 @@ std::ostream& operator<<(std::ostream& os, const VMFunction& vm_func) {
return os; return os;
} }
Object CopyTo(Object src, const DLContext& ctx) { ObjectRef CopyTo(ObjectRef src, const DLContext& ctx) {
if (src->tag == ObjectTag::kTensor) { if (const TensorObj* obj = src.as<TensorObj>()) {
auto tensor = ToNDArray(src); auto tensor = obj->data;
if (tensor->ctx.device_type != ctx.device_type) { if (tensor->ctx.device_type != ctx.device_type) {
auto copy = tensor.CopyTo(ctx); auto copy = tensor.CopyTo(ctx);
return Object::Tensor(copy); return Tensor(copy);
} else { } else {
return src; return src;
} }
...@@ -585,7 +584,7 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, ...@@ -585,7 +584,7 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name,
auto ctx = this->GetParamsContext(); auto ctx = this->GetParamsContext();
// Prepare the func args // Prepare the func args
std::vector<Object> func_args(param_names.size()); std::vector<ObjectRef> func_args(param_names.size());
std::vector<size_t> empty_slots; std::vector<size_t> empty_slots;
for (size_t i = 0; i < param_names.size(); ++i) { for (size_t i = 0; i < param_names.size(); ++i) {
...@@ -599,7 +598,7 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, ...@@ -599,7 +598,7 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name,
CHECK_EQ(empty_slots.size(), args.size() - 1) CHECK_EQ(empty_slots.size(), args.size() - 1)
<< "The number of provided parameters doesn't match the number of arguments"; << "The number of provided parameters doesn't match the number of arguments";
for (int i = 1; i < args.size(); ++i) { for (int i = 1; i < args.size(); ++i) {
Object obj = CopyTo(args[i], ctx); ObjectRef obj = CopyTo(args[i], ctx);
func_args[empty_slots[i - 1]] = obj; func_args[empty_slots[i - 1]] = obj;
} }
...@@ -660,7 +659,7 @@ void VirtualMachine::LoadParams(const std::string& params) { ...@@ -660,7 +659,7 @@ void VirtualMachine::LoadParams(const std::string& params) {
for (size_t i = 0; i < size; i++) { for (size_t i = 0; i < size; i++) {
NDArray arr; NDArray arr;
CHECK(arr.Load(strm)) << "Invalid parameter file"; CHECK(arr.Load(strm)) << "Invalid parameter file";
runtime::Object obj = runtime::Object::Tensor(arr); ObjectRef obj = Tensor(arr);
auto copy = CopyTo(obj, ctx); auto copy = CopyTo(obj, ctx);
params_.emplace(std::make_pair(names[i], copy)); params_.emplace(std::make_pair(names[i], copy));
} }
...@@ -682,7 +681,7 @@ Index VirtualMachine::PopFrame() { ...@@ -682,7 +681,7 @@ Index VirtualMachine::PopFrame() {
return call_stack_size; return call_stack_size;
} }
void VirtualMachine::InvokeGlobal(const VMFunction& func, const std::vector<Object>& args) { void VirtualMachine::InvokeGlobal(const VMFunction& func, const std::vector<ObjectRef>& args) {
DLOG(INFO) << "Invoking global " << func.name << " " << args.size(); DLOG(INFO) << "Invoking global " << func.name << " " << args.size();
PushFrame(func.params.size(), this->pc + 1, func); PushFrame(func.params.size(), this->pc + 1, func);
...@@ -695,7 +694,7 @@ void VirtualMachine::InvokeGlobal(const VMFunction& func, const std::vector<Obje ...@@ -695,7 +694,7 @@ void VirtualMachine::InvokeGlobal(const VMFunction& func, const std::vector<Obje
pc = 0; pc = 0;
} }
Object VirtualMachine::Invoke(const VMFunction& func, const std::vector<Object>& args) { ObjectRef VirtualMachine::Invoke(const VMFunction& func, const std::vector<ObjectRef>& args) {
DLOG(INFO) << "Executing Function: " << std::endl << func; DLOG(INFO) << "Executing Function: " << std::endl << func;
InvokeGlobal(func, args); InvokeGlobal(func, args);
...@@ -705,7 +704,7 @@ Object VirtualMachine::Invoke(const VMFunction& func, const std::vector<Object>& ...@@ -705,7 +704,7 @@ Object VirtualMachine::Invoke(const VMFunction& func, const std::vector<Object>&
return return_register; return return_register;
} }
Object VirtualMachine::Invoke(const std::string& name, const std::vector<Object>& args) { ObjectRef VirtualMachine::Invoke(const std::string& name, const std::vector<ObjectRef>& args) {
auto func_index = this->global_map[name]; auto func_index = this->global_map[name];
DLOG(INFO) << "Invoke Global " << name << " at index " << func_index; DLOG(INFO) << "Invoke Global " << name << " at index " << func_index;
return Invoke(this->functions[func_index], args); return Invoke(this->functions[func_index], args);
...@@ -713,11 +712,11 @@ Object VirtualMachine::Invoke(const std::string& name, const std::vector<Object> ...@@ -713,11 +712,11 @@ Object VirtualMachine::Invoke(const std::string& name, const std::vector<Object>
void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func,
Index arg_count, Index output_size, Index arg_count, Index output_size,
const std::vector<Object>& args) { const std::vector<ObjectRef>& args) {
size_t arity = 0; size_t arity = 0;
for (Index i = 0; i < arg_count; i++) { for (Index i = 0; i < arg_count; i++) {
if (args[i].ptr_->tag == ObjectTag::kDatatype) { if (const auto* obj = args[i].as<DatatypeObj>()) {
arity += args[i].AsDatatype()->fields.size(); arity += obj->fields.size();
} else { } else {
++arity; ++arity;
} }
...@@ -728,15 +727,16 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, ...@@ -728,15 +727,16 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func,
runtime::TVMArgsSetter setter(values.data(), codes.data()); runtime::TVMArgsSetter setter(values.data(), codes.data());
int idx = 0; int idx = 0;
for (Index i = 0; i < arg_count; i++) { for (Index i = 0; i < arg_count; i++) {
if (args[i].ptr_->tag == ObjectTag::kDatatype) { if (const auto* dt_cell = args[i].as<DatatypeObj>()) {
auto dt_cell = args[i].AsDatatype();
for (auto obj : dt_cell->fields) { for (auto obj : dt_cell->fields) {
NDArray data = ToNDArray(obj); const auto* tensor = obj.as<TensorObj>();
setter(idx++, data); CHECK(tensor != nullptr);
setter(idx++, tensor->data);
} }
} else { } else {
NDArray data = ToNDArray(args[i]); const auto* tensor = args[i].as<TensorObj>();
setter(idx++, data); CHECK(tensor != nullptr);
setter(idx++, tensor->data);
} }
} }
...@@ -761,18 +761,20 @@ void VirtualMachine::Init(const std::vector<TVMContext>& ctxs) { ...@@ -761,18 +761,20 @@ void VirtualMachine::Init(const std::vector<TVMContext>& ctxs) {
} }
} }
inline void VirtualMachine::WriteRegister(Index r, const Object& val) { inline void VirtualMachine::WriteRegister(Index r, const ObjectRef& val) {
frames.back().register_file[r] = val; frames.back().register_file[r] = val;
} }
inline Object VirtualMachine::ReadRegister(Index r) const { inline ObjectRef VirtualMachine::ReadRegister(Index r) const {
return frames.back().register_file[r]; return frames.back().register_file[r];
} }
inline int32_t VirtualMachine::LoadScalarInt(Index r) const { inline int32_t VirtualMachine::LoadScalarInt(Index r) const {
int32_t result; int32_t result;
const auto& obj = ReadRegister(r); const auto& obj = ReadRegister(r);
NDArray array = ToNDArray(obj).CopyTo({kDLCPU, 0}); const auto* tensor = obj.as<TensorObj>();
CHECK(tensor != nullptr);
NDArray array = tensor->data.CopyTo({kDLCPU, 0});
if (array->dtype.bits <= 8) { if (array->dtype.bits <= 8) {
result = reinterpret_cast<int8_t*>(array->data)[0]; result = reinterpret_cast<int8_t*>(array->data)[0];
...@@ -798,7 +800,7 @@ void VirtualMachine::RunLoop() { ...@@ -798,7 +800,7 @@ void VirtualMachine::RunLoop() {
switch (instr.op) { switch (instr.op) {
case Opcode::Move: { case Opcode::Move: {
Object from_obj; ObjectRef from_obj;
from_obj = ReadRegister(instr.from); from_obj = ReadRegister(instr.from);
WriteRegister(instr.dst, from_obj); WriteRegister(instr.dst, from_obj);
pc++; pc++;
...@@ -817,12 +819,12 @@ void VirtualMachine::RunLoop() { ...@@ -817,12 +819,12 @@ void VirtualMachine::RunLoop() {
case Opcode::LoadConsti: { case Opcode::LoadConsti: {
auto tensor = NDArray::Empty({1}, {kDLInt, 64, 1}, {kDLCPU, 0}); auto tensor = NDArray::Empty({1}, {kDLInt, 64, 1}, {kDLCPU, 0});
reinterpret_cast<int64_t*>(tensor->data)[0] = instr.load_consti.val; reinterpret_cast<int64_t*>(tensor->data)[0] = instr.load_consti.val;
WriteRegister(instr.dst, Object::Tensor(tensor)); WriteRegister(instr.dst, Tensor(tensor));
pc++; pc++;
goto main_loop; goto main_loop;
} }
case Opcode::Invoke: { case Opcode::Invoke: {
std::vector<Object> args; std::vector<ObjectRef> args;
for (Index i = 0; i < instr.num_args; ++i) { for (Index i = 0; i < instr.num_args; ++i) {
args.push_back(ReadRegister(instr.invoke_args_registers[i])); args.push_back(ReadRegister(instr.invoke_args_registers[i]));
} }
...@@ -833,7 +835,7 @@ void VirtualMachine::RunLoop() { ...@@ -833,7 +835,7 @@ void VirtualMachine::RunLoop() {
case Opcode::InvokePacked: { case Opcode::InvokePacked: {
const auto& func = packed_funcs[instr.packed_index]; const auto& func = packed_funcs[instr.packed_index];
const auto& arity = instr.arity; const auto& arity = instr.arity;
std::vector<Object> args; std::vector<ObjectRef> args;
for (Index i = 0; i < arity; ++i) { for (Index i = 0; i < arity; ++i) {
args.push_back(ReadRegister(instr.packed_args[i])); args.push_back(ReadRegister(instr.packed_args[i]));
} }
...@@ -847,8 +849,9 @@ void VirtualMachine::RunLoop() { ...@@ -847,8 +849,9 @@ void VirtualMachine::RunLoop() {
} }
case Opcode::InvokeClosure: { case Opcode::InvokeClosure: {
auto object = ReadRegister(instr.closure); auto object = ReadRegister(instr.closure);
const auto& closure = object.AsClosure(); const auto* closure = object.as<ClosureObj>();
std::vector<Object> args;
std::vector<ObjectRef> args;
for (auto free_var : closure->free_vars) { for (auto free_var : closure->free_vars) {
args.push_back(free_var); args.push_back(free_var);
} }
...@@ -861,10 +864,10 @@ void VirtualMachine::RunLoop() { ...@@ -861,10 +864,10 @@ void VirtualMachine::RunLoop() {
} }
case Opcode::GetField: { case Opcode::GetField: {
auto object = ReadRegister(instr.object); auto object = ReadRegister(instr.object);
CHECK(object->tag == ObjectTag::kDatatype) const auto* tuple = object.as<DatatypeObj>();
CHECK(tuple != nullptr)
<< "Object is not data type object, register " << instr.object << ", Object tag " << "Object is not data type object, register " << instr.object << ", Object tag "
<< static_cast<int>(object->tag); << object->type_index();
const auto& tuple = object.AsDatatype();
auto field = tuple->fields[instr.field_index]; auto field = tuple->fields[instr.field_index];
WriteRegister(instr.dst, field); WriteRegister(instr.dst, field);
pc++; pc++;
...@@ -872,15 +875,15 @@ void VirtualMachine::RunLoop() { ...@@ -872,15 +875,15 @@ void VirtualMachine::RunLoop() {
} }
case Opcode::GetTag: { case Opcode::GetTag: {
auto object = ReadRegister(instr.get_tag.object); auto object = ReadRegister(instr.get_tag.object);
CHECK(object->tag == ObjectTag::kDatatype) const auto* data = object.as<DatatypeObj>();
CHECK(data != nullptr)
<< "Object is not data type object, register " << "Object is not data type object, register "
<< instr.get_tag.object << ", Object tag " << instr.get_tag.object << ", Object tag "
<< static_cast<int>(object->tag); << object->type_index();
const auto& data = object.AsDatatype();
auto tag = data->tag; auto tag = data->tag;
auto tag_tensor = NDArray::Empty({1}, {kDLInt, 32, 1}, {kDLCPU, 0}); auto tag_tensor = NDArray::Empty({1}, {kDLInt, 32, 1}, {kDLCPU, 0});
reinterpret_cast<int32_t*>(tag_tensor->data)[0] = tag; reinterpret_cast<int32_t*>(tag_tensor->data)[0] = tag;
WriteRegister(instr.dst, Object::Tensor(tag_tensor)); WriteRegister(instr.dst, Tensor(tag_tensor));
pc++; pc++;
goto main_loop; goto main_loop;
} }
...@@ -909,7 +912,7 @@ void VirtualMachine::RunLoop() { ...@@ -909,7 +912,7 @@ void VirtualMachine::RunLoop() {
} }
auto allocator = MemoryManager::Global()->GetAllocator(ctxs[0]); auto allocator = MemoryManager::Global()->GetAllocator(ctxs[0]);
auto data = allocator->Empty(shape, instr.alloc_tensor.dtype, ctxs[0]); auto data = allocator->Empty(shape, instr.alloc_tensor.dtype, ctxs[0]);
auto obj = Object::Tensor(data); auto obj = Tensor(data);
WriteRegister(instr.dst, obj); WriteRegister(instr.dst, obj);
pc++; pc++;
goto main_loop; goto main_loop;
...@@ -920,7 +923,9 @@ void VirtualMachine::RunLoop() { ...@@ -920,7 +923,9 @@ void VirtualMachine::RunLoop() {
cpu_ctx.device_id = 0; cpu_ctx.device_id = 0;
auto shape_tensor_obj = ReadRegister(instr.alloc_tensor_reg.shape_register); auto shape_tensor_obj = ReadRegister(instr.alloc_tensor_reg.shape_register);
NDArray shape_tensor = ToNDArray(shape_tensor_obj).CopyTo(cpu_ctx); const auto* tensor = shape_tensor_obj.as<TensorObj>();
CHECK(tensor != nullptr);
NDArray shape_tensor = tensor->data.CopyTo(cpu_ctx);
int64_t* dims = static_cast<int64_t*>(shape_tensor->data); int64_t* dims = static_cast<int64_t*>(shape_tensor->data);
auto num_dims = shape_tensor->shape[0]; auto num_dims = shape_tensor->shape[0];
...@@ -928,27 +933,27 @@ void VirtualMachine::RunLoop() { ...@@ -928,27 +933,27 @@ void VirtualMachine::RunLoop() {
shape.assign(dims, dims + num_dims); shape.assign(dims, dims + num_dims);
auto allocator = MemoryManager::Global()->GetAllocator(ctxs[0]); auto allocator = MemoryManager::Global()->GetAllocator(ctxs[0]);
auto data = allocator->Empty(shape, instr.alloc_tensor_reg.dtype, ctxs[0]); auto data = allocator->Empty(shape, instr.alloc_tensor_reg.dtype, ctxs[0]);
auto obj = Object::Tensor(data); auto obj = Tensor(data);
WriteRegister(instr.dst, obj); WriteRegister(instr.dst, obj);
pc++; pc++;
goto main_loop; goto main_loop;
} }
case Opcode::AllocDatatype: { case Opcode::AllocDatatype: {
std::vector<Object> fields; std::vector<ObjectRef> fields;
for (Index i = 0; i < instr.num_fields; ++i) { for (Index i = 0; i < instr.num_fields; ++i) {
fields.push_back(ReadRegister(instr.datatype_fields[i])); fields.push_back(ReadRegister(instr.datatype_fields[i]));
} }
Object obj = Object::Datatype(instr.constructor_tag, fields); ObjectRef obj = Datatype(instr.constructor_tag, fields);
WriteRegister(instr.dst, obj); WriteRegister(instr.dst, obj);
pc++; pc++;
goto main_loop; goto main_loop;
} }
case Opcode::AllocClosure: { case Opcode::AllocClosure: {
std::vector<Object> free_vars; std::vector<ObjectRef> free_vars;
for (Index i = 0; i < instr.num_freevar; i++) { for (Index i = 0; i < instr.num_freevar; i++) {
free_vars.push_back(ReadRegister(instr.free_vars[i])); free_vars.push_back(ReadRegister(instr.free_vars[i]));
} }
WriteRegister(instr.dst, Object::Closure(instr.func_index, free_vars)); WriteRegister(instr.dst, Closure(instr.func_index, free_vars));
pc++; pc++;
goto main_loop; goto main_loop;
} }
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/memory.h>
namespace tvm {
namespace test {
using namespace tvm::runtime;
class ObjBase : public Object {
public:
// dynamically allocate slow
static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const uint32_t _type_child_slots = 1;
static constexpr const char* _type_key = "test.ObjBase";
TVM_DECLARE_BASE_OBJECT_INFO(ObjBase, Object);
};
class ObjA : public ObjBase {
public:
static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const uint32_t _type_child_slots = 0;
static constexpr const char* _type_key = "test.ObjA";
TVM_DECLARE_BASE_OBJECT_INFO(ObjA, ObjBase);
};
class ObjB : public ObjBase {
public:
static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const char* _type_key = "test.ObjB";
TVM_DECLARE_BASE_OBJECT_INFO(ObjB, ObjBase);
};
class ObjAA : public ObjA {
public:
static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const char* _type_key = "test.ObjAA";
TVM_DECLARE_FINAL_OBJECT_INFO(ObjAA, ObjA);
};
TVM_REGISTER_OBJECT_TYPE(ObjBase);
TVM_REGISTER_OBJECT_TYPE(ObjA);
TVM_REGISTER_OBJECT_TYPE(ObjB);
TVM_REGISTER_OBJECT_TYPE(ObjAA);
} // namespace test
} // namespace tvm
TEST(ObjectHierachy, Basic) {
using namespace tvm::runtime;
using namespace tvm::test;
ObjectRef refA(make_object<ObjA>());
CHECK_EQ(refA->type_index(), ObjA::type_index());
CHECK(refA.as<Object>() != nullptr);
CHECK(refA.as<ObjA>() != nullptr);
CHECK(refA.as<ObjBase>() != nullptr);
CHECK(refA.as<ObjB>() == nullptr);
CHECK(refA.as<ObjAA>() == nullptr);
ObjectRef refAA(make_object<ObjAA>());
CHECK_EQ(refAA->type_index(), ObjAA::type_index());
CHECK(refAA.as<Object>() != nullptr);
CHECK(refAA.as<ObjBase>() != nullptr);
CHECK(refAA.as<ObjA>() != nullptr);
CHECK(refAA.as<ObjAA>() != nullptr);
CHECK(refAA.as<ObjB>() == nullptr);
ObjectRef refB(make_object<ObjB>());
CHECK_EQ(refB->type_index(), ObjB::type_index());
CHECK(refB.as<Object>() != nullptr);
CHECK(refB.as<ObjBase>() != nullptr);
CHECK(refB.as<ObjA>() == nullptr);
CHECK(refB.as<ObjAA>() == nullptr);
CHECK(refB.as<ObjB>() != nullptr);
}
int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
}
...@@ -582,7 +582,7 @@ def test_set_params(): ...@@ -582,7 +582,7 @@ def test_set_params():
mod["main"] = relay.Function([x, w, b], y) mod["main"] = relay.Function([x, w, b], y)
vm = relay.vm.compile(mod, 'llvm') vm = relay.vm.compile(mod, 'llvm')
vm.init(tvm.cpu()) vm.init(tvm.cpu())
x_np = np.random.uniform(size=(10, 5)).astype('float32') x_np = np.random.uniform(size=(10, 5)).astype('float32')
w_np = np.random.uniform(size=(6, 5)).astype('float32') w_np = np.random.uniform(size=(6, 5)).astype('float32')
b_np = np.random.uniform(size=(6,)).astype('float32') b_np = np.random.uniform(size=(6,)).astype('float32')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment