Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
dd55682d
Commit
dd55682d
authored
May 02, 2019
by
Jared Roesch
Committed by
Tianqi Chen
May 02, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay][Runtime] Add support for virtual machine Objects (#3120)
parent
b175319c
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
619 additions
and
6 deletions
+619
-6
3rdparty/HalideIR
+1
-1
include/tvm/runtime/c_runtime_api.h
+2
-1
include/tvm/runtime/object.h
+391
-0
include/tvm/runtime/packed_func.h
+25
-0
python/setup.py
+1
-1
python/tvm/_ffi/_ctypes/function.py
+14
-0
python/tvm/_ffi/_cython/base.pxi
+2
-0
python/tvm/_ffi/_cython/function.pxi
+33
-0
python/tvm/_ffi/function.py
+12
-3
python/tvm/_ffi/runtime_ctypes.py
+2
-0
src/api/dsl_api.cc
+9
-0
src/lang/reflection.cc
+30
-0
src/relay/ir/pretty_printer.cc
+3
-0
src/runtime/vm/object.cc
+94
-0
No files found.
HalideIR
@
a768f2f0
Subproject commit
55ba1778fd264c7507953552d8e51212ed11f748
Subproject commit
a768f2f0627917659a4d7167eee3190469b9d164
include/tvm/runtime/c_runtime_api.h
View file @
dd55682d
...
@@ -112,7 +112,8 @@ typedef enum {
...
@@ -112,7 +112,8 @@ typedef enum {
kNNVMLast
=
20U
,
kNNVMLast
=
20U
,
// The following section of code is used for non-reserved types.
// The following section of code is used for non-reserved types.
kExtReserveEnd
=
64U
,
kExtReserveEnd
=
64U
,
kExtEnd
=
128U
kExtEnd
=
128U
,
kObject
=
14U
,
}
TVMTypeCode
;
}
TVMTypeCode
;
/*!
/*!
...
...
include/tvm/runtime/object.h
0 → 100644
View file @
dd55682d
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file tvm/runtime/object.h
* \brief A managed object in the TVM runtime.
*/
#ifndef TVM_RUNTIME_OBJECT_H_
#define TVM_RUNTIME_OBJECT_H_
#include <tvm/runtime/ndarray.h>
#include <memory>
#include <utility>
#include <vector>
namespace
tvm
{
namespace
runtime
{
template
<
typename
T
>
class
ObjectPtr
;
class
Object
;
enum
struct
ObjectTag
{
/*! \brief The tag of a tensor. */
kTensor
=
0U
,
/*! \brief The tag of a closure. */
kClosure
=
1U
,
/*! \brief The tag of a structure. */
kDatatype
=
2U
,
};
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
ObjectTag
&
);
struct
ObjectCell
{
public
:
/*!
* \brief The type of object deleter.
* \param The self pointer to the ObjectCell.
*/
typedef
void
(
*
FDeleter
)(
ObjectCell
*
self
);
/*! \brief The tag of the object.
*
* Describes which type of value
* is represented by this object.
*/
ObjectTag
tag
;
/*!
* \brief Increment the reference count.
*/
void
IncRef
()
{
ref_counter_
.
fetch_add
(
1
,
std
::
memory_order_relaxed
);
}
/*!
* \brief Decrement the reference count.
*/
void
DecRef
()
{
if
(
ref_counter_
.
fetch_sub
(
1
,
std
::
memory_order_release
)
==
1
)
{
std
::
atomic_thread_fence
(
std
::
memory_order_acquire
);
if
(
this
->
deleter_
!=
nullptr
)
{
(
*
this
->
deleter_
)(
this
);
}
}
}
protected
:
// default constructor and copy constructor
ObjectCell
()
{}
explicit
ObjectCell
(
ObjectTag
tag
)
:
tag
(
tag
)
{}
// override the copy and assign constructors to do nothing.
// This is to make sure only contents, but not deleter and ref_counter
// are copied when a child class copies itself.
ObjectCell
(
const
ObjectCell
&
other
)
{
// NOLINT(*)
}
ObjectCell
(
ObjectCell
&&
other
)
{
// NOLINT(*)
}
ObjectCell
&
operator
=
(
const
ObjectCell
&
other
)
{
// NOLINT(*)
return
*
this
;
}
ObjectCell
&
operator
=
(
ObjectCell
&&
other
)
{
// NOLINT(*)
return
*
this
;
}
private
:
/*! \brief Internal reference counter */
std
::
atomic
<
int
>
ref_counter_
{
0
};
/*!
* \brief deleter of this object to enable customized allocation.
* If the deleter is nullptr, no deletion will be performed.
* The creator of the Node must always set the deleter field properly.
*/
FDeleter
deleter_
=
nullptr
;
int
use_count
()
const
{
return
ref_counter_
.
load
(
std
::
memory_order_relaxed
);
}
// friend declaration
template
<
typename
>
friend
class
ObjectPtr
;
template
<
typename
Y
,
typename
...
Args
>
friend
ObjectPtr
<
Y
>
MakeObject
(
Args
&&
...);
};
/*!
* \brief A custom smart pointer for Object.
* must be subclass of NodeBase
* \tparam T the content data type.
*/
template
<
typename
T
>
class
ObjectPtr
{
public
:
/*! \brief default constructor */
ObjectPtr
()
{}
/*! \brief default constructor */
ObjectPtr
(
std
::
nullptr_t
)
{}
// NOLINT(*)
/*!
* \brief copy constructor
* \param other The value to be moved
*/
ObjectPtr
(
const
ObjectPtr
<
T
>&
other
)
// NOLINT(*)
:
ObjectPtr
(
other
.
data_
)
{}
/*!
* \brief copy constructor
* \param other The value to be moved
*/
template
<
typename
U
>
ObjectPtr
(
const
ObjectPtr
<
U
>&
other
)
// NOLINT(*)
:
ObjectPtr
(
other
.
data_
)
{
static_assert
(
std
::
is_base_of
<
T
,
U
>::
value
,
"can only assign of child class ObjectPtr to parent"
);
}
/*!
* \brief move constructor
* \param other The value to be moved
*/
ObjectPtr
(
ObjectPtr
<
T
>&&
other
)
// NOLINT(*)
:
data_
(
other
.
data_
)
{
other
.
data_
=
nullptr
;
}
/*!
* \brief move constructor
* \param other The value to be moved
*/
template
<
typename
Y
>
ObjectPtr
(
ObjectPtr
<
Y
>&&
other
)
// NOLINT(*)
:
data_
(
other
.
data_
)
{
static_assert
(
std
::
is_base_of
<
T
,
Y
>::
value
,
"can only assign of child class ObjectPtr to parent"
);
other
.
data_
=
nullptr
;
}
/*! \brief destructor */
~
ObjectPtr
()
{
this
->
reset
();
}
/*!
* \brief Swap this array with another Object
* \param other The other Object
*/
void
swap
(
ObjectPtr
<
T
>&
other
)
{
// NOLINT(*)
std
::
swap
(
data_
,
other
.
data_
);
}
/*!
* \return Get the content of the pointer
*/
T
*
get
()
const
{
return
static_cast
<
T
*>
(
data_
);
}
/*!
* \return The pointer
*/
T
*
operator
->
()
const
{
return
get
();
}
/*!
* \return The reference
*/
T
&
operator
*
()
const
{
// NOLINT(*)
return
*
get
();
}
/*!
* \brief copy assignmemt
* \param other The value to be assigned.
* \return reference to self.
*/
ObjectPtr
<
T
>&
operator
=
(
const
ObjectPtr
<
T
>&
other
)
{
// NOLINT(*)
// takes in plane operator to enable copy elison.
// copy-and-swap idiom
ObjectPtr
(
other
).
swap
(
*
this
);
// NOLINT(*)
return
*
this
;
}
/*!
* \brief move assignmemt
* \param other The value to be assigned.
* \return reference to self.
*/
ObjectPtr
<
T
>&
operator
=
(
ObjectPtr
<
T
>&&
other
)
{
// NOLINT(*)
// copy-and-swap idiom
ObjectPtr
(
std
::
move
(
other
)).
swap
(
*
this
);
// NOLINT(*)
return
*
this
;
}
/*! \brief reset the content of ptr to be nullptr */
void
reset
()
{
if
(
data_
!=
nullptr
)
{
data_
->
DecRef
();
data_
=
nullptr
;
}
}
/*! \return The use count of the ptr, for debug purposes */
int
use_count
()
const
{
return
data_
!=
nullptr
?
data_
->
use_count
()
:
0
;
}
/*! \return whether the reference is unique */
bool
unique
()
const
{
return
data_
!=
nullptr
&&
data_
->
use_count
()
==
1
;
}
/*! \return Whether two ObjectPtr do not equal each other */
bool
operator
==
(
const
ObjectPtr
<
T
>&
other
)
const
{
return
data_
==
other
.
data_
;
}
/*! \return Whether two ObjectPtr equals each other */
bool
operator
!=
(
const
ObjectPtr
<
T
>&
other
)
const
{
return
data_
!=
other
.
data_
;
}
/*! \return Whether the pointer is nullptr */
bool
operator
==
(
std
::
nullptr_t
null
)
const
{
return
data_
==
nullptr
;
}
/*! \return Whether the pointer is not nullptr */
bool
operator
!=
(
std
::
nullptr_t
null
)
const
{
return
data_
!=
nullptr
;
}
/* ObjectPtr's support custom allocators.
*
* The below allocator represents the simplest
* possible impl. It can be easily swapped
* for customized executor's, different allocation
* strategies, and so on.
*
* See memory.h for more discussion on NodePtr's
* allocator.
*/
class
StdAllocator
{
public
:
template
<
typename
...
Args
>
static
T
*
New
(
Args
&&
...
args
)
{
return
new
T
(
std
::
forward
<
Args
>
(
args
)...);
}
static
ObjectCell
::
FDeleter
Deleter
()
{
return
Deleter_
;
}
private
:
static
void
Deleter_
(
ObjectCell
*
ptr
)
{
delete
static_cast
<
T
*>
(
ptr
);
}
};
template
<
typename
U
>
ObjectPtr
<
U
>
As
()
const
{
auto
ptr
=
reinterpret_cast
<
U
*>
(
get
());
return
ObjectPtr
<
U
>
(
ptr
);
}
private
:
/*! \brief internal pointer field */
ObjectCell
*
data_
{
nullptr
};
/*!
* \brief constructor from NodeBase
* \param data The node base pointer
*/
// TODO(jroesch): NodePtr design doesn't really work here due to the passing.
public:
explicit
ObjectPtr
(
ObjectCell
*
data
)
:
data_
(
data
)
{
if
(
data
!=
nullptr
)
{
data_
->
IncRef
();
}
}
private
:
template
<
typename
Y
,
typename
...
Args
>
friend
ObjectPtr
<
Y
>
MakeObject
(
Args
&&
...);
template
<
typename
>
friend
class
ObjectPtr
;
friend
class
NDArray
;
friend
class
TVMPODValue_
;
friend
class
TVMArgValue
;
friend
class
TVMRetValue
;
friend
class
RPCWrappedFunc
;
};
struct
TensorCell
;
struct
DatatypeCell
;
struct
ClosureCell
;
/*!
* \brief A managed object in the TVM runtime.
*
* For example a tuple, list, closure, and so on.
*
* Maintains a reference count for the object.
*/
class
Object
{
public
:
ObjectPtr
<
ObjectCell
>
ptr_
;
explicit
Object
(
ObjectPtr
<
ObjectCell
>
ptr
)
:
ptr_
(
ptr
)
{}
explicit
Object
(
ObjectCell
*
ptr
)
:
ptr_
(
ptr
)
{}
Object
()
:
ptr_
()
{}
Object
(
const
Object
&
obj
)
:
ptr_
(
obj
.
ptr_
)
{}
ObjectCell
*
operator
->
()
{
return
this
->
ptr_
.
operator
->
();
}
/*! \brief Construct a tensor object. */
static
Object
Tensor
(
const
NDArray
&
data
);
/*! \brief Construct a datatype object. */
static
Object
Datatype
(
size_t
tag
,
const
std
::
vector
<
Object
>&
fields
);
/*! \brief Construct a tuple object. */
static
Object
Tuple
(
const
std
::
vector
<
Object
>&
fields
);
/*! \brief Construct a closure object. */
static
Object
Closure
(
size_t
func_index
,
const
std
::
vector
<
Object
>&
free_vars
);
ObjectPtr
<
TensorCell
>
AsTensor
()
const
;
ObjectPtr
<
DatatypeCell
>
AsDatatype
()
const
;
ObjectPtr
<
ClosureCell
>
AsClosure
()
const
;
};
/*! \brief An object containing an NDArray. */
struct
TensorCell
:
public
ObjectCell
{
/*! \brief The NDArray. */
NDArray
data
;
explicit
TensorCell
(
const
NDArray
&
data
)
:
ObjectCell
(
ObjectTag
::
kTensor
),
data
(
data
)
{}
};
/*! \brief An object representing a structure or enumeration. */
struct
DatatypeCell
:
public
ObjectCell
{
/*! \brief The tag representing the constructor used. */
size_t
tag
;
/*! \brief The fields of the structure. */
std
::
vector
<
Object
>
fields
;
DatatypeCell
(
size_t
tag
,
const
std
::
vector
<
Object
>&
fields
)
:
ObjectCell
(
ObjectTag
::
kDatatype
),
tag
(
tag
),
fields
(
fields
)
{}
};
/*! \brief An object representing a closure. */
struct
ClosureCell
:
public
ObjectCell
{
/*! \brief The index into the VM function table. */
size_t
func_index
;
/*! \brief The free variables of the closure. */
std
::
vector
<
Object
>
free_vars
;
ClosureCell
(
size_t
func_index
,
const
std
::
vector
<
Object
>&
free_vars
)
:
ObjectCell
(
ObjectTag
::
kClosure
),
func_index
(
func_index
),
free_vars
(
free_vars
)
{}
};
/*! \brief Extract the NDArray from a tensor object. */
NDArray
ToNDArray
(
const
Object
&
obj
);
/*!
* \brief Allocate a node object.
* \param args arguments to the constructor.
* \tparam T the node type.
* \return The NodePtr to the allocated object.
*/
template
<
typename
T
,
typename
...
Args
>
inline
ObjectPtr
<
T
>
MakeObject
(
Args
&&
...
args
)
{
using
Allocator
=
typename
ObjectPtr
<
T
>::
StdAllocator
;
static_assert
(
std
::
is_base_of
<
ObjectCell
,
T
>::
value
,
"MakeObject can only be used to create "
);
T
*
node
=
Allocator
::
New
(
std
::
forward
<
Args
>
(
args
)...);
node
->
deleter_
=
Allocator
::
Deleter
();
return
ObjectPtr
<
T
>
(
node
);
}
}
// namespace runtime
}
// namespace tvm
#endif // TVM_RUNTIME_OBJECT_H_
include/tvm/runtime/packed_func.h
View file @
dd55682d
...
@@ -39,6 +39,7 @@
...
@@ -39,6 +39,7 @@
#include "c_runtime_api.h"
#include "c_runtime_api.h"
#include "module.h"
#include "module.h"
#include "ndarray.h"
#include "ndarray.h"
#include "object.h"
#include "node_base.h"
#include "node_base.h"
namespace
HalideIR
{
namespace
HalideIR
{
...
@@ -48,6 +49,7 @@ struct Type;
...
@@ -48,6 +49,7 @@ struct Type;
struct
Expr
;
struct
Expr
;
}
}
// Whether use TVM runtime in header only mode.
// Whether use TVM runtime in header only mode.
#ifndef TVM_RUNTIME_HEADER_ONLY
#ifndef TVM_RUNTIME_HEADER_ONLY
#define TVM_RUNTIME_HEADER_ONLY 0
#define TVM_RUNTIME_HEADER_ONLY 0
...
@@ -470,6 +472,11 @@ class TVMPODValue_ {
...
@@ -470,6 +472,11 @@ class TVMPODValue_ {
TVM_CHECK_TYPE_CODE
(
type_code_
,
kNDArrayContainer
);
TVM_CHECK_TYPE_CODE
(
type_code_
,
kNDArrayContainer
);
return
NDArray
(
static_cast
<
NDArray
::
Container
*>
(
value_
.
v_handle
));
return
NDArray
(
static_cast
<
NDArray
::
Container
*>
(
value_
.
v_handle
));
}
}
operator
Object
()
const
{
if
(
type_code_
==
kNull
)
return
Object
();
TVM_CHECK_TYPE_CODE
(
type_code_
,
kObject
);
return
Object
(
static_cast
<
ObjectCell
*>
(
value_
.
v_handle
));
}
operator
TVMContext
()
const
{
operator
TVMContext
()
const
{
TVM_CHECK_TYPE_CODE
(
type_code_
,
kTVMContext
);
TVM_CHECK_TYPE_CODE
(
type_code_
,
kTVMContext
);
return
value_
.
v_ctx
;
return
value_
.
v_ctx
;
...
@@ -542,6 +549,7 @@ class TVMArgValue : public TVMPODValue_ {
...
@@ -542,6 +549,7 @@ class TVMArgValue : public TVMPODValue_ {
using
TVMPODValue_
::
operator
DLTensor
*
;
using
TVMPODValue_
::
operator
DLTensor
*
;
using
TVMPODValue_
::
operator
NDArray
;
using
TVMPODValue_
::
operator
NDArray
;
using
TVMPODValue_
::
operator
TVMContext
;
using
TVMPODValue_
::
operator
TVMContext
;
using
TVMPODValue_
::
operator
Object
;
// conversion operator.
// conversion operator.
operator
std
::
string
()
const
{
operator
std
::
string
()
const
{
...
@@ -637,6 +645,7 @@ class TVMRetValue : public TVMPODValue_ {
...
@@ -637,6 +645,7 @@ class TVMRetValue : public TVMPODValue_ {
using
TVMPODValue_
::
operator
DLTensor
*
;
using
TVMPODValue_
::
operator
DLTensor
*
;
using
TVMPODValue_
::
operator
TVMContext
;
using
TVMPODValue_
::
operator
TVMContext
;
using
TVMPODValue_
::
operator
NDArray
;
using
TVMPODValue_
::
operator
NDArray
;
using
TVMPODValue_
::
operator
Object
;
TVMRetValue
(
const
TVMRetValue
&
other
)
:
TVMPODValue_
()
{
TVMRetValue
(
const
TVMRetValue
&
other
)
:
TVMPODValue_
()
{
this
->
Assign
(
other
);
this
->
Assign
(
other
);
}
}
...
@@ -733,6 +742,13 @@ class TVMRetValue : public TVMPODValue_ {
...
@@ -733,6 +742,13 @@ class TVMRetValue : public TVMPODValue_ {
other
.
data_
=
nullptr
;
other
.
data_
=
nullptr
;
return
*
this
;
return
*
this
;
}
}
TVMRetValue
&
operator
=
(
Object
other
)
{
this
->
Clear
();
type_code_
=
kObject
;
value_
.
v_handle
=
other
.
ptr_
.
data_
;
other
.
ptr_
.
data_
=
nullptr
;
return
*
this
;
}
TVMRetValue
&
operator
=
(
PackedFunc
f
)
{
TVMRetValue
&
operator
=
(
PackedFunc
f
)
{
this
->
SwitchToClass
(
kFuncHandle
,
f
);
this
->
SwitchToClass
(
kFuncHandle
,
f
);
return
*
this
;
return
*
this
;
...
@@ -828,6 +844,10 @@ class TVMRetValue : public TVMPODValue_ {
...
@@ -828,6 +844,10 @@ class TVMRetValue : public TVMPODValue_ {
kNodeHandle
,
*
other
.
template
ptr
<
NodePtr
<
Node
>
>
());
kNodeHandle
,
*
other
.
template
ptr
<
NodePtr
<
Node
>
>
());
break
;
break
;
}
}
case
kObject
:
{
*
this
=
other
.
operator
Object
();
break
;
}
default
:
{
default
:
{
if
(
other
.
type_code
()
<
kExtBegin
)
{
if
(
other
.
type_code
()
<
kExtBegin
)
{
SwitchToPOD
(
other
.
type_code
());
SwitchToPOD
(
other
.
type_code
());
...
@@ -875,6 +895,10 @@ class TVMRetValue : public TVMPODValue_ {
...
@@ -875,6 +895,10 @@ class TVMRetValue : public TVMPODValue_ {
static_cast
<
NDArray
::
Container
*>
(
value_
.
v_handle
)
->
DecRef
();
static_cast
<
NDArray
::
Container
*>
(
value_
.
v_handle
)
->
DecRef
();
break
;
break
;
}
}
case
kObject
:
{
static_cast
<
ObjectCell
*>
(
value_
.
v_handle
)
->
DecRef
();
break
;
}
}
}
if
(
type_code_
>
kExtBegin
)
{
if
(
type_code_
>
kExtBegin
)
{
#if TVM_RUNTIME_HEADER_ONLY
#if TVM_RUNTIME_HEADER_ONLY
...
@@ -904,6 +928,7 @@ inline const char* TypeCode2Str(int type_code) {
...
@@ -904,6 +928,7 @@ inline const char* TypeCode2Str(int type_code) {
case
kFuncHandle
:
return
"FunctionHandle"
;
case
kFuncHandle
:
return
"FunctionHandle"
;
case
kModuleHandle
:
return
"ModuleHandle"
;
case
kModuleHandle
:
return
"ModuleHandle"
;
case
kNDArrayContainer
:
return
"NDArrayContainer"
;
case
kNDArrayContainer
:
return
"NDArrayContainer"
;
case
kObject
:
return
"Object"
;
default:
LOG
(
FATAL
)
<<
"unknown type_code="
default:
LOG
(
FATAL
)
<<
"unknown type_code="
<<
static_cast
<
int
>
(
type_code
);
return
""
;
<<
static_cast
<
int
>
(
type_code
);
return
""
;
}
}
...
...
python/setup.py
View file @
dd55682d
...
@@ -96,7 +96,7 @@ def config_cython():
...
@@ -96,7 +96,7 @@ def config_cython():
library_dirs
=
library_dirs
,
library_dirs
=
library_dirs
,
libraries
=
libraries
,
libraries
=
libraries
,
language
=
"c++"
))
language
=
"c++"
))
return
cythonize
(
ret
)
return
cythonize
(
ret
,
compiler_directives
=
{
"language_level"
:
3
}
)
except
ImportError
:
except
ImportError
:
print
(
"WARNING: Cython is not installed, will compile without cython module"
)
print
(
"WARNING: Cython is not installed, will compile without cython module"
)
return
[]
return
[]
...
...
python/tvm/_ffi/_ctypes/function.py
View file @
dd55682d
...
@@ -162,6 +162,9 @@ def _make_tvm_args(args, temp_args):
...
@@ -162,6 +162,9 @@ def _make_tvm_args(args, temp_args):
values
[
i
]
.
v_handle
=
arg
.
handle
values
[
i
]
.
v_handle
=
arg
.
handle
type_codes
[
i
]
=
TypeCode
.
FUNC_HANDLE
type_codes
[
i
]
=
TypeCode
.
FUNC_HANDLE
temp_args
.
append
(
arg
)
temp_args
.
append
(
arg
)
elif
isinstance
(
arg
,
ObjectBase
):
values
[
i
]
.
v_handle
=
arg
.
handle
type_codes
[
i
]
=
TypeCode
.
OBJECT
else
:
else
:
raise
TypeError
(
"Don't know how to handle type
%
s"
%
type
(
arg
))
raise
TypeError
(
"Don't know how to handle type
%
s"
%
type
(
arg
))
return
values
,
type_codes
,
num_args
return
values
,
type_codes
,
num_args
...
@@ -240,12 +243,18 @@ def _handle_return_func(x):
...
@@ -240,12 +243,18 @@ def _handle_return_func(x):
handle
=
FunctionHandle
(
handle
)
handle
=
FunctionHandle
(
handle
)
return
_CLASS_FUNCTION
(
handle
,
False
)
return
_CLASS_FUNCTION
(
handle
,
False
)
class
ObjectBase
(
object
):
__slots__
=
[
"handle"
]
def
__init__
(
self
,
handle
):
self
.
handle
=
handle
# setup return handle for function type
# setup return handle for function type
_node
.
__init_by_constructor__
=
__init_handle_by_constructor__
_node
.
__init_by_constructor__
=
__init_handle_by_constructor__
RETURN_SWITCH
[
TypeCode
.
FUNC_HANDLE
]
=
_handle_return_func
RETURN_SWITCH
[
TypeCode
.
FUNC_HANDLE
]
=
_handle_return_func
RETURN_SWITCH
[
TypeCode
.
MODULE_HANDLE
]
=
_return_module
RETURN_SWITCH
[
TypeCode
.
MODULE_HANDLE
]
=
_return_module
RETURN_SWITCH
[
TypeCode
.
NDARRAY_CONTAINER
]
=
lambda
x
:
_make_array
(
x
.
v_handle
,
False
,
True
)
RETURN_SWITCH
[
TypeCode
.
NDARRAY_CONTAINER
]
=
lambda
x
:
_make_array
(
x
.
v_handle
,
False
,
True
)
RETURN_SWITCH
[
TypeCode
.
OBJECT
]
=
lambda
x
:
_CLASS_OBJECT
(
x
.
v_handle
)
C_TO_PY_ARG_SWITCH
[
TypeCode
.
FUNC_HANDLE
]
=
_wrap_arg_func
(
C_TO_PY_ARG_SWITCH
[
TypeCode
.
FUNC_HANDLE
]
=
_wrap_arg_func
(
_handle_return_func
,
TypeCode
.
FUNC_HANDLE
)
_handle_return_func
,
TypeCode
.
FUNC_HANDLE
)
C_TO_PY_ARG_SWITCH
[
TypeCode
.
MODULE_HANDLE
]
=
_wrap_arg_func
(
C_TO_PY_ARG_SWITCH
[
TypeCode
.
MODULE_HANDLE
]
=
_wrap_arg_func
(
...
@@ -255,6 +264,7 @@ C_TO_PY_ARG_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handl
...
@@ -255,6 +264,7 @@ C_TO_PY_ARG_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handl
_CLASS_MODULE
=
None
_CLASS_MODULE
=
None
_CLASS_FUNCTION
=
None
_CLASS_FUNCTION
=
None
_CLASS_OBJECT
=
None
def
_set_class_module
(
module_class
):
def
_set_class_module
(
module_class
):
"""Initialize the module."""
"""Initialize the module."""
...
@@ -264,3 +274,7 @@ def _set_class_module(module_class):
...
@@ -264,3 +274,7 @@ def _set_class_module(module_class):
def
_set_class_function
(
func_class
):
def
_set_class_function
(
func_class
):
global
_CLASS_FUNCTION
global
_CLASS_FUNCTION
_CLASS_FUNCTION
=
func_class
_CLASS_FUNCTION
=
func_class
def
_set_class_object
(
obj_class
):
global
_CLASS_OBJECT
_CLASS_OBJECT
=
obj_class
python/tvm/_ffi/_cython/base.pxi
View file @
dd55682d
...
@@ -37,6 +37,7 @@ cdef enum TVMTypeCode:
...
@@ -37,6 +37,7 @@ cdef enum TVMTypeCode:
kStr = 11
kStr = 11
kBytes = 12
kBytes = 12
kNDArrayContainer = 13
kNDArrayContainer = 13
kObject = 14
kExtBegin = 15
kExtBegin = 15
cdef extern from "tvm/runtime/c_runtime_api.h":
cdef extern from "tvm/runtime/c_runtime_api.h":
...
@@ -76,6 +77,7 @@ ctypedef DLTensor* DLTensorHandle
...
@@ -76,6 +77,7 @@ ctypedef DLTensor* DLTensorHandle
ctypedef void* TVMStreamHandle
ctypedef void* TVMStreamHandle
ctypedef void* TVMRetValueHandle
ctypedef void* TVMRetValueHandle
ctypedef void* TVMFunctionHandle
ctypedef void* TVMFunctionHandle
ctypedef void* ObjectHandle
ctypedef void* NodeHandle
ctypedef void* NodeHandle
ctypedef struct TVMNDArrayContainer:
ctypedef struct TVMNDArrayContainer:
...
...
python/tvm/_ffi/_cython/function.pxi
View file @
dd55682d
...
@@ -44,6 +44,7 @@ cdef int tvm_callback(TVMValue* args,
...
@@ -44,6 +44,7 @@ cdef int tvm_callback(TVMValue* args,
if (tcode == kNodeHandle or
if (tcode == kNodeHandle or
tcode == kFuncHandle or
tcode == kFuncHandle or
tcode == kModuleHandle or
tcode == kModuleHandle or
tcode == kObject or
tcode > kExtBegin):
tcode > kExtBegin):
CALL(TVMCbArgToReturn(&value, tcode))
CALL(TVMCbArgToReturn(&value, tcode))
...
@@ -157,6 +158,9 @@ cdef inline int make_arg(object arg,
...
@@ -157,6 +158,9 @@ cdef inline int make_arg(object arg,
elif isinstance(arg, _CLASS_MODULE):
elif isinstance(arg, _CLASS_MODULE):
value[0].v_handle = c_handle(arg.handle)
value[0].v_handle = c_handle(arg.handle)
tcode[0] = kModuleHandle
tcode[0] = kModuleHandle
elif isinstance(arg, _CLASS_OBJECT):
value[0].v_handle = c_handle(arg.handle)
tcode[0] = kObject
elif isinstance(arg, FunctionBase):
elif isinstance(arg, FunctionBase):
value[0].v_handle = (<FunctionBase>arg).chandle
value[0].v_handle = (<FunctionBase>arg).chandle
tcode[0] = kFuncHandle
tcode[0] = kFuncHandle
...
@@ -208,6 +212,8 @@ cdef inline object make_ret(TVMValue value, int tcode):
...
@@ -208,6 +212,8 @@ cdef inline object make_ret(TVMValue value, int tcode):
fobj = _CLASS_FUNCTION(None, False)
fobj = _CLASS_FUNCTION(None, False)
(<FunctionBase>fobj).chandle = value.v_handle
(<FunctionBase>fobj).chandle = value.v_handle
return fobj
return fobj
elif tcode == kObject:
return _CLASS_OBJECT(ctypes_handle(value.v_handle))
elif tcode in _TVM_EXT_RET:
elif tcode in _TVM_EXT_RET:
return _TVM_EXT_RET[tcode](ctypes_handle(value.v_handle))
return _TVM_EXT_RET[tcode](ctypes_handle(value.v_handle))
...
@@ -304,8 +310,31 @@ cdef class FunctionBase:
...
@@ -304,8 +310,31 @@ cdef class FunctionBase:
FuncCall(self.chandle, args, &ret_val, &ret_tcode)
FuncCall(self.chandle, args, &ret_val, &ret_tcode)
return make_ret(ret_val, ret_tcode)
return make_ret(ret_val, ret_tcode)
cdef class ObjectBase:
cdef ObjectHandle chandle
cdef inline _set_handle(self, handle):
if handle is None:
self.chandle = NULL
else:
self.chandle = c_handle(handle)
property handle:
def __get__(self):
if self.chandle == NULL:
return None
else:
return ctypes.cast(<unsigned long long>self.chandle, ctypes.c_void_p)
def __set__(self, value):
self._set_handle(value)
def __init__(self, handle):
self._set_handle(handle)
_CLASS_FUNCTION = None
_CLASS_FUNCTION = None
_CLASS_MODULE = None
_CLASS_MODULE = None
_CLASS_OBJECT = None
def _set_class_module(module_class):
def _set_class_module(module_class):
"""Initialize the module."""
"""Initialize the module."""
...
@@ -315,3 +344,7 @@ def _set_class_module(module_class):
...
@@ -315,3 +344,7 @@ def _set_class_module(module_class):
def _set_class_function(func_class):
def _set_class_function(func_class):
global _CLASS_FUNCTION
global _CLASS_FUNCTION
_CLASS_FUNCTION = func_class
_CLASS_FUNCTION = func_class
def _set_class_object(obj_class):
global _CLASS_OBJECT
_CLASS_OBJECT = obj_class
python/tvm/_ffi/function.py
View file @
dd55682d
...
@@ -30,19 +30,28 @@ try:
...
@@ -30,19 +30,28 @@ try:
if
_FFI_MODE
==
"ctypes"
:
if
_FFI_MODE
==
"ctypes"
:
raise
ImportError
()
raise
ImportError
()
if
sys
.
version_info
>=
(
3
,
0
):
if
sys
.
version_info
>=
(
3
,
0
):
from
._cy3.core
import
_set_class_function
,
_set_class_module
from
._cy3.core
import
_set_class_function
,
_set_class_module
,
_set_class_object
from
._cy3.core
import
FunctionBase
as
_FunctionBase
from
._cy3.core
import
FunctionBase
as
_FunctionBase
from
._cy3.core
import
ObjectBase
as
_ObjectBase
from
._cy3.core
import
convert_to_tvm_func
from
._cy3.core
import
convert_to_tvm_func
else
:
else
:
from
._cy2.core
import
_set_class_function
,
_set_class_module
from
._cy2.core
import
_set_class_function
,
_set_class_module
,
_set_class_object
from
._cy2.core
import
FunctionBase
as
_FunctionBase
from
._cy2.core
import
FunctionBase
as
_FunctionBase
from
._cy2.core
import
ObjectBase
as
_ObjectBase
from
._cy2.core
import
convert_to_tvm_func
from
._cy2.core
import
convert_to_tvm_func
except
IMPORT_EXCEPT
:
except
IMPORT_EXCEPT
:
# pylint: disable=wrong-import-position
# pylint: disable=wrong-import-position
from
._ctypes.function
import
_set_class_function
,
_set_class_module
from
._ctypes.function
import
_set_class_function
,
_set_class_module
,
_set_class_object
from
._ctypes.function
import
ObjectBase
as
_ObjectBase
from
._ctypes.function
import
FunctionBase
as
_FunctionBase
from
._ctypes.function
import
FunctionBase
as
_FunctionBase
from
._ctypes.function
import
convert_to_tvm_func
from
._ctypes.function
import
convert_to_tvm_func
class
Object
(
_ObjectBase
):
# TODO(@jroesch): Eventually add back introspection functionality.
pass
_set_class_object
(
Object
)
FunctionHandle
=
ctypes
.
c_void_p
FunctionHandle
=
ctypes
.
c_void_p
class
Function
(
_FunctionBase
):
class
Function
(
_FunctionBase
):
...
...
python/tvm/_ffi/runtime_ctypes.py
View file @
dd55682d
...
@@ -42,8 +42,10 @@ class TypeCode(object):
...
@@ -42,8 +42,10 @@ class TypeCode(object):
STR
=
11
STR
=
11
BYTES
=
12
BYTES
=
12
NDARRAY_CONTAINER
=
13
NDARRAY_CONTAINER
=
13
OBJECT
=
14
EXT_BEGIN
=
15
EXT_BEGIN
=
15
class
TVMByteArray
(
ctypes
.
Structure
):
class
TVMByteArray
(
ctypes
.
Structure
):
"""Temp data structure for byte array."""
"""Temp data structure for byte array."""
_fields_
=
[(
"data"
,
ctypes
.
POINTER
(
ctypes
.
c_byte
)),
_fields_
=
[(
"data"
,
ctypes
.
POINTER
(
ctypes
.
c_byte
)),
...
...
src/api/dsl_api.cc
View file @
dd55682d
...
@@ -92,6 +92,12 @@ struct APIAttrGetter : public AttrVisitor {
...
@@ -92,6 +92,12 @@ struct APIAttrGetter : public AttrVisitor {
found_ref_object
=
true
;
found_ref_object
=
true
;
}
}
}
}
void
Visit
(
const
char
*
key
,
runtime
::
Object
*
value
)
final
{
if
(
skey
==
key
)
{
*
ret
=
value
[
0
];
found_ref_object
=
true
;
}
}
};
};
struct
APIAttrDir
:
public
AttrVisitor
{
struct
APIAttrDir
:
public
AttrVisitor
{
...
@@ -127,6 +133,9 @@ struct APIAttrDir : public AttrVisitor {
...
@@ -127,6 +133,9 @@ struct APIAttrDir : public AttrVisitor {
void
Visit
(
const
char
*
key
,
runtime
::
NDArray
*
value
)
final
{
void
Visit
(
const
char
*
key
,
runtime
::
NDArray
*
value
)
final
{
names
->
push_back
(
key
);
names
->
push_back
(
key
);
}
}
void
Visit
(
const
char
*
key
,
runtime
::
Object
*
value
)
final
{
names
->
push_back
(
key
);
}
};
};
class
DSLAPIImpl
:
public
DSLAPI
{
class
DSLAPIImpl
:
public
DSLAPI
{
...
...
src/lang/reflection.cc
View file @
dd55682d
...
@@ -53,6 +53,8 @@ inline Type String2Type(std::string s) {
...
@@ -53,6 +53,8 @@ inline Type String2Type(std::string s) {
return
TVMType2Type
(
runtime
::
String2TVMType
(
s
));
return
TVMType2Type
(
runtime
::
String2TVMType
(
s
));
}
}
using
runtime
::
Object
;
using
runtime
::
ObjectCell
;
// indexer to index all the ndoes
// indexer to index all the ndoes
class
NodeIndexer
:
public
AttrVisitor
{
class
NodeIndexer
:
public
AttrVisitor
{
...
@@ -61,6 +63,8 @@ class NodeIndexer : public AttrVisitor {
...
@@ -61,6 +63,8 @@ class NodeIndexer : public AttrVisitor {
std
::
vector
<
Node
*>
node_list
{
nullptr
};
std
::
vector
<
Node
*>
node_list
{
nullptr
};
std
::
unordered_map
<
DLTensor
*
,
size_t
>
tensor_index
;
std
::
unordered_map
<
DLTensor
*
,
size_t
>
tensor_index
;
std
::
vector
<
DLTensor
*>
tensor_list
;
std
::
vector
<
DLTensor
*>
tensor_list
;
std
::
unordered_map
<
ObjectCell
*
,
size_t
>
vm_obj_index
;
std
::
vector
<
ObjectCell
*>
vm_obj_list
;
void
Visit
(
const
char
*
key
,
double
*
value
)
final
{}
void
Visit
(
const
char
*
key
,
double
*
value
)
final
{}
void
Visit
(
const
char
*
key
,
int64_t
*
value
)
final
{}
void
Visit
(
const
char
*
key
,
int64_t
*
value
)
final
{}
...
@@ -73,6 +77,7 @@ class NodeIndexer : public AttrVisitor {
...
@@ -73,6 +77,7 @@ class NodeIndexer : public AttrVisitor {
void
Visit
(
const
char
*
key
,
NodeRef
*
value
)
final
{
void
Visit
(
const
char
*
key
,
NodeRef
*
value
)
final
{
MakeIndex
(
value
->
node_
.
get
());
MakeIndex
(
value
->
node_
.
get
());
}
}
void
Visit
(
const
char
*
key
,
runtime
::
NDArray
*
value
)
final
{
void
Visit
(
const
char
*
key
,
runtime
::
NDArray
*
value
)
final
{
DLTensor
*
ptr
=
const_cast
<
DLTensor
*>
((
*
value
).
operator
->
());
DLTensor
*
ptr
=
const_cast
<
DLTensor
*>
((
*
value
).
operator
->
());
if
(
tensor_index
.
count
(
ptr
))
return
;
if
(
tensor_index
.
count
(
ptr
))
return
;
...
@@ -80,6 +85,15 @@ class NodeIndexer : public AttrVisitor {
...
@@ -80,6 +85,15 @@ class NodeIndexer : public AttrVisitor {
tensor_index
[
ptr
]
=
tensor_list
.
size
();
tensor_index
[
ptr
]
=
tensor_list
.
size
();
tensor_list
.
push_back
(
ptr
);
tensor_list
.
push_back
(
ptr
);
}
}
void
Visit
(
const
char
*
key
,
Object
*
value
)
final
{
ObjectCell
*
ptr
=
value
->
ptr_
.
get
();
if
(
vm_obj_index
.
count
(
ptr
))
return
;
CHECK_EQ
(
vm_obj_index
.
size
(),
vm_obj_list
.
size
());
vm_obj_index
[
ptr
]
=
vm_obj_list
.
size
();
vm_obj_list
.
push_back
(
ptr
);
}
// make index of all the children of node
// make index of all the children of node
void
MakeIndex
(
Node
*
node
)
{
void
MakeIndex
(
Node
*
node
)
{
if
(
node
==
nullptr
)
return
;
if
(
node
==
nullptr
)
return
;
...
@@ -163,6 +177,7 @@ class JSONAttrGetter : public AttrVisitor {
...
@@ -163,6 +177,7 @@ class JSONAttrGetter : public AttrVisitor {
public
:
public
:
const
std
::
unordered_map
<
Node
*
,
size_t
>*
node_index_
;
const
std
::
unordered_map
<
Node
*
,
size_t
>*
node_index_
;
const
std
::
unordered_map
<
DLTensor
*
,
size_t
>*
tensor_index_
;
const
std
::
unordered_map
<
DLTensor
*
,
size_t
>*
tensor_index_
;
const
std
::
unordered_map
<
ObjectCell
*
,
size_t
>*
vm_obj_index_
;
JSONNode
*
node_
;
JSONNode
*
node_
;
void
Visit
(
const
char
*
key
,
double
*
value
)
final
{
void
Visit
(
const
char
*
key
,
double
*
value
)
final
{
...
@@ -197,6 +212,10 @@ class JSONAttrGetter : public AttrVisitor {
...
@@ -197,6 +212,10 @@ class JSONAttrGetter : public AttrVisitor {
node_
->
attrs
[
key
]
=
std
::
to_string
(
node_
->
attrs
[
key
]
=
std
::
to_string
(
tensor_index_
->
at
(
const_cast
<
DLTensor
*>
((
*
value
).
operator
->
())));
tensor_index_
->
at
(
const_cast
<
DLTensor
*>
((
*
value
).
operator
->
())));
}
}
void
Visit
(
const
char
*
key
,
Object
*
value
)
final
{
node_
->
attrs
[
key
]
=
std
::
to_string
(
vm_obj_index_
->
at
(
value
->
ptr_
.
get
()));
}
// Get the node
// Get the node
void
Get
(
Node
*
node
)
{
void
Get
(
Node
*
node
)
{
if
(
node
==
nullptr
)
{
if
(
node
==
nullptr
)
{
...
@@ -250,6 +269,8 @@ class JSONAttrSetter : public AttrVisitor {
...
@@ -250,6 +269,8 @@ class JSONAttrSetter : public AttrVisitor {
public
:
public
:
const
std
::
vector
<
NodePtr
<
Node
>
>*
node_list_
;
const
std
::
vector
<
NodePtr
<
Node
>
>*
node_list_
;
const
std
::
vector
<
runtime
::
NDArray
>*
tensor_list_
;
const
std
::
vector
<
runtime
::
NDArray
>*
tensor_list_
;
const
std
::
vector
<
Object
>*
vm_obj_list_
;
JSONNode
*
node_
;
JSONNode
*
node_
;
std
::
string
GetValue
(
const
char
*
key
)
const
{
std
::
string
GetValue
(
const
char
*
key
)
const
{
...
@@ -304,6 +325,12 @@ class JSONAttrSetter : public AttrVisitor {
...
@@ -304,6 +325,12 @@ class JSONAttrSetter : public AttrVisitor {
CHECK_LE
(
index
,
tensor_list_
->
size
());
CHECK_LE
(
index
,
tensor_list_
->
size
());
*
value
=
tensor_list_
->
at
(
index
);
*
value
=
tensor_list_
->
at
(
index
);
}
}
void
Visit
(
const
char
*
key
,
Object
*
value
)
final
{
size_t
index
;
ParseValue
(
key
,
&
index
);
CHECK_LE
(
index
,
vm_obj_list_
->
size
());
*
value
=
vm_obj_list_
->
at
(
index
);
}
// set node to be current JSONNode
// set node to be current JSONNode
void
Set
(
Node
*
node
)
{
void
Set
(
Node
*
node
)
{
if
(
node
==
nullptr
)
return
;
if
(
node
==
nullptr
)
return
;
...
@@ -481,6 +508,9 @@ class NodeAttrSetter : public AttrVisitor {
...
@@ -481,6 +508,9 @@ class NodeAttrSetter : public AttrVisitor {
void
Visit
(
const
char
*
key
,
runtime
::
NDArray
*
value
)
final
{
void
Visit
(
const
char
*
key
,
runtime
::
NDArray
*
value
)
final
{
*
value
=
GetAttr
(
key
).
operator
runtime
::
NDArray
();
*
value
=
GetAttr
(
key
).
operator
runtime
::
NDArray
();
}
}
void
Visit
(
const
char
*
key
,
Object
*
value
)
final
{
*
value
=
GetAttr
(
key
).
operator
Object
();
}
private
:
private
:
runtime
::
TVMArgValue
GetAttr
(
const
char
*
key
)
{
runtime
::
TVMArgValue
GetAttr
(
const
char
*
key
)
{
...
...
src/relay/ir/pretty_printer.cc
View file @
dd55682d
...
@@ -775,6 +775,9 @@ class PrettyPrinter::AttrPrinter : public AttrVisitor {
...
@@ -775,6 +775,9 @@ class PrettyPrinter::AttrPrinter : public AttrVisitor {
void
Visit
(
const
char
*
key
,
runtime
::
NDArray
*
value
)
final
{
void
Visit
(
const
char
*
key
,
runtime
::
NDArray
*
value
)
final
{
LOG
(
FATAL
)
<<
"do not allow NDarray as argument"
;
LOG
(
FATAL
)
<<
"do not allow NDarray as argument"
;
}
}
void
Visit
(
const
char
*
key
,
runtime
::
Object
*
obj
)
final
{
LOG
(
FATAL
)
<<
"do not allow Object as argument"
;
}
private
:
private
:
Doc
&
doc_
;
Doc
&
doc_
;
...
...
src/runtime/vm/object.cc
0 → 100644
View file @
dd55682d
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file object.cc
* \brief A managed object in the TVM runtime.
*/
#include <tvm/logging.h>
#include <tvm/runtime/object.h>
#include <iostream>
namespace
tvm
{
namespace
runtime
{
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
ObjectTag
&
tag
)
{
switch
(
tag
)
{
case
ObjectTag
:
:
kClosure
:
os
<<
"Closure"
;
break
;
case
ObjectTag
:
:
kDatatype
:
os
<<
"Datatype"
;
break
;
case
ObjectTag
:
:
kTensor
:
os
<<
"Tensor"
;
break
;
case
ObjectTag
:
:
kExternalFunc
:
os
<<
"ExternalFunction"
;
break
;
default
:
LOG
(
FATAL
)
<<
"Invalid object tag: found "
<<
static_cast
<
int
>
(
tag
);
}
return
os
;
}
Object
Object
::
Tensor
(
const
NDArray
&
data
)
{
ObjectPtr
<
ObjectCell
>
ptr
=
MakeObject
<
TensorCell
>
(
data
);
return
Object
(
ptr
);
}
Object
Object
::
Datatype
(
size_t
tag
,
const
std
::
vector
<
Object
>&
fields
)
{
ObjectPtr
<
ObjectCell
>
ptr
=
MakeObject
<
DatatypeCell
>
(
tag
,
fields
);
return
Object
(
ptr
);
}
Object
Object
::
Tuple
(
const
std
::
vector
<
Object
>&
fields
)
{
return
Object
::
Datatype
(
0
,
fields
);
}
Object
Object
::
Closure
(
size_t
func_index
,
const
std
::
vector
<
Object
>&
free_vars
)
{
ObjectPtr
<
ObjectCell
>
ptr
=
MakeObject
<
ClosureCell
>
(
func_index
,
free_vars
);
return
Object
(
ptr
);
}
ObjectPtr
<
TensorCell
>
Object
::
AsTensor
()
const
{
CHECK
(
ptr
.
get
());
CHECK
(
ptr
.
get
()
->
tag
==
ObjectTag
::
kTensor
);
return
ptr
.
As
<
TensorCell
>
();
}
ObjectPtr
<
DatatypeCell
>
Object
::
AsDatatype
()
const
{
CHECK
(
ptr
.
get
());
CHECK
(
ptr
.
get
()
->
tag
==
ObjectTag
::
kDatatype
);
return
ptr
.
As
<
DatatypeCell
>
();
}
ObjectPtr
<
ClosureCell
>
Object
::
AsClosure
()
const
{
CHECK
(
ptr
.
get
());
CHECK
(
ptr
.
get
()
->
tag
==
ObjectTag
::
kClosure
);
return
ptr
.
As
<
ClosureCell
>
();
}
NDArray
ToNDArray
(
const
Object
&
obj
)
{
auto
tensor
=
obj
.
AsTensor
();
return
tensor
->
data
;
}
}
// namespace runtime
}
// namespace tvm
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment