Commit dd55682d by Jared Roesch Committed by Tianqi Chen

[Relay][Runtime] Add support for virtual machine Objects (#3120)

parent b175319c
Subproject commit 55ba1778fd264c7507953552d8e51212ed11f748
Subproject commit a768f2f0627917659a4d7167eee3190469b9d164
......@@ -112,7 +112,8 @@ typedef enum {
kNNVMLast = 20U,
// The following section of code is used for non-reserved types.
kExtReserveEnd = 64U,
kExtEnd = 128U
kExtEnd = 128U,
kObject = 14U,
} TVMTypeCode;
/*!
......
......@@ -39,6 +39,7 @@
#include "c_runtime_api.h"
#include "module.h"
#include "ndarray.h"
#include "object.h"
#include "node_base.h"
namespace HalideIR {
......@@ -48,6 +49,7 @@ struct Type;
struct Expr;
}
// Whether use TVM runtime in header only mode.
#ifndef TVM_RUNTIME_HEADER_ONLY
#define TVM_RUNTIME_HEADER_ONLY 0
......@@ -470,6 +472,11 @@ class TVMPODValue_ {
TVM_CHECK_TYPE_CODE(type_code_, kNDArrayContainer);
return NDArray(static_cast<NDArray::Container*>(value_.v_handle));
}
operator Object() const {
if (type_code_ == kNull) return Object();
TVM_CHECK_TYPE_CODE(type_code_, kObject);
return Object(static_cast<ObjectCell*>(value_.v_handle));
}
operator TVMContext() const {
TVM_CHECK_TYPE_CODE(type_code_, kTVMContext);
return value_.v_ctx;
......@@ -542,6 +549,7 @@ class TVMArgValue : public TVMPODValue_ {
using TVMPODValue_::operator DLTensor*;
using TVMPODValue_::operator NDArray;
using TVMPODValue_::operator TVMContext;
using TVMPODValue_::operator Object;
// conversion operator.
operator std::string() const {
......@@ -637,6 +645,7 @@ class TVMRetValue : public TVMPODValue_ {
using TVMPODValue_::operator DLTensor*;
using TVMPODValue_::operator TVMContext;
using TVMPODValue_::operator NDArray;
using TVMPODValue_::operator Object;
TVMRetValue(const TVMRetValue& other) : TVMPODValue_() {
this->Assign(other);
}
......@@ -733,6 +742,13 @@ class TVMRetValue : public TVMPODValue_ {
other.data_ = nullptr;
return *this;
}
TVMRetValue& operator=(Object other) {
this->Clear();
type_code_ = kObject;
value_.v_handle = other.ptr_.data_;
other.ptr_.data_ = nullptr;
return *this;
}
TVMRetValue& operator=(PackedFunc f) {
this->SwitchToClass(kFuncHandle, f);
return *this;
......@@ -828,6 +844,10 @@ class TVMRetValue : public TVMPODValue_ {
kNodeHandle, *other.template ptr<NodePtr<Node> >());
break;
}
case kObject: {
*this = other.operator Object();
break;
}
default: {
if (other.type_code() < kExtBegin) {
SwitchToPOD(other.type_code());
......@@ -875,6 +895,10 @@ class TVMRetValue : public TVMPODValue_ {
static_cast<NDArray::Container*>(value_.v_handle)->DecRef();
break;
}
case kObject: {
static_cast<ObjectCell*>(value_.v_handle)->DecRef();
break;
}
}
if (type_code_ > kExtBegin) {
#if TVM_RUNTIME_HEADER_ONLY
......@@ -904,6 +928,7 @@ inline const char* TypeCode2Str(int type_code) {
case kFuncHandle: return "FunctionHandle";
case kModuleHandle: return "ModuleHandle";
case kNDArrayContainer: return "NDArrayContainer";
case kObject: return "Object";
default: LOG(FATAL) << "unknown type_code="
<< static_cast<int>(type_code); return "";
}
......
......@@ -96,7 +96,7 @@ def config_cython():
library_dirs=library_dirs,
libraries=libraries,
language="c++"))
return cythonize(ret)
return cythonize(ret, compiler_directives={"language_level": 3})
except ImportError:
print("WARNING: Cython is not installed, will compile without cython module")
return []
......
......@@ -162,6 +162,9 @@ def _make_tvm_args(args, temp_args):
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.FUNC_HANDLE
temp_args.append(arg)
elif isinstance(arg, ObjectBase):
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.OBJECT
else:
raise TypeError("Don't know how to handle type %s" % type(arg))
return values, type_codes, num_args
......@@ -240,12 +243,18 @@ def _handle_return_func(x):
handle = FunctionHandle(handle)
return _CLASS_FUNCTION(handle, False)
class ObjectBase(object):
__slots__ = ["handle"]
def __init__(self, handle):
self.handle = handle
# setup return handle for function type
_node.__init_by_constructor__ = __init_handle_by_constructor__
RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func
RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module
RETURN_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handle, False, True)
RETURN_SWITCH[TypeCode.OBJECT] = lambda x: _CLASS_OBJECT(x.v_handle)
C_TO_PY_ARG_SWITCH[TypeCode.FUNC_HANDLE] = _wrap_arg_func(
_handle_return_func, TypeCode.FUNC_HANDLE)
C_TO_PY_ARG_SWITCH[TypeCode.MODULE_HANDLE] = _wrap_arg_func(
......@@ -255,6 +264,7 @@ C_TO_PY_ARG_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handl
_CLASS_MODULE = None
_CLASS_FUNCTION = None
_CLASS_OBJECT = None
def _set_class_module(module_class):
"""Initialize the module."""
......@@ -264,3 +274,7 @@ def _set_class_module(module_class):
def _set_class_function(func_class):
global _CLASS_FUNCTION
_CLASS_FUNCTION = func_class
def _set_class_object(obj_class):
global _CLASS_OBJECT
_CLASS_OBJECT = obj_class
......@@ -37,6 +37,7 @@ cdef enum TVMTypeCode:
kStr = 11
kBytes = 12
kNDArrayContainer = 13
kObject = 14
kExtBegin = 15
cdef extern from "tvm/runtime/c_runtime_api.h":
......@@ -76,6 +77,7 @@ ctypedef DLTensor* DLTensorHandle
ctypedef void* TVMStreamHandle
ctypedef void* TVMRetValueHandle
ctypedef void* TVMFunctionHandle
ctypedef void* ObjectHandle
ctypedef void* NodeHandle
ctypedef struct TVMNDArrayContainer:
......
......@@ -44,6 +44,7 @@ cdef int tvm_callback(TVMValue* args,
if (tcode == kNodeHandle or
tcode == kFuncHandle or
tcode == kModuleHandle or
tcode == kObject or
tcode > kExtBegin):
CALL(TVMCbArgToReturn(&value, tcode))
......@@ -157,6 +158,9 @@ cdef inline int make_arg(object arg,
elif isinstance(arg, _CLASS_MODULE):
value[0].v_handle = c_handle(arg.handle)
tcode[0] = kModuleHandle
elif isinstance(arg, _CLASS_OBJECT):
value[0].v_handle = c_handle(arg.handle)
tcode[0] = kObject
elif isinstance(arg, FunctionBase):
value[0].v_handle = (<FunctionBase>arg).chandle
tcode[0] = kFuncHandle
......@@ -208,6 +212,8 @@ cdef inline object make_ret(TVMValue value, int tcode):
fobj = _CLASS_FUNCTION(None, False)
(<FunctionBase>fobj).chandle = value.v_handle
return fobj
elif tcode == kObject:
return _CLASS_OBJECT(ctypes_handle(value.v_handle))
elif tcode in _TVM_EXT_RET:
return _TVM_EXT_RET[tcode](ctypes_handle(value.v_handle))
......@@ -304,8 +310,31 @@ cdef class FunctionBase:
FuncCall(self.chandle, args, &ret_val, &ret_tcode)
return make_ret(ret_val, ret_tcode)
cdef class ObjectBase:
cdef ObjectHandle chandle
cdef inline _set_handle(self, handle):
if handle is None:
self.chandle = NULL
else:
self.chandle = c_handle(handle)
property handle:
def __get__(self):
if self.chandle == NULL:
return None
else:
return ctypes.cast(<unsigned long long>self.chandle, ctypes.c_void_p)
def __set__(self, value):
self._set_handle(value)
def __init__(self, handle):
self._set_handle(handle)
_CLASS_FUNCTION = None
_CLASS_MODULE = None
_CLASS_OBJECT = None
def _set_class_module(module_class):
"""Initialize the module."""
......@@ -315,3 +344,7 @@ def _set_class_module(module_class):
def _set_class_function(func_class):
global _CLASS_FUNCTION
_CLASS_FUNCTION = func_class
def _set_class_object(obj_class):
global _CLASS_OBJECT
_CLASS_OBJECT = obj_class
......@@ -30,19 +30,28 @@ try:
if _FFI_MODE == "ctypes":
raise ImportError()
if sys.version_info >= (3, 0):
from ._cy3.core import _set_class_function, _set_class_module
from ._cy3.core import _set_class_function, _set_class_module, _set_class_object
from ._cy3.core import FunctionBase as _FunctionBase
from ._cy3.core import ObjectBase as _ObjectBase
from ._cy3.core import convert_to_tvm_func
else:
from ._cy2.core import _set_class_function, _set_class_module
from ._cy2.core import _set_class_function, _set_class_module, _set_class_object
from ._cy2.core import FunctionBase as _FunctionBase
from ._cy2.core import ObjectBase as _ObjectBase
from ._cy2.core import convert_to_tvm_func
except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position
from ._ctypes.function import _set_class_function, _set_class_module
from ._ctypes.function import _set_class_function, _set_class_module, _set_class_object
from ._ctypes.function import ObjectBase as _ObjectBase
from ._ctypes.function import FunctionBase as _FunctionBase
from ._ctypes.function import convert_to_tvm_func
class Object(_ObjectBase):
# TODO(@jroesch): Eventually add back introspection functionality.
pass
_set_class_object(Object)
FunctionHandle = ctypes.c_void_p
class Function(_FunctionBase):
......
......@@ -42,8 +42,10 @@ class TypeCode(object):
STR = 11
BYTES = 12
NDARRAY_CONTAINER = 13
OBJECT = 14
EXT_BEGIN = 15
class TVMByteArray(ctypes.Structure):
"""Temp data structure for byte array."""
_fields_ = [("data", ctypes.POINTER(ctypes.c_byte)),
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -92,6 +92,12 @@ struct APIAttrGetter : public AttrVisitor {
found_ref_object = true;
}
}
void Visit(const char* key, runtime::Object* value) final {
if (skey == key) {
*ret = value[0];
found_ref_object = true;
}
}
};
struct APIAttrDir : public AttrVisitor {
......@@ -127,6 +133,9 @@ struct APIAttrDir : public AttrVisitor {
void Visit(const char* key, runtime::NDArray* value) final {
names->push_back(key);
}
void Visit(const char* key, runtime::Object* value) final {
names->push_back(key);
}
};
class DSLAPIImpl : public DSLAPI {
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -53,6 +53,8 @@ inline Type String2Type(std::string s) {
return TVMType2Type(runtime::String2TVMType(s));
}
using runtime::Object;
using runtime::ObjectCell;
// indexer to index all the ndoes
class NodeIndexer : public AttrVisitor {
......@@ -61,6 +63,8 @@ class NodeIndexer : public AttrVisitor {
std::vector<Node*> node_list{nullptr};
std::unordered_map<DLTensor*, size_t> tensor_index;
std::vector<DLTensor*> tensor_list;
std::unordered_map<ObjectCell*, size_t> vm_obj_index;
std::vector<ObjectCell*> vm_obj_list;
void Visit(const char* key, double* value) final {}
void Visit(const char* key, int64_t* value) final {}
......@@ -73,6 +77,7 @@ class NodeIndexer : public AttrVisitor {
void Visit(const char* key, NodeRef* value) final {
MakeIndex(value->node_.get());
}
void Visit(const char* key, runtime::NDArray* value) final {
DLTensor* ptr = const_cast<DLTensor*>((*value).operator->());
if (tensor_index.count(ptr)) return;
......@@ -80,6 +85,15 @@ class NodeIndexer : public AttrVisitor {
tensor_index[ptr] = tensor_list.size();
tensor_list.push_back(ptr);
}
void Visit(const char* key, Object* value) final {
ObjectCell* ptr = value->ptr_.get();
if (vm_obj_index.count(ptr)) return;
CHECK_EQ(vm_obj_index.size(), vm_obj_list.size());
vm_obj_index[ptr] = vm_obj_list.size();
vm_obj_list.push_back(ptr);
}
// make index of all the children of node
void MakeIndex(Node* node) {
if (node == nullptr) return;
......@@ -163,6 +177,7 @@ class JSONAttrGetter : public AttrVisitor {
public:
const std::unordered_map<Node*, size_t>* node_index_;
const std::unordered_map<DLTensor*, size_t>* tensor_index_;
const std::unordered_map<ObjectCell*, size_t>* vm_obj_index_;
JSONNode* node_;
void Visit(const char* key, double* value) final {
......@@ -197,6 +212,10 @@ class JSONAttrGetter : public AttrVisitor {
node_->attrs[key] = std::to_string(
tensor_index_->at(const_cast<DLTensor*>((*value).operator->())));
}
void Visit(const char* key, Object* value) final {
node_->attrs[key] = std::to_string(
vm_obj_index_->at(value->ptr_.get()));
}
// Get the node
void Get(Node* node) {
if (node == nullptr) {
......@@ -250,6 +269,8 @@ class JSONAttrSetter : public AttrVisitor {
public:
const std::vector<NodePtr<Node> >* node_list_;
const std::vector<runtime::NDArray>* tensor_list_;
const std::vector<Object>* vm_obj_list_;
JSONNode* node_;
std::string GetValue(const char* key) const {
......@@ -304,6 +325,12 @@ class JSONAttrSetter : public AttrVisitor {
CHECK_LE(index, tensor_list_->size());
*value = tensor_list_->at(index);
}
void Visit(const char* key, Object* value) final {
size_t index;
ParseValue(key, &index);
CHECK_LE(index, vm_obj_list_->size());
*value = vm_obj_list_->at(index);
}
// set node to be current JSONNode
void Set(Node* node) {
if (node == nullptr) return;
......@@ -481,6 +508,9 @@ class NodeAttrSetter : public AttrVisitor {
void Visit(const char* key, runtime::NDArray* value) final {
*value = GetAttr(key).operator runtime::NDArray();
}
void Visit(const char* key, Object* value) final {
*value = GetAttr(key).operator Object();
}
private:
runtime::TVMArgValue GetAttr(const char* key) {
......
......@@ -775,6 +775,9 @@ class PrettyPrinter::AttrPrinter : public AttrVisitor {
void Visit(const char* key, runtime::NDArray* value) final {
LOG(FATAL) << "do not allow NDarray as argument";
}
void Visit(const char* key, runtime::Object* obj) final {
LOG(FATAL) << "do not allow Object as argument";
}
private:
Doc& doc_;
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file object.cc
* \brief A managed object in the TVM runtime.
*/
#include <tvm/logging.h>
#include <tvm/runtime/object.h>
#include <iostream>
namespace tvm {
namespace runtime {
std::ostream& operator<<(std::ostream& os, const ObjectTag& tag) {
switch (tag) {
case ObjectTag::kClosure:
os << "Closure";
break;
case ObjectTag::kDatatype:
os << "Datatype";
break;
case ObjectTag::kTensor:
os << "Tensor";
break;
case ObjectTag::kExternalFunc:
os << "ExternalFunction";
break;
default:
LOG(FATAL) << "Invalid object tag: found " << static_cast<int>(tag);
}
return os;
}
Object Object::Tensor(const NDArray& data) {
ObjectPtr<ObjectCell> ptr = MakeObject<TensorCell>(data);
return Object(ptr);
}
Object Object::Datatype(size_t tag, const std::vector<Object>& fields) {
ObjectPtr<ObjectCell> ptr = MakeObject<DatatypeCell>(tag, fields);
return Object(ptr);
}
Object Object::Tuple(const std::vector<Object>& fields) { return Object::Datatype(0, fields); }
Object Object::Closure(size_t func_index, const std::vector<Object>& free_vars) {
ObjectPtr<ObjectCell> ptr = MakeObject<ClosureCell>(func_index, free_vars);
return Object(ptr);
}
ObjectPtr<TensorCell> Object::AsTensor() const {
CHECK(ptr.get());
CHECK(ptr.get()->tag == ObjectTag::kTensor);
return ptr.As<TensorCell>();
}
ObjectPtr<DatatypeCell> Object::AsDatatype() const {
CHECK(ptr.get());
CHECK(ptr.get()->tag == ObjectTag::kDatatype);
return ptr.As<DatatypeCell>();
}
ObjectPtr<ClosureCell> Object::AsClosure() const {
CHECK(ptr.get());
CHECK(ptr.get()->tag == ObjectTag::kClosure);
return ptr.As<ClosureCell>();
}
NDArray ToNDArray(const Object& obj) {
auto tensor = obj.AsTensor();
return tensor->data;
}
} // namespace runtime
} // namespace tvm
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment