Unverified Commit d1e1ac49 by Tianqi Chen Committed by GitHub

[REFACTOR][PY] Establish tvm.arith (#4904)

parent 38d1dd24
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Integer bound analysis, simplification and pattern detection."""
from .int_set import IntSet, IntervalSet
from .analyzer import ModularSet, ConstIntBound, Analyzer
from .bound import deduce_bound
from .pattern import detect_linear_equation, detect_clip_bound
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""FFI APIs for tvm.arith"""
import tvm._ffi
tvm._ffi._init_api("arith", __name__)
...@@ -17,34 +17,7 @@ ...@@ -17,34 +17,7 @@
"""Arithmetic data structure and utility""" """Arithmetic data structure and utility"""
import tvm._ffi import tvm._ffi
from tvm.runtime import Object from tvm.runtime import Object
from . import _ffi_api
class IntSet(Object):
"""Represent a set of integer in one dimension."""
def is_nothing(self):
"""Whether the set represent nothing"""
return _IntSetIsNothing(self)
def is_everything(self):
"""Whether the set represent everything"""
return _IntSetIsEverything(self)
@tvm._ffi.register_object("arith.IntervalSet")
class IntervalSet(IntSet):
"""Represent set of continuous interval [min_value, max_value]
Parameters
----------
min_value : Expr
The minimum value in the interval.
max_value : Expr
The maximum value in the interval.
"""
def __init__(self, min_value, max_value):
self.__init_handle_by_constructor__(
_make_IntervalSet, min_value, max_value)
@tvm._ffi.register_object("arith.ModularSet") @tvm._ffi.register_object("arith.ModularSet")
...@@ -52,7 +25,7 @@ class ModularSet(Object): ...@@ -52,7 +25,7 @@ class ModularSet(Object):
"""Represent range of (coeff * x + base) for x in Z """ """Represent range of (coeff * x + base) for x in Z """
def __init__(self, coeff, base): def __init__(self, coeff, base):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_make_ModularSet, coeff, base) _ffi_api.ModularSet, coeff, base)
@tvm._ffi.register_object("arith.ConstIntBound") @tvm._ffi.register_object("arith.ConstIntBound")
...@@ -72,7 +45,7 @@ class ConstIntBound(Object): ...@@ -72,7 +45,7 @@ class ConstIntBound(Object):
def __init__(self, min_value, max_value): def __init__(self, min_value, max_value):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_make_ConstIntBound, min_value, max_value) _ffi_api.ConstIntBound, min_value, max_value)
class ConstraintScope: class ConstraintScope:
...@@ -105,11 +78,12 @@ class Analyzer: ...@@ -105,11 +78,12 @@ class Analyzer:
be used to perform various symbolic integer analysis. be used to perform various symbolic integer analysis.
""" """
def __init__(self): def __init__(self):
_mod = _CreateAnalyzer() _mod = _ffi_api.CreateAnalyzer()
self._const_int_bound = _mod("const_int_bound") self._const_int_bound = _mod("const_int_bound")
self._const_int_bound_update = _mod("const_int_bound_update") self._const_int_bound_update = _mod("const_int_bound_update")
self._bind = _mod("bind") self._bind = _mod("bind")
self._modular_set = _mod("modular_set") self._modular_set = _mod("modular_set")
self._simplify = _mod("Simplify")
self._rewrite_simplify = _mod("rewrite_simplify") self._rewrite_simplify = _mod("rewrite_simplify")
self._canonical_simplify = _mod("canonical_simplify") self._canonical_simplify = _mod("canonical_simplify")
self._int_set = _mod("int_set") self._int_set = _mod("int_set")
...@@ -120,7 +94,7 @@ class Analyzer: ...@@ -120,7 +94,7 @@ class Analyzer:
Parameters Parameters
---------- ----------
expr : tvm.Expr expr : PrimExpr
The expression. The expression.
Returns Returns
...@@ -135,7 +109,7 @@ class Analyzer: ...@@ -135,7 +109,7 @@ class Analyzer:
Parameters Parameters
---------- ----------
expr : tvm.Expr expr : PrimExpr
The expression. The expression.
Returns Returns
...@@ -145,12 +119,27 @@ class Analyzer: ...@@ -145,12 +119,27 @@ class Analyzer:
""" """
return self._modular_set(expr) return self._modular_set(expr)
def simplify(self, expr):
"""Simplify expression via both rewrite and canonicalization.
Parameters
----------
expr : PrimExpr
The expression.
Returns
-------
result : Expr
The result.
"""
return self._simplify(expr)
def rewrite_simplify(self, expr): def rewrite_simplify(self, expr):
"""Simplify expression via rewriting rules. """Simplify expression via rewriting rules.
Parameters Parameters
---------- ----------
expr : tvm.Expr expr : PrimExpr
The expression. The expression.
Returns Returns
...@@ -165,7 +154,7 @@ class Analyzer: ...@@ -165,7 +154,7 @@ class Analyzer:
Parameters Parameters
---------- ----------
expr : tvm.Expr expr : PrimExpr
The expression. The expression.
Returns Returns
...@@ -180,7 +169,7 @@ class Analyzer: ...@@ -180,7 +169,7 @@ class Analyzer:
Parameters Parameters
---------- ----------
expr : tvm.Expr expr : PrimExpr
The expression. The expression.
dom_map : Dict[Var, tvm.arith.IntSet] dom_map : Dict[Var, tvm.arith.IntSet]
...@@ -198,10 +187,10 @@ class Analyzer: ...@@ -198,10 +187,10 @@ class Analyzer:
Parameters Parameters
---------- ----------
var : tvm.Var var : tvm.tir.Var
The variable. The variable.
expr : tvm.Expr expr : PrimExpr
The expression. The expression.
""" """
return self._bind(var, expr) return self._bind(var, expr)
...@@ -211,7 +200,7 @@ class Analyzer: ...@@ -211,7 +200,7 @@ class Analyzer:
Parameters Parameters
---------- ----------
constraint : tvm.Expr constraint : PrimExpr
The constraint expression. The constraint expression.
returns returns
...@@ -240,7 +229,7 @@ class Analyzer: ...@@ -240,7 +229,7 @@ class Analyzer:
Parameters Parameters
---------- ----------
var : tvm.Var var : tvm.tir.Var
The variable. The variable.
info : tvm.Object info : tvm.Object
...@@ -254,6 +243,3 @@ class Analyzer: ...@@ -254,6 +243,3 @@ class Analyzer:
else: else:
raise TypeError( raise TypeError(
"Do not know how to handle type {}".format(type(info))) "Do not know how to handle type {}".format(type(info)))
tvm._ffi._init_api("tvm.arith")
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Bound deduction."""
from . import _ffi_api
def deduce_bound(var, cond, hint_map, relax_map):
"""Deduce the bound of the target variable in the cond.
Parameters
----------
var : Var
The target variable to be deduced.
cond : PrimExpr
The condition
hint_map : Map[Var, IntSet]
Domain of variables used to help deduction.
relax_map : Map[Var, IntSet]
The fomain of the variables to be relaxed
using the provided domain.
"""
return _ffi_api.DeduceBound(var, cond, hint_map, relax_map)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Integer set."""
import tvm._ffi
from tvm.runtime import Object
from . import _ffi_api
class IntSet(Object):
"""Represent a set of integer in one dimension."""
def is_nothing(self):
"""Whether the set represent nothing"""
return _ffi_api.IntSetIsNothing(self)
def is_everything(self):
"""Whether the set represent everything"""
return _ffi_api.IntSetIsEverything(self)
@staticmethod
def vector(vec):
"""Construct an integer set that covers the vector expr
Parameters
----------
vec : PrimExpr
The vector expression.
Returns
-------
rset : IntSet
The result set.
"""
return _ffi_api.intset_vector(vec)
@staticmethod
def single_point(point):
"""Construct a point set.
Parameters
----------
point : PrimExpr
The vector expression.
Returns
-------
rset : IntSet
The result set.
"""
return _ffi_api.intset_single_point(point)
@tvm._ffi.register_object("arith.IntervalSet")
class IntervalSet(IntSet):
"""Represent set of continuous interval [min_value, max_value]
Parameters
----------
min_value : PrimExpr
The minimum value in the interval.
max_value : PrimExpr
The maximum value in the interval.
"""
def __init__(self, min_value, max_value):
self.__init_handle_by_constructor__(
_ffi_api.IntervalSet, min_value, max_value)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Detect common patterns."""
from . import _ffi_api
def detect_linear_equation(expr, var_list):
"""Match `expr = sum_{i=0}^{n-1} var[i] * coeff[i] + coeff[n]`
Where coeff[i] and base are invariant of var[j] for all i and j.
Parameters
----------
expr : PrimExpr
The expression to be matched.
var_list : List[tvm.tir.Var]
A list of variables.
Returns
-------
coeff : List[PrimExpr]
A list of co-efficients if the match is successful.
An empty list if the match failed.
"""
return _ffi_api.DetectLinearEquation(expr, var_list)
def detect_clip_bound(expr, var_list):
""" Detect if expression corresponds to clip bound of the vars
Parameters
----------
expr : PrimExpr
The expression to be matched.
var_list : List[tvm.tir.Var]
A list of variables.
Returns
-------
coeff : List[PrimExpr]
`concat([min_value[i], max_value[i]] for i, v in enumerate(var_list))`
An empty list if the match failed.
"""
return _ffi_api.DetectClipBound(expr, var_list)
...@@ -64,33 +64,33 @@ TVM_REGISTER_GLOBAL("arith.DeduceBound") ...@@ -64,33 +64,33 @@ TVM_REGISTER_GLOBAL("arith.DeduceBound")
TVM_REGISTER_GLOBAL("arith.DomainTouched") TVM_REGISTER_GLOBAL("arith.DomainTouched")
.set_body_typed(DomainTouched); .set_body_typed(DomainTouched);
TVM_REGISTER_GLOBAL("arith._IntervalSetGetMin") TVM_REGISTER_GLOBAL("arith.IntervalSetGetMin")
.set_body_method(&IntSet::min); .set_body_method(&IntSet::min);
TVM_REGISTER_GLOBAL("arith._IntervalSetGetMax") TVM_REGISTER_GLOBAL("arith.IntervalSetGetMax")
.set_body_method(&IntSet::max); .set_body_method(&IntSet::max);
TVM_REGISTER_GLOBAL("arith._IntSetIsNothing") TVM_REGISTER_GLOBAL("arith.IntSetIsNothing")
.set_body_method(&IntSet::is_nothing); .set_body_method(&IntSet::is_nothing);
TVM_REGISTER_GLOBAL("arith._IntSetIsEverything") TVM_REGISTER_GLOBAL("arith.IntSetIsEverything")
.set_body_method(&IntSet::is_everything); .set_body_method(&IntSet::is_everything);
ConstIntBound MakeConstIntBound(int64_t min_value, int64_t max_value) { ConstIntBound MakeConstIntBound(int64_t min_value, int64_t max_value) {
return ConstIntBound(min_value, max_value); return ConstIntBound(min_value, max_value);
} }
TVM_REGISTER_GLOBAL("arith._make_ConstIntBound") TVM_REGISTER_GLOBAL("arith.ConstIntBound")
.set_body_typed(MakeConstIntBound); .set_body_typed(MakeConstIntBound);
ModularSet MakeModularSet(int64_t coeff, int64_t base) { ModularSet MakeModularSet(int64_t coeff, int64_t base) {
return ModularSet(coeff, base); return ModularSet(coeff, base);
} }
TVM_REGISTER_GLOBAL("arith._make_ModularSet") TVM_REGISTER_GLOBAL("arith.ModularSet")
.set_body_typed(MakeModularSet); .set_body_typed(MakeModularSet);
TVM_REGISTER_GLOBAL("arith._CreateAnalyzer") TVM_REGISTER_GLOBAL("arith.CreateAnalyzer")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
using runtime::PackedFunc; using runtime::PackedFunc;
using runtime::TypedPackedFunc; using runtime::TypedPackedFunc;
...@@ -108,6 +108,10 @@ TVM_REGISTER_GLOBAL("arith._CreateAnalyzer") ...@@ -108,6 +108,10 @@ TVM_REGISTER_GLOBAL("arith._CreateAnalyzer")
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
self->const_int_bound.Update(args[0], args[1], args[2]); self->const_int_bound.Update(args[0], args[1], args[2]);
}); });
} else if (name == "Simplify") {
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
*ret = self->Simplify(args[0]);
});
} else if (name == "rewrite_simplify") { } else if (name == "rewrite_simplify") {
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
*ret = self->rewrite_simplify(args[0]); *ret = self->rewrite_simplify(args[0]);
......
...@@ -54,7 +54,7 @@ IntervalSet MakeIntervalSet(PrimExpr min_value, PrimExpr max_value) { ...@@ -54,7 +54,7 @@ IntervalSet MakeIntervalSet(PrimExpr min_value, PrimExpr max_value) {
return IntervalSet(min_value, max_value); return IntervalSet(min_value, max_value);
} }
TVM_REGISTER_GLOBAL("arith._make_IntervalSet") TVM_REGISTER_GLOBAL("arith.IntervalSet")
.set_body_typed(MakeIntervalSet); .set_body_typed(MakeIntervalSet);
......
...@@ -38,90 +38,90 @@ def test_deduce(): ...@@ -38,90 +38,90 @@ def test_deduce():
fdiv = tvm.floordiv fdiv = tvm.floordiv
e0 = (-b)*a+c-d e0 = (-b)*a+c-d
res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {}) res0 = tvm.arith.deduce_bound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {})
ans0 = fdiv(d - c, b*-1) ans0 = fdiv(d - c, b*-1)
assert_expr_equal(res0.max_value, ans0) assert_expr_equal(res0.max_value, ans0)
# expression containing variable a is on rhs # expression containing variable a is on rhs
res0 = tvm.arith.DeduceBound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {}) res0 = tvm.arith.deduce_bound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {})
assert_expr_equal(res0.max_value, ans0) assert_expr_equal(res0.max_value, ans0)
e0 = d*a+c-d e0 = d*a+c-d
res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {}) res0 = tvm.arith.deduce_bound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {})
ans0 = fdiv(d-c, d) ans0 = fdiv(d-c, d)
assert_expr_equal(res0.max_value, ans0) assert_expr_equal(res0.max_value, ans0)
# expression containing variable a is on rhs # expression containing variable a is on rhs
res0 = tvm.arith.DeduceBound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {}) res0 = tvm.arith.deduce_bound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {})
assert_expr_equal(res0.max_value, ans0) assert_expr_equal(res0.max_value, ans0)
e1 = (a*4+b < c) e1 = (a*4+b < c)
res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {}) res1 = tvm.arith.deduce_bound(a, e1, {b: b_s, c: c_s, d: d_s}, {})
ans1 = fdiv(c-1-b, 4) ans1 = fdiv(c-1-b, 4)
assert_expr_equal(res1.max_value, ans1) assert_expr_equal(res1.max_value, ans1)
# expression containing variable a is on rhs # expression containing variable a is on rhs
e1 = (c > a*4+b) e1 = (c > a*4+b)
res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {}) res1 = tvm.arith.deduce_bound(a, e1, {b: b_s, c: c_s, d: d_s}, {})
assert_expr_equal(res1.max_value, ans1) assert_expr_equal(res1.max_value, ans1)
e2 = (tvm.max(5, a * 4) < 0) e2 = (tvm.max(5, a * 4) < 0)
res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s}, {}) res2 = tvm.arith.deduce_bound(a, e2, {b: b_s, c: c_s, d: d_s}, {})
assert str(res2.max_value) == "neg_inf" assert str(res2.max_value) == "neg_inf"
assert str(res2.min_value) == "pos_inf" assert str(res2.min_value) == "pos_inf"
# expression containing variable a is on rhs # expression containing variable a is on rhs
e2 = (zero < tvm.max(5, a * 4)) e2 = (zero < tvm.max(5, a * 4))
res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s}, {}) res2 = tvm.arith.deduce_bound(a, e2, {b: b_s, c: c_s, d: d_s}, {})
assert str(res2.max_value) == "neg_inf" assert str(res2.max_value) == "neg_inf"
assert str(res2.min_value) == "pos_inf" assert str(res2.min_value) == "pos_inf"
e3 = (-b)+a*c-d e3 = (-b)+a*c-d
res3 = tvm.arith.DeduceBound(a, e3>=0, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s}) res3 = tvm.arith.deduce_bound(a, e3>=0, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s})
ans3 = fdiv(2,c)+1 ans3 = fdiv(2,c)+1
assert str(tvm.ir_pass.Simplify(res3.min_value)) == str(ans3) assert str(tvm.ir_pass.Simplify(res3.min_value)) == str(ans3)
res3 = tvm.arith.DeduceBound(a, zero <= e3, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s}) res3 = tvm.arith.deduce_bound(a, zero <= e3, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s})
assert str(tvm.ir_pass.Simplify(res3.min_value)) == str(ans3) assert str(tvm.ir_pass.Simplify(res3.min_value)) == str(ans3)
# tests for `EQ` op # tests for `EQ` op
res4 = tvm.arith.DeduceBound(a, a == b, {}, {}) res4 = tvm.arith.deduce_bound(a, a == b, {}, {})
assert_expr_equal(res4.max_value, b) assert_expr_equal(res4.max_value, b)
assert_expr_equal(res4.min_value, b) assert_expr_equal(res4.min_value, b)
# Unsatisfiable `EQ`, variable as one of the Operand # Unsatisfiable `EQ`, variable as one of the Operand
res5 = tvm.arith.DeduceBound(a, (a == b), {b: b_s}, {b: b_s}) res5 = tvm.arith.deduce_bound(a, (a == b), {b: b_s}, {b: b_s})
assert str(res5.max_value) == "neg_inf" assert str(res5.max_value) == "neg_inf"
assert str(res5.min_value) == "pos_inf" assert str(res5.min_value) == "pos_inf"
# variable `a` on the RHS side # variable `a` on the RHS side
res6 = tvm.arith.DeduceBound(a, 10 == a, {}, {}) res6 = tvm.arith.deduce_bound(a, 10 == a, {}, {})
assert_expr_equal(res6.max_value, 10) assert_expr_equal(res6.max_value, 10)
assert_expr_equal(res6.min_value, 10) assert_expr_equal(res6.min_value, 10)
# Add, Sub in `EQ` # Add, Sub in `EQ`
e4 = ((a - c) == (b + d)) e4 = ((a - c) == (b + d))
ans4 = (b + d + c) ans4 = (b + d + c)
res7 = tvm.arith.DeduceBound(a, e4, {b: b_s, c: c_s, d: d_s}, {}) res7 = tvm.arith.deduce_bound(a, e4, {b: b_s, c: c_s, d: d_s}, {})
assert_expr_equal(res7.max_value, ans4) assert_expr_equal(res7.max_value, ans4)
assert_expr_equal(res7.min_value, ans4) assert_expr_equal(res7.min_value, ans4)
# Satisfiable Mul in `EQ` with negative sign # Satisfiable Mul in `EQ` with negative sign
res8 = tvm.arith.DeduceBound(a, (5 * a == -10), {}, {}) res8 = tvm.arith.deduce_bound(a, (5 * a == -10), {}, {})
assert_expr_equal(res8.max_value, -2) assert_expr_equal(res8.max_value, -2)
assert_expr_equal(res8.min_value, -2) assert_expr_equal(res8.min_value, -2)
# Unsatisfiable Mul in `EQ` # Unsatisfiable Mul in `EQ`
e5 = (4 * a == b) e5 = (4 * a == b)
res9 = tvm.arith.DeduceBound(a, e5, {b: b_s}, {}) res9 = tvm.arith.deduce_bound(a, e5, {b: b_s}, {})
assert str(res9.max_value) == "neg_inf" assert str(res9.max_value) == "neg_inf"
assert str(res9.min_value) == "pos_inf" assert str(res9.min_value) == "pos_inf"
# Unsatisfiable Mul in `EQ` # Unsatisfiable Mul in `EQ`
res10 = tvm.arith.DeduceBound(a, (b * a == b), {b: b_s}, {}) # simplifier is not able to prove that (b % b == 0) res10 = tvm.arith.deduce_bound(a, (b * a == b), {b: b_s}, {}) # simplifier is not able to prove that (b % b == 0)
assert str(res10.max_value) == "neg_inf" assert str(res10.max_value) == "neg_inf"
assert str(res10.min_value) == "pos_inf" assert str(res10.min_value) == "pos_inf"
...@@ -137,15 +137,15 @@ def test_check(): ...@@ -137,15 +137,15 @@ def test_check():
d_s = tvm.arith.IntervalSet(-3, -1) d_s = tvm.arith.IntervalSet(-3, -1)
# no compare operator # no compare operator
res1 = tvm.arith.DeduceBound(a, a+b, {b: b_s}, {}) res1 = tvm.arith.deduce_bound(a, a+b, {b: b_s}, {})
assert res1.is_nothing() assert res1.is_nothing()
# multiple compare operators # multiple compare operators
res2 = tvm.arith.DeduceBound(a, (a+b>3).astype(c.dtype)>c , {b: b_s, c: c_s}, {}) res2 = tvm.arith.deduce_bound(a, (a+b>3).astype(c.dtype)>c , {b: b_s, c: c_s}, {})
assert res2.is_nothing() assert res2.is_nothing()
# multiple target variable # multiple target variable
res2 = tvm.arith.DeduceBound(a, a*2-a>b, {b: b_s}, {}) res2 = tvm.arith.deduce_bound(a, a*2-a>b, {b: b_s}, {})
assert res2.is_nothing() assert res2.is_nothing()
def test_deduce_basic(): def test_deduce_basic():
...@@ -155,21 +155,21 @@ def test_deduce_basic(): ...@@ -155,21 +155,21 @@ def test_deduce_basic():
b_s = tvm.arith.IntervalSet(a1, a2) b_s = tvm.arith.IntervalSet(a1, a2)
e0 = b + a*coff + 3 e0 = b + a*coff + 3
res1 = tvm.arith.DeduceBound(a, e0<17, {b: b_s}, {b: b_s}) res1 = tvm.arith.deduce_bound(a, e0<17, {b: b_s}, {b: b_s})
[x, y] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value] [x, y] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value]
assert (tvm.ir_pass.Simplify((x * coff + 3 + y) < 17)).value == 1 assert (tvm.ir_pass.Simplify((x * coff + 3 + y) < 17)).value == 1
# expression containing variable a is on rhs # expression containing variable a is on rhs
res1 = tvm.arith.DeduceBound(a, tvm.const(17, "int32") < e0, {b: b_s}, {b: b_s}) res1 = tvm.arith.deduce_bound(a, tvm.const(17, "int32") < e0, {b: b_s}, {b: b_s})
[x, y] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value] [x, y] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value]
assert (tvm.ir_pass.Simplify((x * coff + 3 + y) > 17)).value == 1 assert (tvm.ir_pass.Simplify((x * coff + 3 + y) > 17)).value == 1
# expression containing variable a is on rhs # expression containing variable a is on rhs
res1 = tvm.arith.DeduceBound(a, tvm.const(17, "int32")>= e0, {b: b_s}, {b: b_s}) res1 = tvm.arith.deduce_bound(a, tvm.const(17, "int32")>= e0, {b: b_s}, {b: b_s})
[x, y] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value] [x, y] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value]
assert (tvm.ir_pass.Simplify((x * coff + 3 + y) <= 17)).value == 1 assert (tvm.ir_pass.Simplify((x * coff + 3 + y) <= 17)).value == 1
res1 = tvm.arith.DeduceBound(a, e0>=17, {b: b_s}, {b: b_s}) res1 = tvm.arith.deduce_bound(a, e0>=17, {b: b_s}, {b: b_s})
[x, y] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value] [x, y] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value]
assert (tvm.ir_pass.Simplify((x * coff + 3 + y) >= 17)).value == 1 assert (tvm.ir_pass.Simplify((x * coff + 3 + y) >= 17)).value == 1
...@@ -187,21 +187,21 @@ def test_deduce_complex(): ...@@ -187,21 +187,21 @@ def test_deduce_complex():
b_s = tvm.arith.IntervalSet(a1, a2) b_s = tvm.arith.IntervalSet(a1, a2)
e0 = (b*3 + a* coff) * 4 e0 = (b*3 + a* coff) * 4
res1 = tvm.arith.DeduceBound(a, e0<63, {b: b_s}, {b: b_s}) res1 = tvm.arith.deduce_bound(a, e0<63, {b: b_s}, {b: b_s})
[t, x] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value] [t, x] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value]
assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) < 63)).value == 1 assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) < 63)).value == 1
# expression containing variable a is on rhs # expression containing variable a is on rhs
res1 = tvm.arith.DeduceBound(a, tvm.const(63, "int32")>= e0, {b: b_s}, {b: b_s}) res1 = tvm.arith.deduce_bound(a, tvm.const(63, "int32")>= e0, {b: b_s}, {b: b_s})
[t, x] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value] [t, x] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value]
assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) <= 63)).value == 1 assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) <= 63)).value == 1
res1 = tvm.arith.DeduceBound(a, e0>63, {b: b_s}, {b: b_s}) res1 = tvm.arith.deduce_bound(a, e0>63, {b: b_s}, {b: b_s})
[t, x] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value] [t, x] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value]
assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) > 63)).value == 1 assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) > 63)).value == 1
# expression containing variable a is on rhs # expression containing variable a is on rhs
res1 = tvm.arith.DeduceBound(a, tvm.const(63, "int32") <= e0, {b: b_s}, {b: b_s}) res1 = tvm.arith.deduce_bound(a, tvm.const(63, "int32") <= e0, {b: b_s}, {b: b_s})
[t, x] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value] [t, x] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value]
assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) >= 63)).value == 1 assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) >= 63)).value == 1
......
...@@ -20,14 +20,14 @@ def test_basic(): ...@@ -20,14 +20,14 @@ def test_basic():
a = tvm.var("a") a = tvm.var("a")
b = tvm.var("b") b = tvm.var("b")
c = tvm.var("c") c = tvm.var("c")
m = tvm.arith.DetectClipBound(tvm.all(a * 1 < b * 6, m = tvm.arith.detect_clip_bound(tvm.all(a * 1 < b * 6,
a - 1 > 0), [a]) a - 1 > 0), [a])
assert tvm.ir_pass.Simplify(m[1] - (b * 6 - 1)).value == 0 assert tvm.ir_pass.Simplify(m[1] - (b * 6 - 1)).value == 0
assert m[0].value == 2 assert m[0].value == 2
m = tvm.arith.DetectClipBound(tvm.all(a * 1 < b * 6, m = tvm.arith.detect_clip_bound(tvm.all(a * 1 < b * 6,
a - 1 > 0), [a, b]) a - 1 > 0), [a, b])
assert len(m) == 0 assert len(m) == 0
m = tvm.arith.DetectClipBound(tvm.all(a + 10 * c <= 20, m = tvm.arith.detect_clip_bound(tvm.all(a + 10 * c <= 20,
b - 1 > 0), [a, b]) b - 1 > 0), [a, b])
assert tvm.ir_pass.Simplify(m[1] - (20 - 10 * c)).value == 0 assert tvm.ir_pass.Simplify(m[1] - (20 - 10 * c)).value == 0
assert tvm.ir_pass.Simplify(m[2] - 2).value == 0 assert tvm.ir_pass.Simplify(m[2] - 2).value == 0
......
...@@ -19,50 +19,50 @@ import tvm ...@@ -19,50 +19,50 @@ import tvm
def test_basic(): def test_basic():
a = tvm.var("a") a = tvm.var("a")
b = tvm.var("b") b = tvm.var("b")
m = tvm.arith.DetectLinearEquation(a * 4 + b * 6 + 7, [a]) m = tvm.arith.detect_linear_equation(a * 4 + b * 6 + 7, [a])
assert m[0].value == 4 assert m[0].value == 4
assert tvm.ir_pass.Simplify(m[1] - (b * 6 + 7)).value == 0 assert tvm.ir_pass.Simplify(m[1] - (b * 6 + 7)).value == 0
m = tvm.arith.DetectLinearEquation(a * 4 * (a+1) + b * 6 + 7, [a]) m = tvm.arith.detect_linear_equation(a * 4 * (a+1) + b * 6 + 7, [a])
assert len(m) == 0 assert len(m) == 0
m = tvm.arith.DetectLinearEquation(a * 4 + (a+1) + b * 6 + 7, [a]) m = tvm.arith.detect_linear_equation(a * 4 + (a+1) + b * 6 + 7, [a])
assert m[0].value == 5 assert m[0].value == 5
assert tvm.ir_pass.Simplify(m[1] - (b * 6 + 7 + 1)).value == 0 assert tvm.ir_pass.Simplify(m[1] - (b * 6 + 7 + 1)).value == 0
m = tvm.arith.DetectLinearEquation(a * b + 7, [a]) m = tvm.arith.detect_linear_equation(a * b + 7, [a])
assert m[0] == b assert m[0] == b
m = tvm.arith.DetectLinearEquation(b * 7, [a]) m = tvm.arith.detect_linear_equation(b * 7, [a])
assert m[0].value == 0 assert m[0].value == 0
m = tvm.arith.DetectLinearEquation(b * 7, []) m = tvm.arith.detect_linear_equation(b * 7, [])
assert len(m) == 1 assert len(m) == 1
assert tvm.ir_pass.Simplify(m[0] - b * 7).value == 0 assert tvm.ir_pass.Simplify(m[0] - b * 7).value == 0
def test_multivariate(): def test_multivariate():
v = [tvm.var("v%d" % i) for i in range(4)] v = [tvm.var("v%d" % i) for i in range(4)]
b = tvm.var("b") b = tvm.var("b")
m = tvm.arith.DetectLinearEquation(v[0] * (b + 4) + v[0] + v[1] * 8, v) m = tvm.arith.detect_linear_equation(v[0] * (b + 4) + v[0] + v[1] * 8, v)
assert(tvm.ir_pass.Equal(tvm.ir_pass.Simplify(m[0]), b + 5)) assert(tvm.ir_pass.Equal(tvm.ir_pass.Simplify(m[0]), b + 5))
assert(m[1].value == 8) assert(m[1].value == 8)
m = tvm.arith.DetectLinearEquation(v[0] * (b + 4) + v[0] + v[1] * 8 * v[2], v) m = tvm.arith.detect_linear_equation(v[0] * (b + 4) + v[0] + v[1] * 8 * v[2], v)
assert(len(m) == 0) assert(len(m) == 0)
m = tvm.arith.DetectLinearEquation(v[0] * (b + 4) + v[0] + v[1] * 8 * v[1] + v[3], v) m = tvm.arith.detect_linear_equation(v[0] * (b + 4) + v[0] + v[1] * 8 * v[1] + v[3], v)
assert(len(m) == 0) assert(len(m) == 0)
m = tvm.arith.DetectLinearEquation(((v[0] * b + v[1]) * 8 + v[2] + 1) * 2, v) m = tvm.arith.detect_linear_equation(((v[0] * b + v[1]) * 8 + v[2] + 1) * 2, v)
assert(m[1].value == 16) assert(m[1].value == 16)
assert(m[2].value == 2) assert(m[2].value == 2)
assert(m[len(m)-1].value == 2) assert(m[len(m)-1].value == 2)
m = tvm.arith.DetectLinearEquation((v[0] - v[1]), [v[2]]) m = tvm.arith.detect_linear_equation((v[0] - v[1]), [v[2]])
assert(m[0].value == 0) assert(m[0].value == 0)
assert(tvm.ir_pass.Simplify(m[1] - (v[0] - v[1])).value == 0) assert(tvm.ir_pass.Simplify(m[1] - (v[0] - v[1])).value == 0)
m = tvm.arith.DetectLinearEquation((v[0] - v[1]), []) m = tvm.arith.detect_linear_equation((v[0] - v[1]), [])
assert(len(m) == 1) assert(len(m) == 1)
assert(tvm.ir_pass.Simplify(m[0] - (v[0] - v[1])).value == 0) assert(tvm.ir_pass.Simplify(m[0] - (v[0] - v[1])).value == 0)
......
...@@ -35,19 +35,19 @@ def test_domain_touched(): ...@@ -35,19 +35,19 @@ def test_domain_touched():
) )
) )
) )
a_domain_r = tvm.arith.DomainTouched(ir, a, True, False) a_domain_r = tvm.arith._ffi_api.DomainTouched(ir, a, True, False)
assert a_domain_r[0].min.value == -1 assert a_domain_r[0].min.value == -1
assert a_domain_r[0].extent.value == 100 assert a_domain_r[0].extent.value == 100
assert a_domain_r[1].min.value == -1 assert a_domain_r[1].min.value == -1
assert a_domain_r[1].extent.name == 'm' assert a_domain_r[1].extent.name == 'm'
a_domain_w = tvm.arith.DomainTouched(ir, a, False, True) a_domain_w = tvm.arith._ffi_api.DomainTouched(ir, a, False, True)
assert a_domain_w[0].min.value == 0 assert a_domain_w[0].min.value == 0
assert a_domain_w[0].extent.value == 100 assert a_domain_w[0].extent.value == 100
assert a_domain_w[1].min.value == 0 assert a_domain_w[1].min.value == 0
assert a_domain_w[1].extent.name == 'm' assert a_domain_w[1].extent.name == 'm'
a_domain_rw= tvm.arith.DomainTouched(ir, a, True, True) a_domain_rw= tvm.arith._ffi_api.DomainTouched(ir, a, True, True)
assert a_domain_rw[0].min.value == -1 assert a_domain_rw[0].min.value == -1
assert a_domain_rw[0].extent.value == 101 assert a_domain_rw[0].extent.value == 101
assert a_domain_rw[1].min.value == -1 assert a_domain_rw[1].min.value == -1
...@@ -55,17 +55,16 @@ def test_domain_touched(): ...@@ -55,17 +55,16 @@ def test_domain_touched():
assert a_domain_rw[1].extent.a.name == 'm' assert a_domain_rw[1].extent.a.name == 'm'
assert a_domain_rw[1].extent.b.value == 1 assert a_domain_rw[1].extent.b.value == 1
b_domain_r = tvm.arith.DomainTouched(ir, b, True, False) b_domain_r = tvm.arith._ffi_api.DomainTouched(ir, b, True, False)
assert b_domain_r assert b_domain_r
assert b_domain_r[0].min.value == -1 assert b_domain_r[0].min.value == -1
assert b_domain_r[0].extent.value == 100 assert b_domain_r[0].extent.value == 100
assert b_domain_r[1].min.value == 1 assert b_domain_r[1].min.value == 1
assert b_domain_r[1].extent.name == 'm' assert b_domain_r[1].extent.name == 'm'
b_domain_w = tvm.arith.DomainTouched(ir, b, False, True) b_domain_w = tvm.arith._ffi_api.DomainTouched(ir, b, False, True)
assert isinstance(b_domain_w, tvm.container.Array) assert isinstance(b_domain_w, tvm.container.Array)
assert len(b_domain_w) == 0 assert len(b_domain_w) == 0
if __name__ == "__main__": if __name__ == "__main__":
test_domain_touched() test_domain_touched()
...@@ -36,12 +36,16 @@ def test_basic(): ...@@ -36,12 +36,16 @@ def test_basic():
assert s.min_value.value == 2 assert s.min_value.value == 2
assert s.max_value.value == 3 assert s.max_value.value == 3
s = tvm.arith.IntSet.single_point(2)
assert s.min_value.value == 2
assert s.max_value.value == 2
def test_vector(): def test_vector():
base = 10 base = 10
stride = 3 stride = 3
lanes = 2 lanes = 2
s = tvm.arith.intset_vector(tvm.tir.Ramp(base, stride, lanes)) s = tvm.arith.IntSet.vector(tvm.tir.Ramp(base, stride, lanes))
assert s.min_value.value == base assert s.min_value.value == base
assert s.max_value.value == base + stride * lanes - 1 assert s.max_value.value == base + stride * lanes - 1
......
...@@ -76,7 +76,7 @@ def fold_uop_loop(stmt_in): ...@@ -76,7 +76,7 @@ def fold_uop_loop(stmt_in):
args = [] args = []
args += op.args[:base_args] args += op.args[:base_args]
for i in range(3): for i in range(3):
m = tvm.arith.DetectLinearEquation( m = tvm.arith.detect_linear_equation(
op.args[i + base_args], [loop_var]) op.args[i + base_args], [loop_var])
if not m: if not m:
fail[0] = True fail[0] = True
...@@ -867,25 +867,25 @@ def inject_alu_intrin(stmt_in): ...@@ -867,25 +867,25 @@ def inject_alu_intrin(stmt_in):
type(loop_body.value), str(loop_body.value), str(stmt))) type(loop_body.value), str(loop_body.value), str(stmt)))
# Derive array index coefficients # Derive array index coefficients
dst_coeff = tvm.arith.DetectLinearEquation(dst_idx, indices) dst_coeff = tvm.arith.detect_linear_equation(dst_idx, indices)
# Check if lhs/rhs is immediate # Check if lhs/rhs is immediate
use_imm = False use_imm = False
imm_val = None imm_val = None
if isinstance(rhs, tvm.tir.IntImm): if isinstance(rhs, tvm.tir.IntImm):
assert lhs.buffer_var.same_as(dst_var) assert lhs.buffer_var.same_as(dst_var)
src_coeff = tvm.arith.DetectLinearEquation(lhs.index, indices) src_coeff = tvm.arith.detect_linear_equation(lhs.index, indices)
use_imm = True use_imm = True
imm_val = rhs imm_val = rhs
if isinstance(lhs, tvm.tir.IntImm): if isinstance(lhs, tvm.tir.IntImm):
assert rhs.buffer_var.same_as(dst_var) assert rhs.buffer_var.same_as(dst_var)
src_coeff = tvm.arith.DetectLinearEquation(rhs.index, indices) src_coeff = tvm.arith.detect_linear_equation(rhs.index, indices)
use_imm = True use_imm = True
imm_val = lhs imm_val = lhs
if imm_val is None: if imm_val is None:
imm_val = 0 imm_val = 0
assert lhs.buffer_var.same_as(dst_var) and rhs.buffer_var.same_as(dst_var) assert lhs.buffer_var.same_as(dst_var) and rhs.buffer_var.same_as(dst_var)
src_lhs_coeff = tvm.arith.DetectLinearEquation(lhs.index, indices) src_lhs_coeff = tvm.arith.detect_linear_equation(lhs.index, indices)
src_rhs_coeff = tvm.arith.DetectLinearEquation(rhs.index, indices) src_rhs_coeff = tvm.arith.detect_linear_equation(rhs.index, indices)
# Determine which side has the same coefficients # Determine which side has the same coefficients
lhs_equal = True lhs_equal = True
rhs_equal = True rhs_equal = True
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment