Commit 3a48b323 by tqchen

Enable bracket syntax sugar to get tensor element

parent 5445a936
......@@ -65,6 +65,46 @@ class Tensor : public FunctionRef {
* \return the result expression representing tensor read.
*/
Expr operator()(Array<Expr> indices) const;
/*!
* \brief data structure to represent a slice that fixes first k coordinates.
* This is used to enable syntax sugar of Tensor[x][y][z] to get the element.
*/
class Slice {
public:
// construct via tensor and indices
Slice(const Tensor& tensor, std::vector<Expr> indices)
: tensor_(tensor), indices_(indices) {}
/*!
* \brief get i-th slice from the current slice.
* \param i the index of the coordinate
* \return the subsequent slice.
*/
inline Slice operator[](Expr i) {
std::vector<Expr> other = indices_;
other.emplace_back(i);
return Slice(tensor_, other);
}
/*!
* \brief Convert slice to expression.
* This is only valid when all the coordinates are fully specified.
* \return the corresponding expression of this slice.
*/
inline operator Expr() const {
return tensor_(indices_);
}
private:
const Tensor& tensor_;
std::vector<Expr> indices_;
};
/*!
* \brief get i-th slice from the current Tensor.
* \param i the index of the coordinate
* \return the subsequent slice.
*/
inline Slice operator[](Expr i) const {
return Slice(*this, {i});
}
/*! \brief specify container node */
using ContainerType = TensorNode;
};
......@@ -163,5 +203,42 @@ inline size_t Tensor::ndim() const {
return (*this)->shape.size();
}
// macro to turn every operation of slice to expression
#define DEFINE_OVERLOAD_SLICE_UNARY_OP(Op) \
inline Expr operator Op (const Tensor::Slice& a) { \
return Op a.operator Expr() ; \
}
#define DEFINE_OVERLOAD_SLICE_BINARY_OP(Op) \
template<typename T> \
inline Expr operator Op (const Tensor::Slice& a, const T& b) { \
return a.operator Expr() Op b; \
} \
template<typename T> \
inline Expr operator Op (const T& a, const Tensor::Slice& b) { \
return a Op b.operator Expr(); \
} \
inline Expr operator Op (const Tensor::Slice& a, const Tensor::Slice& b) { \
return a.operator Expr() Op b.operator Expr(); \
}
DEFINE_OVERLOAD_SLICE_UNARY_OP(!);
DEFINE_OVERLOAD_SLICE_UNARY_OP(-);
DEFINE_OVERLOAD_SLICE_BINARY_OP(+);
DEFINE_OVERLOAD_SLICE_BINARY_OP(-);
DEFINE_OVERLOAD_SLICE_BINARY_OP(*);
DEFINE_OVERLOAD_SLICE_BINARY_OP(/);
DEFINE_OVERLOAD_SLICE_BINARY_OP(%);
DEFINE_OVERLOAD_SLICE_BINARY_OP(==);
DEFINE_OVERLOAD_SLICE_BINARY_OP(<=);
DEFINE_OVERLOAD_SLICE_BINARY_OP(>=);
DEFINE_OVERLOAD_SLICE_BINARY_OP(!=);
DEFINE_OVERLOAD_SLICE_BINARY_OP(&&);
DEFINE_OVERLOAD_SLICE_BINARY_OP(||);
DEFINE_OVERLOAD_SLICE_BINARY_OP(>>);
DEFINE_OVERLOAD_SLICE_BINARY_OP(<<);
DEFINE_OVERLOAD_SLICE_BINARY_OP(>); // NOLINT(*)
DEFINE_OVERLOAD_SLICE_BINARY_OP(<); // NOLINT(*)
} // namespace tvm
#endif // TVM_TENSOR_H_
......@@ -13,7 +13,6 @@ from .._base import FunctionHandle, NodeHandle
from .._base import check_call, ctypes2docstring
from .. import _function_internal
class ArgVariant(ctypes.Union):
_fields_ = [("v_long", ctypes.c_long),
("v_double", ctypes.c_double),
......@@ -46,6 +45,9 @@ RET_SWITCH = {
kNodeHandle: lambda x: NODE_TYPE.get(_type_key(x), NodeBase)(x.v_handle)
}
class SliceBase(object):
"""base class of slice object"""
pass
class NodeBase(object):
"""Symbol is symbolic graph."""
......@@ -113,6 +115,8 @@ def convert(value):
elif isinstance(value, (list, tuple)):
value = [convert(x) for x in value]
return _function_internal._Array(*value)
elif isinstance(value, SliceBase):
return value.tensor(*value.indices)
else:
if not isinstance(value, NodeBase):
raise ValueError("don't know how to handle type %s" % type(value))
......@@ -176,7 +180,7 @@ def _make_function(handle, name):
"""TVM function"""
cargs = []
for x in args:
if isinstance(x, (list, tuple)):
if isinstance(x, (list, tuple, SliceBase)):
cargs.append(convert(x))
else:
cargs.append(x)
......
......@@ -24,5 +24,5 @@ class Range(NodeBase):
@register_node
class IterVar(_expr.ExprCompatible):
class IterVar(NodeBase, _expr.ExprOp):
pass
......@@ -2,7 +2,7 @@ from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, register_node
from . import make as _make
class ExprCompatible(NodeBase):
class ExprOp(object):
def __add__(self, other):
return _make.Add(self, other)
......@@ -37,7 +37,7 @@ class ExprCompatible(NodeBase):
return self.__mul__(-1)
class Expr(ExprCompatible):
class Expr(NodeBase, ExprOp):
pass
class ConstExpr(Expr):
......
from __future__ import absolute_import as _abs
from numbers import Number as _Number, Integral as _Integral
from ._ctypes._api import _init_function_module
from ._ctypes._api import _init_function_module, convert
from . import _function_internal
from . import make as _make
from . import expr as _expr
......@@ -33,17 +33,6 @@ def Var(name="tindex", dtype=int32):
return _function_internal._Var(name, dtype)
def convert(value):
"""Convert a value to expression."""
if isinstance(value, _Number):
return const(value)
elif isinstance(value, (list, tuple)):
value = [convert(x) for x in value]
return _function_internal._Array(*value)
else:
return value
def placeholder(shape, dtype = None, name="TensorObj"):
"""Construct an empty tensor object.
......
from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, register_node, convert
from ._ctypes._api import NodeBase, SliceBase, register_node, convert
from . import collections as _collections
from . import make as _make
from . import expr as _expr
class TensorSlice(SliceBase, _expr.ExprOp):
"""Auxiliary data structure for enable slicing syntax from tensor."""
def __init__(self, tensor, indices):
self.tensor = tensor
self.indices = indices
def __getitem__(self, indices):
return TensorSlice(self.tensor, self.indices + indices)
@register_node
class Tensor(NodeBase):
"""Tensor object, to construct, see function.Tensor"""
......@@ -13,7 +23,6 @@ class Tensor(NodeBase):
raise ValueError("Need to provide %d index in tensor slice" % ndim)
indices = convert(indices)
args = []
for x in indices:
if isinstance(x, _collections.IterVar):
args.append(x.var)
......@@ -24,6 +33,9 @@ class Tensor(NodeBase):
return _make.Call(self.dtype, self.name, args, _expr.Call.Halide, self, 0)
def __getitem__(self, indices):
return TensorSlice(self, indices)
@property
def ndim(self):
return len(self.shape)
......@@ -9,8 +9,11 @@ TEST(Tensor, Basic) {
Tensor B({n, l}, "B");
auto C = Compute({m, n}, [&](Var i, Var j) {
return A(i, j) * B(j, i);
return A[i][j];
}, "C");
Tensor::Slice x = A[n];
LOG(INFO) << C->op.as<ComputeOpNode>()->body;
}
TEST(Tensor, Reduce) {
......@@ -21,7 +24,7 @@ TEST(Tensor, Reduce) {
IterVar rv(Range{0, l}, "k");
auto C = Compute({m, n}, [&](Var i, Var j) {
return sum(max(A(i, rv) * B(j, rv), 1), {rv});
return sum(max(1 + A[i][rv] + 1, B[j][rv]), {rv});
}, "C");
LOG(INFO) << C->op.as<ComputeOpNode>()->body;
}
......
......@@ -6,7 +6,7 @@ def test_tensor():
l = tvm.Var('l')
A = tvm.placeholder((m, l), name='A')
B = tvm.placeholder((n, l), name='B')
T = tvm.compute((m, n, l), lambda i, j, k: A(i, k) * B(j, k))
T = tvm.compute((m, n, l), lambda i, j, k: A[i, k] * B[j, k])
print(T)
print(T.op.body)
assert(tuple(T.shape) == (m, n, l))
......@@ -18,7 +18,7 @@ def test_tensor_reduce():
l = tvm.Var('l')
A = tvm.placeholder((m, l), name='A')
B = tvm.placeholder((n, l), name='B')
T = tvm.compute((m, n, l), lambda i, j, k: A(i, k) * B(j, k))
T = tvm.compute((m, n, l), lambda i, j, k: A[i, k] * B[j, k])
rv = tvm.IterVar((0, A.shape[1]), name="k")
C = tvm.compute((m, n), lambda i, j: tvm.sum(T(i, j, rv+1), rdom=rv))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment