Commit a0bd3786 by Tianqi Chen Committed by Zhi

[RFC][RUNTIME] Introduce new object protocol. (#4115)

* [RUNTIME] Introduce new object protocol.

This PR introduces a new object protocol to unify the node and object.
We also updated the existing runtime::vm code to make use of the new system.

Update to the node will be done in a follow up PR.

Other changes:

- Remove object related code in json serializer as that code logic was not complete
  and we have a separate serializer for VM, can revisit later.

* address review  comment

* Fix the child slot logic
parent 68472596
......@@ -70,7 +70,7 @@ cpplint:
python3 3rdparty/dmlc-core/scripts/lint.py vta cpp vta/include vta/src
python3 3rdparty/dmlc-core/scripts/lint.py topi cpp topi/include;
python3 3rdparty/dmlc-core/scripts/lint.py nnvm cpp nnvm/include nnvm/src;
python3 3rdparty/dmlc-core/scripts/lint.py tvm cpp include src verilog\
python3 3rdparty/dmlc-core/scripts/lint.py tvm cpp include src \
examples/extension/src examples/graph_executor/src
pylint:
......
......@@ -42,7 +42,7 @@ namespace runtime {
// forward declaration
class NDArray;
// forward declaration
class Object;
class ObjectRef;
} // namespace runtime
/*!
......@@ -63,7 +63,7 @@ class TVM_DLL AttrVisitor {
virtual void Visit(const char* key, DataType* value) = 0;
virtual void Visit(const char* key, NodeRef* value) = 0;
virtual void Visit(const char* key, runtime::NDArray* value) = 0;
virtual void Visit(const char* key, runtime::Object* value) = 0;
virtual void Visit(const char* key, runtime::ObjectRef* value) = 0;
template<typename ENum,
typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
void Visit(const char* key, ENum* ptr) {
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/runtime/memory.h
* \brief Runtime memory management.
*/
#ifndef TVM_RUNTIME_MEMORY_H_
#define TVM_RUNTIME_MEMORY_H_
#include <utility>
#include <type_traits>
#include "object.h"
namespace tvm {
namespace runtime {
/*!
* \brief Allocate an object using default allocator.
* \param args arguments to the constructor.
* \tparam T the node type.
* \return The NodePtr to the allocated object.
*/
template<typename T, typename... Args>
inline ObjectPtr<T> make_object(Args&&... args);
// Detail implementations after this
//
// The current design allows swapping the
// allocator pattern when necessary.
//
// Possible future allocator optimizations:
// - Arena allocator that gives ownership of memory to arena (deleter_= nullptr)
// - Thread-local object pools: one pool per size and alignment requirement.
// - Can specialize by type of object to give the specific allocator to each object.
/*!
* \brief Base class of object allocators that implements make.
* Use curiously recurring template pattern.
*
* \tparam Derived The derived class.
*/
template<typename Derived>
class ObjAllocatorBase {
public:
/*!
* \tparam T The type to be allocated.
* \tparam Args The constructor signature.
* \param args The arguments.
*/
template<typename T, typename... Args>
inline ObjectPtr<T> make(Args&&... args) {
using Handler = typename Derived::template Handler<T>;
static_assert(std::is_base_of<Object, T>::value,
"make_node can only be used to create NodeBase");
T* ptr = Handler::New(static_cast<Derived*>(this),
std::forward<Args>(args)...);
ptr->type_index_ = T::type_index();
ptr->deleter_ = Handler::Deleter();
return ObjectPtr<T>(ptr);
}
};
// Simple allocator that uses new/delete.
class SimpleObjAllocator :
public ObjAllocatorBase<SimpleObjAllocator> {
public:
template<typename T>
class Handler {
public:
template<typename... Args>
static T* New(SimpleObjAllocator*, Args&&... args) {
// NOTE: the first argument is not needed for SimpleObjAllocator
// It is reserved for special allocators that needs to recycle
// the object to itself (e.g. in the case of object pool).
//
// In the case of an object pool, an allocator needs to create
// a special chunk memory that hides reference to the allocator
// and call allocator's release function in the deleter.
return new T(std::forward<Args>(args)...);
}
static Object::FDeleter Deleter() {
return Deleter_;
}
private:
static void Deleter_(Object* ptr) {
delete static_cast<T*>(ptr);
}
};
};
template<typename T, typename... Args>
inline ObjectPtr<T> make_object(Args&&... args) {
return SimpleObjAllocator().make<T>(std::forward<Args>(args)...);
}
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_MEMORY_H_
......@@ -489,10 +489,10 @@ class TVMPODValue_ {
TVM_CHECK_TYPE_CODE(type_code_, kNDArrayContainer);
return NDArray(static_cast<NDArray::Container*>(value_.v_handle));
}
operator Object() const {
if (type_code_ == kNull) return Object();
operator ObjectRef() const {
if (type_code_ == kNull) return ObjectRef(ObjectPtr<Object>(nullptr));
TVM_CHECK_TYPE_CODE(type_code_, kObjectCell);
return Object(static_cast<ObjectCell*>(value_.v_handle));
return ObjectRef(ObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
}
operator TVMContext() const {
TVM_CHECK_TYPE_CODE(type_code_, kTVMContext);
......@@ -566,7 +566,7 @@ class TVMArgValue : public TVMPODValue_ {
using TVMPODValue_::operator DLTensor*;
using TVMPODValue_::operator NDArray;
using TVMPODValue_::operator TVMContext;
using TVMPODValue_::operator Object;
using TVMPODValue_::operator ObjectRef;
// conversion operator.
operator std::string() const {
......@@ -662,7 +662,7 @@ class TVMRetValue : public TVMPODValue_ {
using TVMPODValue_::operator DLTensor*;
using TVMPODValue_::operator TVMContext;
using TVMPODValue_::operator NDArray;
using TVMPODValue_::operator Object;
using TVMPODValue_::operator ObjectRef;
TVMRetValue(const TVMRetValue& other) : TVMPODValue_() {
this->Assign(other);
}
......@@ -759,11 +759,12 @@ class TVMRetValue : public TVMPODValue_ {
other.data_ = nullptr;
return *this;
}
TVMRetValue& operator=(Object other) {
TVMRetValue& operator=(ObjectRef other) {
this->Clear();
type_code_ = kObjectCell;
value_.v_handle = other.ptr_.data_;
other.ptr_.data_ = nullptr;
// move the handle out
value_.v_handle = other.data_.data_;
other.data_.data_ = nullptr;
return *this;
}
TVMRetValue& operator=(PackedFunc f) {
......@@ -862,7 +863,7 @@ class TVMRetValue : public TVMPODValue_ {
break;
}
case kObjectCell: {
*this = other.operator Object();
*this = other.operator ObjectRef();
break;
}
default: {
......@@ -913,7 +914,7 @@ class TVMRetValue : public TVMPODValue_ {
break;
}
case kObjectCell: {
static_cast<ObjectCell*>(value_.v_handle)->DecRef();
static_cast<Object*>(value_.v_handle)->DecRef();
break;
}
}
......@@ -1161,6 +1162,10 @@ class TVMArgsSetter {
values_[i].v_handle = value.data_;
type_codes_[i] = kNDArrayContainer;
}
void operator()(size_t i, const ObjectRef& value) const { // NOLINT(*)
values_[i].v_handle = value.data_.data_;
type_codes_[i] = kObjectCell;
}
void operator()(size_t i, const TVMRetValue& value) const { // NOLINT(*)
if (value.type_code() == kStr) {
values_[i].v_str = value.ptr<std::string>()->c_str();
......
......@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/runtime/vm.h
* \brief A virtual machine for executing Relay programs.
*/
......@@ -36,6 +35,75 @@ namespace tvm {
namespace runtime {
namespace vm {
/*! \brief An object containing an NDArray. */
class TensorObj : public Object {
public:
/*! \brief The NDArray. */
NDArray data;
static constexpr const uint32_t _type_index = TypeIndex::kVMTensor;
static constexpr const char* _type_key = "vm.Tensor";
TVM_DECLARE_FINAL_OBJECT_INFO(TensorObj, Object);
};
/*! \brief reference to tensor. */
class Tensor : public ObjectRef {
public:
explicit Tensor(NDArray data);
TVM_DEFINE_OBJECT_REF_METHODS(Tensor, ObjectRef, TensorObj);
};
/*! \brief An object representing a structure or enumeration. */
class DatatypeObj : public Object {
public:
/*! \brief The tag representing the constructor used. */
size_t tag;
/*! \brief The fields of the structure. */
std::vector<ObjectRef> fields;
static constexpr const uint32_t _type_index = TypeIndex::kVMDatatype;
static constexpr const char* _type_key = "vm.Datatype";
TVM_DECLARE_FINAL_OBJECT_INFO(DatatypeObj, Object);
};
/*! \brief reference to data type. */
class Datatype : public ObjectRef {
public:
Datatype(size_t tag, std::vector<ObjectRef> fields);
/*!
* \brief construct a tuple object.
* \param fields The fields of the tuple.
* \return The constructed tuple type.
*/
static Datatype Tuple(std::vector<ObjectRef> fields);
TVM_DEFINE_OBJECT_REF_METHODS(Datatype, ObjectRef, DatatypeObj);
};
/*! \brief An object representing a closure. */
class ClosureObj : public Object {
public:
/*! \brief The index into the VM function table. */
size_t func_index;
/*! \brief The free variables of the closure. */
std::vector<ObjectRef> free_vars;
static constexpr const uint32_t _type_index = TypeIndex::kVMClosure;
static constexpr const char* _type_key = "vm.Closure";
TVM_DECLARE_FINAL_OBJECT_INFO(ClosureObj, Object);
};
/*! \brief reference to closure. */
class Closure : public ObjectRef {
public:
Closure(size_t func_index, std::vector<ObjectRef> free_vars);
TVM_DEFINE_OBJECT_REF_METHODS(Closure, ObjectRef, ClosureObj);
};
/*! \brief Magic number for NDArray list file */
constexpr uint64_t kTVMNDArrayListMagic = 0xF7E58D4F05049CB7;
......@@ -193,7 +261,7 @@ struct Instruction {
static Instruction Ret(RegName return_reg);
/*! \brief Construct a fatal instruction.
* \return The fatal instruction.
* */
* */
static Instruction Fatal();
/*! \brief Construct a invoke packed instruction.
* \param packed_index The index of the packed function.
......@@ -348,7 +416,7 @@ struct VMFrame {
const Instruction* code;
/*! \brief Statically allocated space for objects */
std::vector<Object> register_file;
std::vector<ObjectRef> register_file;
/*! \brief Register in caller's frame to put return value */
RegName caller_return_register;
......@@ -406,8 +474,11 @@ class VirtualMachine : public runtime::ModuleNode {
*
* \note The return value will be stored in the last output_size slots of args.
*/
virtual void InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count,
Index output_size, const std::vector<Object>& args);
virtual void InvokePacked(Index packed_index,
const PackedFunc& func,
Index arg_count,
Index output_size,
const std::vector<ObjectRef>& args);
virtual ~VirtualMachine() {}
......@@ -424,7 +495,7 @@ class VirtualMachine : public runtime::ModuleNode {
/*! \brief The current stack of call frames. */
std::vector<VMFrame> frames;
/*! \brief The global constant pool. */
std::vector<Object> constants;
std::vector<ObjectRef> constants;
/*! \brief The fuction table index of the current function. */
Index func_index;
/*! \brief The current pointer to the code section. */
......@@ -433,7 +504,7 @@ class VirtualMachine : public runtime::ModuleNode {
Index pc;
/*! \brief The special return register. */
Object return_register;
ObjectRef return_register;
/*! \brief The set of TVM contexts the VM is currently executing on. */
std::vector<TVMContext> ctxs;
......@@ -449,13 +520,13 @@ class VirtualMachine : public runtime::ModuleNode {
* \param reg The register to write to.
* \param obj The object to write to.
*/
inline void WriteRegister(RegName reg, const Object& obj);
inline void WriteRegister(RegName reg, const ObjectRef& obj);
/*! \brief Read a VM register.
* \param reg The register to read from.
* \return The read object.
*/
inline Object ReadRegister(RegName reg) const;
inline ObjectRef ReadRegister(RegName reg) const;
/*! \brief Read a VM register and cast it to int32_t
* \param reg The register to read from.
......@@ -468,15 +539,16 @@ class VirtualMachine : public runtime::ModuleNode {
* \param args The arguments to the function.
* \return The object representing the result.
*/
Object Invoke(const VMFunction& func, const std::vector<Object>& args);
ObjectRef Invoke(const VMFunction& func, const std::vector<ObjectRef>& args);
// TODO(@jroesch): I really would like this to be a global variable.
/*! \brief Invoke a VM function by name.
/*!
* \brief Invoke a VM function by name.
* \param name The function's name.
* \param args The arguments to the function.
* \return The object representing the result.
*/
Object Invoke(const std::string& name, const std::vector<Object>& args);
ObjectRef Invoke(const std::string& name, const std::vector<ObjectRef>& args);
VirtualMachine() : functions(), frames(), func_index(0), code(nullptr), pc(0) {}
......@@ -513,11 +585,10 @@ class VirtualMachine : public runtime::ModuleNode {
*
* This does not begin execution of the VM.
*/
void InvokeGlobal(const VMFunction& func, const std::vector<Object>& args);
void InvokeGlobal(const VMFunction& func, const std::vector<ObjectRef>& args);
/*! \brief The parameter name to data mapping. */
std::unordered_map<std::string, Object> params_;
std::unordered_map<std::string, ObjectRef> params_;
};
} // namespace vm
......
......@@ -44,9 +44,9 @@ except IMPORT_EXCEPT:
class ObjectTag(object):
"""Type code used in API calls"""
TENSOR = 0
CLOSURE = 1
DATATYPE = 2
TENSOR = 1
CLOSURE = 2
DATATYPE = 3
class Object(_ObjectBase):
......
......@@ -92,7 +92,7 @@ struct APIAttrGetter : public AttrVisitor {
found_ref_object = true;
}
}
void Visit(const char* key, runtime::Object* value) final {
void Visit(const char* key, runtime::ObjectRef* value) final {
if (skey == key) {
*ret = value[0];
found_ref_object = true;
......@@ -133,7 +133,7 @@ struct APIAttrDir : public AttrVisitor {
void Visit(const char* key, runtime::NDArray* value) final {
names->push_back(key);
}
void Visit(const char* key, runtime::Object* value) final {
void Visit(const char* key, runtime::ObjectRef* value) final {
names->push_back(key);
}
};
......
......@@ -54,7 +54,7 @@ inline Type String2Type(std::string s) {
}
using runtime::Object;
using runtime::ObjectCell;
using runtime::ObjectRef;
// indexer to index all the ndoes
class NodeIndexer : public AttrVisitor {
......@@ -63,8 +63,6 @@ class NodeIndexer : public AttrVisitor {
std::vector<Node*> node_list{nullptr};
std::unordered_map<DLTensor*, size_t> tensor_index;
std::vector<DLTensor*> tensor_list;
std::unordered_map<ObjectCell*, size_t> vm_obj_index;
std::vector<ObjectCell*> vm_obj_list;
void Visit(const char* key, double* value) final {}
void Visit(const char* key, int64_t* value) final {}
......@@ -86,12 +84,8 @@ class NodeIndexer : public AttrVisitor {
tensor_list.push_back(ptr);
}
void Visit(const char* key, Object* value) final {
ObjectCell* ptr = value->ptr_.get();
if (vm_obj_index.count(ptr)) return;
CHECK_EQ(vm_obj_index.size(), vm_obj_list.size());
vm_obj_index[ptr] = vm_obj_list.size();
vm_obj_list.push_back(ptr);
void Visit(const char* key, ObjectRef* value) final {
LOG(FATAL) << "Do not support json serialize non-node object";
}
// make index of all the children of node
......@@ -177,7 +171,6 @@ class JSONAttrGetter : public AttrVisitor {
public:
const std::unordered_map<Node*, size_t>* node_index_;
const std::unordered_map<DLTensor*, size_t>* tensor_index_;
const std::unordered_map<ObjectCell*, size_t>* vm_obj_index_;
JSONNode* node_;
void Visit(const char* key, double* value) final {
......@@ -212,9 +205,8 @@ class JSONAttrGetter : public AttrVisitor {
node_->attrs[key] = std::to_string(
tensor_index_->at(const_cast<DLTensor*>((*value).operator->())));
}
void Visit(const char* key, Object* value) final {
node_->attrs[key] = std::to_string(
vm_obj_index_->at(value->ptr_.get()));
void Visit(const char* key, ObjectRef* value) final {
LOG(FATAL) << "Do not support json serialize non-node object";
}
// Get the node
void Get(Node* node) {
......@@ -269,7 +261,6 @@ class JSONAttrSetter : public AttrVisitor {
public:
const std::vector<NodePtr<Node> >* node_list_;
const std::vector<runtime::NDArray>* tensor_list_;
const std::vector<Object>* vm_obj_list_;
JSONNode* node_;
......@@ -325,11 +316,8 @@ class JSONAttrSetter : public AttrVisitor {
CHECK_LE(index, tensor_list_->size());
*value = tensor_list_->at(index);
}
void Visit(const char* key, Object* value) final {
size_t index;
ParseValue(key, &index);
CHECK_LE(index, vm_obj_list_->size());
*value = vm_obj_list_->at(index);
void Visit(const char* key, ObjectRef* value) final {
LOG(FATAL) << "Do not support json serialize non-node object";
}
// set node to be current JSONNode
void Set(Node* node) {
......@@ -508,8 +496,8 @@ class NodeAttrSetter : public AttrVisitor {
void Visit(const char* key, runtime::NDArray* value) final {
*value = GetAttr(key).operator runtime::NDArray();
}
void Visit(const char* key, Object* value) final {
*value = GetAttr(key).operator Object();
void Visit(const char* key, ObjectRef* value) final {
*value = GetAttr(key).operator ObjectRef();
}
private:
......
......@@ -885,7 +885,7 @@ void VMCompiler::Compile(Module mod,
// populate constants
for (auto data : context_.constants) {
vm_->constants.push_back(Object::Tensor(data));
vm_->constants.push_back(runtime::vm::Tensor(data));
}
LibraryCodegen();
......
......@@ -102,7 +102,7 @@ void Deserializer::DeserializeConstantSection() {
for (size_t i = 0; i < size; i++) {
runtime::NDArray constant;
STREAM_CHECK(constant.Load(strm_), "constant");
runtime::Object obj = runtime::Object::Tensor(constant);
runtime::ObjectRef obj = runtime::vm::Tensor(constant);
vm_->constants.push_back(obj);
}
}
......
......@@ -98,8 +98,8 @@ std::string Serializer::Stats() const {
// Get the number of constants and the shape of each of them.
oss << " Constant shapes (# " << vm_->constants.size() << "): [";
for (const auto& it : vm_->constants) {
auto cell = it.AsTensor();
CHECK(cell.operator->());
auto* cell = it.as<runtime::vm::TensorObj>();
CHECK(cell != nullptr);
runtime::NDArray data = cell->data;
const auto& shape = data.Shape();
......@@ -175,7 +175,8 @@ void Serializer::SerializeGlobalSection() {
void Serializer::SerializeConstantSection() {
std::vector<DLTensor*> arrays;
for (const auto& obj : vm_->constants) {
auto cell = obj.AsTensor();
const auto* cell = obj.as<runtime::vm::TensorObj>();
CHECK(cell != nullptr);
runtime::NDArray data = cell->data;
arrays.push_back(const_cast<DLTensor*>(data.operator->()));
}
......
......@@ -930,7 +930,7 @@ class PrettyPrinter::AttrPrinter : public AttrVisitor {
void Visit(const char* key, runtime::NDArray* value) final {
LOG(FATAL) << "do not allow NDarray as argument";
}
void Visit(const char* key, runtime::Object* obj) final {
void Visit(const char* key, runtime::ObjectRef* obj) final {
LOG(FATAL) << "do not allow Object as argument";
}
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*
* \file src/runtime/object.cc
* \brief Object type management system.
*/
#include <dmlc/logging.h>
#include <tvm/runtime/object.h>
#include <mutex>
#include <string>
#include <vector>
#include <unordered_map>
namespace tvm {
namespace runtime {
/*! \brief Type information */
struct TypeInfo {
/*! \brief The current index. */
uint32_t index{0};
/*! \brief Index of the parent in the type hierachy */
uint32_t parent_index{0};
// NOTE: the indices in [index, index + num_reserved_slots) are
// reserved for the child-class of this type.
/*! \brief Total number of slots reserved for the type and its children. */
uint32_t num_slots{0};
/*! \brief number of allocated child slots. */
uint32_t allocated_slots{0};
/*! \brief Whether child can overflow. */
bool child_slots_can_overflow{true};
/*! \brief name of the type. */
std::string name;
};
/*!
* \brief Type context that manages the type hierachy information.
*/
class TypeContext {
public:
// NOTE: this is a relatively slow path for child checking
// Most types are already checked by the fast-path via reserved slot checking.
bool DerivedFrom(uint32_t child_tindex, uint32_t parent_tindex) {
// invariance: child's type index is always bigger than its parent.
if (child_tindex < parent_tindex) return false;
if (child_tindex == parent_tindex) return true;
{
std::lock_guard<std::mutex> lock(mutex_);
CHECK_LT(child_tindex, type_table_.size());
while (child_tindex > parent_tindex) {
child_tindex = type_table_[child_tindex].parent_index;
}
}
return child_tindex == parent_tindex;
}
uint32_t GetOrAllocRuntimeTypeIndex(const char* key,
uint32_t static_tindex,
uint32_t parent_tindex,
uint32_t num_child_slots,
bool child_slots_can_overflow) {
std::lock_guard<std::mutex> lock(mutex_);
std::string skey = key;
auto it = type_key2index_.find(skey);
if (it != type_key2index_.end()) {
return it->second;
}
// try to allocate from parent's type table.
CHECK_LT(parent_tindex, type_table_.size());
TypeInfo& pinfo = type_table_[parent_tindex];
CHECK_EQ(pinfo.index, parent_tindex);
// if parent cannot overflow, then this class cannot.
if (!pinfo.child_slots_can_overflow) {
child_slots_can_overflow = false;
}
// total number of slots include the type itself.
uint32_t num_slots = num_child_slots + 1;
uint32_t allocated_tindex;
if (static_tindex != TypeIndex::kDynamic) {
// statically assigned type
allocated_tindex = static_tindex;
CHECK_LT(static_tindex, type_table_.size());
CHECK_EQ(type_table_[allocated_tindex].allocated_slots, 0U)
<< "Conflicting static index " << static_tindex
<< " between " << type_table_[allocated_tindex].name
<< " and "
<< key;
} else if (pinfo.allocated_slots + num_slots < pinfo.num_slots) {
// allocate the slot from parent's reserved pool
allocated_tindex = parent_tindex + pinfo.allocated_slots;
// update parent's state
pinfo.allocated_slots += num_slots;
} else {
CHECK(pinfo.child_slots_can_overflow)
<< "Reach maximum number of sub-classes for " << pinfo.name;
// allocate new entries.
allocated_tindex = type_counter_;
type_counter_ += num_slots;
CHECK_LE(type_table_.size(), allocated_tindex);
type_table_.resize(allocated_tindex + 1, TypeInfo());
}
CHECK_GT(allocated_tindex, parent_tindex);
// initialize the slot.
type_table_[allocated_tindex].index = allocated_tindex;
type_table_[allocated_tindex].parent_index = parent_tindex;
type_table_[allocated_tindex].num_slots = num_slots;
type_table_[allocated_tindex].allocated_slots = 1;
type_table_[allocated_tindex].child_slots_can_overflow =
child_slots_can_overflow;
type_table_[allocated_tindex].name = skey;
// update the key2index mapping.
type_key2index_[skey] = allocated_tindex;
return allocated_tindex;
}
std::string TypeIndex2Key(uint32_t tindex) {
std::lock_guard<std::mutex> lock(mutex_);
CHECK(tindex < type_table_.size() &&
type_table_[tindex].allocated_slots != 0)
<< "Unknown type index " << tindex;
return type_table_[tindex].name;
}
uint32_t TypeKey2Index(const char* key) {
std::string skey = key;
auto it = type_key2index_.find(skey);
CHECK(it != type_key2index_.end())
<< "Cannot find type " << key;
return it->second;
}
static TypeContext* Global() {
static TypeContext inst;
return &inst;
}
private:
TypeContext() {
type_table_.resize(TypeIndex::kStaticIndexEnd, TypeInfo());
}
// mutex to avoid registration from multiple threads.
std::mutex mutex_;
std::atomic<uint32_t> type_counter_{TypeIndex::kStaticIndexEnd};
std::vector<TypeInfo> type_table_;
std::unordered_map<std::string, uint32_t> type_key2index_;
};
uint32_t Object::GetOrAllocRuntimeTypeIndex(const char* key,
uint32_t static_tindex,
uint32_t parent_tindex,
uint32_t num_child_slots,
bool child_slots_can_overflow) {
return TypeContext::Global()->GetOrAllocRuntimeTypeIndex(
key, static_tindex, parent_tindex, num_child_slots, child_slots_can_overflow);
}
bool Object::DerivedFrom(uint32_t parent_tindex) const {
return TypeContext::Global()->DerivedFrom(
this->type_index_, parent_tindex);
}
std::string Object::TypeIndex2Key(uint32_t tindex) {
return TypeContext::Global()->TypeIndex2Key(tindex);
}
uint32_t Object::TypeKey2Index(const char* key) {
return TypeContext::Global()->TypeKey2Index(key);
}
} // namespace runtime
} // namespace tvm
......@@ -18,134 +18,110 @@
*/
/*!
* Copyright (c) 2019 by Contributors
* \file object.cc
* \brief A managed object in the TVM runtime.
* \file src/runtime/vm/object.cc
* \brief VM related objects.
*/
#include <tvm/logging.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/vm.h>
#include <tvm/runtime/memory.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/c_runtime_api.h>
#include <iostream>
#include "../runtime_base.h"
namespace tvm {
namespace runtime {
namespace vm {
std::ostream& operator<<(std::ostream& os, const ObjectTag& tag) {
switch (tag) {
case ObjectTag::kClosure:
os << "Closure";
break;
case ObjectTag::kDatatype:
os << "Datatype";
break;
case ObjectTag::kTensor:
os << "Tensor";
break;
default:
LOG(FATAL) << "Invalid object tag: found " << static_cast<int>(tag);
}
return os;
}
Object Object::Tensor(const NDArray& data) {
ObjectPtr<ObjectCell> ptr = MakeObject<TensorCell>(data);
return Object(ptr);
}
Object Object::Datatype(size_t tag, const std::vector<Object>& fields) {
ObjectPtr<ObjectCell> ptr = MakeObject<DatatypeCell>(tag, fields);
return Object(ptr);
Tensor::Tensor(NDArray data) {
auto ptr = make_object<TensorObj>();
ptr->data = std::move(data);
data_ = std::move(ptr);
}
Object Object::Tuple(const std::vector<Object>& fields) { return Object::Datatype(0, fields); }
Object Object::Closure(size_t func_index, const std::vector<Object>& free_vars) {
ObjectPtr<ObjectCell> ptr = MakeObject<ClosureCell>(func_index, free_vars);
return Object(ptr);
Datatype::Datatype(size_t tag, std::vector<ObjectRef> fields) {
auto ptr = make_object<DatatypeObj>();
ptr->tag = tag;
ptr->fields = std::move(fields);
data_ = std::move(ptr);
}
ObjectPtr<TensorCell> Object::AsTensor() const {
CHECK(ptr_.get());
CHECK(ptr_.get()->tag == ObjectTag::kTensor);
return ptr_.As<TensorCell>();
Datatype Datatype::Tuple(std::vector<ObjectRef> fields) {
return Datatype(0, fields);
}
ObjectPtr<DatatypeCell> Object::AsDatatype() const {
CHECK(ptr_.get());
CHECK(ptr_.get()->tag == ObjectTag::kDatatype);
return ptr_.As<DatatypeCell>();
Closure::Closure(size_t func_index, std::vector<ObjectRef> free_vars) {
auto ptr = make_object<ClosureObj>();
ptr->func_index = func_index;
ptr->free_vars = std::move(free_vars);
data_ = std::move(ptr);
}
ObjectPtr<ClosureCell> Object::AsClosure() const {
CHECK(ptr_.get());
CHECK(ptr_.get()->tag == ObjectTag::kClosure);
return ptr_.As<ClosureCell>();
}
NDArray ToNDArray(const Object& obj) {
auto tensor = obj.AsTensor();
return tensor->data;
}
TVM_REGISTER_GLOBAL("_vmobj.GetTensorData")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Object obj = args[0];
auto cell = obj.AsTensor();
ObjectRef obj = args[0];
const auto* cell = obj.as<TensorObj>();
CHECK(cell != nullptr);
*rv = cell->data;
});
TVM_REGISTER_GLOBAL("_vmobj.GetDatatypeTag")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Object obj = args[0];
auto cell = obj.AsDatatype();
*rv = static_cast<int>(cell->tag);
ObjectRef obj = args[0];
const auto* cell = obj.as<DatatypeObj>();
CHECK(cell != nullptr);
*rv = static_cast<int64_t>(cell->tag);
});
TVM_REGISTER_GLOBAL("_vmobj.GetDatatypeNumberOfFields")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Object obj = args[0];
auto cell = obj.AsDatatype();
*rv = static_cast<int>(cell->fields.size());
ObjectRef obj = args[0];
const auto* cell = obj.as<DatatypeObj>();
CHECK(cell != nullptr);
*rv = static_cast<int64_t>(cell->fields.size());
});
TVM_REGISTER_GLOBAL("_vmobj.GetDatatypeFields")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Object obj = args[0];
ObjectRef obj = args[0];
int idx = args[1];
auto cell = obj.AsDatatype();
const auto* cell = obj.as<DatatypeObj>();
CHECK(cell != nullptr);
CHECK_LT(idx, cell->fields.size());
*rv = cell->fields[idx];
});
TVM_REGISTER_GLOBAL("_vmobj.Tensor")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = Object::Tensor(args[0]);
*rv = Tensor(args[0].operator NDArray());
});
TVM_REGISTER_GLOBAL("_vmobj.Tuple")
.set_body([](TVMArgs args, TVMRetValue* rv) {
std::vector<Object> fields;
std::vector<ObjectRef> fields;
for (auto i = 0; i < args.size(); ++i) {
fields.push_back(args[i]);
}
*rv = Object::Tuple(fields);
*rv = Datatype::Tuple(fields);
});
TVM_REGISTER_GLOBAL("_vmobj.Datatype")
.set_body([](TVMArgs args, TVMRetValue* rv) {
int itag = args[0];
size_t tag = static_cast<size_t>(itag);
std::vector<Object> fields;
std::vector<ObjectRef> fields;
for (int i = 1; i < args.size(); i++) {
fields.push_back(args[i]);
}
*rv = Object::Datatype(tag, fields);
*rv = Datatype(tag, fields);
});
TVM_REGISTER_OBJECT_TYPE(TensorObj);
TVM_REGISTER_OBJECT_TYPE(DatatypeObj);
TVM_REGISTER_OBJECT_TYPE(ClosureObj);
} // namespace vm
} // namespace runtime
} // namespace tvm
......@@ -153,6 +129,7 @@ using namespace tvm::runtime;
int TVMGetObjectTag(TVMObjectHandle handle, int* tag) {
API_BEGIN();
*tag = static_cast<int>(static_cast<ObjectCell*>(handle)->tag);
int res = static_cast<int>(static_cast<Object*>(handle)->type_index());
*tag = res;
API_END();
}
......@@ -96,7 +96,7 @@ void VirtualMachineDebug::Init(const std::vector<TVMContext>& ctxs) {
void VirtualMachineDebug::InvokePacked(Index packed_index,
const PackedFunc& func, Index arg_count,
Index output_size,
const std::vector<Object>& args) {
const std::vector<ObjectRef>& args) {
auto ctx = VirtualMachine::GetParamsContext();
// warmup
VirtualMachine::InvokePacked(packed_index, func, arg_count, output_size,
......
......@@ -45,7 +45,7 @@ class VirtualMachineDebug : public VirtualMachine {
const std::shared_ptr<ModuleNode>& sptr_to_self) final;
void InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count,
Index output_size, const std::vector<Object>& args) final;
Index output_size, const std::vector<ObjectRef>& args) final;
~VirtualMachineDebug() {}
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/memory.h>
namespace tvm {
namespace test {
using namespace tvm::runtime;
class ObjBase : public Object {
public:
// dynamically allocate slow
static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const uint32_t _type_child_slots = 1;
static constexpr const char* _type_key = "test.ObjBase";
TVM_DECLARE_BASE_OBJECT_INFO(ObjBase, Object);
};
class ObjA : public ObjBase {
public:
static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const uint32_t _type_child_slots = 0;
static constexpr const char* _type_key = "test.ObjA";
TVM_DECLARE_BASE_OBJECT_INFO(ObjA, ObjBase);
};
class ObjB : public ObjBase {
public:
static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const char* _type_key = "test.ObjB";
TVM_DECLARE_BASE_OBJECT_INFO(ObjB, ObjBase);
};
class ObjAA : public ObjA {
public:
static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const char* _type_key = "test.ObjAA";
TVM_DECLARE_FINAL_OBJECT_INFO(ObjAA, ObjA);
};
TVM_REGISTER_OBJECT_TYPE(ObjBase);
TVM_REGISTER_OBJECT_TYPE(ObjA);
TVM_REGISTER_OBJECT_TYPE(ObjB);
TVM_REGISTER_OBJECT_TYPE(ObjAA);
} // namespace test
} // namespace tvm
TEST(ObjectHierachy, Basic) {
using namespace tvm::runtime;
using namespace tvm::test;
ObjectRef refA(make_object<ObjA>());
CHECK_EQ(refA->type_index(), ObjA::type_index());
CHECK(refA.as<Object>() != nullptr);
CHECK(refA.as<ObjA>() != nullptr);
CHECK(refA.as<ObjBase>() != nullptr);
CHECK(refA.as<ObjB>() == nullptr);
CHECK(refA.as<ObjAA>() == nullptr);
ObjectRef refAA(make_object<ObjAA>());
CHECK_EQ(refAA->type_index(), ObjAA::type_index());
CHECK(refAA.as<Object>() != nullptr);
CHECK(refAA.as<ObjBase>() != nullptr);
CHECK(refAA.as<ObjA>() != nullptr);
CHECK(refAA.as<ObjAA>() != nullptr);
CHECK(refAA.as<ObjB>() == nullptr);
ObjectRef refB(make_object<ObjB>());
CHECK_EQ(refB->type_index(), ObjB::type_index());
CHECK(refB.as<Object>() != nullptr);
CHECK(refB.as<ObjBase>() != nullptr);
CHECK(refB.as<ObjA>() == nullptr);
CHECK(refB.as<ObjAA>() == nullptr);
CHECK(refB.as<ObjB>() != nullptr);
}
int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
}
......@@ -582,7 +582,7 @@ def test_set_params():
mod["main"] = relay.Function([x, w, b], y)
vm = relay.vm.compile(mod, 'llvm')
vm.init(tvm.cpu())
x_np = np.random.uniform(size=(10, 5)).astype('float32')
w_np = np.random.uniform(size=(6, 5)).astype('float32')
b_np = np.random.uniform(size=(6,)).astype('float32')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment