Unverified Commit 2a7aebe5 by Tianqi Chen Committed by GitHub

[ARITH] More recursive rewrite rule, cleanup simplify tests (#3502)

parent eadc4e38
......@@ -24,7 +24,6 @@ You can use make function to build the IR node.
"""
from __future__ import absolute_import as _abs
from ._ffi.function import _init_api
from ._ffi.runtime_ctypes import TVMType
def range_by_min_extent(min_value, extent):
......@@ -48,35 +47,6 @@ def range_by_min_extent(min_value, extent):
return _range_by_min_extent(min_value, extent)
def static_cast(dtype, expr):
"""Cast expr to dtype.
If expr is scalar and dtype is a corresponding vector
type, a Broadcast is generated. Otherwise it is a Cast.
Parameters
----------
dtype : str
The target data type.
expr : Expr
The expression to be casted.
Returns
-------
casted : Expr
The casted expression.
"""
target_type = TVMType(dtype)
src_type = TVMType(expr.dtype)
if target_type.type_code == src_type.type_code and src_type.bits == target_type.bits:
if src_type.lanes == target_type.lanes:
return expr
if src_type.lanes == 1 and target_type.lanes > 1:
return Broadcast(expr, target_type.lanes)
return Cast(dtype, expr)
def node(type_key, **kwargs):
"""Make a new DSL node by its type key and fields
......
......@@ -1194,9 +1194,9 @@ Mutate_(const LT* op, const Expr& self) {
TVM_TRY_RECURSIVE_REWRITE(c1 - y < x, c1 < x + y);
TVM_TRY_RECURSIVE_REWRITE(c1 + y < x, c1 < x - y);
TVM_TRY_RECURSIVE_REWRITE(x + c1 < c2, x < c2 - c1);
TVM_TRY_RECURSIVE_REWRITE(x - c1 < c2, x < c2 + c1);
TVM_TRY_REWRITE(x - c1 < 0, x < c1);
TVM_TRY_REWRITE(x + c1 < c2, x < c2 - c1);
}
return ret;
}
......
......@@ -31,11 +31,24 @@
namespace tvm {
namespace arith {
// statement simplifier
using namespace ir;
class StmtSimplifier : public IRMutator {
public:
using IRMutator::Mutate;
Expr Mutate(Expr expr) final {
return analyzer_.Simplify(expr);
}
Stmt Simplify(Stmt stmt, Map<Var, Range> vrange) {
for (auto kv : vrange) {
analyzer_.Bind(kv.first, kv.second);
}
return Mutate(stmt);
}
Stmt Mutate_(const For* op, const Stmt& s) final {
Var loop_var(op->loop_var.node_);
analyzer_.Bind(loop_var, Range::make_by_min_extent(op->min, op->extent));
......@@ -124,28 +137,12 @@ class StmtSimplifier : public IRMutator {
std::unordered_map<const Variable*, Range> var_dom_;
};
class CanonicalStmtSimplifier : public StmtSimplifier {
public:
using StmtSimplifier::Mutate;
Expr Mutate(Expr expr) final {
return analyzer_.canonical_simplify(expr);
}
Stmt CanonicalSimplify(Stmt stmt, Map<Var, Range> vrange) {
for (auto kv : vrange) {
analyzer_.Bind(kv.first, kv.second);
}
return Mutate(stmt);
}
};
} // namespace arith
namespace ir {
Stmt CanonicalSimplify(Stmt stmt, Map<Var, Range> vrange) {
return arith::CanonicalStmtSimplifier().CanonicalSimplify(
return arith::StmtSimplifier().Simplify(
stmt, vrange);
}
......@@ -167,7 +164,7 @@ Expr Simplify(Expr expr, Map<Var, Range> vrange) {
}
Stmt Simplify(Stmt stmt, Map<Var, Range> vrange) {
return arith::CanonicalStmtSimplifier().CanonicalSimplify(
return arith::StmtSimplifier().Simplify(
stmt, vrange);
}
} // namespace ir
......
......@@ -81,6 +81,10 @@ def test_canonical_mixed():
z = tvm.const(3, "int32")
ck.verify(x / (z*z) - x / (z*z), 0)
ck.verify(x / (z+z) - x / (z+z), 0)
ck.verify(x - 2 < 3, x < 5)
ck.verify(tvm.max(x, 1) - tvm.max(x, 1), 0)
ck.verify(tvm.min(x, 1) - tvm.min(x, 1), 0)
ck.verify(x * x - x * x, 0)
def test_reduce_combiner_simplify():
......@@ -211,6 +215,8 @@ def test_complex_cases():
ck.verify(res3, ((((x*1024) + y)/256) - (y/256)) - (x*4))
if __name__ == "__main__":
test_simplify_if_then_else()
test_div_simplify()
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
def csimplify(z):
return tvm.ir_pass.CanonicalSimplify(
tvm.make.Evaluate(z)).value
def test_simplify():
x = tvm.var('n')
z = x * 4 - x * 2
zz = csimplify(z)
assert zz.b.value == 2
z = (x / 4) * 2 - (x / 4)
zz = csimplify(z)
assert zz.a == x and zz.b.value == 4
z = (x % 4) * 3 + (x % 4)
zz = csimplify(z)
assert zz.b.value == 4
zz = zz.a
assert zz.a == x and zz.b.value == 4
n = tvm.var('n')
assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(n % 1), tvm.const(0, "int32"))
assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(n / 1), n)
tvm.ir_pass.CanonicalSimplify(n / (-1))
# This is not true in the current implementation
# assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(n / (-1)),
# tvm.ir_pass.CanonicalSimplify(-n))
def test_simplify_mod():
ib = tvm.ir_builder.create()
n = tvm.var('n')
A = ib.pointer("float32", name="A")
with ib.for_range(0, 10, name="j") as j:
with ib.for_range(0, 16, name="i") as i:
A[i] = A[(j * 32 + i+1) % 16]
body = ib.get()
stmt = tvm.ir_pass.CanonicalSimplify(body)
diff = tvm.ir_pass.CanonicalSimplify(stmt.body.body.value.index - (1 + i) % 16)
assert diff.value == 0
# if we can't prove that j is non-negative, we can't prove that (j+16) % 16 is j%16
index = tvm.ir_pass.CanonicalSimplify((j + 16) % 16)
assert index != j
index = tvm.ir_pass.CanonicalSimplify((j + 16) % 16, {j: tvm.Range(0, 6)})
assert index == j
# if we can't prove that j+n*32 is non-negative, we can't prove that (j+n*32) % 16 is j%16
index = tvm.ir_pass.CanonicalSimplify(
(j + n * 32) % 16, {j: tvm.Range(0, 6)})
assert index != j
index = tvm.ir_pass.CanonicalSimplify(
(j + n * 32) % 16, {j: tvm.Range(0, 6), n: tvm.Range(0, 10)})
assert index == j
def test_simplify_minmax():
x = tvm.var('x')
e1 = tvm.max(x, 1) - tvm.max(x, 1)
e1s = tvm.ir_pass.CanonicalSimplify(e1)
assert e1s.value == 0
e2 = tvm.min(x, 1) - tvm.min(x, 1)
e2s = tvm.ir_pass.CanonicalSimplify(e2)
assert e2s.value == 0
def test_mul():
x = tvm.var('x')
e = x * x - x * x
es = tvm.ir_pass.CanonicalSimplify(e)
assert es.value == 0
def test_modular():
rx = tvm.var("rx")
ry = tvm.var("ry")
y = tvm.var("y")
x = tvm.var("x")
i32_const = lambda x: tvm.const(x, "int32")
vmap = {rx: tvm.Range(i32_const(0), i32_const(3)),
ry: tvm.Range(i32_const(0), i32_const(3)),
y: tvm.Range(i32_const(0), i32_const(2)),
x: tvm.Range(i32_const(0), i32_const(14))}
idx = ry * 16 + rx + y * 16 + x
z2 = tvm.ir_pass.CanonicalSimplify(idx % 16, vmap)
z1 = tvm.ir_pass.CanonicalSimplify(idx // 16, vmap)
assert tvm.ir_pass.CanonicalSimplify(z1 - (ry + y)).value == 0
assert tvm.ir_pass.CanonicalSimplify(z2 - (rx + x)).value == 0
def test_const_propagation():
x1 = tvm.const(4, "int32")
x2 = x1 + 5
assert isinstance(x2, tvm.expr.IntImm) and x2.value == 9
x3 = x2 / 3
assert isinstance(x3, tvm.expr.IntImm) and x3.value == 3
x4 = x3 + 0.5
assert isinstance(x4, tvm.expr.FloatImm) and x4.value == 3.5
x5 = tvm.ceil(x4)
assert isinstance(x5, tvm.expr.FloatImm) and x5.value == 4
x6 = x5.astype('int')
assert isinstance(x6, tvm.expr.IntImm) and x6.value == 4
y = (tvm.round((tvm.const(6.5, 'float32') - 1) / 1.5) + 2).astype('int')
assert isinstance(y, tvm.expr.IntImm) and y.value == 6
if __name__ == "__main__":
test_modular()
test_simplify()
test_mul()
test_simplify_minmax()
test_const_propagation()
test_simplify_mod()
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
import numpy
from tvm import comm_reducer
from tvm.ir_pass import Simplify, CanonicalSimplify, Equal
def test_simplify():
"""Not yet working, mock design"""
dtype = 'int64'
n = tvm.var('n')
Ab = tvm.decl_buffer((n, ), dtype)
i = tvm.var('i')
j = tvm.var('j')
# for i in 0 to n-1:
stmt = tvm.make.For(
i, 2, n, 0, 0,
tvm.make.For(j, 0, n, 0, 0,
tvm.make.IfThenElse(
tvm.make.LT(i + 2, n),
tvm.make.Store(Ab.data,
tvm.make.Load(dtype, Ab.data, i + 4) + 1,
(j + 1) * 4 - 4 * j + i),
None)))
stmt = tvm.ir_pass.CanonicalSimplify(stmt)
def test_basic():
m = tvm.var('m')
ret = tvm.ir_pass.CanonicalSimplify(tvm.make.Evaluate(m-1))
assert str(ret.value) == "(m - 1)"
def test_bound():
m = tvm.var('m')
vrange = tvm.convert({m: tvm.Range(tvm.const(0, "int32"), tvm.const(10, "int32"))})
ret = tvm.ir_pass.Simplify(m % 10, vrange)
assert ret == m
if __name__ == "__main__":
test_bound()
test_basic()
test_simplify()
......@@ -83,7 +83,25 @@ def test_const_fold3():
assert tvm.any(x, true).same_as(true)
assert tvm.any(true, x).same_as(true)
def test_const_fold4():
x1 = tvm.const(4, "int32")
x2 = x1 + 5
assert isinstance(x2, tvm.expr.IntImm) and x2.value == 9
x3 = x2 / 3
assert isinstance(x3, tvm.expr.IntImm) and x3.value == 3
x4 = x3 + 0.55
assert isinstance(x4, tvm.expr.FloatImm) and abs(x4.value - 3.55) < 1e-6
x5 = tvm.ceil(x4)
assert isinstance(x5, tvm.expr.FloatImm) and x5.value == 4
x6 = x5.astype('int')
assert isinstance(x6, tvm.expr.IntImm) and x6.value == 4, "x6={}".format(x6)
y = (tvm.round((tvm.const(6.5, 'float32') - 1) / 1.5) + 2).astype('int')
assert isinstance(y, tvm.expr.IntImm) and y.value == 6
if __name__ == "__main__":
test_const_fold()
test_const_fold2()
test_const_fold3()
test_const_fold4()
......@@ -342,4 +342,4 @@ def cast(x, dtype):
if isinstance(x, tvm.tensor.Tensor):
return tvm.compute(
x.shape, lambda *i: x(*i).astype(dtype), tag=tag.ELEMWISE)
return tvm.make.static_cast(dtype, x)
return tvm.make._cast(dtype, x)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment