Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
3a48b323
Commit
3a48b323
authored
Dec 01, 2016
by
tqchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Enable bracket syntax sugar to get tensor element
parent
5445a936
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
108 additions
and
23 deletions
+108
-23
include/tvm/tensor.h
+77
-0
python/tvm/_ctypes/_api.py
+6
-2
python/tvm/collections.py
+1
-1
python/tvm/expr.py
+2
-2
python/tvm/function.py
+1
-12
python/tvm/tensor.py
+14
-2
tests/cpp/tensor_test.cc
+5
-2
tests/python/test_tensor.py
+2
-2
No files found.
include/tvm/tensor.h
View file @
3a48b323
...
...
@@ -65,6 +65,46 @@ class Tensor : public FunctionRef {
* \return the result expression representing tensor read.
*/
Expr
operator
()(
Array
<
Expr
>
indices
)
const
;
/*!
* \brief data structure to represent a slice that fixes first k coordinates.
* This is used to enable syntax sugar of Tensor[x][y][z] to get the element.
*/
class
Slice
{
public
:
// construct via tensor and indices
Slice
(
const
Tensor
&
tensor
,
std
::
vector
<
Expr
>
indices
)
:
tensor_
(
tensor
),
indices_
(
indices
)
{}
/*!
* \brief get i-th slice from the current slice.
* \param i the index of the coordinate
* \return the subsequent slice.
*/
inline
Slice
operator
[](
Expr
i
)
{
std
::
vector
<
Expr
>
other
=
indices_
;
other
.
emplace_back
(
i
);
return
Slice
(
tensor_
,
other
);
}
/*!
* \brief Convert slice to expression.
* This is only valid when all the coordinates are fully specified.
* \return the corresponding expression of this slice.
*/
inline
operator
Expr
()
const
{
return
tensor_
(
indices_
);
}
private
:
const
Tensor
&
tensor_
;
std
::
vector
<
Expr
>
indices_
;
};
/*!
* \brief get i-th slice from the current Tensor.
* \param i the index of the coordinate
* \return the subsequent slice.
*/
inline
Slice
operator
[](
Expr
i
)
const
{
return
Slice
(
*
this
,
{
i
});
}
/*! \brief specify container node */
using
ContainerType
=
TensorNode
;
};
...
...
@@ -163,5 +203,42 @@ inline size_t Tensor::ndim() const {
return
(
*
this
)
->
shape
.
size
();
}
// macro to turn every operation of slice to expression
#define DEFINE_OVERLOAD_SLICE_UNARY_OP(Op) \
inline Expr operator Op (const Tensor::Slice& a) { \
return Op a.operator Expr() ; \
}
#define DEFINE_OVERLOAD_SLICE_BINARY_OP(Op) \
template<typename T> \
inline Expr operator Op (const Tensor::Slice& a, const T& b) { \
return a.operator Expr() Op b; \
} \
template<typename T> \
inline Expr operator Op (const T& a, const Tensor::Slice& b) { \
return a Op b.operator Expr(); \
} \
inline Expr operator Op (const Tensor::Slice& a, const Tensor::Slice& b) { \
return a.operator Expr() Op b.operator Expr(); \
}
DEFINE_OVERLOAD_SLICE_UNARY_OP
(
!
);
DEFINE_OVERLOAD_SLICE_UNARY_OP
(
-
);
DEFINE_OVERLOAD_SLICE_BINARY_OP
(
+
);
DEFINE_OVERLOAD_SLICE_BINARY_OP
(
-
);
DEFINE_OVERLOAD_SLICE_BINARY_OP
(
*
);
DEFINE_OVERLOAD_SLICE_BINARY_OP
(
/
);
DEFINE_OVERLOAD_SLICE_BINARY_OP
(
%
);
DEFINE_OVERLOAD_SLICE_BINARY_OP
(
==
);
DEFINE_OVERLOAD_SLICE_BINARY_OP
(
<=
);
DEFINE_OVERLOAD_SLICE_BINARY_OP
(
>=
);
DEFINE_OVERLOAD_SLICE_BINARY_OP
(
!=
);
DEFINE_OVERLOAD_SLICE_BINARY_OP
(
&&
);
DEFINE_OVERLOAD_SLICE_BINARY_OP
(
||
);
DEFINE_OVERLOAD_SLICE_BINARY_OP
(
>>
);
DEFINE_OVERLOAD_SLICE_BINARY_OP
(
<<
);
DEFINE_OVERLOAD_SLICE_BINARY_OP
(
>
);
// NOLINT(*)
DEFINE_OVERLOAD_SLICE_BINARY_OP
(
<
);
// NOLINT(*)
}
// namespace tvm
#endif // TVM_TENSOR_H_
python/tvm/_ctypes/_api.py
View file @
3a48b323
...
...
@@ -13,7 +13,6 @@ from .._base import FunctionHandle, NodeHandle
from
.._base
import
check_call
,
ctypes2docstring
from
..
import
_function_internal
class
ArgVariant
(
ctypes
.
Union
):
_fields_
=
[(
"v_long"
,
ctypes
.
c_long
),
(
"v_double"
,
ctypes
.
c_double
),
...
...
@@ -46,6 +45,9 @@ RET_SWITCH = {
kNodeHandle
:
lambda
x
:
NODE_TYPE
.
get
(
_type_key
(
x
),
NodeBase
)(
x
.
v_handle
)
}
class
SliceBase
(
object
):
"""base class of slice object"""
pass
class
NodeBase
(
object
):
"""Symbol is symbolic graph."""
...
...
@@ -113,6 +115,8 @@ def convert(value):
elif
isinstance
(
value
,
(
list
,
tuple
)):
value
=
[
convert
(
x
)
for
x
in
value
]
return
_function_internal
.
_Array
(
*
value
)
elif
isinstance
(
value
,
SliceBase
):
return
value
.
tensor
(
*
value
.
indices
)
else
:
if
not
isinstance
(
value
,
NodeBase
):
raise
ValueError
(
"don't know how to handle type
%
s"
%
type
(
value
))
...
...
@@ -176,7 +180,7 @@ def _make_function(handle, name):
"""TVM function"""
cargs
=
[]
for
x
in
args
:
if
isinstance
(
x
,
(
list
,
tuple
)):
if
isinstance
(
x
,
(
list
,
tuple
,
SliceBase
)):
cargs
.
append
(
convert
(
x
))
else
:
cargs
.
append
(
x
)
...
...
python/tvm/collections.py
View file @
3a48b323
...
...
@@ -24,5 +24,5 @@ class Range(NodeBase):
@register_node
class
IterVar
(
_expr
.
ExprCompatible
):
class
IterVar
(
NodeBase
,
_expr
.
ExprOp
):
pass
python/tvm/expr.py
View file @
3a48b323
...
...
@@ -2,7 +2,7 @@ from __future__ import absolute_import as _abs
from
._ctypes._api
import
NodeBase
,
register_node
from
.
import
make
as
_make
class
Expr
Compatible
(
NodeBase
):
class
Expr
Op
(
object
):
def
__add__
(
self
,
other
):
return
_make
.
Add
(
self
,
other
)
...
...
@@ -37,7 +37,7 @@ class ExprCompatible(NodeBase):
return
self
.
__mul__
(
-
1
)
class
Expr
(
ExprCompatible
):
class
Expr
(
NodeBase
,
ExprOp
):
pass
class
ConstExpr
(
Expr
):
...
...
python/tvm/function.py
View file @
3a48b323
from
__future__
import
absolute_import
as
_abs
from
numbers
import
Number
as
_Number
,
Integral
as
_Integral
from
._ctypes._api
import
_init_function_module
from
._ctypes._api
import
_init_function_module
,
convert
from
.
import
_function_internal
from
.
import
make
as
_make
from
.
import
expr
as
_expr
...
...
@@ -33,17 +33,6 @@ def Var(name="tindex", dtype=int32):
return
_function_internal
.
_Var
(
name
,
dtype
)
def
convert
(
value
):
"""Convert a value to expression."""
if
isinstance
(
value
,
_Number
):
return
const
(
value
)
elif
isinstance
(
value
,
(
list
,
tuple
)):
value
=
[
convert
(
x
)
for
x
in
value
]
return
_function_internal
.
_Array
(
*
value
)
else
:
return
value
def
placeholder
(
shape
,
dtype
=
None
,
name
=
"TensorObj"
):
"""Construct an empty tensor object.
...
...
python/tvm/tensor.py
View file @
3a48b323
from
__future__
import
absolute_import
as
_abs
from
._ctypes._api
import
NodeBase
,
register_node
,
convert
from
._ctypes._api
import
NodeBase
,
SliceBase
,
register_node
,
convert
from
.
import
collections
as
_collections
from
.
import
make
as
_make
from
.
import
expr
as
_expr
class
TensorSlice
(
SliceBase
,
_expr
.
ExprOp
):
"""Auxiliary data structure for enable slicing syntax from tensor."""
def
__init__
(
self
,
tensor
,
indices
):
self
.
tensor
=
tensor
self
.
indices
=
indices
def
__getitem__
(
self
,
indices
):
return
TensorSlice
(
self
.
tensor
,
self
.
indices
+
indices
)
@register_node
class
Tensor
(
NodeBase
):
"""Tensor object, to construct, see function.Tensor"""
...
...
@@ -13,7 +23,6 @@ class Tensor(NodeBase):
raise
ValueError
(
"Need to provide
%
d index in tensor slice"
%
ndim
)
indices
=
convert
(
indices
)
args
=
[]
for
x
in
indices
:
if
isinstance
(
x
,
_collections
.
IterVar
):
args
.
append
(
x
.
var
)
...
...
@@ -24,6 +33,9 @@ class Tensor(NodeBase):
return
_make
.
Call
(
self
.
dtype
,
self
.
name
,
args
,
_expr
.
Call
.
Halide
,
self
,
0
)
def
__getitem__
(
self
,
indices
):
return
TensorSlice
(
self
,
indices
)
@property
def
ndim
(
self
):
return
len
(
self
.
shape
)
tests/cpp/tensor_test.cc
View file @
3a48b323
...
...
@@ -9,8 +9,11 @@ TEST(Tensor, Basic) {
Tensor
B
({
n
,
l
},
"B"
);
auto
C
=
Compute
({
m
,
n
},
[
&
](
Var
i
,
Var
j
)
{
return
A
(
i
,
j
)
*
B
(
j
,
i
)
;
return
A
[
i
][
j
]
;
},
"C"
);
Tensor
::
Slice
x
=
A
[
n
];
LOG
(
INFO
)
<<
C
->
op
.
as
<
ComputeOpNode
>
()
->
body
;
}
TEST
(
Tensor
,
Reduce
)
{
...
...
@@ -21,7 +24,7 @@ TEST(Tensor, Reduce) {
IterVar
rv
(
Range
{
0
,
l
},
"k"
);
auto
C
=
Compute
({
m
,
n
},
[
&
](
Var
i
,
Var
j
)
{
return
sum
(
max
(
A
(
i
,
rv
)
*
B
(
j
,
rv
),
1
),
{
rv
});
return
sum
(
max
(
1
+
A
[
i
][
rv
]
+
1
,
B
[
j
][
rv
]
),
{
rv
});
},
"C"
);
LOG
(
INFO
)
<<
C
->
op
.
as
<
ComputeOpNode
>
()
->
body
;
}
...
...
tests/python/test_tensor.py
View file @
3a48b323
...
...
@@ -6,7 +6,7 @@ def test_tensor():
l
=
tvm
.
Var
(
'l'
)
A
=
tvm
.
placeholder
((
m
,
l
),
name
=
'A'
)
B
=
tvm
.
placeholder
((
n
,
l
),
name
=
'B'
)
T
=
tvm
.
compute
((
m
,
n
,
l
),
lambda
i
,
j
,
k
:
A
(
i
,
k
)
*
B
(
j
,
k
)
)
T
=
tvm
.
compute
((
m
,
n
,
l
),
lambda
i
,
j
,
k
:
A
[
i
,
k
]
*
B
[
j
,
k
]
)
print
(
T
)
print
(
T
.
op
.
body
)
assert
(
tuple
(
T
.
shape
)
==
(
m
,
n
,
l
))
...
...
@@ -18,7 +18,7 @@ def test_tensor_reduce():
l
=
tvm
.
Var
(
'l'
)
A
=
tvm
.
placeholder
((
m
,
l
),
name
=
'A'
)
B
=
tvm
.
placeholder
((
n
,
l
),
name
=
'B'
)
T
=
tvm
.
compute
((
m
,
n
,
l
),
lambda
i
,
j
,
k
:
A
(
i
,
k
)
*
B
(
j
,
k
)
)
T
=
tvm
.
compute
((
m
,
n
,
l
),
lambda
i
,
j
,
k
:
A
[
i
,
k
]
*
B
[
j
,
k
]
)
rv
=
tvm
.
IterVar
((
0
,
A
.
shape
[
1
]),
name
=
"k"
)
C
=
tvm
.
compute
((
m
,
n
),
lambda
i
,
j
:
tvm
.
sum
(
T
(
i
,
j
,
rv
+
1
),
rdom
=
rv
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment