Commit f03483bf by Haichen Shen

checked split

parent 1a18f08e
...@@ -6,3 +6,4 @@ from .expr import Var, const ...@@ -6,3 +6,4 @@ from .expr import Var, const
from .expr_util import * from .expr_util import *
from .tensor import Tensor from .tensor import Tensor
from .domain import RDom, Range, infer_range from .domain import RDom, Range, infer_range
from .split import Split
...@@ -17,7 +17,7 @@ class Range(object): ...@@ -17,7 +17,7 @@ class Range(object):
self.extent = _expr_util.simplify(end - begin) self.extent = _expr_util.simplify(end - begin)
def is_value(self): def is_value(self):
return isinstance(self.extent, _expr.ConstExpr) and self.extend.value == 1 return isinstance(self.extent, _expr.ConstExpr) and self.extent.value == 1
def __str__(self): def __str__(self):
return "(%s, %s)" % ( return "(%s, %s)" % (
......
...@@ -6,7 +6,7 @@ constant_canonical_key = '__constant__' ...@@ -6,7 +6,7 @@ constant_canonical_key = '__constant__'
def canonical_to_expr(c): def canonical_to_expr(c):
elements = [] elements = []
for k, v in sorted(c.items()): for k, v in sorted(c.items()):
if k == constant_canonical_key: if k == constant_canonical_key and v != 0:
elements.append(_expr.const(v)) elements.append(_expr.const(v))
elif v == 0: elif v == 0:
continue continue
...@@ -87,7 +87,7 @@ class DivOp(BinaryOp): ...@@ -87,7 +87,7 @@ class DivOp(BinaryOp):
if isinstance(erhs, _expr.ConstExpr): if isinstance(erhs, _expr.ConstExpr):
lhs = lhs.copy() lhs = lhs.copy()
for k, v in lhs.items(): for k, v in lhs.items():
lhs[k] /= erhs.value lhs[k] /= float(erhs.value)
return lhs return lhs
elhs = canonical_to_expr(lhs) elhs = canonical_to_expr(lhs)
return {elhs / erhs: 1} return {elhs / erhs: 1}
......
from __future__ import absolute_import as _abs
from . import expr as _expr
from . import domain as _dom
from . import tensor as _tensor
class Split(object):
def __init__(self, dim, factor):
self.dim = dim
self.factor = factor
self.loop_index = _expr.Var('loop_index_%d_' % dim)
def infer_inner_domain(self, domain):
if isinstance(domain, _dom.RDom):
domain = domain.domain
assert self.dim < len(domain)
inner_domain = domain[:]
dim_out_range = domain[self.dim]
dim_inner_begin = dim_out_range.begin + self.loop_index * self.factor
inner_domain[self.dim] = _dom.Range(dim_inner_begin, dim_inner_begin + self.factor)
return inner_domain
...@@ -25,7 +25,6 @@ class Tensor(object): ...@@ -25,7 +25,6 @@ class Tensor(object):
self.name = name if name else "TensorObj" self.name = name if name else "TensorObj"
self.inputs = None self.inputs = None
self.rdom = None
def __call__(self, *indices): def __call__(self, *indices):
if len(indices) != self.ndim: if len(indices) != self.ndim:
......
...@@ -26,6 +26,6 @@ def test_simplify(): ...@@ -26,6 +26,6 @@ def test_simplify():
assert tvm.format_str(tvm.simplify(e4)) == '0' assert tvm.format_str(tvm.simplify(e4)) == '0'
if __name__ == "__main__": if __name__ == "__main__":
test_simplify()
test_basic() test_basic()
test_bind() test_bind()
test_simplify()
import tvm
def test_split_dom_infer():
A = tvm.Tensor(2, name='A')
rd = tvm.RDom(tvm.Range(A.shape[1]))
split1 = tvm.Split(0, 64)
split2 = tvm.Split(1, 64)
split3 = tvm.Split(0, 8)
dom = [tvm.Range(A.shape[0]), tvm.Range(A.shape[1])]
dom1 = split1.infer_inner_domain(dom)
dom2 = split2.infer_inner_domain(dom1)
dom3 = split3.infer_inner_domain(dom2)
dom4 = split3.infer_inner_domain(rd)
i1 = split1.loop_index.name
i2 = split2.loop_index.name
i3 = split3.loop_index.name
assert str(dom1) == "[((%s * 64), ((%s * 64) + 64)), (0, A_shape_1_0)]" % (i1, i1)
assert str(dom2) == "[((%s * 64), ((%s * 64) + 64)), ((%s * 64), ((%s * 64) + 64))]" % (i1, i1, i2, i2)
assert str(dom3) == "[(((%s * 64) + (%s * 8)), (((%s * 64) + (%s * 8)) + 8)), ((%s * 64), ((%s * 64) + 64))]" % (i1, i3, i1, i3, i2, i2)
assert str(dom4) == "[((%s * 8), ((%s * 8) + 8))]" % (i3, i3)
if __name__ == "__main__":
test_split_dom_infer()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment