Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
9595a9c1
Commit
9595a9c1
authored
Oct 19, 2016
by
tqchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Expose array to python
parent
de2be97e
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
82 additions
and
1 deletions
+82
-1
python/tvm/cpp/__init__.py
+1
-0
python/tvm/cpp/domain.py
+11
-0
python/tvm/cpp/expr.py
+21
-0
python/tvm/cpp/function.py
+3
-0
src/c_api/c_api_function.cc
+35
-0
src/c_api/c_api_registry.h
+1
-1
tests/python/test_cpp.py
+10
-0
No files found.
python/tvm/cpp/__init__.py
View file @
9595a9c1
...
...
@@ -4,3 +4,4 @@ from __future__ import absolute_import as _abs
from
.function
import
*
from
._ctypes._api
import
register_node
from
.
import
expr
from
.
import
domain
python/tvm/cpp/domain.py
View file @
9595a9c1
from
__future__
import
absolute_import
as
_abs
from
._ctypes._api
import
NodeBase
,
register_node
from
.
import
_function_internal
@register_node
(
"RangeNode"
)
class
Range
(
NodeBase
):
pass
@register_node
(
"ArrayNode"
)
class
Array
(
NodeBase
):
def
__getitem__
(
self
,
i
):
return
_function_internal
.
_ArrayGetItem
(
self
,
i
)
def
__len__
(
self
):
return
_function_internal
.
_ArraySize
(
self
)
python/tvm/cpp/expr.py
View file @
9595a9c1
from
__future__
import
absolute_import
as
_abs
from
._ctypes._api
import
NodeBase
,
register_node
from
.function
import
binary_op
...
...
@@ -40,6 +41,26 @@ class Expr(NodeBase):
class
Var
(
Expr
):
pass
@register_node
(
"IntNode"
)
class
IntExpr
(
Expr
):
pass
@register_node
(
"FloatNode"
)
class
FloatExpr
(
Expr
):
pass
@register_node
(
"UnaryOpNode"
)
class
UnaryOpExpr
(
Expr
):
pass
@register_node
(
"BinaryOpNode"
)
class
BinaryOpExpr
(
Expr
):
pass
@register_node
(
"ReduceNode"
)
class
ReduceExpr
(
Expr
):
pass
@register_node
(
"TensorReadNode"
)
class
TensorReadExpr
(
Expr
):
pass
python/tvm/cpp/function.py
View file @
9595a9c1
...
...
@@ -24,6 +24,9 @@ def _symbol(value):
"""Convert a value to expression."""
if
isinstance
(
value
,
_Number
):
return
constant
(
value
)
elif
isinstance
(
value
,
list
):
value
=
[
_symbol
(
x
)
for
x
in
value
]
return
_function_internal
.
_Array
(
*
value
)
else
:
return
value
...
...
src/c_api/c_api_function.cc
View file @
9595a9c1
...
...
@@ -61,6 +61,41 @@ TVM_REGISTER_API(Range)
.
add_argument
(
"begin"
,
"Expr"
,
"beginning of the range."
)
.
add_argument
(
"end"
,
"Expr"
,
"end of the range"
);
TVM_REGISTER_API
(
_Array
)
.
set_body
([](
const
ArgStack
&
args
,
RetValue
*
ret
)
{
std
::
vector
<
std
::
shared_ptr
<
Node
>
>
data
;
for
(
size_t
i
=
0
;
i
<
args
.
size
();
++
i
)
{
CHECK
(
args
.
at
(
i
).
type_id
==
kNodeHandle
);
data
.
push_back
(
args
.
at
(
i
).
sptr
);
}
auto
node
=
std
::
make_shared
<
ArrayNode
>
();
node
->
data
=
std
::
move
(
data
);
ret
->
type_id
=
kNodeHandle
;
ret
->
sptr
=
node
;
});
TVM_REGISTER_API
(
_ArrayGetItem
)
.
set_body
([](
const
ArgStack
&
args
,
RetValue
*
ret
)
{
CHECK
(
args
.
at
(
0
).
type_id
==
kNodeHandle
);
int64_t
i
=
args
.
at
(
1
);
auto
&
sptr
=
args
.
at
(
0
).
sptr
;
CHECK
(
sptr
->
is_type
<
ArrayNode
>
());
auto
*
n
=
static_cast
<
const
ArrayNode
*>
(
sptr
.
get
());
CHECK_LT
(
static_cast
<
size_t
>
(
i
),
n
->
data
.
size
())
<<
"out of bound of array"
;
ret
->
sptr
=
n
->
data
[
i
];
ret
->
type_id
=
kNodeHandle
;
});
TVM_REGISTER_API
(
_ArraySize
)
.
set_body
([](
const
ArgStack
&
args
,
RetValue
*
ret
)
{
CHECK
(
args
.
at
(
0
).
type_id
==
kNodeHandle
);
auto
&
sptr
=
args
.
at
(
0
).
sptr
;
CHECK
(
sptr
->
is_type
<
ArrayNode
>
());
*
ret
=
static_cast
<
int64_t
>
(
static_cast
<
const
ArrayNode
*>
(
sptr
.
get
())
->
data
.
size
());
});
TVM_REGISTER_API
(
_TensorInput
)
.
set_body
([](
const
ArgStack
&
args
,
RetValue
*
ret
)
{
*
ret
=
Tensor
(
...
...
src/c_api/c_api_registry.h
View file @
9595a9c1
...
...
@@ -57,7 +57,7 @@ struct APIVariantValue {
return
*
this
;
}
template
<
typename
T
,
typename
=
typename
std
::
enable_if
<
std
::
is_base_of
<
NodeRef
,
T
>::
value
>::
type
>
typename
=
typename
std
::
enable_if
<
std
::
is_base_of
<
NodeRef
,
T
>::
value
>::
type
>
inline
operator
T
()
const
{
if
(
type_id
==
kNull
)
return
T
();
CHECK_EQ
(
type_id
,
kNodeHandle
);
...
...
tests/python/test_cpp.py
View file @
9595a9c1
...
...
@@ -9,5 +9,15 @@ def test_basic():
assert
c
.
dtype
==
tvm
.
int32
assert
tvm
.
format_str
(
c
)
==
'(
%
s +
%
s)'
%
(
a
.
name
,
b
.
name
)
def
test_array
():
a
=
tvm
.
Var
(
'a'
)
x
=
tvm
.
function
.
_symbol
([
1
,
2
,
a
])
print
type
(
x
)
print
len
(
x
)
print
x
[
4
]
if
__name__
==
"__main__"
:
test_basic
()
test_array
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment