Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
3a48b323
Commit
3a48b323
authored
8 years ago
by
tqchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Enable bracket syntax sugar to get tensor element
parent
5445a936
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
108 additions
and
23 deletions
+108
-23
include/tvm/tensor.h
+77
-0
python/tvm/_ctypes/_api.py
+6
-2
python/tvm/collections.py
+1
-1
python/tvm/expr.py
+2
-2
python/tvm/function.py
+1
-12
python/tvm/tensor.py
+14
-2
tests/cpp/tensor_test.cc
+5
-2
tests/python/test_tensor.py
+2
-2
No files found.
include/tvm/tensor.h
View file @
3a48b323
...
...
@@ -65,6 +65,46 @@ class Tensor : public FunctionRef {
* \return the result expression representing tensor read.
*/
Expr
operator
()(
Array
<
Expr
>
indices
)
const
;
/*!
* \brief data structure to represent a slice that fixes first k coordinates.
* This is used to enable syntax sugar of Tensor[x][y][z] to get the element.
*/
class
Slice
{
public
:
// construct via tensor and indices
Slice
(
const
Tensor
&
tensor
,
std
::
vector
<
Expr
>
indices
)
:
tensor_
(
tensor
),
indices_
(
indices
)
{}
/*!
* \brief get i-th slice from the current slice.
* \param i the index of the coordinate
* \return the subsequent slice.
*/
inline
Slice
operator
[](
Expr
i
)
{
std
::
vector
<
Expr
>
other
=
indices_
;
other
.
emplace_back
(
i
);
return
Slice
(
tensor_
,
other
);
}
/*!
* \brief Convert slice to expression.
* This is only valid when all the coordinates are fully specified.
* \return the corresponding expression of this slice.
*/
inline
operator
Expr
()
const
{
return
tensor_
(
indices_
);
}
private
:
const
Tensor
&
tensor_
;
std
::
vector
<
Expr
>
indices_
;
};
/*!
* \brief get i-th slice from the current Tensor.
* \param i the index of the coordinate
* \return the subsequent slice.
*/
inline
Slice
operator
[](
Expr
i
)
const
{
return
Slice
(
*
this
,
{
i
});
}
/*! \brief specify container node */
using
ContainerType
=
TensorNode
;
};
...
...
@@ -163,5 +203,42 @@ inline size_t Tensor::ndim() const {
return
(
*
this
)
->
shape
.
size
();
}
// macro to turn every operation of slice to expression
#define DEFINE_OVERLOAD_SLICE_UNARY_OP(Op) \
inline Expr operator Op (const Tensor::Slice& a) { \
return Op a.operator Expr() ; \
}
#define DEFINE_OVERLOAD_SLICE_BINARY_OP(Op) \
template<typename T> \
inline Expr operator Op (const Tensor::Slice& a, const T& b) { \
return a.operator Expr() Op b; \
} \
template<typename T> \
inline Expr operator Op (const T& a, const Tensor::Slice& b) { \
return a Op b.operator Expr(); \
} \
inline Expr operator Op (const Tensor::Slice& a, const Tensor::Slice& b) { \
return a.operator Expr() Op b.operator Expr(); \
}
DEFINE_OVERLOAD_SLICE_UNARY_OP
(
!
);
DEFINE_OVERLOAD_SLICE_UNARY_OP
(
-
);
DEFINE_OVERLOAD_SLICE_BINARY_OP
(
+
);
DEFINE_OVERLOAD_SLICE_BINARY_OP
(
-
);
DEFINE_OVERLOAD_SLICE_BINARY_OP
(
*
);
DEFINE_OVERLOAD_SLICE_BINARY_OP
(
/
);
DEFINE_OVERLOAD_SLICE_BINARY_OP
(
%
);
DEFINE_OVERLOAD_SLICE_BINARY_OP
(
==
);
DEFINE_OVERLOAD_SLICE_BINARY_OP
(
<=
);
DEFINE_OVERLOAD_SLICE_BINARY_OP
(
>=
);
DEFINE_OVERLOAD_SLICE_BINARY_OP
(
!=
);
DEFINE_OVERLOAD_SLICE_BINARY_OP
(
&&
);
DEFINE_OVERLOAD_SLICE_BINARY_OP
(
||
);
DEFINE_OVERLOAD_SLICE_BINARY_OP
(
>>
);
DEFINE_OVERLOAD_SLICE_BINARY_OP
(
<<
);
DEFINE_OVERLOAD_SLICE_BINARY_OP
(
>
);
// NOLINT(*)
DEFINE_OVERLOAD_SLICE_BINARY_OP
(
<
);
// NOLINT(*)
}
// namespace tvm
#endif // TVM_TENSOR_H_
This diff is collapsed.
Click to expand it.
python/tvm/_ctypes/_api.py
View file @
3a48b323
...
...
@@ -13,7 +13,6 @@ from .._base import FunctionHandle, NodeHandle
from
.._base
import
check_call
,
ctypes2docstring
from
..
import
_function_internal
class
ArgVariant
(
ctypes
.
Union
):
_fields_
=
[(
"v_long"
,
ctypes
.
c_long
),
(
"v_double"
,
ctypes
.
c_double
),
...
...
@@ -46,6 +45,9 @@ RET_SWITCH = {
kNodeHandle
:
lambda
x
:
NODE_TYPE
.
get
(
_type_key
(
x
),
NodeBase
)(
x
.
v_handle
)
}
class
SliceBase
(
object
):
"""base class of slice object"""
pass
class
NodeBase
(
object
):
"""Symbol is symbolic graph."""
...
...
@@ -113,6 +115,8 @@ def convert(value):
elif
isinstance
(
value
,
(
list
,
tuple
)):
value
=
[
convert
(
x
)
for
x
in
value
]
return
_function_internal
.
_Array
(
*
value
)
elif
isinstance
(
value
,
SliceBase
):
return
value
.
tensor
(
*
value
.
indices
)
else
:
if
not
isinstance
(
value
,
NodeBase
):
raise
ValueError
(
"don't know how to handle type
%
s"
%
type
(
value
))
...
...
@@ -176,7 +180,7 @@ def _make_function(handle, name):
"""TVM function"""
cargs
=
[]
for
x
in
args
:
if
isinstance
(
x
,
(
list
,
tuple
)):
if
isinstance
(
x
,
(
list
,
tuple
,
SliceBase
)):
cargs
.
append
(
convert
(
x
))
else
:
cargs
.
append
(
x
)
...
...
This diff is collapsed.
Click to expand it.
python/tvm/collections.py
View file @
3a48b323
...
...
@@ -24,5 +24,5 @@ class Range(NodeBase):
@register_node
class
IterVar
(
_expr
.
ExprCompatible
):
class
IterVar
(
NodeBase
,
_expr
.
ExprOp
):
pass
This diff is collapsed.
Click to expand it.
python/tvm/expr.py
View file @
3a48b323
...
...
@@ -2,7 +2,7 @@ from __future__ import absolute_import as _abs
from
._ctypes._api
import
NodeBase
,
register_node
from
.
import
make
as
_make
class
Expr
Compatible
(
NodeBase
):
class
Expr
Op
(
object
):
def
__add__
(
self
,
other
):
return
_make
.
Add
(
self
,
other
)
...
...
@@ -37,7 +37,7 @@ class ExprCompatible(NodeBase):
return
self
.
__mul__
(
-
1
)
class
Expr
(
ExprCompatible
):
class
Expr
(
NodeBase
,
ExprOp
):
pass
class
ConstExpr
(
Expr
):
...
...
This diff is collapsed.
Click to expand it.
python/tvm/function.py
View file @
3a48b323
from
__future__
import
absolute_import
as
_abs
from
numbers
import
Number
as
_Number
,
Integral
as
_Integral
from
._ctypes._api
import
_init_function_module
from
._ctypes._api
import
_init_function_module
,
convert
from
.
import
_function_internal
from
.
import
make
as
_make
from
.
import
expr
as
_expr
...
...
@@ -33,17 +33,6 @@ def Var(name="tindex", dtype=int32):
return
_function_internal
.
_Var
(
name
,
dtype
)
def
convert
(
value
):
"""Convert a value to expression."""
if
isinstance
(
value
,
_Number
):
return
const
(
value
)
elif
isinstance
(
value
,
(
list
,
tuple
)):
value
=
[
convert
(
x
)
for
x
in
value
]
return
_function_internal
.
_Array
(
*
value
)
else
:
return
value
def
placeholder
(
shape
,
dtype
=
None
,
name
=
"TensorObj"
):
"""Construct an empty tensor object.
...
...
This diff is collapsed.
Click to expand it.
python/tvm/tensor.py
View file @
3a48b323
from
__future__
import
absolute_import
as
_abs
from
._ctypes._api
import
NodeBase
,
register_node
,
convert
from
._ctypes._api
import
NodeBase
,
SliceBase
,
register_node
,
convert
from
.
import
collections
as
_collections
from
.
import
make
as
_make
from
.
import
expr
as
_expr
class
TensorSlice
(
SliceBase
,
_expr
.
ExprOp
):
"""Auxiliary data structure for enable slicing syntax from tensor."""
def
__init__
(
self
,
tensor
,
indices
):
self
.
tensor
=
tensor
self
.
indices
=
indices
def
__getitem__
(
self
,
indices
):
return
TensorSlice
(
self
.
tensor
,
self
.
indices
+
indices
)
@register_node
class
Tensor
(
NodeBase
):
"""Tensor object, to construct, see function.Tensor"""
...
...
@@ -13,7 +23,6 @@ class Tensor(NodeBase):
raise
ValueError
(
"Need to provide
%
d index in tensor slice"
%
ndim
)
indices
=
convert
(
indices
)
args
=
[]
for
x
in
indices
:
if
isinstance
(
x
,
_collections
.
IterVar
):
args
.
append
(
x
.
var
)
...
...
@@ -24,6 +33,9 @@ class Tensor(NodeBase):
return
_make
.
Call
(
self
.
dtype
,
self
.
name
,
args
,
_expr
.
Call
.
Halide
,
self
,
0
)
def
__getitem__
(
self
,
indices
):
return
TensorSlice
(
self
,
indices
)
@property
def
ndim
(
self
):
return
len
(
self
.
shape
)
This diff is collapsed.
Click to expand it.
tests/cpp/tensor_test.cc
View file @
3a48b323
...
...
@@ -9,8 +9,11 @@ TEST(Tensor, Basic) {
Tensor
B
({
n
,
l
},
"B"
);
auto
C
=
Compute
({
m
,
n
},
[
&
](
Var
i
,
Var
j
)
{
return
A
(
i
,
j
)
*
B
(
j
,
i
)
;
return
A
[
i
][
j
]
;
},
"C"
);
Tensor
::
Slice
x
=
A
[
n
];
LOG
(
INFO
)
<<
C
->
op
.
as
<
ComputeOpNode
>
()
->
body
;
}
TEST
(
Tensor
,
Reduce
)
{
...
...
@@ -21,7 +24,7 @@ TEST(Tensor, Reduce) {
IterVar
rv
(
Range
{
0
,
l
},
"k"
);
auto
C
=
Compute
({
m
,
n
},
[
&
](
Var
i
,
Var
j
)
{
return
sum
(
max
(
A
(
i
,
rv
)
*
B
(
j
,
rv
),
1
),
{
rv
});
return
sum
(
max
(
1
+
A
[
i
][
rv
]
+
1
,
B
[
j
][
rv
]
),
{
rv
});
},
"C"
);
LOG
(
INFO
)
<<
C
->
op
.
as
<
ComputeOpNode
>
()
->
body
;
}
...
...
This diff is collapsed.
Click to expand it.
tests/python/test_tensor.py
View file @
3a48b323
...
...
@@ -6,7 +6,7 @@ def test_tensor():
l
=
tvm
.
Var
(
'l'
)
A
=
tvm
.
placeholder
((
m
,
l
),
name
=
'A'
)
B
=
tvm
.
placeholder
((
n
,
l
),
name
=
'B'
)
T
=
tvm
.
compute
((
m
,
n
,
l
),
lambda
i
,
j
,
k
:
A
(
i
,
k
)
*
B
(
j
,
k
)
)
T
=
tvm
.
compute
((
m
,
n
,
l
),
lambda
i
,
j
,
k
:
A
[
i
,
k
]
*
B
[
j
,
k
]
)
print
(
T
)
print
(
T
.
op
.
body
)
assert
(
tuple
(
T
.
shape
)
==
(
m
,
n
,
l
))
...
...
@@ -18,7 +18,7 @@ def test_tensor_reduce():
l
=
tvm
.
Var
(
'l'
)
A
=
tvm
.
placeholder
((
m
,
l
),
name
=
'A'
)
B
=
tvm
.
placeholder
((
n
,
l
),
name
=
'B'
)
T
=
tvm
.
compute
((
m
,
n
,
l
),
lambda
i
,
j
,
k
:
A
(
i
,
k
)
*
B
(
j
,
k
)
)
T
=
tvm
.
compute
((
m
,
n
,
l
),
lambda
i
,
j
,
k
:
A
[
i
,
k
]
*
B
[
j
,
k
]
)
rv
=
tvm
.
IterVar
((
0
,
A
.
shape
[
1
]),
name
=
"k"
)
C
=
tvm
.
compute
((
m
,
n
),
lambda
i
,
j
:
tvm
.
sum
(
T
(
i
,
j
,
rv
+
1
),
rdom
=
rv
))
...
...
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment