Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
2bf5fd2b
Commit
2bf5fd2b
authored
Dec 01, 2019
by
Wei Chen
Committed by
Tianqi Chen
Dec 01, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Runtime] Make ADTObject POD container type (#4346)
parent
2a8c6978
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
498 additions
and
69 deletions
+498
-69
include/tvm/runtime/container.h
+279
-0
include/tvm/runtime/memory.h
+77
-2
include/tvm/runtime/vm.h
+0
-29
src/runtime/vm/object.cc
+8
-21
src/runtime/vm/vm.cc
+8
-13
tests/cpp/container_test.cc
+126
-4
No files found.
include/tvm/runtime/container.h
0 → 100644
View file @
2bf5fd2b
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/runtime/container.h
* \brief Common POD(plain old data) container types.
*/
#ifndef TVM_RUNTIME_CONTAINER_H_
#define TVM_RUNTIME_CONTAINER_H_
#include <dmlc/logging.h>
#include <tvm/runtime/memory.h>
#include <tvm/runtime/object.h>
#include <initializer_list>
#include <type_traits>
#include <utility>
#include <vector>
namespace
tvm
{
namespace
runtime
{
/*!
* \brief Base template for classes with array like memory layout.
*
* It provides general methods to access the memory. The memory
* layout is ArrayType + [ElemType]. The alignment of ArrayType
* and ElemType is handled by the memory allocator.
*
* \tparam ArrayType The array header type, contains object specific metadata.
* \tparam ElemType The type of objects stored in the array right after
* ArrayType.
*
* \code
* // Example usage of the template to define a simple array wrapper
* class ArrayObj : public InplaceArrayBase<ArrayObj, Elem> {
* public:
* // Wrap EmplaceInit to initialize the elements
* template <typename Iterator>
* void Init(Iterator begin, Iterator end) {
* size_t num_elems = std::distance(begin, end);
* auto it = begin;
* this->size = 0;
* for (size_t i = 0; i < num_elems; ++i) {
* InplaceArrayBase::EmplaceInit(i, *it++);
* this->size++;
* }
* }
* }
*
* void test_function() {
* vector<Elem> fields;
* auto ptr = make_inplace_array_object<ArrayObj, Elem>(fields.size());
* ptr->Init(fields.begin(), fields.end());
*
* // Access the 0th element in the array.
* assert(ptr->operator[](0) == fields[0]);
* }
*
* \endcode
*/
template
<
typename
ArrayType
,
typename
ElemType
>
class
InplaceArrayBase
{
public
:
/*!
* \brief Access element at index
* \param idx The index of the element.
* \return Const reference to ElemType at the index.
*/
const
ElemType
&
operator
[](
size_t
idx
)
const
{
size_t
size
=
Self
()
->
GetSize
();
CHECK_LT
(
idx
,
size
)
<<
"Index "
<<
idx
<<
" out of bounds "
<<
size
<<
"
\n
"
;
return
*
(
reinterpret_cast
<
ElemType
*>
(
AddressOf
(
idx
)));
}
/*!
* \brief Access element at index
* \param idx The index of the element.
* \return Reference to ElemType at the index.
*/
ElemType
&
operator
[](
size_t
idx
)
{
size_t
size
=
Self
()
->
GetSize
();
CHECK_LT
(
idx
,
size
)
<<
"Index "
<<
idx
<<
" out of bounds "
<<
size
<<
"
\n
"
;
return
*
(
reinterpret_cast
<
ElemType
*>
(
AddressOf
(
idx
)));
}
/*!
* \brief Destroy the Inplace Array Base object
*/
~
InplaceArrayBase
()
{
if
(
!
(
std
::
is_standard_layout
<
ElemType
>::
value
&&
std
::
is_trivial
<
ElemType
>::
value
))
{
size_t
size
=
Self
()
->
GetSize
();
for
(
size_t
i
=
0
;
i
<
size
;
++
i
)
{
ElemType
*
fp
=
reinterpret_cast
<
ElemType
*>
(
AddressOf
(
i
));
fp
->
ElemType
::~
ElemType
();
}
}
}
protected
:
/*!
* \brief Construct a value in place with the arguments.
*
* \tparam Args Type parameters of the arguments.
* \param idx Index of the element.
* \param args Arguments to construct the new value.
*
* \note Please make sure ArrayType::GetSize returns 0 before first call of
* EmplaceInit, and increment GetSize by 1 each time EmplaceInit succeeds.
*/
template
<
typename
...
Args
>
void
EmplaceInit
(
size_t
idx
,
Args
&&
...
args
)
{
void
*
field_ptr
=
AddressOf
(
idx
);
new
(
field_ptr
)
ElemType
(
std
::
forward
<
Args
>
(
args
)...);
}
private
:
/*!
* \brief Return the self object for the array.
*
* \return Pointer to ArrayType.
*/
inline
ArrayType
*
Self
()
const
{
return
static_cast
<
ArrayType
*>
(
const_cast
<
InplaceArrayBase
*>
(
this
));
}
/*!
* \brief Return the raw pointer to the element at idx.
*
* \param idx The index of the element.
* \return Raw pointer to the element.
*/
void
*
AddressOf
(
size_t
idx
)
const
{
static_assert
(
alignof
(
ArrayType
)
%
alignof
(
ElemType
)
==
0
&&
sizeof
(
ArrayType
)
%
alignof
(
ElemType
)
==
0
,
"The size and alignment of ArrayType should respect "
"ElemType's alignment."
);
size_t
kDataStart
=
sizeof
(
ArrayType
);
ArrayType
*
self
=
Self
();
char
*
data_start
=
reinterpret_cast
<
char
*>
(
self
)
+
kDataStart
;
return
data_start
+
idx
*
sizeof
(
ElemType
);
}
};
/*! \brief An object representing a structure or enumeration. */
class
ADTObj
:
public
Object
,
public
InplaceArrayBase
<
ADTObj
,
ObjectRef
>
{
public
:
/*! \brief The tag representing the constructor used. */
uint32_t
tag
;
/*! \brief Number of fields in the ADT object. */
uint32_t
size
;
// The fields of the structure follows directly in memory.
static
constexpr
const
uint32_t
_type_index
=
TypeIndex
::
kVMADT
;
static
constexpr
const
char
*
_type_key
=
"vm.ADT"
;
TVM_DECLARE_FINAL_OBJECT_INFO
(
ADTObj
,
Object
);
private
:
/*!
* \return The number of elements in the array.
*/
size_t
GetSize
()
const
{
return
size
;
}
/*!
* \brief Initialize the elements in the array.
*
* \tparam Iterator Iterator type of the array.
* \param begin The begin iterator.
* \param end The end iterator.
*/
template
<
typename
Iterator
>
void
Init
(
Iterator
begin
,
Iterator
end
)
{
size_t
num_elems
=
std
::
distance
(
begin
,
end
);
this
->
size
=
0
;
auto
it
=
begin
;
for
(
size_t
i
=
0
;
i
<
num_elems
;
++
i
)
{
InplaceArrayBase
::
EmplaceInit
(
i
,
*
it
++
);
// Only increment size after the initialization succeeds
this
->
size
++
;
}
}
friend
class
ADT
;
friend
class
InplaceArrayBase
;
};
/*! \brief reference to algebraic data type objects. */
class
ADT
:
public
ObjectRef
{
public
:
/*!
* \brief construct an ADT object reference.
* \param tag The tag of the ADT object.
* \param fields The fields of the ADT object.
* \return The constructed ADT object reference.
*/
ADT
(
uint32_t
tag
,
std
::
vector
<
ObjectRef
>
fields
)
:
ADT
(
tag
,
fields
.
begin
(),
fields
.
end
()){};
/*!
* \brief construct an ADT object reference.
* \param tag The tag of the ADT object.
* \param begin The begin iterator to the start of the fields array.
* \param end The end iterator to the end of the fields array.
* \return The constructed ADT object reference.
*/
template
<
typename
Iterator
>
ADT
(
uint32_t
tag
,
Iterator
begin
,
Iterator
end
)
{
size_t
num_elems
=
std
::
distance
(
begin
,
end
);
auto
ptr
=
make_inplace_array_object
<
ADTObj
,
ObjectRef
>
(
num_elems
);
ptr
->
tag
=
tag
;
ptr
->
Init
(
begin
,
end
);
data_
=
std
::
move
(
ptr
);
}
/*!
* \brief construct an ADT object reference.
* \param tag The tag of the ADT object.
* \param init The initializer list of fields.
* \return The constructed ADT object reference.
*/
ADT
(
uint32_t
tag
,
std
::
initializer_list
<
ObjectRef
>
init
)
:
ADT
(
tag
,
init
.
begin
(),
init
.
end
()){};
/*!
* \brief Access element at index.
*
* \param idx The array index
* \return const ObjectRef
*/
const
ObjectRef
&
operator
[](
size_t
idx
)
const
{
return
operator
->
()
->
operator
[](
idx
);
}
/*!
* \brief Return the ADT tag.
*/
size_t
tag
()
const
{
return
operator
->
()
->
tag
;
}
/*!
* \brief Return the number of fields.
*/
size_t
size
()
const
{
return
operator
->
()
->
size
;
}
/*!
* \brief Construct a tuple object.
*
* \tparam Args Type params of tuple feilds.
* \param args Tuple fields.
* \return ADT The tuple object reference.
*/
template
<
typename
...
Args
>
static
ADT
Tuple
(
Args
&&
...
args
)
{
return
ADT
(
0
,
std
::
forward
<
Args
>
(
args
)...);
}
TVM_DEFINE_OBJECT_REF_METHODS
(
ADT
,
ObjectRef
,
ADTObj
);
};
}
// namespace runtime
}
// namespace tvm
#endif // TVM_RUNTIME_CONTAINER_H_
include/tvm/runtime/memory.h
View file @
2bf5fd2b
...
...
@@ -23,6 +23,7 @@
#ifndef TVM_RUNTIME_MEMORY_H_
#define TVM_RUNTIME_MEMORY_H_
#include <cstdlib>
#include <utility>
#include <type_traits>
#include "object.h"
...
...
@@ -33,7 +34,7 @@ namespace runtime {
* \brief Allocate an object using default allocator.
* \param args arguments to the constructor.
* \tparam T the node type.
* \return The
Node
Ptr to the allocated object.
* \return The
Object
Ptr to the allocated object.
*/
template
<
typename
T
,
typename
...
Args
>
inline
ObjectPtr
<
T
>
make_object
(
Args
&&
...
args
);
...
...
@@ -67,13 +68,33 @@ class ObjAllocatorBase {
inline
ObjectPtr
<
T
>
make_object
(
Args
&&
...
args
)
{
using
Handler
=
typename
Derived
::
template
Handler
<
T
>
;
static_assert
(
std
::
is_base_of
<
Object
,
T
>::
value
,
"make
_node can only be used to create NodeBase
"
);
"make
can only be used to create Object
"
);
T
*
ptr
=
Handler
::
New
(
static_cast
<
Derived
*>
(
this
),
std
::
forward
<
Args
>
(
args
)...);
ptr
->
type_index_
=
T
::
RuntimeTypeIndex
();
ptr
->
deleter_
=
Handler
::
Deleter
();
return
ObjectPtr
<
T
>
(
ptr
);
}
/*!
* \tparam ArrayType The type to be allocated.
* \tparam ElemType The type of array element.
* \tparam Args The constructor signature.
* \param num_elems The number of array elements.
* \param args The arguments.
*/
template
<
typename
ArrayType
,
typename
ElemType
,
typename
...
Args
>
inline
ObjectPtr
<
ArrayType
>
make_inplace_array
(
size_t
num_elems
,
Args
&&
...
args
)
{
using
Handler
=
typename
Derived
::
template
ArrayHandler
<
ArrayType
,
ElemType
>
;
static_assert
(
std
::
is_base_of
<
Object
,
ArrayType
>::
value
,
"make_inplace_array can only be used to create Object"
);
ArrayType
*
ptr
=
Handler
::
New
(
static_cast
<
Derived
*>
(
this
),
num_elems
,
std
::
forward
<
Args
>
(
args
)...);
ptr
->
type_index_
=
ArrayType
::
RuntimeTypeIndex
();
ptr
->
deleter_
=
Handler
::
Deleter
();
return
ObjectPtr
<
ArrayType
>
(
ptr
);
}
};
// Simple allocator that uses new/delete.
...
...
@@ -123,6 +144,54 @@ class SimpleObjAllocator :
delete
reinterpret_cast
<
StorageType
*>
(
tptr
);
}
};
// Array handler that uses new/delete.
template
<
typename
ArrayType
,
typename
ElemType
>
class
ArrayHandler
{
public
:
using
StorageType
=
typename
std
::
aligned_union
<
sizeof
(
ArrayType
),
ArrayType
,
ElemType
>::
type
;
template
<
typename
...
Args
>
static
ArrayType
*
New
(
SimpleObjAllocator
*
,
size_t
num_elems
,
Args
&&
...
args
)
{
// NOTE: the first argument is not needed for ArrayObjAllocator
// It is reserved for special allocators that needs to recycle
// the object to itself (e.g. in the case of object pool).
//
// In the case of an object pool, an allocator needs to create
// a special chunk memory that hides reference to the allocator
// and call allocator's release function in the deleter.
// NOTE2: Use inplace new to allocate
// This is used to get rid of warning when deleting a virtual
// class with non-virtual destructor.
// We are fine here as we captured the right deleter during construction.
// This is also the right way to get storage type for an object pool.
size_t
factor
=
sizeof
(
ArrayType
)
/
sizeof
(
ElemType
);
num_elems
=
(
num_elems
+
factor
-
1
)
/
factor
;
StorageType
*
data
=
new
StorageType
[
num_elems
+
1
];
new
(
data
)
ArrayType
(
std
::
forward
<
Args
>
(
args
)...);
return
reinterpret_cast
<
ArrayType
*>
(
data
);
}
static
Object
::
FDeleter
Deleter
()
{
return
Deleter_
;
}
private
:
static
void
Deleter_
(
Object
*
objptr
)
{
// NOTE: this is important to cast back to ArrayType*
// because objptr and tptr may not be the same
// depending on how sub-class allocates the space.
ArrayType
*
tptr
=
static_cast
<
ArrayType
*>
(
objptr
);
// It is important to do tptr->ArrayType::~ArrayType(),
// so that we explicitly call the specific destructor
// instead of tptr->~ArrayType(), which could mean the intention
// call a virtual destructor(which may not be available and is not required).
tptr
->
ArrayType
::~
ArrayType
();
StorageType
*
p
=
reinterpret_cast
<
StorageType
*>
(
tptr
);
delete
[]
p
;
}
};
};
template
<
typename
T
,
typename
...
Args
>
...
...
@@ -130,6 +199,12 @@ inline ObjectPtr<T> make_object(Args&&... args) {
return
SimpleObjAllocator
().
make_object
<
T
>
(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
ArrayType
,
typename
ElemType
,
typename
...
Args
>
inline
ObjectPtr
<
ArrayType
>
make_inplace_array_object
(
size_t
num_elems
,
Args
&&
...
args
)
{
return
SimpleObjAllocator
().
make_inplace_array
<
ArrayType
,
ElemType
>
(
num_elems
,
std
::
forward
<
Args
>
(
args
)...);
}
}
// namespace runtime
}
// namespace tvm
#endif // TVM_RUNTIME_MEMORY_H_
include/tvm/runtime/vm.h
View file @
2bf5fd2b
...
...
@@ -55,35 +55,6 @@ class Tensor : public ObjectRef {
TVM_DEFINE_OBJECT_REF_METHODS
(
Tensor
,
ObjectRef
,
TensorObj
);
};
/*! \brief An object representing a structure or enumeration. */
class
ADTObj
:
public
Object
{
public
:
/*! \brief The tag representing the constructor used. */
size_t
tag
;
/*! \brief The fields of the structure. */
std
::
vector
<
ObjectRef
>
fields
;
static
constexpr
const
uint32_t
_type_index
=
TypeIndex
::
kVMADT
;
static
constexpr
const
char
*
_type_key
=
"vm.ADT"
;
TVM_DECLARE_FINAL_OBJECT_INFO
(
ADTObj
,
Object
);
};
/*! \brief reference to algebraic data type objects. */
class
ADT
:
public
ObjectRef
{
public
:
ADT
(
size_t
tag
,
std
::
vector
<
ObjectRef
>
fields
);
/*!
* \brief construct a tuple object.
* \param fields The fields of the tuple.
* \return The constructed tuple type.
*/
static
ADT
Tuple
(
std
::
vector
<
ObjectRef
>
fields
);
TVM_DEFINE_OBJECT_REF_METHODS
(
ADT
,
ObjectRef
,
ADTObj
);
};
/*! \brief An object representing a closure. */
class
ClosureObj
:
public
Object
{
public
:
...
...
src/runtime/vm/object.cc
View file @
2bf5fd2b
...
...
@@ -22,6 +22,7 @@
* \brief VM related objects.
*/
#include <tvm/logging.h>
#include <tvm/runtime/container.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/vm.h>
#include <tvm/runtime/memory.h>
...
...
@@ -39,17 +40,6 @@ Tensor::Tensor(NDArray data) {
data_
=
std
::
move
(
ptr
);
}
ADT
::
ADT
(
size_t
tag
,
std
::
vector
<
ObjectRef
>
fields
)
{
auto
ptr
=
make_object
<
ADTObj
>
();
ptr
->
tag
=
tag
;
ptr
->
fields
=
std
::
move
(
fields
);
data_
=
std
::
move
(
ptr
);
}
ADT
ADT
::
Tuple
(
std
::
vector
<
ObjectRef
>
fields
)
{
return
ADT
(
0
,
fields
);
}
Closure
::
Closure
(
size_t
func_index
,
std
::
vector
<
ObjectRef
>
free_vars
)
{
auto
ptr
=
make_object
<
ClosureObj
>
();
ptr
->
func_index
=
func_index
;
...
...
@@ -69,17 +59,15 @@ TVM_REGISTER_GLOBAL("_vmobj.GetTensorData")
TVM_REGISTER_GLOBAL
(
"_vmobj.GetADTTag"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
ObjectRef
obj
=
args
[
0
];
const
auto
*
cell
=
obj
.
as
<
ADTObj
>
();
CHECK
(
cell
!=
nullptr
);
*
rv
=
static_cast
<
int64_t
>
(
cell
->
tag
);
const
auto
&
adt
=
Downcast
<
ADT
>
(
obj
);
*
rv
=
static_cast
<
int64_t
>
(
adt
.
tag
());
});
TVM_REGISTER_GLOBAL
(
"_vmobj.GetADTNumberOfFields"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
ObjectRef
obj
=
args
[
0
];
const
auto
*
cell
=
obj
.
as
<
ADTObj
>
();
CHECK
(
cell
!=
nullptr
);
*
rv
=
static_cast
<
int64_t
>
(
cell
->
fields
.
size
());
const
auto
&
adt
=
Downcast
<
ADT
>
(
obj
);
*
rv
=
static_cast
<
int64_t
>
(
adt
.
size
());
});
...
...
@@ -87,10 +75,9 @@ TVM_REGISTER_GLOBAL("_vmobj.GetADTFields")
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
ObjectRef
obj
=
args
[
0
];
int
idx
=
args
[
1
];
const
auto
*
cell
=
obj
.
as
<
ADTObj
>
();
CHECK
(
cell
!=
nullptr
);
CHECK_LT
(
idx
,
cell
->
fields
.
size
());
*
rv
=
cell
->
fields
[
idx
];
const
auto
&
adt
=
Downcast
<
ADT
>
(
obj
);
CHECK_LT
(
idx
,
adt
.
size
());
*
rv
=
adt
[
idx
];
});
TVM_REGISTER_GLOBAL
(
"_vmobj.Tensor"
)
...
...
src/runtime/vm/vm.cc
View file @
2bf5fd2b
...
...
@@ -24,6 +24,7 @@
#include <dmlc/memory_io.h>
#include <tvm/logging.h>
#include <tvm/runtime/container.h>
#include <tvm/runtime/vm.h>
#include <tvm/runtime/memory.h>
#include <tvm/runtime/object.h>
...
...
@@ -755,7 +756,7 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func,
size_t
arity
=
0
;
for
(
Index
i
=
0
;
i
<
arg_count
;
i
++
)
{
if
(
const
auto
*
obj
=
args
[
i
].
as
<
ADTObj
>
())
{
arity
+=
obj
->
fields
.
size
()
;
arity
+=
obj
->
size
;
}
else
{
++
arity
;
}
...
...
@@ -767,7 +768,8 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func,
int
idx
=
0
;
for
(
Index
i
=
0
;
i
<
arg_count
;
i
++
)
{
if
(
const
auto
*
dt_cell
=
args
[
i
].
as
<
ADTObj
>
())
{
for
(
auto
obj
:
dt_cell
->
fields
)
{
for
(
size_t
fi
=
0
;
fi
<
dt_cell
->
size
;
++
fi
)
{
auto
obj
=
(
*
dt_cell
)[
fi
];
const
auto
*
tensor
=
obj
.
as
<
TensorObj
>
();
CHECK
(
tensor
!=
nullptr
);
setter
(
idx
++
,
tensor
->
data
);
...
...
@@ -924,23 +926,16 @@ void VirtualMachine::RunLoop() {
}
case
Opcode
:
:
GetField
:
{
auto
object
=
ReadRegister
(
instr
.
object
);
const
auto
*
tuple
=
object
.
as
<
ADTObj
>
();
CHECK
(
tuple
!=
nullptr
)
<<
"Object is not data type object, register "
<<
instr
.
object
<<
", Object tag "
<<
object
->
type_index
();
auto
field
=
tuple
->
fields
[
instr
.
field_index
];
const
auto
&
tuple
=
Downcast
<
ADT
>
(
object
);
auto
field
=
tuple
[
instr
.
field_index
];
WriteRegister
(
instr
.
dst
,
field
);
pc_
++
;
goto
main_loop
;
}
case
Opcode
:
:
GetTag
:
{
auto
object
=
ReadRegister
(
instr
.
get_tag
.
object
);
const
auto
*
data
=
object
.
as
<
ADTObj
>
();
CHECK
(
data
!=
nullptr
)
<<
"Object is not data type object, register "
<<
instr
.
get_tag
.
object
<<
", Object tag "
<<
object
->
type_index
();
auto
tag
=
data
->
tag
;
const
auto
&
adt
=
Downcast
<
ADT
>
(
object
);
auto
tag
=
adt
.
tag
();
auto
tag_tensor
=
NDArray
::
Empty
({
1
},
{
kDLInt
,
32
,
1
},
{
kDLCPU
,
0
});
reinterpret_cast
<
int32_t
*>
(
tag_tensor
->
data
)[
0
]
=
tag
;
WriteRegister
(
instr
.
dst
,
Tensor
(
tag_tensor
));
...
...
tests/cpp/container_test.cc
View file @
2bf5fd2b
...
...
@@ -17,11 +17,132 @@
* under the License.
*/
#include <vector>
#include <unordered_map>
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/packed_func_ext.h>
#include <tvm/runtime/container.h>
#include <new>
#include <unordered_map>
#include <vector>
using
namespace
tvm
;
using
namespace
tvm
::
runtime
;
class
TestErrorSwitch
{
public
:
// Need this so that destructor of temporary objects don't interrupt our
// testing.
TestErrorSwitch
(
const
TestErrorSwitch
&
other
)
:
should_fail
(
other
.
should_fail
)
{
const_cast
<
TestErrorSwitch
&>
(
other
).
should_fail
=
false
;
}
TestErrorSwitch
(
bool
fail_flag
)
:
should_fail
{
fail_flag
}
{}
bool
should_fail
{
false
};
~
TestErrorSwitch
()
{
if
(
should_fail
)
{
exit
(
1
);
}
}
};
class
TestArrayObj
:
public
Object
,
public
InplaceArrayBase
<
TestArrayObj
,
TestErrorSwitch
>
{
public
:
static
constexpr
const
uint32_t
_type_index
=
TypeIndex
::
kDynamic
;
static
constexpr
const
char
*
_type_key
=
"test.TestArrayObj"
;
TVM_DECLARE_FINAL_OBJECT_INFO
(
TestArrayObj
,
Object
);
uint32_t
size
;
size_t
GetSize
()
const
{
return
size
;
}
template
<
typename
Iterator
>
void
Init
(
Iterator
begin
,
Iterator
end
)
{
size_t
num_elems
=
std
::
distance
(
begin
,
end
);
this
->
size
=
0
;
auto
it
=
begin
;
for
(
size_t
i
=
0
;
i
<
num_elems
;
++
i
)
{
InplaceArrayBase
::
EmplaceInit
(
i
,
*
it
++
);
if
(
i
==
1
)
{
throw
std
::
bad_alloc
();
}
// Only increment size after the initialization succeeds
this
->
size
++
;
}
}
template
<
typename
Iterator
>
void
WrongInit
(
Iterator
begin
,
Iterator
end
)
{
size_t
num_elems
=
std
::
distance
(
begin
,
end
);
this
->
size
=
num_elems
;
auto
it
=
begin
;
for
(
size_t
i
=
0
;
i
<
num_elems
;
++
i
)
{
InplaceArrayBase
::
EmplaceInit
(
i
,
*
it
++
);
if
(
i
==
1
)
{
throw
std
::
bad_alloc
();
}
}
}
friend
class
InplaceArrayBase
;
};
TEST
(
ADT
,
Constructor
)
{
std
::
vector
<
ObjectRef
>
fields
;
auto
f1
=
ADT
::
Tuple
(
fields
);
auto
f2
=
ADT
::
Tuple
(
fields
);
ADT
v1
{
1
,
{
f1
,
f2
}};
ASSERT_EQ
(
f1
.
tag
(),
0
);
ASSERT_EQ
(
f2
.
size
(),
0
);
ASSERT_EQ
(
v1
.
tag
(),
1
);
ASSERT_EQ
(
v1
.
size
(),
2
);
ASSERT_EQ
(
Downcast
<
ADT
>
(
v1
[
0
]).
tag
(),
0
);
ASSERT_EQ
(
Downcast
<
ADT
>
(
v1
[
1
]).
size
(),
0
);
}
TEST
(
InplaceArrayBase
,
BadExceptionSafety
)
{
auto
wrong_init
=
[]()
{
TestErrorSwitch
f1
{
false
};
// WrongInit will set size to 3 so it will call destructor at index 1, which
// will exit with error status.
TestErrorSwitch
f2
{
true
};
TestErrorSwitch
f3
{
false
};
std
::
vector
<
TestErrorSwitch
>
fields
{
f1
,
f2
,
f3
};
auto
ptr
=
make_inplace_array_object
<
TestArrayObj
,
TestErrorSwitch
>
(
fields
.
size
());
try
{
ptr
->
WrongInit
(
fields
.
begin
(),
fields
.
end
());
}
catch
(...)
{
}
// Call ~InplaceArrayBase
ptr
.
reset
();
// never reaches here.
exit
(
0
);
};
ASSERT_EXIT
(
wrong_init
(),
::
testing
::
ExitedWithCode
(
1
),
""
);
}
TEST
(
InplaceArrayBase
,
ExceptionSafety
)
{
auto
correct_init
=
[]()
{
TestErrorSwitch
f1
{
false
};
// Init will fail at index 1, so destrucotr at index 1 should not be called
// since it's not initalized.
TestErrorSwitch
f2
{
true
};
std
::
vector
<
TestErrorSwitch
>
fields
{
f1
,
f2
};
auto
ptr
=
make_inplace_array_object
<
TestArrayObj
,
TestErrorSwitch
>
(
fields
.
size
());
try
{
ptr
->
Init
(
fields
.
begin
(),
fields
.
end
());
}
catch
(...)
{
}
// Call ~InplaceArrayBase
ptr
.
reset
();
// Skip the destructors of f1, f2, and fields
exit
(
0
);
};
ASSERT_EXIT
(
correct_init
(),
::
testing
::
ExitedWithCode
(
0
),
""
);
}
TEST
(
Array
,
Expr
)
{
using
namespace
tvm
;
...
...
@@ -99,11 +220,12 @@ TEST(Map, Iterator) {
using
namespace
tvm
;
Expr
a
=
1
,
b
=
2
;
Map
<
Expr
,
Expr
>
map1
{{
a
,
b
}};
std
::
unordered_map
<
Expr
,
Expr
,
NodeHash
,
NodeEqual
>
map2
(
map1
.
begin
(),
map1
.
end
());
std
::
unordered_map
<
Expr
,
Expr
,
NodeHash
,
NodeEqual
>
map2
(
map1
.
begin
(),
map1
.
end
());
CHECK
(
map2
[
a
].
as
<
IntImm
>
()
->
value
==
2
);
}
int
main
(
int
argc
,
char
**
argv
)
{
int
main
(
int
argc
,
char
**
argv
)
{
testing
::
InitGoogleTest
(
&
argc
,
argv
);
testing
::
FLAGS_gtest_death_test_style
=
"threadsafe"
;
return
RUN_ALL_TESTS
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment