Commit bda95817 by tqchen

checkin tensor

parent fc4ba796
...@@ -4,3 +4,4 @@ from __future__ import absolute_import as _abs ...@@ -4,3 +4,4 @@ from __future__ import absolute_import as _abs
from .op import * from .op import *
from .expr import Var, const from .expr import Var, const
from .expr_util import * from .expr_util import *
from .tensor import Tensor
...@@ -79,7 +79,7 @@ class Var(Expr): ...@@ -79,7 +79,7 @@ class Var(Expr):
optional name to the var. optional name to the var.
""" """
def __init__(self, name=None): def __init__(self, name=None):
if name is None: name = 'i' if name is None: name = 'index'
self.name = _name.NameManager.current.get(name) self.name = _name.NameManager.current.get(name)
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from . import expr as _expr from . import expr as _expr
from . import op as _op from . import op as _op
from . import tensor as _tensor
def expr_with_new_children(e, children): def expr_with_new_children(e, children):
"""Returns same expr as e but with new children """Returns same expr as e but with new children
...@@ -75,6 +76,8 @@ def format_str(expr): ...@@ -75,6 +76,8 @@ def format_str(expr):
return str(e.value) return str(e.value)
elif isinstance(e, _expr.Var): elif isinstance(e, _expr.Var):
return e.name return e.name
elif isinstance(e, _tensor.TensorReadExpr):
return "%s(%s)" % (e.tensor.name, ','.join(result_children))
else: else:
raise TypeError("Do not know how to handle type " + str(type(e))) raise TypeError("Do not know how to handle type " + str(type(e)))
return transform(expr, make_str) return transform(expr, make_str)
......
from __future__ import absolute_import as _abs
from . import expr as _expr
class TensorReadExpr(_expr.Expr):
def __init__(self, tensor, indices):
self.tensor = tensor
self.indices = indices
def children(self):
return self.indices
class Tensor(object):
def __init__(self, ndim, fcompute=None, name=None):
self.ndim = ndim
if fcompute:
arg_names = fcompute.func_code.co_varnames
assert(len(arg_names) == ndim)
self.dim_index = [_expr.Var(n) for n in arg_names]
self.expr = fcompute(*self.dim_index)
else:
self.expr = None
self.dim_index = None
shape_name = '_shape'
if name: shape_name = name + shape_name
self.shape = tuple(_expr.Var("%s_%d_" % (shape_name, i)) for i in range(ndim))
self.name = name if name else "TensorObj"
def __call__(self, *indices):
if len(indices) != self.ndim:
raise ValueError("Need to provide %d index in tensor slice" % self.ndim)
return TensorReadExpr(self, indices)
import tvm import tvm
from tvm import expr
def test_bind(): def test_bind():
x = tvm.Var('x') x = tvm.Var('x')
......
import tvm
def test_tensor():
A = tvm.Tensor(2, name='A')
B = tvm.Tensor(2, name='B')
T = tvm.Tensor(3, lambda i, j, k: A(i, k) * B(j, k))
print(tvm.format_str(T.expr))
if __name__ == "__main__":
test_tensor()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment