Commit dcddd208 by tqchen

finish tensor dom infer

parent 6819145a
...@@ -5,4 +5,4 @@ from .op import * ...@@ -5,4 +5,4 @@ from .op import *
from .expr import Var, const from .expr import Var, const
from .expr_util import * from .expr_util import *
from .tensor import Tensor from .tensor import Tensor
from .domain import RDom, Range from .domain import RDom, Range, infer_range
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from . import expr as _expr from . import expr as _expr
from . import expr_util as _expr_util from . import expr_util as _expr_util
from . import op as _op
class Range(object): class Range(object):
"""Represent a range in one dimension. """Represent a range in one dimension.
...@@ -10,10 +10,15 @@ class Range(object): ...@@ -10,10 +10,15 @@ class Range(object):
if end is None: if end is None:
end = begin end = begin
begin = _expr.const(0) begin = _expr.const(0)
self.begin = _expr._symbol(begin) begin = _expr_util.simplify(_expr._symbol(begin))
self.end = _expr._symbol(end) end = _expr_util.simplify(_expr._symbol(end))
self.begin = begin
self.end = end
self.extent = _expr_util.simplify(end - begin) self.extent = _expr_util.simplify(end - begin)
def is_value(self):
return isinstance(self.extent, _expr.ConstExpr) and self.extend.value == 1
def __str__(self): def __str__(self):
return "(%s, %s)" % ( return "(%s, %s)" % (
_expr_util.format_str(self.begin), _expr_util.format_str(self.begin),
...@@ -22,9 +27,13 @@ class Range(object): ...@@ -22,9 +27,13 @@ class Range(object):
def __repr__(self): def __repr__(self):
return self.__str__() return self.__str__()
class RangeInferError(ValueError):
pass
class RDom(object): class RDom(object):
"""reduction Domain """Reduction Domain."""
"""
def __init__(self, domain): def __init__(self, domain):
if isinstance(domain, Range): if isinstance(domain, Range):
domain = [domain] domain = [domain]
...@@ -36,3 +45,63 @@ class RDom(object): ...@@ -36,3 +45,63 @@ class RDom(object):
"""Use list of ranges as domain""" """Use list of ranges as domain"""
Domain = list Domain = list
def _combine_range_binary_op(op, lhs, rhs):
if op == _op.add:
return Range(lhs.begin + rhs.begin, lhs.end + rhs.end - 1)
elif op == _op.sub:
return Range(lhs.begin - rhs.end + 1, lhs.end - rhs.begin)
elif op == _op.mul:
v = None
if lhs.is_value():
v = lhs.begin.value
e = rhs
elif rhs.is_value():
v = rhs.begin.value
e = lhs
if v == -1:
return Range(-e.end, -e.begin)
raise InferRangeError("donot know how to infer range for %s" % type(op))
def infer_range(e, range_dict, allow_unbind_var=True):
"""Infer the range of result e given range of variables.
Parameters
----------
expr : Expr
Input expression
range_dict : dict of Var->Range
The variables to be replaced.
allow_unbind_var: bool
Whether allow unbinded variables
"""
def combine_range(e, result_children):
if isinstance(e, _expr.ConstExpr):
return Range(e, e + 1)
elif isinstance(e, _expr.BinaryOpExpr):
return _combine_range_binary_op(e.op, result_children[0], result_children[1])
elif isinstance(e, _expr.Var):
if e in range_dict:
return range_dict[e]
else:
if allow_unbind_var:
return Range(e, e + 1)
else:
raise ValueError("Cannot find var %s in range_dict" % e.name)
else:
raise InferRangeError("cannot infer range for %s" % _expr_util.format_str(e))
return _expr_util.transform(e, combine_range)
def union_range(lhs, rhs):
if lhs is None:
return rhs
if rhs is None:
return lhs
begin = _op.min(lhs.begin, rhs.begin)
end = _op.max(rhs.end, lhs.end)
return Range(begin, end)
...@@ -22,7 +22,6 @@ def canonical_to_expr(c): ...@@ -22,7 +22,6 @@ def canonical_to_expr(c):
else: else:
return _expr.const(0) return _expr.const(0)
class BinaryOp(object): class BinaryOp(object):
"""Base class of binary operator""" """Base class of binary operator"""
def __call__(self, lhs, rhs): def __call__(self, lhs, rhs):
...@@ -45,7 +44,6 @@ class AddOp(BinaryOp): ...@@ -45,7 +44,6 @@ class AddOp(BinaryOp):
lhs[k] = v lhs[k] = v
return lhs return lhs
class SubOp(BinaryOp): class SubOp(BinaryOp):
def format_str(self, lhs, rhs): def format_str(self, lhs, rhs):
return '(%s - %s)' % (lhs, rhs) return '(%s - %s)' % (lhs, rhs)
......
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from . import expr as _expr from . import expr as _expr
from . import expr_util as _expr_util from . import expr_util as _expr_util
from . import domain as _dom
class Tensor(object): class Tensor(object):
...@@ -39,16 +40,17 @@ class Tensor(object): ...@@ -39,16 +40,17 @@ class Tensor(object):
""" """
if self.inputs is not None: if self.inputs is not None:
return self.inputs return self.inputs
self.inputs = [] inputs = []
if self.expr: if self.expr:
def collect(e): def collect(e):
if isinstance(e, _expr.TensorReadExpr): if isinstance(e, _expr.TensorReadExpr):
self.inputs.append(e.tensor) inputs.append(e.tensor)
_expr_util.visit(self.expr, collect) _expr_util.visit(self.expr, collect)
self.inputs = set(inputs)
return self.inputs return self.inputs
def infer_input_domains(self, out_domain): def infer_input_domains(self, out_domain, inputs):
"""Infer the input domains of each domain given output domains """Infer the input domains of each domain in given inputs list.
Parameters Parameters
---------- ----------
...@@ -64,7 +66,26 @@ class Tensor(object): ...@@ -64,7 +66,26 @@ class Tensor(object):
index_domains = { index_domains = {
self.dim_index[i] : out_domain[i] for i in range(len(out_domain)) self.dim_index[i] : out_domain[i] for i in range(len(out_domain))
} }
def collect(e): iset = {}
if isinstance(e, _expr.TensorReadExpr): for t in inputs:
self.inputs.append(e.tensor) assert t in self.input_tensors()
_expr_util.visit(self.expr, collect) iset[t] = []
def prepare(e):
if isinstance(e, _expr.ReduceExpr):
rd = e.rdom
for i in range(len(rd.domain)):
index_domains[rd.index[i]] = rd.domain[i]
elif isinstance(e, _expr.TensorReadExpr):
if e.tensor in iset:
iset[e.tensor].append(e)
_expr_util.visit(self.expr, prepare)
result = {}
for k, v in iset.items():
dm = [None] * len(v[0].indices)
for e in v:
for i, idx in enumerate(e.indices):
dm[i] = _dom.union_range(
dm[i], _dom.infer_range(idx, index_domains, allow_unbind_var=False))
result[k] = dm
return result
import tvm
def test_range_infer():
x = tvm.Var('x')
y = tvm.Var('y')
t = tvm.Var('t')
z = x + y + t
zr = tvm.infer_range(z, {x: tvm.Range(10, 20), y : tvm.Range(10, 11)})
assert str(zr) == "((t0 + 20), (t0 + 30))"
def test_tensor_dom_infer():
A = tvm.Tensor(2, name='A')
B = tvm.Tensor(2, name='B')
T = tvm.Tensor(3, lambda i, j, k: A(i, k) * B(j, k),
shape=(A.shape[0], B.shape[0], A.shape[1]))
rd = tvm.RDom(tvm.Range(A.shape[1]))
C = tvm.Tensor(2, lambda i, j: tvm.reduce_sum(T(i, j, rd.index[0]), rdom=rd),
shape=(A.shape[0], B.shape[0]))
cdom = [tvm.Range(0, 10), tvm.Range(1, 11)]
tdom = C.infer_input_domains(cdom, inputs=[T])[T]
assert str(tdom[0]) == "(0, 10)"
if __name__ == "__main__":
test_range_infer()
test_tensor_dom_infer()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment