Commit 1a18f08e by tqchen

Fold RTensor into tensor

parent dcddd208
...@@ -25,6 +25,7 @@ class Tensor(object): ...@@ -25,6 +25,7 @@ class Tensor(object):
self.name = name if name else "TensorObj" self.name = name if name else "TensorObj"
self.inputs = None self.inputs = None
self.rdom = None
def __call__(self, *indices): def __call__(self, *indices):
if len(indices) != self.ndim: if len(indices) != self.ndim:
...@@ -49,7 +50,7 @@ class Tensor(object): ...@@ -49,7 +50,7 @@ class Tensor(object):
self.inputs = set(inputs) self.inputs = set(inputs)
return self.inputs return self.inputs
def infer_input_domains(self, out_domain, inputs): def infer_input_domains(self, out_domain, inputs, red_domain=None):
"""Infer the input domains of each domain in given inputs list. """Infer the input domains of each domain in given inputs list.
Parameters Parameters
...@@ -57,6 +58,12 @@ class Tensor(object): ...@@ -57,6 +58,12 @@ class Tensor(object):
out_domain : list of Range out_domain : list of Range
Domain of each dimension. Domain of each dimension.
red_domain : list of Range
Domain of reduction variables, if this tensor
this can only be specified if
self.expr finishes with an ReduceExpr, and we can schedule
over the last reduction that creates this tensor.
Returns Returns
------- -------
in_domains: dict Tensor->Domain in_domains: dict Tensor->Domain
...@@ -66,6 +73,17 @@ class Tensor(object): ...@@ -66,6 +73,17 @@ class Tensor(object):
index_domains = { index_domains = {
self.dim_index[i] : out_domain[i] for i in range(len(out_domain)) self.dim_index[i] : out_domain[i] for i in range(len(out_domain))
} }
begin_expr = self.expr
if red_domain:
if not isinstance(self.expr, _expr.ReduceExpr):
raise ValueError("red_domain must work with tensor that stores a reduction")
rdom = self.expr.rdom
begin_expr = self.expr.src
assert len(red_domain) == len(rdom.index)
for i in range(len(red_domain)):
index_domains[rdom.index[i]] = red_domain[i]
iset = {} iset = {}
for t in inputs: for t in inputs:
assert t in self.input_tensors() assert t in self.input_tensors()
...@@ -79,7 +97,7 @@ class Tensor(object): ...@@ -79,7 +97,7 @@ class Tensor(object):
elif isinstance(e, _expr.TensorReadExpr): elif isinstance(e, _expr.TensorReadExpr):
if e.tensor in iset: if e.tensor in iset:
iset[e.tensor].append(e) iset[e.tensor].append(e)
_expr_util.visit(self.expr, prepare) _expr_util.visit(begin_expr, prepare)
result = {} result = {}
for k, v in iset.items(): for k, v in iset.items():
dm = [None] * len(v[0].indices) dm = [None] * len(v[0].indices)
...@@ -89,3 +107,13 @@ class Tensor(object): ...@@ -89,3 +107,13 @@ class Tensor(object):
dm[i], _dom.infer_range(idx, index_domains, allow_unbind_var=False)) dm[i], _dom.infer_range(idx, index_domains, allow_unbind_var=False))
result[k] = dm result[k] = dm
return result return result
@property
def is_rtensor(self):
"""Whether this tensor is a result of reduction.
Returns
-------
is_rtensor : Whether the tensor is RTensor
"""
return self.expr and isinstance(self.expr, _expr.ReduceExpr)
...@@ -11,14 +11,16 @@ def test_range_infer(): ...@@ -11,14 +11,16 @@ def test_range_infer():
def test_tensor_dom_infer(): def test_tensor_dom_infer():
A = tvm.Tensor(2, name='A') A = tvm.Tensor(2, name='A')
B = tvm.Tensor(2, name='B') B = tvm.Tensor(2, name='B')
T = tvm.Tensor(3, lambda i, j, k: A(i, k) * B(j, k),
shape=(A.shape[0], B.shape[0], A.shape[1]))
rd = tvm.RDom(tvm.Range(A.shape[1])) rd = tvm.RDom(tvm.Range(A.shape[1]))
C = tvm.Tensor(2, lambda i, j: tvm.reduce_sum(T(i, j, rd.index[0]), rdom=rd), T = tvm.Tensor(2, lambda i, j:
tvm.reduce_sum(A(i, rd.index[0]) * B(j, rd.index[0]), rdom=rd),
shape=(A.shape[0], B.shape[0]))
C = tvm.Tensor(2, lambda i, j: T(i,j),
shape=(A.shape[0], B.shape[0])) shape=(A.shape[0], B.shape[0]))
cdom = [tvm.Range(0, 10), tvm.Range(1, 11)] cdom = [tvm.Range(0, 10), tvm.Range(1, 11)]
tdom = C.infer_input_domains(cdom, inputs=[T])[T] tdom = C.infer_input_domains(cdom, inputs=[T])[T]
assert T.is_rtensor
assert str(tdom[0]) == "(0, 10)" assert str(tdom[0]) == "(0, 10)"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment