Commit dd55682d by Jared Roesch Committed by Tianqi Chen

[Relay][Runtime] Add support for virtual machine Objects (#3120)

parent b175319c
Subproject commit 55ba1778fd264c7507953552d8e51212ed11f748
Subproject commit a768f2f0627917659a4d7167eee3190469b9d164
......@@ -112,7 +112,8 @@ typedef enum {
kNNVMLast = 20U,
// The following section of code is used for non-reserved types.
kExtReserveEnd = 64U,
kExtEnd = 128U
kExtEnd = 128U,
kObject = 14U,
} TVMTypeCode;
/*!
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file tvm/runtime/object.h
* \brief A managed object in the TVM runtime.
*/
#ifndef TVM_RUNTIME_OBJECT_H_
#define TVM_RUNTIME_OBJECT_H_
#include <tvm/runtime/ndarray.h>
#include <memory>
#include <utility>
#include <vector>
namespace tvm {
namespace runtime {
template <typename T>
class ObjectPtr;
class Object;
enum struct ObjectTag {
/*! \brief The tag of a tensor. */
kTensor = 0U,
/*! \brief The tag of a closure. */
kClosure = 1U,
/*! \brief The tag of a structure. */
kDatatype = 2U,
};
std::ostream& operator<<(std::ostream& os, const ObjectTag&);
struct ObjectCell {
public:
/*!
* \brief The type of object deleter.
* \param The self pointer to the ObjectCell.
*/
typedef void (*FDeleter)(ObjectCell* self);
/*! \brief The tag of the object.
*
* Describes which type of value
* is represented by this object.
*/
ObjectTag tag;
/*!
* \brief Increment the reference count.
*/
void IncRef() { ref_counter_.fetch_add(1, std::memory_order_relaxed); }
/*!
* \brief Decrement the reference count.
*/
void DecRef() {
if (ref_counter_.fetch_sub(1, std::memory_order_release) == 1) {
std::atomic_thread_fence(std::memory_order_acquire);
if (this->deleter_ != nullptr) {
(*this->deleter_)(this);
}
}
}
protected:
// default constructor and copy constructor
ObjectCell() {}
explicit ObjectCell(ObjectTag tag) : tag(tag) {}
// override the copy and assign constructors to do nothing.
// This is to make sure only contents, but not deleter and ref_counter
// are copied when a child class copies itself.
ObjectCell(const ObjectCell& other) { // NOLINT(*)
}
ObjectCell(ObjectCell&& other) { // NOLINT(*)
}
ObjectCell& operator=(const ObjectCell& other) { // NOLINT(*)
return *this;
}
ObjectCell& operator=(ObjectCell&& other) { // NOLINT(*)
return *this;
}
private:
/*! \brief Internal reference counter */
std::atomic<int> ref_counter_{0};
/*!
* \brief deleter of this object to enable customized allocation.
* If the deleter is nullptr, no deletion will be performed.
* The creator of the Node must always set the deleter field properly.
*/
FDeleter deleter_ = nullptr;
int use_count() const { return ref_counter_.load(std::memory_order_relaxed); }
// friend declaration
template <typename>
friend class ObjectPtr;
template <typename Y, typename... Args>
friend ObjectPtr<Y> MakeObject(Args&&...);
};
/*!
* \brief A custom smart pointer for Object.
* must be subclass of NodeBase
* \tparam T the content data type.
*/
template <typename T>
class ObjectPtr {
public:
/*! \brief default constructor */
ObjectPtr() {}
/*! \brief default constructor */
ObjectPtr(std::nullptr_t) {} // NOLINT(*)
/*!
* \brief copy constructor
* \param other The value to be moved
*/
ObjectPtr(const ObjectPtr<T>& other) // NOLINT(*)
: ObjectPtr(other.data_) {}
/*!
* \brief copy constructor
* \param other The value to be moved
*/
template <typename U>
ObjectPtr(const ObjectPtr<U>& other) // NOLINT(*)
: ObjectPtr(other.data_) {
static_assert(std::is_base_of<T, U>::value,
"can only assign of child class ObjectPtr to parent");
}
/*!
* \brief move constructor
* \param other The value to be moved
*/
ObjectPtr(ObjectPtr<T>&& other) // NOLINT(*)
: data_(other.data_) {
other.data_ = nullptr;
}
/*!
* \brief move constructor
* \param other The value to be moved
*/
template <typename Y>
ObjectPtr(ObjectPtr<Y>&& other) // NOLINT(*)
: data_(other.data_) {
static_assert(std::is_base_of<T, Y>::value,
"can only assign of child class ObjectPtr to parent");
other.data_ = nullptr;
}
/*! \brief destructor */
~ObjectPtr() { this->reset(); }
/*!
* \brief Swap this array with another Object
* \param other The other Object
*/
void swap(ObjectPtr<T>& other) { // NOLINT(*)
std::swap(data_, other.data_);
}
/*!
* \return Get the content of the pointer
*/
T* get() const { return static_cast<T*>(data_); }
/*!
* \return The pointer
*/
T* operator->() const { return get(); }
/*!
* \return The reference
*/
T& operator*() const { // NOLINT(*)
return *get();
}
/*!
* \brief copy assignmemt
* \param other The value to be assigned.
* \return reference to self.
*/
ObjectPtr<T>& operator=(const ObjectPtr<T>& other) { // NOLINT(*)
// takes in plane operator to enable copy elison.
// copy-and-swap idiom
ObjectPtr(other).swap(*this); // NOLINT(*)
return *this;
}
/*!
* \brief move assignmemt
* \param other The value to be assigned.
* \return reference to self.
*/
ObjectPtr<T>& operator=(ObjectPtr<T>&& other) { // NOLINT(*)
// copy-and-swap idiom
ObjectPtr(std::move(other)).swap(*this); // NOLINT(*)
return *this;
}
/*! \brief reset the content of ptr to be nullptr */
void reset() {
if (data_ != nullptr) {
data_->DecRef();
data_ = nullptr;
}
}
/*! \return The use count of the ptr, for debug purposes */
int use_count() const { return data_ != nullptr ? data_->use_count() : 0; }
/*! \return whether the reference is unique */
bool unique() const { return data_ != nullptr && data_->use_count() == 1; }
/*! \return Whether two ObjectPtr do not equal each other */
bool operator==(const ObjectPtr<T>& other) const { return data_ == other.data_; }
/*! \return Whether two ObjectPtr equals each other */
bool operator!=(const ObjectPtr<T>& other) const { return data_ != other.data_; }
/*! \return Whether the pointer is nullptr */
bool operator==(std::nullptr_t null) const { return data_ == nullptr; }
/*! \return Whether the pointer is not nullptr */
bool operator!=(std::nullptr_t null) const { return data_ != nullptr; }
/* ObjectPtr's support custom allocators.
*
* The below allocator represents the simplest
* possible impl. It can be easily swapped
* for customized executor's, different allocation
* strategies, and so on.
*
* See memory.h for more discussion on NodePtr's
* allocator.
*/
class StdAllocator {
public:
template <typename... Args>
static T* New(Args&&... args) {
return new T(std::forward<Args>(args)...);
}
static ObjectCell::FDeleter Deleter() { return Deleter_; }
private:
static void Deleter_(ObjectCell* ptr) { delete static_cast<T*>(ptr); }
};
template <typename U>
ObjectPtr<U> As() const {
auto ptr = reinterpret_cast<U*>(get());
return ObjectPtr<U>(ptr);
}
private:
/*! \brief internal pointer field */
ObjectCell* data_{nullptr};
/*!
* \brief constructor from NodeBase
* \param data The node base pointer
*/
// TODO(jroesch): NodePtr design doesn't really work here due to the passing.
public:
explicit ObjectPtr(ObjectCell* data) : data_(data) {
if (data != nullptr) {
data_->IncRef();
}
}
private:
template <typename Y, typename... Args>
friend ObjectPtr<Y> MakeObject(Args&&...);
template <typename>
friend class ObjectPtr;
friend class NDArray;
friend class TVMPODValue_;
friend class TVMArgValue;
friend class TVMRetValue;
friend class RPCWrappedFunc;
};
struct TensorCell;
struct DatatypeCell;
struct ClosureCell;
/*!
* \brief A managed object in the TVM runtime.
*
* For example a tuple, list, closure, and so on.
*
* Maintains a reference count for the object.
*/
class Object {
public:
ObjectPtr<ObjectCell> ptr_;
explicit Object(ObjectPtr<ObjectCell> ptr) : ptr_(ptr) {}
explicit Object(ObjectCell* ptr) : ptr_(ptr) {}
Object() : ptr_() {}
Object(const Object& obj) : ptr_(obj.ptr_) {}
ObjectCell* operator->() { return this->ptr_.operator->(); }
/*! \brief Construct a tensor object. */
static Object Tensor(const NDArray& data);
/*! \brief Construct a datatype object. */
static Object Datatype(size_t tag, const std::vector<Object>& fields);
/*! \brief Construct a tuple object. */
static Object Tuple(const std::vector<Object>& fields);
/*! \brief Construct a closure object. */
static Object Closure(size_t func_index, const std::vector<Object>& free_vars);
ObjectPtr<TensorCell> AsTensor() const;
ObjectPtr<DatatypeCell> AsDatatype() const;
ObjectPtr<ClosureCell> AsClosure() const;
};
/*! \brief An object containing an NDArray. */
struct TensorCell : public ObjectCell {
/*! \brief The NDArray. */
NDArray data;
explicit TensorCell(const NDArray& data) : ObjectCell(ObjectTag::kTensor), data(data) {}
};
/*! \brief An object representing a structure or enumeration. */
struct DatatypeCell : public ObjectCell {
/*! \brief The tag representing the constructor used. */
size_t tag;
/*! \brief The fields of the structure. */
std::vector<Object> fields;
DatatypeCell(size_t tag, const std::vector<Object>& fields)
: ObjectCell(ObjectTag::kDatatype), tag(tag), fields(fields) {}
};
/*! \brief An object representing a closure. */
struct ClosureCell : public ObjectCell {
/*! \brief The index into the VM function table. */
size_t func_index;
/*! \brief The free variables of the closure. */
std::vector<Object> free_vars;
ClosureCell(size_t func_index, const std::vector<Object>& free_vars)
: ObjectCell(ObjectTag::kClosure), func_index(func_index), free_vars(free_vars) {}
};
/*! \brief Extract the NDArray from a tensor object. */
NDArray ToNDArray(const Object& obj);
/*!
* \brief Allocate a node object.
* \param args arguments to the constructor.
* \tparam T the node type.
* \return The NodePtr to the allocated object.
*/
template <typename T, typename... Args>
inline ObjectPtr<T> MakeObject(Args&&... args) {
using Allocator = typename ObjectPtr<T>::StdAllocator;
static_assert(std::is_base_of<ObjectCell, T>::value, "MakeObject can only be used to create ");
T* node = Allocator::New(std::forward<Args>(args)...);
node->deleter_ = Allocator::Deleter();
return ObjectPtr<T>(node);
}
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_OBJECT_H_
......@@ -39,6 +39,7 @@
#include "c_runtime_api.h"
#include "module.h"
#include "ndarray.h"
#include "object.h"
#include "node_base.h"
namespace HalideIR {
......@@ -48,6 +49,7 @@ struct Type;
struct Expr;
}
// Whether use TVM runtime in header only mode.
#ifndef TVM_RUNTIME_HEADER_ONLY
#define TVM_RUNTIME_HEADER_ONLY 0
......@@ -470,6 +472,11 @@ class TVMPODValue_ {
TVM_CHECK_TYPE_CODE(type_code_, kNDArrayContainer);
return NDArray(static_cast<NDArray::Container*>(value_.v_handle));
}
operator Object() const {
if (type_code_ == kNull) return Object();
TVM_CHECK_TYPE_CODE(type_code_, kObject);
return Object(static_cast<ObjectCell*>(value_.v_handle));
}
operator TVMContext() const {
TVM_CHECK_TYPE_CODE(type_code_, kTVMContext);
return value_.v_ctx;
......@@ -542,6 +549,7 @@ class TVMArgValue : public TVMPODValue_ {
using TVMPODValue_::operator DLTensor*;
using TVMPODValue_::operator NDArray;
using TVMPODValue_::operator TVMContext;
using TVMPODValue_::operator Object;
// conversion operator.
operator std::string() const {
......@@ -637,6 +645,7 @@ class TVMRetValue : public TVMPODValue_ {
using TVMPODValue_::operator DLTensor*;
using TVMPODValue_::operator TVMContext;
using TVMPODValue_::operator NDArray;
using TVMPODValue_::operator Object;
TVMRetValue(const TVMRetValue& other) : TVMPODValue_() {
this->Assign(other);
}
......@@ -733,6 +742,13 @@ class TVMRetValue : public TVMPODValue_ {
other.data_ = nullptr;
return *this;
}
TVMRetValue& operator=(Object other) {
this->Clear();
type_code_ = kObject;
value_.v_handle = other.ptr_.data_;
other.ptr_.data_ = nullptr;
return *this;
}
TVMRetValue& operator=(PackedFunc f) {
this->SwitchToClass(kFuncHandle, f);
return *this;
......@@ -828,6 +844,10 @@ class TVMRetValue : public TVMPODValue_ {
kNodeHandle, *other.template ptr<NodePtr<Node> >());
break;
}
case kObject: {
*this = other.operator Object();
break;
}
default: {
if (other.type_code() < kExtBegin) {
SwitchToPOD(other.type_code());
......@@ -875,6 +895,10 @@ class TVMRetValue : public TVMPODValue_ {
static_cast<NDArray::Container*>(value_.v_handle)->DecRef();
break;
}
case kObject: {
static_cast<ObjectCell*>(value_.v_handle)->DecRef();
break;
}
}
if (type_code_ > kExtBegin) {
#if TVM_RUNTIME_HEADER_ONLY
......@@ -904,6 +928,7 @@ inline const char* TypeCode2Str(int type_code) {
case kFuncHandle: return "FunctionHandle";
case kModuleHandle: return "ModuleHandle";
case kNDArrayContainer: return "NDArrayContainer";
case kObject: return "Object";
default: LOG(FATAL) << "unknown type_code="
<< static_cast<int>(type_code); return "";
}
......
......@@ -96,7 +96,7 @@ def config_cython():
library_dirs=library_dirs,
libraries=libraries,
language="c++"))
return cythonize(ret)
return cythonize(ret, compiler_directives={"language_level": 3})
except ImportError:
print("WARNING: Cython is not installed, will compile without cython module")
return []
......
......@@ -162,6 +162,9 @@ def _make_tvm_args(args, temp_args):
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.FUNC_HANDLE
temp_args.append(arg)
elif isinstance(arg, ObjectBase):
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.OBJECT
else:
raise TypeError("Don't know how to handle type %s" % type(arg))
return values, type_codes, num_args
......@@ -240,12 +243,18 @@ def _handle_return_func(x):
handle = FunctionHandle(handle)
return _CLASS_FUNCTION(handle, False)
class ObjectBase(object):
__slots__ = ["handle"]
def __init__(self, handle):
self.handle = handle
# setup return handle for function type
_node.__init_by_constructor__ = __init_handle_by_constructor__
RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func
RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module
RETURN_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handle, False, True)
RETURN_SWITCH[TypeCode.OBJECT] = lambda x: _CLASS_OBJECT(x.v_handle)
C_TO_PY_ARG_SWITCH[TypeCode.FUNC_HANDLE] = _wrap_arg_func(
_handle_return_func, TypeCode.FUNC_HANDLE)
C_TO_PY_ARG_SWITCH[TypeCode.MODULE_HANDLE] = _wrap_arg_func(
......@@ -255,6 +264,7 @@ C_TO_PY_ARG_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handl
_CLASS_MODULE = None
_CLASS_FUNCTION = None
_CLASS_OBJECT = None
def _set_class_module(module_class):
"""Initialize the module."""
......@@ -264,3 +274,7 @@ def _set_class_module(module_class):
def _set_class_function(func_class):
global _CLASS_FUNCTION
_CLASS_FUNCTION = func_class
def _set_class_object(obj_class):
global _CLASS_OBJECT
_CLASS_OBJECT = obj_class
......@@ -37,6 +37,7 @@ cdef enum TVMTypeCode:
kStr = 11
kBytes = 12
kNDArrayContainer = 13
kObject = 14
kExtBegin = 15
cdef extern from "tvm/runtime/c_runtime_api.h":
......@@ -76,6 +77,7 @@ ctypedef DLTensor* DLTensorHandle
ctypedef void* TVMStreamHandle
ctypedef void* TVMRetValueHandle
ctypedef void* TVMFunctionHandle
ctypedef void* ObjectHandle
ctypedef void* NodeHandle
ctypedef struct TVMNDArrayContainer:
......
......@@ -44,6 +44,7 @@ cdef int tvm_callback(TVMValue* args,
if (tcode == kNodeHandle or
tcode == kFuncHandle or
tcode == kModuleHandle or
tcode == kObject or
tcode > kExtBegin):
CALL(TVMCbArgToReturn(&value, tcode))
......@@ -157,6 +158,9 @@ cdef inline int make_arg(object arg,
elif isinstance(arg, _CLASS_MODULE):
value[0].v_handle = c_handle(arg.handle)
tcode[0] = kModuleHandle
elif isinstance(arg, _CLASS_OBJECT):
value[0].v_handle = c_handle(arg.handle)
tcode[0] = kObject
elif isinstance(arg, FunctionBase):
value[0].v_handle = (<FunctionBase>arg).chandle
tcode[0] = kFuncHandle
......@@ -208,6 +212,8 @@ cdef inline object make_ret(TVMValue value, int tcode):
fobj = _CLASS_FUNCTION(None, False)
(<FunctionBase>fobj).chandle = value.v_handle
return fobj
elif tcode == kObject:
return _CLASS_OBJECT(ctypes_handle(value.v_handle))
elif tcode in _TVM_EXT_RET:
return _TVM_EXT_RET[tcode](ctypes_handle(value.v_handle))
......@@ -304,8 +310,31 @@ cdef class FunctionBase:
FuncCall(self.chandle, args, &ret_val, &ret_tcode)
return make_ret(ret_val, ret_tcode)
cdef class ObjectBase:
cdef ObjectHandle chandle
cdef inline _set_handle(self, handle):
if handle is None:
self.chandle = NULL
else:
self.chandle = c_handle(handle)
property handle:
def __get__(self):
if self.chandle == NULL:
return None
else:
return ctypes.cast(<unsigned long long>self.chandle, ctypes.c_void_p)
def __set__(self, value):
self._set_handle(value)
def __init__(self, handle):
self._set_handle(handle)
_CLASS_FUNCTION = None
_CLASS_MODULE = None
_CLASS_OBJECT = None
def _set_class_module(module_class):
"""Initialize the module."""
......@@ -315,3 +344,7 @@ def _set_class_module(module_class):
def _set_class_function(func_class):
global _CLASS_FUNCTION
_CLASS_FUNCTION = func_class
def _set_class_object(obj_class):
global _CLASS_OBJECT
_CLASS_OBJECT = obj_class
......@@ -30,19 +30,28 @@ try:
if _FFI_MODE == "ctypes":
raise ImportError()
if sys.version_info >= (3, 0):
from ._cy3.core import _set_class_function, _set_class_module
from ._cy3.core import _set_class_function, _set_class_module, _set_class_object
from ._cy3.core import FunctionBase as _FunctionBase
from ._cy3.core import ObjectBase as _ObjectBase
from ._cy3.core import convert_to_tvm_func
else:
from ._cy2.core import _set_class_function, _set_class_module
from ._cy2.core import _set_class_function, _set_class_module, _set_class_object
from ._cy2.core import FunctionBase as _FunctionBase
from ._cy2.core import ObjectBase as _ObjectBase
from ._cy2.core import convert_to_tvm_func
except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position
from ._ctypes.function import _set_class_function, _set_class_module
from ._ctypes.function import _set_class_function, _set_class_module, _set_class_object
from ._ctypes.function import ObjectBase as _ObjectBase
from ._ctypes.function import FunctionBase as _FunctionBase
from ._ctypes.function import convert_to_tvm_func
class Object(_ObjectBase):
# TODO(@jroesch): Eventually add back introspection functionality.
pass
_set_class_object(Object)
FunctionHandle = ctypes.c_void_p
class Function(_FunctionBase):
......
......@@ -42,8 +42,10 @@ class TypeCode(object):
STR = 11
BYTES = 12
NDARRAY_CONTAINER = 13
OBJECT = 14
EXT_BEGIN = 15
class TVMByteArray(ctypes.Structure):
"""Temp data structure for byte array."""
_fields_ = [("data", ctypes.POINTER(ctypes.c_byte)),
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -92,6 +92,12 @@ struct APIAttrGetter : public AttrVisitor {
found_ref_object = true;
}
}
void Visit(const char* key, runtime::Object* value) final {
if (skey == key) {
*ret = value[0];
found_ref_object = true;
}
}
};
struct APIAttrDir : public AttrVisitor {
......@@ -127,6 +133,9 @@ struct APIAttrDir : public AttrVisitor {
void Visit(const char* key, runtime::NDArray* value) final {
names->push_back(key);
}
void Visit(const char* key, runtime::Object* value) final {
names->push_back(key);
}
};
class DSLAPIImpl : public DSLAPI {
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -53,6 +53,8 @@ inline Type String2Type(std::string s) {
return TVMType2Type(runtime::String2TVMType(s));
}
using runtime::Object;
using runtime::ObjectCell;
// indexer to index all the ndoes
class NodeIndexer : public AttrVisitor {
......@@ -61,6 +63,8 @@ class NodeIndexer : public AttrVisitor {
std::vector<Node*> node_list{nullptr};
std::unordered_map<DLTensor*, size_t> tensor_index;
std::vector<DLTensor*> tensor_list;
std::unordered_map<ObjectCell*, size_t> vm_obj_index;
std::vector<ObjectCell*> vm_obj_list;
void Visit(const char* key, double* value) final {}
void Visit(const char* key, int64_t* value) final {}
......@@ -73,6 +77,7 @@ class NodeIndexer : public AttrVisitor {
void Visit(const char* key, NodeRef* value) final {
MakeIndex(value->node_.get());
}
void Visit(const char* key, runtime::NDArray* value) final {
DLTensor* ptr = const_cast<DLTensor*>((*value).operator->());
if (tensor_index.count(ptr)) return;
......@@ -80,6 +85,15 @@ class NodeIndexer : public AttrVisitor {
tensor_index[ptr] = tensor_list.size();
tensor_list.push_back(ptr);
}
void Visit(const char* key, Object* value) final {
ObjectCell* ptr = value->ptr_.get();
if (vm_obj_index.count(ptr)) return;
CHECK_EQ(vm_obj_index.size(), vm_obj_list.size());
vm_obj_index[ptr] = vm_obj_list.size();
vm_obj_list.push_back(ptr);
}
// make index of all the children of node
void MakeIndex(Node* node) {
if (node == nullptr) return;
......@@ -163,6 +177,7 @@ class JSONAttrGetter : public AttrVisitor {
public:
const std::unordered_map<Node*, size_t>* node_index_;
const std::unordered_map<DLTensor*, size_t>* tensor_index_;
const std::unordered_map<ObjectCell*, size_t>* vm_obj_index_;
JSONNode* node_;
void Visit(const char* key, double* value) final {
......@@ -197,6 +212,10 @@ class JSONAttrGetter : public AttrVisitor {
node_->attrs[key] = std::to_string(
tensor_index_->at(const_cast<DLTensor*>((*value).operator->())));
}
void Visit(const char* key, Object* value) final {
node_->attrs[key] = std::to_string(
vm_obj_index_->at(value->ptr_.get()));
}
// Get the node
void Get(Node* node) {
if (node == nullptr) {
......@@ -250,6 +269,8 @@ class JSONAttrSetter : public AttrVisitor {
public:
const std::vector<NodePtr<Node> >* node_list_;
const std::vector<runtime::NDArray>* tensor_list_;
const std::vector<Object>* vm_obj_list_;
JSONNode* node_;
std::string GetValue(const char* key) const {
......@@ -304,6 +325,12 @@ class JSONAttrSetter : public AttrVisitor {
CHECK_LE(index, tensor_list_->size());
*value = tensor_list_->at(index);
}
void Visit(const char* key, Object* value) final {
size_t index;
ParseValue(key, &index);
CHECK_LE(index, vm_obj_list_->size());
*value = vm_obj_list_->at(index);
}
// set node to be current JSONNode
void Set(Node* node) {
if (node == nullptr) return;
......@@ -481,6 +508,9 @@ class NodeAttrSetter : public AttrVisitor {
void Visit(const char* key, runtime::NDArray* value) final {
*value = GetAttr(key).operator runtime::NDArray();
}
void Visit(const char* key, Object* value) final {
*value = GetAttr(key).operator Object();
}
private:
runtime::TVMArgValue GetAttr(const char* key) {
......
......@@ -775,6 +775,9 @@ class PrettyPrinter::AttrPrinter : public AttrVisitor {
void Visit(const char* key, runtime::NDArray* value) final {
LOG(FATAL) << "do not allow NDarray as argument";
}
void Visit(const char* key, runtime::Object* obj) final {
LOG(FATAL) << "do not allow Object as argument";
}
private:
Doc& doc_;
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file object.cc
* \brief A managed object in the TVM runtime.
*/
#include <tvm/logging.h>
#include <tvm/runtime/object.h>
#include <iostream>
namespace tvm {
namespace runtime {
std::ostream& operator<<(std::ostream& os, const ObjectTag& tag) {
switch (tag) {
case ObjectTag::kClosure:
os << "Closure";
break;
case ObjectTag::kDatatype:
os << "Datatype";
break;
case ObjectTag::kTensor:
os << "Tensor";
break;
case ObjectTag::kExternalFunc:
os << "ExternalFunction";
break;
default:
LOG(FATAL) << "Invalid object tag: found " << static_cast<int>(tag);
}
return os;
}
Object Object::Tensor(const NDArray& data) {
ObjectPtr<ObjectCell> ptr = MakeObject<TensorCell>(data);
return Object(ptr);
}
Object Object::Datatype(size_t tag, const std::vector<Object>& fields) {
ObjectPtr<ObjectCell> ptr = MakeObject<DatatypeCell>(tag, fields);
return Object(ptr);
}
Object Object::Tuple(const std::vector<Object>& fields) { return Object::Datatype(0, fields); }
Object Object::Closure(size_t func_index, const std::vector<Object>& free_vars) {
ObjectPtr<ObjectCell> ptr = MakeObject<ClosureCell>(func_index, free_vars);
return Object(ptr);
}
ObjectPtr<TensorCell> Object::AsTensor() const {
CHECK(ptr.get());
CHECK(ptr.get()->tag == ObjectTag::kTensor);
return ptr.As<TensorCell>();
}
ObjectPtr<DatatypeCell> Object::AsDatatype() const {
CHECK(ptr.get());
CHECK(ptr.get()->tag == ObjectTag::kDatatype);
return ptr.As<DatatypeCell>();
}
ObjectPtr<ClosureCell> Object::AsClosure() const {
CHECK(ptr.get());
CHECK(ptr.get()->tag == ObjectTag::kClosure);
return ptr.As<ClosureCell>();
}
NDArray ToNDArray(const Object& obj) {
auto tensor = obj.AsTensor();
return tensor->data;
}
} // namespace runtime
} // namespace tvm
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment