Commit 6819145a by tqchen

checkin domain

parent bda95817
......@@ -5,3 +5,4 @@ from .op import *
from .expr import Var, const
from .expr_util import *
from .tensor import Tensor
from .domain import RDom, Range
from __future__ import absolute_import as _abs
from . import expr as _expr
from . import expr_util as _expr_util
class Range(object):
"""Represent a range in one dimension.
"""
def __init__(self, begin, end=None):
if end is None:
end = begin
begin = _expr.const(0)
self.begin = _expr._symbol(begin)
self.end = _expr._symbol(end)
self.extent = _expr_util.simplify(end - begin)
def __str__(self):
return "(%s, %s)" % (
_expr_util.format_str(self.begin),
_expr_util.format_str(self.end))
def __repr__(self):
return self.__str__()
class RDom(object):
"""reduction Domain
"""
def __init__(self, domain):
if isinstance(domain, Range):
domain = [domain]
self.index = []
self.domain = domain
for i in range(len(domain)):
self.index.append(_expr.Var("rd_index_%d_" % i))
"""Use list of ranges as domain"""
Domain = list
......@@ -108,7 +108,27 @@ class UnaryOpExpr(Expr):
self.src = _symbol(src)
def children(self):
return (self.src)
return (self.src,)
class ReduceExpr(Expr):
def __init__(self, op, src, rdom):
self.op = op
self.src = src
self.rdom = rdom
def children(self):
return (self.src,)
class TensorReadExpr(Expr):
"""Tensor read expression, tensor[indices]"""
def __init__(self, tensor, indices):
self.tensor = tensor
self.indices = indices
def children(self):
return self.indices
def const(value):
......
......@@ -2,7 +2,6 @@
from __future__ import absolute_import as _abs
from . import expr as _expr
from . import op as _op
from . import tensor as _tensor
def expr_with_new_children(e, children):
"""Returns same expr as e but with new children
......@@ -50,10 +49,27 @@ def transform(e, f):
result : return value of f
The final result of transformation.
"""
assert isinstance(e, _expr.Expr)
if not isinstance(e, _expr.Expr):
raise TypeError("Cannot handle type %s" % type(e))
return f(e , [transform(c, f) for c in e.children()])
def visit(e, f):
"""Apply f to each element of e
Parameters
----------
e : Expr
The input expression.
f : function with signiture (e)
"""
assert isinstance(e, _expr.Expr)
for c in e.children():
visit(c, f)
f(e)
def format_str(expr):
"""change expression to string.
......@@ -76,12 +92,15 @@ def format_str(expr):
return str(e.value)
elif isinstance(e, _expr.Var):
return e.name
elif isinstance(e, _tensor.TensorReadExpr):
elif isinstance(e, _expr.TensorReadExpr):
return "%s(%s)" % (e.tensor.name, ','.join(result_children))
elif isinstance(e, _expr.ReduceExpr):
return e.op.format_reduce_str(result_children[0], e.rdom.domain)
else:
raise TypeError("Do not know how to handle type " + str(type(e)))
return transform(expr, make_str)
def simplify(expr):
"""simplify expression
......
......@@ -22,15 +22,20 @@ def canonical_to_expr(c):
else:
return _expr.const(0)
class BinaryOp(object):
"""Base class of binary operator"""
def __call__(self, lhs, rhs):
return _expr.BinaryOpExpr(self, lhs, rhs)
class AddOp(BinaryOp):
def format_str(self, lhs, rhs):
return '(%s + %s)' % (lhs, rhs)
def format_reduce_str(self, src, rd):
return "reduce_sum(%s, rdom=%s)" % (src, str(rd))
def canonical(self, lhs, rhs):
lhs = lhs.copy()
for k, v in rhs.items():
......@@ -40,6 +45,7 @@ class AddOp(BinaryOp):
lhs[k] = v
return lhs
class SubOp(BinaryOp):
def format_str(self, lhs, rhs):
return '(%s - %s)' % (lhs, rhs)
......@@ -53,6 +59,7 @@ class SubOp(BinaryOp):
lhs[k] = -v
return lhs
class MulOp(BinaryOp):
def format_str(self, lhs, rhs):
return '(%s * %s)' % (lhs, rhs)
......@@ -72,6 +79,7 @@ class MulOp(BinaryOp):
return rhs
return {elhs * erhs: 1}
class DivOp(BinaryOp):
def format_str(self, lhs, rhs):
return '(%s / %s)' % (lhs, rhs)
......@@ -86,6 +94,7 @@ class DivOp(BinaryOp):
elhs = canonical_to_expr(lhs)
return {elhs / erhs: 1}
class MaxOp(BinaryOp):
def format_str(self, lhs, rhs):
return 'max(%s, %s)' % (lhs, rhs)
......@@ -97,6 +106,7 @@ class MaxOp(BinaryOp):
return lhs if ediff.value >= 0 else rhs
return {MaxOp()(lhs, rhs): 1}
class MinOp(BinaryOp):
def format_str(self, lhs, rhs):
return 'min(%s, %s)' % (lhs, rhs)
......@@ -120,3 +130,16 @@ _expr.__addop__ = add
_expr.__subop__ = sub
_expr.__mulop__ = mul
_expr.__divop__ = div
def reduce_sum(expr, rdom):
return _expr.ReduceExpr(add, expr, rdom)
def reduce_prod(expr, rdom):
return _expr.ReduceExpr(mul, expr, rdom)
def reduce_min(expr, rdom):
return _expr.ReduceExpr(min, expr, rdom)
def reduce_max(expr, rdom):
return _expr.ReduceExpr(max, expr, rdom)
from __future__ import absolute_import as _abs
from . import expr as _expr
class TensorReadExpr(_expr.Expr):
def __init__(self, tensor, indices):
self.tensor = tensor
self.indices = indices
def children(self):
return self.indices
from . import expr_util as _expr_util
class Tensor(object):
def __init__(self, ndim, fcompute=None, name=None):
def __init__(self, ndim, fcompute=None, name=None, shape=None):
self.ndim = ndim
if fcompute:
arg_names = fcompute.func_code.co_varnames
assert(len(arg_names) == ndim)
self.dim_index = [_expr.Var(n) for n in arg_names]
self.expr = fcompute(*self.dim_index)
if shape is None:
raise ValueError("argument shape need to be given for intermediate tensor")
self.shape = shape
else:
self.expr = None
self.dim_index = None
shape_name = '_shape'
if name: shape_name = name + shape_name
self.shape = tuple(_expr.Var("%s_%d_" % (shape_name, i)) for i in range(ndim))
self.shape = shape if shape else tuple(
_expr.Var("%s_%d_" % (shape_name, i)) for i in range(ndim))
self.name = name if name else "TensorObj"
self.inputs = None
def __call__(self, *indices):
if len(indices) != self.ndim:
raise ValueError("Need to provide %d index in tensor slice" % self.ndim)
return TensorReadExpr(self, indices)
return _expr.TensorReadExpr(self, indices)
def input_tensors(self):
"""List of input tensors to this tensor.
Returns
-------
inputs : list of input tensors
"""
if self.inputs is not None:
return self.inputs
self.inputs = []
if self.expr:
def collect(e):
if isinstance(e, _expr.TensorReadExpr):
self.inputs.append(e.tensor)
_expr_util.visit(self.expr, collect)
return self.inputs
def infer_input_domains(self, out_domain):
"""Infer the input domains of each domain given output domains
Parameters
----------
out_domain : list of Range
Domain of each dimension.
Returns
-------
in_domains: dict Tensor->Domain
"""
assert self.expr
assert len(out_domain) == len(self.dim_index)
index_domains = {
self.dim_index[i] : out_domain[i] for i in range(len(out_domain))
}
def collect(e):
if isinstance(e, _expr.TensorReadExpr):
self.inputs.append(e.tensor)
_expr_util.visit(self.expr, collect)
......@@ -3,8 +3,27 @@ import tvm
def test_tensor():
A = tvm.Tensor(2, name='A')
B = tvm.Tensor(2, name='B')
T = tvm.Tensor(3, lambda i, j, k: A(i, k) * B(j, k))
T = tvm.Tensor(3, lambda i, j, k: A(i, k) * B(j, k),
shape=(A.shape[0], B.shape[0], A.shape[1]))
print(tvm.format_str(T.expr))
def test_tensor_inputs():
A = tvm.Tensor(2, name='A')
B = tvm.Tensor(2, name='B')
T = tvm.Tensor(3, lambda i, j, k: A(i, k) * B(j, k),
shape=(A.shape[0], B.shape[0], A.shape[1]))
assert(T.input_tensors() == [A, B])
def test_tensor_reduce():
A = tvm.Tensor(2, name='A')
B = tvm.Tensor(2, name='B')
T = tvm.Tensor(3, lambda i, j, k: A(i, k) * B(j, k),
shape=(A.shape[0], B.shape[0], A.shape[1]))
rd = tvm.RDom(tvm.Range(A.shape[1]))
C = tvm.Tensor(2, lambda i, j: tvm.reduce_sum(T(i, j, rd.index[0]), rdom=rd),
shape=(A.shape[0], B.shape[0]))
print(tvm.format_str(C.expr))
if __name__ == "__main__":
test_tensor()
test_tensor_inputs()
test_tensor_reduce()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment