Commit 05e871d4 by tqchen

Merge branch 'master' of ssh://github.com/tqchen/tvm

parents c41d9d23 f03483bf
......@@ -6,3 +6,4 @@ from .expr import Var, const
from .expr_util import *
from .tensor import Tensor
from .domain import RDom, Range, infer_range
from .split import Split
......@@ -17,7 +17,7 @@ class Range(object):
self.extent = _expr_util.simplify(end - begin)
def is_value(self):
return isinstance(self.extent, _expr.ConstExpr) and self.extend.value == 1
return isinstance(self.extent, _expr.ConstExpr) and self.extent.value == 1
def __str__(self):
return "(%s, %s)" % (
......
......@@ -6,7 +6,7 @@ constant_canonical_key = '__constant__'
def canonical_to_expr(c):
elements = []
for k, v in sorted(c.items()):
if k == constant_canonical_key:
if k == constant_canonical_key and v != 0:
elements.append(_expr.const(v))
elif v == 0:
continue
......@@ -87,7 +87,7 @@ class DivOp(BinaryOp):
if isinstance(erhs, _expr.ConstExpr):
lhs = lhs.copy()
for k, v in lhs.items():
lhs[k] /= erhs.value
lhs[k] /= float(erhs.value)
return lhs
elhs = canonical_to_expr(lhs)
return {elhs / erhs: 1}
......
from __future__ import absolute_import as _abs
from . import expr as _expr
from . import domain as _dom
from . import tensor as _tensor
class Split(object):
def __init__(self, dim, factor):
self.dim = dim
self.factor = factor
self.loop_index = _expr.Var('loop_index_%d_' % dim)
def infer_inner_domain(self, domain):
if isinstance(domain, _dom.RDom):
domain = domain.domain
assert self.dim < len(domain)
inner_domain = domain[:]
dim_out_range = domain[self.dim]
dim_inner_begin = dim_out_range.begin + self.loop_index * self.factor
inner_domain[self.dim] = _dom.Range(dim_inner_begin, dim_inner_begin + self.factor)
return inner_domain
......@@ -25,7 +25,6 @@ class Tensor(object):
self.name = name if name else "TensorObj"
self.inputs = None
self.rdom = None
def __call__(self, *indices):
if len(indices) != self.ndim:
......
......@@ -26,6 +26,6 @@ def test_simplify():
assert tvm.format_str(tvm.simplify(e4)) == '0'
if __name__ == "__main__":
test_simplify()
test_basic()
test_bind()
test_simplify()
import tvm
def test_split_dom_infer():
A = tvm.Tensor(2, name='A')
rd = tvm.RDom(tvm.Range(A.shape[1]))
split1 = tvm.Split(0, 64)
split2 = tvm.Split(1, 64)
split3 = tvm.Split(0, 8)
dom = [tvm.Range(A.shape[0]), tvm.Range(A.shape[1])]
dom1 = split1.infer_inner_domain(dom)
dom2 = split2.infer_inner_domain(dom1)
dom3 = split3.infer_inner_domain(dom2)
dom4 = split3.infer_inner_domain(rd)
i1 = split1.loop_index.name
i2 = split2.loop_index.name
i3 = split3.loop_index.name
assert str(dom1) == "[((%s * 64), ((%s * 64) + 64)), (0, A_shape_1_0)]" % (i1, i1)
assert str(dom2) == "[((%s * 64), ((%s * 64) + 64)), ((%s * 64), ((%s * 64) + 64))]" % (i1, i1, i2, i2)
assert str(dom3) == "[(((%s * 64) + (%s * 8)), (((%s * 64) + (%s * 8)) + 8)), ((%s * 64), ((%s * 64) + 64))]" % (i1, i3, i1, i3, i2, i2)
assert str(dom4) == "[((%s * 8), ((%s * 8) + 8))]" % (i3, i3)
if __name__ == "__main__":
test_split_dom_infer()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment