Commit 6819145a by tqchen

checkin domain

parent bda95817
...@@ -5,3 +5,4 @@ from .op import * ...@@ -5,3 +5,4 @@ from .op import *
from .expr import Var, const from .expr import Var, const
from .expr_util import * from .expr_util import *
from .tensor import Tensor from .tensor import Tensor
from .domain import RDom, Range
from __future__ import absolute_import as _abs
from . import expr as _expr
from . import expr_util as _expr_util
class Range(object):
"""Represent a range in one dimension.
"""
def __init__(self, begin, end=None):
if end is None:
end = begin
begin = _expr.const(0)
self.begin = _expr._symbol(begin)
self.end = _expr._symbol(end)
self.extent = _expr_util.simplify(end - begin)
def __str__(self):
return "(%s, %s)" % (
_expr_util.format_str(self.begin),
_expr_util.format_str(self.end))
def __repr__(self):
return self.__str__()
class RDom(object):
"""reduction Domain
"""
def __init__(self, domain):
if isinstance(domain, Range):
domain = [domain]
self.index = []
self.domain = domain
for i in range(len(domain)):
self.index.append(_expr.Var("rd_index_%d_" % i))
"""Use list of ranges as domain"""
Domain = list
...@@ -108,7 +108,27 @@ class UnaryOpExpr(Expr): ...@@ -108,7 +108,27 @@ class UnaryOpExpr(Expr):
self.src = _symbol(src) self.src = _symbol(src)
def children(self): def children(self):
return (self.src) return (self.src,)
class ReduceExpr(Expr):
def __init__(self, op, src, rdom):
self.op = op
self.src = src
self.rdom = rdom
def children(self):
return (self.src,)
class TensorReadExpr(Expr):
"""Tensor read expression, tensor[indices]"""
def __init__(self, tensor, indices):
self.tensor = tensor
self.indices = indices
def children(self):
return self.indices
def const(value): def const(value):
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from . import expr as _expr from . import expr as _expr
from . import op as _op from . import op as _op
from . import tensor as _tensor
def expr_with_new_children(e, children): def expr_with_new_children(e, children):
"""Returns same expr as e but with new children """Returns same expr as e but with new children
...@@ -50,10 +49,27 @@ def transform(e, f): ...@@ -50,10 +49,27 @@ def transform(e, f):
result : return value of f result : return value of f
The final result of transformation. The final result of transformation.
""" """
assert isinstance(e, _expr.Expr) if not isinstance(e, _expr.Expr):
raise TypeError("Cannot handle type %s" % type(e))
return f(e , [transform(c, f) for c in e.children()]) return f(e , [transform(c, f) for c in e.children()])
def visit(e, f):
"""Apply f to each element of e
Parameters
----------
e : Expr
The input expression.
f : function with signiture (e)
"""
assert isinstance(e, _expr.Expr)
for c in e.children():
visit(c, f)
f(e)
def format_str(expr): def format_str(expr):
"""change expression to string. """change expression to string.
...@@ -76,12 +92,15 @@ def format_str(expr): ...@@ -76,12 +92,15 @@ def format_str(expr):
return str(e.value) return str(e.value)
elif isinstance(e, _expr.Var): elif isinstance(e, _expr.Var):
return e.name return e.name
elif isinstance(e, _tensor.TensorReadExpr): elif isinstance(e, _expr.TensorReadExpr):
return "%s(%s)" % (e.tensor.name, ','.join(result_children)) return "%s(%s)" % (e.tensor.name, ','.join(result_children))
elif isinstance(e, _expr.ReduceExpr):
return e.op.format_reduce_str(result_children[0], e.rdom.domain)
else: else:
raise TypeError("Do not know how to handle type " + str(type(e))) raise TypeError("Do not know how to handle type " + str(type(e)))
return transform(expr, make_str) return transform(expr, make_str)
def simplify(expr): def simplify(expr):
"""simplify expression """simplify expression
......
...@@ -22,15 +22,20 @@ def canonical_to_expr(c): ...@@ -22,15 +22,20 @@ def canonical_to_expr(c):
else: else:
return _expr.const(0) return _expr.const(0)
class BinaryOp(object): class BinaryOp(object):
"""Base class of binary operator""" """Base class of binary operator"""
def __call__(self, lhs, rhs): def __call__(self, lhs, rhs):
return _expr.BinaryOpExpr(self, lhs, rhs) return _expr.BinaryOpExpr(self, lhs, rhs)
class AddOp(BinaryOp): class AddOp(BinaryOp):
def format_str(self, lhs, rhs): def format_str(self, lhs, rhs):
return '(%s + %s)' % (lhs, rhs) return '(%s + %s)' % (lhs, rhs)
def format_reduce_str(self, src, rd):
return "reduce_sum(%s, rdom=%s)" % (src, str(rd))
def canonical(self, lhs, rhs): def canonical(self, lhs, rhs):
lhs = lhs.copy() lhs = lhs.copy()
for k, v in rhs.items(): for k, v in rhs.items():
...@@ -40,6 +45,7 @@ class AddOp(BinaryOp): ...@@ -40,6 +45,7 @@ class AddOp(BinaryOp):
lhs[k] = v lhs[k] = v
return lhs return lhs
class SubOp(BinaryOp): class SubOp(BinaryOp):
def format_str(self, lhs, rhs): def format_str(self, lhs, rhs):
return '(%s - %s)' % (lhs, rhs) return '(%s - %s)' % (lhs, rhs)
...@@ -53,6 +59,7 @@ class SubOp(BinaryOp): ...@@ -53,6 +59,7 @@ class SubOp(BinaryOp):
lhs[k] = -v lhs[k] = -v
return lhs return lhs
class MulOp(BinaryOp): class MulOp(BinaryOp):
def format_str(self, lhs, rhs): def format_str(self, lhs, rhs):
return '(%s * %s)' % (lhs, rhs) return '(%s * %s)' % (lhs, rhs)
...@@ -72,6 +79,7 @@ class MulOp(BinaryOp): ...@@ -72,6 +79,7 @@ class MulOp(BinaryOp):
return rhs return rhs
return {elhs * erhs: 1} return {elhs * erhs: 1}
class DivOp(BinaryOp): class DivOp(BinaryOp):
def format_str(self, lhs, rhs): def format_str(self, lhs, rhs):
return '(%s / %s)' % (lhs, rhs) return '(%s / %s)' % (lhs, rhs)
...@@ -86,6 +94,7 @@ class DivOp(BinaryOp): ...@@ -86,6 +94,7 @@ class DivOp(BinaryOp):
elhs = canonical_to_expr(lhs) elhs = canonical_to_expr(lhs)
return {elhs / erhs: 1} return {elhs / erhs: 1}
class MaxOp(BinaryOp): class MaxOp(BinaryOp):
def format_str(self, lhs, rhs): def format_str(self, lhs, rhs):
return 'max(%s, %s)' % (lhs, rhs) return 'max(%s, %s)' % (lhs, rhs)
...@@ -97,6 +106,7 @@ class MaxOp(BinaryOp): ...@@ -97,6 +106,7 @@ class MaxOp(BinaryOp):
return lhs if ediff.value >= 0 else rhs return lhs if ediff.value >= 0 else rhs
return {MaxOp()(lhs, rhs): 1} return {MaxOp()(lhs, rhs): 1}
class MinOp(BinaryOp): class MinOp(BinaryOp):
def format_str(self, lhs, rhs): def format_str(self, lhs, rhs):
return 'min(%s, %s)' % (lhs, rhs) return 'min(%s, %s)' % (lhs, rhs)
...@@ -120,3 +130,16 @@ _expr.__addop__ = add ...@@ -120,3 +130,16 @@ _expr.__addop__ = add
_expr.__subop__ = sub _expr.__subop__ = sub
_expr.__mulop__ = mul _expr.__mulop__ = mul
_expr.__divop__ = div _expr.__divop__ = div
def reduce_sum(expr, rdom):
return _expr.ReduceExpr(add, expr, rdom)
def reduce_prod(expr, rdom):
return _expr.ReduceExpr(mul, expr, rdom)
def reduce_min(expr, rdom):
return _expr.ReduceExpr(min, expr, rdom)
def reduce_max(expr, rdom):
return _expr.ReduceExpr(max, expr, rdom)
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from . import expr as _expr from . import expr as _expr
from . import expr_util as _expr_util
class TensorReadExpr(_expr.Expr):
def __init__(self, tensor, indices):
self.tensor = tensor
self.indices = indices
def children(self):
return self.indices
class Tensor(object): class Tensor(object):
def __init__(self, ndim, fcompute=None, name=None): def __init__(self, ndim, fcompute=None, name=None, shape=None):
self.ndim = ndim self.ndim = ndim
if fcompute: if fcompute:
arg_names = fcompute.func_code.co_varnames arg_names = fcompute.func_code.co_varnames
assert(len(arg_names) == ndim) assert(len(arg_names) == ndim)
self.dim_index = [_expr.Var(n) for n in arg_names] self.dim_index = [_expr.Var(n) for n in arg_names]
self.expr = fcompute(*self.dim_index) self.expr = fcompute(*self.dim_index)
if shape is None:
raise ValueError("argument shape need to be given for intermediate tensor")
self.shape = shape
else: else:
self.expr = None self.expr = None
self.dim_index = None self.dim_index = None
shape_name = '_shape' shape_name = '_shape'
if name: shape_name = name + shape_name if name: shape_name = name + shape_name
self.shape = tuple(_expr.Var("%s_%d_" % (shape_name, i)) for i in range(ndim)) self.shape = shape if shape else tuple(
_expr.Var("%s_%d_" % (shape_name, i)) for i in range(ndim))
self.name = name if name else "TensorObj" self.name = name if name else "TensorObj"
self.inputs = None
def __call__(self, *indices): def __call__(self, *indices):
if len(indices) != self.ndim: if len(indices) != self.ndim:
raise ValueError("Need to provide %d index in tensor slice" % self.ndim) raise ValueError("Need to provide %d index in tensor slice" % self.ndim)
return TensorReadExpr(self, indices) return _expr.TensorReadExpr(self, indices)
def input_tensors(self):
"""List of input tensors to this tensor.
Returns
-------
inputs : list of input tensors
"""
if self.inputs is not None:
return self.inputs
self.inputs = []
if self.expr:
def collect(e):
if isinstance(e, _expr.TensorReadExpr):
self.inputs.append(e.tensor)
_expr_util.visit(self.expr, collect)
return self.inputs
def infer_input_domains(self, out_domain):
"""Infer the input domains of each domain given output domains
Parameters
----------
out_domain : list of Range
Domain of each dimension.
Returns
-------
in_domains: dict Tensor->Domain
"""
assert self.expr
assert len(out_domain) == len(self.dim_index)
index_domains = {
self.dim_index[i] : out_domain[i] for i in range(len(out_domain))
}
def collect(e):
if isinstance(e, _expr.TensorReadExpr):
self.inputs.append(e.tensor)
_expr_util.visit(self.expr, collect)
...@@ -3,8 +3,27 @@ import tvm ...@@ -3,8 +3,27 @@ import tvm
def test_tensor(): def test_tensor():
A = tvm.Tensor(2, name='A') A = tvm.Tensor(2, name='A')
B = tvm.Tensor(2, name='B') B = tvm.Tensor(2, name='B')
T = tvm.Tensor(3, lambda i, j, k: A(i, k) * B(j, k)) T = tvm.Tensor(3, lambda i, j, k: A(i, k) * B(j, k),
shape=(A.shape[0], B.shape[0], A.shape[1]))
print(tvm.format_str(T.expr)) print(tvm.format_str(T.expr))
def test_tensor_inputs():
A = tvm.Tensor(2, name='A')
B = tvm.Tensor(2, name='B')
T = tvm.Tensor(3, lambda i, j, k: A(i, k) * B(j, k),
shape=(A.shape[0], B.shape[0], A.shape[1]))
assert(T.input_tensors() == [A, B])
def test_tensor_reduce():
A = tvm.Tensor(2, name='A')
B = tvm.Tensor(2, name='B')
T = tvm.Tensor(3, lambda i, j, k: A(i, k) * B(j, k),
shape=(A.shape[0], B.shape[0], A.shape[1]))
rd = tvm.RDom(tvm.Range(A.shape[1]))
C = tvm.Tensor(2, lambda i, j: tvm.reduce_sum(T(i, j, rd.index[0]), rdom=rd),
shape=(A.shape[0], B.shape[0]))
print(tvm.format_str(C.expr))
if __name__ == "__main__": if __name__ == "__main__":
test_tensor() test_tensor_inputs()
test_tensor_reduce()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment