Commit 1a18f08e by tqchen

Fold RTensor into tensor

parent dcddd208
......@@ -25,6 +25,7 @@ class Tensor(object):
self.name = name if name else "TensorObj"
self.inputs = None
self.rdom = None
def __call__(self, *indices):
if len(indices) != self.ndim:
......@@ -49,7 +50,7 @@ class Tensor(object):
self.inputs = set(inputs)
return self.inputs
def infer_input_domains(self, out_domain, inputs):
def infer_input_domains(self, out_domain, inputs, red_domain=None):
"""Infer the input domains of each domain in given inputs list.
Parameters
......@@ -57,6 +58,12 @@ class Tensor(object):
out_domain : list of Range
Domain of each dimension.
red_domain : list of Range
Domain of reduction variables, if this tensor
this can only be specified if
self.expr finishes with an ReduceExpr, and we can schedule
over the last reduction that creates this tensor.
Returns
-------
in_domains: dict Tensor->Domain
......@@ -66,6 +73,17 @@ class Tensor(object):
index_domains = {
self.dim_index[i] : out_domain[i] for i in range(len(out_domain))
}
begin_expr = self.expr
if red_domain:
if not isinstance(self.expr, _expr.ReduceExpr):
raise ValueError("red_domain must work with tensor that stores a reduction")
rdom = self.expr.rdom
begin_expr = self.expr.src
assert len(red_domain) == len(rdom.index)
for i in range(len(red_domain)):
index_domains[rdom.index[i]] = red_domain[i]
iset = {}
for t in inputs:
assert t in self.input_tensors()
......@@ -79,7 +97,7 @@ class Tensor(object):
elif isinstance(e, _expr.TensorReadExpr):
if e.tensor in iset:
iset[e.tensor].append(e)
_expr_util.visit(self.expr, prepare)
_expr_util.visit(begin_expr, prepare)
result = {}
for k, v in iset.items():
dm = [None] * len(v[0].indices)
......@@ -89,3 +107,13 @@ class Tensor(object):
dm[i], _dom.infer_range(idx, index_domains, allow_unbind_var=False))
result[k] = dm
return result
@property
def is_rtensor(self):
"""Whether this tensor is a result of reduction.
Returns
-------
is_rtensor : Whether the tensor is RTensor
"""
return self.expr and isinstance(self.expr, _expr.ReduceExpr)
......@@ -11,14 +11,16 @@ def test_range_infer():
def test_tensor_dom_infer():
A = tvm.Tensor(2, name='A')
B = tvm.Tensor(2, name='B')
T = tvm.Tensor(3, lambda i, j, k: A(i, k) * B(j, k),
shape=(A.shape[0], B.shape[0], A.shape[1]))
rd = tvm.RDom(tvm.Range(A.shape[1]))
C = tvm.Tensor(2, lambda i, j: tvm.reduce_sum(T(i, j, rd.index[0]), rdom=rd),
T = tvm.Tensor(2, lambda i, j:
tvm.reduce_sum(A(i, rd.index[0]) * B(j, rd.index[0]), rdom=rd),
shape=(A.shape[0], B.shape[0]))
C = tvm.Tensor(2, lambda i, j: T(i,j),
shape=(A.shape[0], B.shape[0]))
cdom = [tvm.Range(0, 10), tvm.Range(1, 11)]
tdom = C.infer_input_domains(cdom, inputs=[T])[T]
assert T.is_rtensor
assert str(tdom[0]) == "(0, 10)"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment