Commit dd55682d by Jared Roesch Committed by Tianqi Chen

[Relay][Runtime] Add support for virtual machine Objects (#3120)

parent b175319c
Subproject commit 55ba1778fd264c7507953552d8e51212ed11f748 Subproject commit a768f2f0627917659a4d7167eee3190469b9d164
...@@ -112,7 +112,8 @@ typedef enum { ...@@ -112,7 +112,8 @@ typedef enum {
kNNVMLast = 20U, kNNVMLast = 20U,
// The following section of code is used for non-reserved types. // The following section of code is used for non-reserved types.
kExtReserveEnd = 64U, kExtReserveEnd = 64U,
kExtEnd = 128U kExtEnd = 128U,
kObject = 14U,
} TVMTypeCode; } TVMTypeCode;
/*! /*!
......
...@@ -39,6 +39,7 @@ ...@@ -39,6 +39,7 @@
#include "c_runtime_api.h" #include "c_runtime_api.h"
#include "module.h" #include "module.h"
#include "ndarray.h" #include "ndarray.h"
#include "object.h"
#include "node_base.h" #include "node_base.h"
namespace HalideIR { namespace HalideIR {
...@@ -48,6 +49,7 @@ struct Type; ...@@ -48,6 +49,7 @@ struct Type;
struct Expr; struct Expr;
} }
// Whether use TVM runtime in header only mode. // Whether use TVM runtime in header only mode.
#ifndef TVM_RUNTIME_HEADER_ONLY #ifndef TVM_RUNTIME_HEADER_ONLY
#define TVM_RUNTIME_HEADER_ONLY 0 #define TVM_RUNTIME_HEADER_ONLY 0
...@@ -470,6 +472,11 @@ class TVMPODValue_ { ...@@ -470,6 +472,11 @@ class TVMPODValue_ {
TVM_CHECK_TYPE_CODE(type_code_, kNDArrayContainer); TVM_CHECK_TYPE_CODE(type_code_, kNDArrayContainer);
return NDArray(static_cast<NDArray::Container*>(value_.v_handle)); return NDArray(static_cast<NDArray::Container*>(value_.v_handle));
} }
operator Object() const {
if (type_code_ == kNull) return Object();
TVM_CHECK_TYPE_CODE(type_code_, kObject);
return Object(static_cast<ObjectCell*>(value_.v_handle));
}
operator TVMContext() const { operator TVMContext() const {
TVM_CHECK_TYPE_CODE(type_code_, kTVMContext); TVM_CHECK_TYPE_CODE(type_code_, kTVMContext);
return value_.v_ctx; return value_.v_ctx;
...@@ -542,6 +549,7 @@ class TVMArgValue : public TVMPODValue_ { ...@@ -542,6 +549,7 @@ class TVMArgValue : public TVMPODValue_ {
using TVMPODValue_::operator DLTensor*; using TVMPODValue_::operator DLTensor*;
using TVMPODValue_::operator NDArray; using TVMPODValue_::operator NDArray;
using TVMPODValue_::operator TVMContext; using TVMPODValue_::operator TVMContext;
using TVMPODValue_::operator Object;
// conversion operator. // conversion operator.
operator std::string() const { operator std::string() const {
...@@ -637,6 +645,7 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -637,6 +645,7 @@ class TVMRetValue : public TVMPODValue_ {
using TVMPODValue_::operator DLTensor*; using TVMPODValue_::operator DLTensor*;
using TVMPODValue_::operator TVMContext; using TVMPODValue_::operator TVMContext;
using TVMPODValue_::operator NDArray; using TVMPODValue_::operator NDArray;
using TVMPODValue_::operator Object;
TVMRetValue(const TVMRetValue& other) : TVMPODValue_() { TVMRetValue(const TVMRetValue& other) : TVMPODValue_() {
this->Assign(other); this->Assign(other);
} }
...@@ -733,6 +742,13 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -733,6 +742,13 @@ class TVMRetValue : public TVMPODValue_ {
other.data_ = nullptr; other.data_ = nullptr;
return *this; return *this;
} }
TVMRetValue& operator=(Object other) {
this->Clear();
type_code_ = kObject;
value_.v_handle = other.ptr_.data_;
other.ptr_.data_ = nullptr;
return *this;
}
TVMRetValue& operator=(PackedFunc f) { TVMRetValue& operator=(PackedFunc f) {
this->SwitchToClass(kFuncHandle, f); this->SwitchToClass(kFuncHandle, f);
return *this; return *this;
...@@ -828,6 +844,10 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -828,6 +844,10 @@ class TVMRetValue : public TVMPODValue_ {
kNodeHandle, *other.template ptr<NodePtr<Node> >()); kNodeHandle, *other.template ptr<NodePtr<Node> >());
break; break;
} }
case kObject: {
*this = other.operator Object();
break;
}
default: { default: {
if (other.type_code() < kExtBegin) { if (other.type_code() < kExtBegin) {
SwitchToPOD(other.type_code()); SwitchToPOD(other.type_code());
...@@ -875,6 +895,10 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -875,6 +895,10 @@ class TVMRetValue : public TVMPODValue_ {
static_cast<NDArray::Container*>(value_.v_handle)->DecRef(); static_cast<NDArray::Container*>(value_.v_handle)->DecRef();
break; break;
} }
case kObject: {
static_cast<ObjectCell*>(value_.v_handle)->DecRef();
break;
}
} }
if (type_code_ > kExtBegin) { if (type_code_ > kExtBegin) {
#if TVM_RUNTIME_HEADER_ONLY #if TVM_RUNTIME_HEADER_ONLY
...@@ -904,6 +928,7 @@ inline const char* TypeCode2Str(int type_code) { ...@@ -904,6 +928,7 @@ inline const char* TypeCode2Str(int type_code) {
case kFuncHandle: return "FunctionHandle"; case kFuncHandle: return "FunctionHandle";
case kModuleHandle: return "ModuleHandle"; case kModuleHandle: return "ModuleHandle";
case kNDArrayContainer: return "NDArrayContainer"; case kNDArrayContainer: return "NDArrayContainer";
case kObject: return "Object";
default: LOG(FATAL) << "unknown type_code=" default: LOG(FATAL) << "unknown type_code="
<< static_cast<int>(type_code); return ""; << static_cast<int>(type_code); return "";
} }
......
...@@ -96,7 +96,7 @@ def config_cython(): ...@@ -96,7 +96,7 @@ def config_cython():
library_dirs=library_dirs, library_dirs=library_dirs,
libraries=libraries, libraries=libraries,
language="c++")) language="c++"))
return cythonize(ret) return cythonize(ret, compiler_directives={"language_level": 3})
except ImportError: except ImportError:
print("WARNING: Cython is not installed, will compile without cython module") print("WARNING: Cython is not installed, will compile without cython module")
return [] return []
......
...@@ -162,6 +162,9 @@ def _make_tvm_args(args, temp_args): ...@@ -162,6 +162,9 @@ def _make_tvm_args(args, temp_args):
values[i].v_handle = arg.handle values[i].v_handle = arg.handle
type_codes[i] = TypeCode.FUNC_HANDLE type_codes[i] = TypeCode.FUNC_HANDLE
temp_args.append(arg) temp_args.append(arg)
elif isinstance(arg, ObjectBase):
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.OBJECT
else: else:
raise TypeError("Don't know how to handle type %s" % type(arg)) raise TypeError("Don't know how to handle type %s" % type(arg))
return values, type_codes, num_args return values, type_codes, num_args
...@@ -240,12 +243,18 @@ def _handle_return_func(x): ...@@ -240,12 +243,18 @@ def _handle_return_func(x):
handle = FunctionHandle(handle) handle = FunctionHandle(handle)
return _CLASS_FUNCTION(handle, False) return _CLASS_FUNCTION(handle, False)
class ObjectBase(object):
__slots__ = ["handle"]
def __init__(self, handle):
self.handle = handle
# setup return handle for function type # setup return handle for function type
_node.__init_by_constructor__ = __init_handle_by_constructor__ _node.__init_by_constructor__ = __init_handle_by_constructor__
RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func
RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module
RETURN_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handle, False, True) RETURN_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handle, False, True)
RETURN_SWITCH[TypeCode.OBJECT] = lambda x: _CLASS_OBJECT(x.v_handle)
C_TO_PY_ARG_SWITCH[TypeCode.FUNC_HANDLE] = _wrap_arg_func( C_TO_PY_ARG_SWITCH[TypeCode.FUNC_HANDLE] = _wrap_arg_func(
_handle_return_func, TypeCode.FUNC_HANDLE) _handle_return_func, TypeCode.FUNC_HANDLE)
C_TO_PY_ARG_SWITCH[TypeCode.MODULE_HANDLE] = _wrap_arg_func( C_TO_PY_ARG_SWITCH[TypeCode.MODULE_HANDLE] = _wrap_arg_func(
...@@ -255,6 +264,7 @@ C_TO_PY_ARG_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handl ...@@ -255,6 +264,7 @@ C_TO_PY_ARG_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handl
_CLASS_MODULE = None _CLASS_MODULE = None
_CLASS_FUNCTION = None _CLASS_FUNCTION = None
_CLASS_OBJECT = None
def _set_class_module(module_class): def _set_class_module(module_class):
"""Initialize the module.""" """Initialize the module."""
...@@ -264,3 +274,7 @@ def _set_class_module(module_class): ...@@ -264,3 +274,7 @@ def _set_class_module(module_class):
def _set_class_function(func_class): def _set_class_function(func_class):
global _CLASS_FUNCTION global _CLASS_FUNCTION
_CLASS_FUNCTION = func_class _CLASS_FUNCTION = func_class
def _set_class_object(obj_class):
global _CLASS_OBJECT
_CLASS_OBJECT = obj_class
...@@ -37,6 +37,7 @@ cdef enum TVMTypeCode: ...@@ -37,6 +37,7 @@ cdef enum TVMTypeCode:
kStr = 11 kStr = 11
kBytes = 12 kBytes = 12
kNDArrayContainer = 13 kNDArrayContainer = 13
kObject = 14
kExtBegin = 15 kExtBegin = 15
cdef extern from "tvm/runtime/c_runtime_api.h": cdef extern from "tvm/runtime/c_runtime_api.h":
...@@ -76,6 +77,7 @@ ctypedef DLTensor* DLTensorHandle ...@@ -76,6 +77,7 @@ ctypedef DLTensor* DLTensorHandle
ctypedef void* TVMStreamHandle ctypedef void* TVMStreamHandle
ctypedef void* TVMRetValueHandle ctypedef void* TVMRetValueHandle
ctypedef void* TVMFunctionHandle ctypedef void* TVMFunctionHandle
ctypedef void* ObjectHandle
ctypedef void* NodeHandle ctypedef void* NodeHandle
ctypedef struct TVMNDArrayContainer: ctypedef struct TVMNDArrayContainer:
......
...@@ -44,6 +44,7 @@ cdef int tvm_callback(TVMValue* args, ...@@ -44,6 +44,7 @@ cdef int tvm_callback(TVMValue* args,
if (tcode == kNodeHandle or if (tcode == kNodeHandle or
tcode == kFuncHandle or tcode == kFuncHandle or
tcode == kModuleHandle or tcode == kModuleHandle or
tcode == kObject or
tcode > kExtBegin): tcode > kExtBegin):
CALL(TVMCbArgToReturn(&value, tcode)) CALL(TVMCbArgToReturn(&value, tcode))
...@@ -157,6 +158,9 @@ cdef inline int make_arg(object arg, ...@@ -157,6 +158,9 @@ cdef inline int make_arg(object arg,
elif isinstance(arg, _CLASS_MODULE): elif isinstance(arg, _CLASS_MODULE):
value[0].v_handle = c_handle(arg.handle) value[0].v_handle = c_handle(arg.handle)
tcode[0] = kModuleHandle tcode[0] = kModuleHandle
elif isinstance(arg, _CLASS_OBJECT):
value[0].v_handle = c_handle(arg.handle)
tcode[0] = kObject
elif isinstance(arg, FunctionBase): elif isinstance(arg, FunctionBase):
value[0].v_handle = (<FunctionBase>arg).chandle value[0].v_handle = (<FunctionBase>arg).chandle
tcode[0] = kFuncHandle tcode[0] = kFuncHandle
...@@ -208,6 +212,8 @@ cdef inline object make_ret(TVMValue value, int tcode): ...@@ -208,6 +212,8 @@ cdef inline object make_ret(TVMValue value, int tcode):
fobj = _CLASS_FUNCTION(None, False) fobj = _CLASS_FUNCTION(None, False)
(<FunctionBase>fobj).chandle = value.v_handle (<FunctionBase>fobj).chandle = value.v_handle
return fobj return fobj
elif tcode == kObject:
return _CLASS_OBJECT(ctypes_handle(value.v_handle))
elif tcode in _TVM_EXT_RET: elif tcode in _TVM_EXT_RET:
return _TVM_EXT_RET[tcode](ctypes_handle(value.v_handle)) return _TVM_EXT_RET[tcode](ctypes_handle(value.v_handle))
...@@ -304,8 +310,31 @@ cdef class FunctionBase: ...@@ -304,8 +310,31 @@ cdef class FunctionBase:
FuncCall(self.chandle, args, &ret_val, &ret_tcode) FuncCall(self.chandle, args, &ret_val, &ret_tcode)
return make_ret(ret_val, ret_tcode) return make_ret(ret_val, ret_tcode)
cdef class ObjectBase:
cdef ObjectHandle chandle
cdef inline _set_handle(self, handle):
if handle is None:
self.chandle = NULL
else:
self.chandle = c_handle(handle)
property handle:
def __get__(self):
if self.chandle == NULL:
return None
else:
return ctypes.cast(<unsigned long long>self.chandle, ctypes.c_void_p)
def __set__(self, value):
self._set_handle(value)
def __init__(self, handle):
self._set_handle(handle)
_CLASS_FUNCTION = None _CLASS_FUNCTION = None
_CLASS_MODULE = None _CLASS_MODULE = None
_CLASS_OBJECT = None
def _set_class_module(module_class): def _set_class_module(module_class):
"""Initialize the module.""" """Initialize the module."""
...@@ -315,3 +344,7 @@ def _set_class_module(module_class): ...@@ -315,3 +344,7 @@ def _set_class_module(module_class):
def _set_class_function(func_class): def _set_class_function(func_class):
global _CLASS_FUNCTION global _CLASS_FUNCTION
_CLASS_FUNCTION = func_class _CLASS_FUNCTION = func_class
def _set_class_object(obj_class):
global _CLASS_OBJECT
_CLASS_OBJECT = obj_class
...@@ -30,19 +30,28 @@ try: ...@@ -30,19 +30,28 @@ try:
if _FFI_MODE == "ctypes": if _FFI_MODE == "ctypes":
raise ImportError() raise ImportError()
if sys.version_info >= (3, 0): if sys.version_info >= (3, 0):
from ._cy3.core import _set_class_function, _set_class_module from ._cy3.core import _set_class_function, _set_class_module, _set_class_object
from ._cy3.core import FunctionBase as _FunctionBase from ._cy3.core import FunctionBase as _FunctionBase
from ._cy3.core import ObjectBase as _ObjectBase
from ._cy3.core import convert_to_tvm_func from ._cy3.core import convert_to_tvm_func
else: else:
from ._cy2.core import _set_class_function, _set_class_module from ._cy2.core import _set_class_function, _set_class_module, _set_class_object
from ._cy2.core import FunctionBase as _FunctionBase from ._cy2.core import FunctionBase as _FunctionBase
from ._cy2.core import ObjectBase as _ObjectBase
from ._cy2.core import convert_to_tvm_func from ._cy2.core import convert_to_tvm_func
except IMPORT_EXCEPT: except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position # pylint: disable=wrong-import-position
from ._ctypes.function import _set_class_function, _set_class_module from ._ctypes.function import _set_class_function, _set_class_module, _set_class_object
from ._ctypes.function import ObjectBase as _ObjectBase
from ._ctypes.function import FunctionBase as _FunctionBase from ._ctypes.function import FunctionBase as _FunctionBase
from ._ctypes.function import convert_to_tvm_func from ._ctypes.function import convert_to_tvm_func
class Object(_ObjectBase):
# TODO(@jroesch): Eventually add back introspection functionality.
pass
_set_class_object(Object)
FunctionHandle = ctypes.c_void_p FunctionHandle = ctypes.c_void_p
class Function(_FunctionBase): class Function(_FunctionBase):
......
...@@ -42,8 +42,10 @@ class TypeCode(object): ...@@ -42,8 +42,10 @@ class TypeCode(object):
STR = 11 STR = 11
BYTES = 12 BYTES = 12
NDARRAY_CONTAINER = 13 NDARRAY_CONTAINER = 13
OBJECT = 14
EXT_BEGIN = 15 EXT_BEGIN = 15
class TVMByteArray(ctypes.Structure): class TVMByteArray(ctypes.Structure):
"""Temp data structure for byte array.""" """Temp data structure for byte array."""
_fields_ = [("data", ctypes.POINTER(ctypes.c_byte)), _fields_ = [("data", ctypes.POINTER(ctypes.c_byte)),
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -92,6 +92,12 @@ struct APIAttrGetter : public AttrVisitor { ...@@ -92,6 +92,12 @@ struct APIAttrGetter : public AttrVisitor {
found_ref_object = true; found_ref_object = true;
} }
} }
void Visit(const char* key, runtime::Object* value) final {
if (skey == key) {
*ret = value[0];
found_ref_object = true;
}
}
}; };
struct APIAttrDir : public AttrVisitor { struct APIAttrDir : public AttrVisitor {
...@@ -127,6 +133,9 @@ struct APIAttrDir : public AttrVisitor { ...@@ -127,6 +133,9 @@ struct APIAttrDir : public AttrVisitor {
void Visit(const char* key, runtime::NDArray* value) final { void Visit(const char* key, runtime::NDArray* value) final {
names->push_back(key); names->push_back(key);
} }
void Visit(const char* key, runtime::Object* value) final {
names->push_back(key);
}
}; };
class DSLAPIImpl : public DSLAPI { class DSLAPIImpl : public DSLAPI {
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -53,6 +53,8 @@ inline Type String2Type(std::string s) { ...@@ -53,6 +53,8 @@ inline Type String2Type(std::string s) {
return TVMType2Type(runtime::String2TVMType(s)); return TVMType2Type(runtime::String2TVMType(s));
} }
using runtime::Object;
using runtime::ObjectCell;
// indexer to index all the ndoes // indexer to index all the ndoes
class NodeIndexer : public AttrVisitor { class NodeIndexer : public AttrVisitor {
...@@ -61,6 +63,8 @@ class NodeIndexer : public AttrVisitor { ...@@ -61,6 +63,8 @@ class NodeIndexer : public AttrVisitor {
std::vector<Node*> node_list{nullptr}; std::vector<Node*> node_list{nullptr};
std::unordered_map<DLTensor*, size_t> tensor_index; std::unordered_map<DLTensor*, size_t> tensor_index;
std::vector<DLTensor*> tensor_list; std::vector<DLTensor*> tensor_list;
std::unordered_map<ObjectCell*, size_t> vm_obj_index;
std::vector<ObjectCell*> vm_obj_list;
void Visit(const char* key, double* value) final {} void Visit(const char* key, double* value) final {}
void Visit(const char* key, int64_t* value) final {} void Visit(const char* key, int64_t* value) final {}
...@@ -73,6 +77,7 @@ class NodeIndexer : public AttrVisitor { ...@@ -73,6 +77,7 @@ class NodeIndexer : public AttrVisitor {
void Visit(const char* key, NodeRef* value) final { void Visit(const char* key, NodeRef* value) final {
MakeIndex(value->node_.get()); MakeIndex(value->node_.get());
} }
void Visit(const char* key, runtime::NDArray* value) final { void Visit(const char* key, runtime::NDArray* value) final {
DLTensor* ptr = const_cast<DLTensor*>((*value).operator->()); DLTensor* ptr = const_cast<DLTensor*>((*value).operator->());
if (tensor_index.count(ptr)) return; if (tensor_index.count(ptr)) return;
...@@ -80,6 +85,15 @@ class NodeIndexer : public AttrVisitor { ...@@ -80,6 +85,15 @@ class NodeIndexer : public AttrVisitor {
tensor_index[ptr] = tensor_list.size(); tensor_index[ptr] = tensor_list.size();
tensor_list.push_back(ptr); tensor_list.push_back(ptr);
} }
void Visit(const char* key, Object* value) final {
ObjectCell* ptr = value->ptr_.get();
if (vm_obj_index.count(ptr)) return;
CHECK_EQ(vm_obj_index.size(), vm_obj_list.size());
vm_obj_index[ptr] = vm_obj_list.size();
vm_obj_list.push_back(ptr);
}
// make index of all the children of node // make index of all the children of node
void MakeIndex(Node* node) { void MakeIndex(Node* node) {
if (node == nullptr) return; if (node == nullptr) return;
...@@ -163,6 +177,7 @@ class JSONAttrGetter : public AttrVisitor { ...@@ -163,6 +177,7 @@ class JSONAttrGetter : public AttrVisitor {
public: public:
const std::unordered_map<Node*, size_t>* node_index_; const std::unordered_map<Node*, size_t>* node_index_;
const std::unordered_map<DLTensor*, size_t>* tensor_index_; const std::unordered_map<DLTensor*, size_t>* tensor_index_;
const std::unordered_map<ObjectCell*, size_t>* vm_obj_index_;
JSONNode* node_; JSONNode* node_;
void Visit(const char* key, double* value) final { void Visit(const char* key, double* value) final {
...@@ -197,6 +212,10 @@ class JSONAttrGetter : public AttrVisitor { ...@@ -197,6 +212,10 @@ class JSONAttrGetter : public AttrVisitor {
node_->attrs[key] = std::to_string( node_->attrs[key] = std::to_string(
tensor_index_->at(const_cast<DLTensor*>((*value).operator->()))); tensor_index_->at(const_cast<DLTensor*>((*value).operator->())));
} }
void Visit(const char* key, Object* value) final {
node_->attrs[key] = std::to_string(
vm_obj_index_->at(value->ptr_.get()));
}
// Get the node // Get the node
void Get(Node* node) { void Get(Node* node) {
if (node == nullptr) { if (node == nullptr) {
...@@ -250,6 +269,8 @@ class JSONAttrSetter : public AttrVisitor { ...@@ -250,6 +269,8 @@ class JSONAttrSetter : public AttrVisitor {
public: public:
const std::vector<NodePtr<Node> >* node_list_; const std::vector<NodePtr<Node> >* node_list_;
const std::vector<runtime::NDArray>* tensor_list_; const std::vector<runtime::NDArray>* tensor_list_;
const std::vector<Object>* vm_obj_list_;
JSONNode* node_; JSONNode* node_;
std::string GetValue(const char* key) const { std::string GetValue(const char* key) const {
...@@ -304,6 +325,12 @@ class JSONAttrSetter : public AttrVisitor { ...@@ -304,6 +325,12 @@ class JSONAttrSetter : public AttrVisitor {
CHECK_LE(index, tensor_list_->size()); CHECK_LE(index, tensor_list_->size());
*value = tensor_list_->at(index); *value = tensor_list_->at(index);
} }
void Visit(const char* key, Object* value) final {
size_t index;
ParseValue(key, &index);
CHECK_LE(index, vm_obj_list_->size());
*value = vm_obj_list_->at(index);
}
// set node to be current JSONNode // set node to be current JSONNode
void Set(Node* node) { void Set(Node* node) {
if (node == nullptr) return; if (node == nullptr) return;
...@@ -481,6 +508,9 @@ class NodeAttrSetter : public AttrVisitor { ...@@ -481,6 +508,9 @@ class NodeAttrSetter : public AttrVisitor {
void Visit(const char* key, runtime::NDArray* value) final { void Visit(const char* key, runtime::NDArray* value) final {
*value = GetAttr(key).operator runtime::NDArray(); *value = GetAttr(key).operator runtime::NDArray();
} }
void Visit(const char* key, Object* value) final {
*value = GetAttr(key).operator Object();
}
private: private:
runtime::TVMArgValue GetAttr(const char* key) { runtime::TVMArgValue GetAttr(const char* key) {
......
...@@ -775,6 +775,9 @@ class PrettyPrinter::AttrPrinter : public AttrVisitor { ...@@ -775,6 +775,9 @@ class PrettyPrinter::AttrPrinter : public AttrVisitor {
void Visit(const char* key, runtime::NDArray* value) final { void Visit(const char* key, runtime::NDArray* value) final {
LOG(FATAL) << "do not allow NDarray as argument"; LOG(FATAL) << "do not allow NDarray as argument";
} }
void Visit(const char* key, runtime::Object* obj) final {
LOG(FATAL) << "do not allow Object as argument";
}
private: private:
Doc& doc_; Doc& doc_;
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file object.cc
* \brief A managed object in the TVM runtime.
*/
#include <tvm/logging.h>
#include <tvm/runtime/object.h>
#include <iostream>
namespace tvm {
namespace runtime {
std::ostream& operator<<(std::ostream& os, const ObjectTag& tag) {
switch (tag) {
case ObjectTag::kClosure:
os << "Closure";
break;
case ObjectTag::kDatatype:
os << "Datatype";
break;
case ObjectTag::kTensor:
os << "Tensor";
break;
case ObjectTag::kExternalFunc:
os << "ExternalFunction";
break;
default:
LOG(FATAL) << "Invalid object tag: found " << static_cast<int>(tag);
}
return os;
}
Object Object::Tensor(const NDArray& data) {
ObjectPtr<ObjectCell> ptr = MakeObject<TensorCell>(data);
return Object(ptr);
}
Object Object::Datatype(size_t tag, const std::vector<Object>& fields) {
ObjectPtr<ObjectCell> ptr = MakeObject<DatatypeCell>(tag, fields);
return Object(ptr);
}
Object Object::Tuple(const std::vector<Object>& fields) { return Object::Datatype(0, fields); }
Object Object::Closure(size_t func_index, const std::vector<Object>& free_vars) {
ObjectPtr<ObjectCell> ptr = MakeObject<ClosureCell>(func_index, free_vars);
return Object(ptr);
}
ObjectPtr<TensorCell> Object::AsTensor() const {
CHECK(ptr.get());
CHECK(ptr.get()->tag == ObjectTag::kTensor);
return ptr.As<TensorCell>();
}
ObjectPtr<DatatypeCell> Object::AsDatatype() const {
CHECK(ptr.get());
CHECK(ptr.get()->tag == ObjectTag::kDatatype);
return ptr.As<DatatypeCell>();
}
ObjectPtr<ClosureCell> Object::AsClosure() const {
CHECK(ptr.get());
CHECK(ptr.get()->tag == ObjectTag::kClosure);
return ptr.As<ClosureCell>();
}
NDArray ToNDArray(const Object& obj) {
auto tensor = obj.AsTensor();
return tensor->data;
}
} // namespace runtime
} // namespace tvm
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment