Unverified Commit e21f2682 by Yizhi Liu Committed by GitHub

[Arith] linear system and equation solver (#5171)

* [arith] linear system and equation solver

Co-authored-by: Sergei Grechanik <sergei.grechanik+h@gmail.com>

* avoid constructing analyzer every time

* generate random test cases and address comments

Co-authored-by: Sergei Grechanik <sergei.grechanik@gmail.com>

* rename linear_system to int_constraints

* add comments and use random seed

* message for reporting failure with seed

* add SEqualReduce to IntConstraints; allow variables & ranges to be None

Co-authored-by: Sergei Grechanik <sergei.grechanik+h@gmail.com>
Co-authored-by: Sergei Grechanik <sergei.grechanik@gmail.com>
parent b236565e
......@@ -424,6 +424,12 @@ class Analyzer {
*/
void Bind(const Var& var, const Range& range);
/*!
* \brief Bind all the vars in the Map
*
* \param variables The {variable -> range} map.
*/
void Bind(const Map<Var, Range>& variables);
/*!
* \brief Whether can we prove expr >= val.
* Non-negative proof is very useful in integer analysis
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/arith/int_solver.h
* \brief integer constraints data structures and solvers
*/
#ifndef TVM_ARITH_INT_SOLVER_H_
#define TVM_ARITH_INT_SOLVER_H_
#include <tvm/ir/expr.h>
#include <tvm/tir/expr.h>
#include <unordered_map>
#include <vector>
namespace tvm {
namespace arith {
using tir::Var;
using tir::VarNode;
using tir::IterVar;
/*!
* \brief Represent integer constrains including (integer) variables, their ranges and
* the relations between them (either equations or inequalities).
* \sa LinearSystem
*/
class IntConstraintsNode : public Object {
public:
// e.g., \alpha, \beta, must be integers
Array<Var> variables;
// e.g., 1 <= \alpha <= N, etc.
// it is absolutely ok to include ranges for parameters
// (variables that are not in this->variables) in this map
Map<Var, Range> ranges;
// linear equalities or inequalities
// e.g., A \alpha = \beta or A \alpha <= \beta
Array<PrimExpr> relations;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("variables", &variables);
v->Visit("ranges", &ranges);
v->Visit("relations", &relations);
}
bool SEqualReduce(const IntConstraintsNode* other, SEqualReducer equal) const {
return
equal(variables, other->variables) &&
equal(ranges, other->ranges) &&
equal(relations, other->relations);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(variables);
hash_reduce(ranges);
hash_reduce(relations);
}
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const char* _type_key = "arith.IntConstraints";
TVM_DECLARE_FINAL_OBJECT_INFO(IntConstraintsNode, Object);
};
/*!
* \brief Managed reference to IntConstraintsNode.
* \sa IntConstraintsNode
*/
class IntConstraints : public ObjectRef {
public:
/*!
* \brief Constructor by fields
* \param variables The variables in the constraints, must be integers.
* \param ranges The ranges of the variables.
* \param relations The linear relations between the variables
* (either equations or inequalities)
*/
TVM_DLL IntConstraints(Array<Var> variables,
Map<Var, Range> ranges,
Array<PrimExpr> relations);
TVM_DEFINE_OBJECT_REF_METHODS(IntConstraints, ObjectRef, IntConstraintsNode);
};
/*!
* \brief We can have different set of variables to represent the same constraints.
* For example, the following two systems are equivalent,
* {a + b = 0 | a >= 0, b >= 0} and
* {m - n = 0 | m >= 0, n <= 0}
* This data structure represents the transformation
* between two equivalent linear systems.
* In the above example,
* src : {a + b = 0 | a >= 0, b >= 0}
* dst : {m - n = 0 | m >= 0, n <= 0}
* src_to_dst : {a -> m, b -> -n}
* dst_to_src : {m -> a, n -> -b}
* \sa IntConstraintsTransform
*/
class IntConstraintsTransformNode : public Object {
public:
IntConstraints src;
IntConstraints dst;
Map<Var, PrimExpr> src_to_dst;
Map<Var, PrimExpr> dst_to_src;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("src", &src);
v->Visit("dst", &dst);
v->Visit("src_to_dst", &src_to_dst);
v->Visit("dst_to_src", &dst_to_src);
}
bool SEqualReduce(const IntConstraintsTransformNode* other, SEqualReducer equal) const {
return
equal(src, other->src) &&
equal(dst, other->dst) &&
equal(src_to_dst, other->src_to_dst) &&
equal(dst_to_src, other->dst_to_src);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(src);
hash_reduce(dst);
hash_reduce(src_to_dst);
hash_reduce(dst_to_src);
}
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const char* _type_key = "arith.IntConstraintsTransform";
TVM_DECLARE_FINAL_OBJECT_INFO(IntConstraintsTransformNode, Object);
};
/*!
* \brief Managed reference to IntConstraintsTransformNode.
* \sa IntConstraintsTransformNode
*/
class IntConstraintsTransform : public ObjectRef {
public:
/*!
* \brief Constructor by fields
* \param src source integer constraints, e.g., {a + b = 0 | a >= 0, b >= 0}
* \param dst integer constraints equivalent to the source,
* e.g., {m - n = 0 | m >= 0, n <= 0}
* \param src_to_dst mapping from variables in the \p src to the variables in the \p dst,
* e.g., {a -> m, b -> -n}
* \param dst_to_src mapping from variables in the \p dst to the variables in the \p src,
* e.g., {m -> a, n -> -b}
*/
TVM_DLL IntConstraintsTransform(IntConstraints src,
IntConstraints dst,
Map<Var, PrimExpr> src_to_dst,
Map<Var, PrimExpr> dst_to_src);
TVM_DEFINE_OBJECT_REF_METHODS(IntConstraintsTransform, ObjectRef, IntConstraintsTransformNode);
};
/*!
* \brief Obtain Smith Normal Form of linear equation A x = y.
* Smith Normal Form of matrix A_{mxn} is S_{mxn} = U_{mxm} A_{mxn} V_{nxn},
* in which S_{mxn} is diag(s1, s2, ..., sr, 0, ..., 0) and r is the rank of A.
* NOTE: Although in standard Smith Normal Form the diagonal elements satisfy
* s_i | s_{i+1} (| means divides), the implement here does not guarantee it.
* TODO(yzhliu): From sergei-grechanik:
* computing the proper Smith normal form may improve stability of automatic differentiation
* (generating the same gradient code for slightly different but equivalent input code
* U_{mxm} and V_{nxn} are invertible matrices.
* This function modifies \p S to be S_{mxn}, \p V to be V_{nxn},
* \p y to be U_{mxm} y_{mx1} and \p x to be V^{-1} x.
* \param S the original A_{mxn}, it will be modified to S_{mxn}
* \param V an identity matrix, it will be modified to V_{nxn}
* \param x the x in A x = y. it will be modified to V^{-1}_{nxn} x_{nx1}
* \param y the y in A x = y. it will be modified to U_{mxm} y_{mx1}
*/
void SmithNormalFormDiag(std::vector<std::vector<int64_t>> *S,
std::vector<std::vector<int64_t>> *V,
std::vector<PrimExpr>* x,
std::vector<PrimExpr> *y);
/*!
* \brief Solve linear equations.
* \param system_to_solve the variables to solve, their ranges, and a list of equations.
* \return A new linear system, with less variables (if \p system_to_solve is NOT of full rank),
* or no variable (if \p system_to_solve is of full rank),
* or an empty linear system (if \p system_to_solve is unsolvable).
* It also provides the ranges of the variables in the new system,
* as well as inequalities inferred from the \p system_to_solve.
* You can get the mapping from the original variables to the solution via ret->src_to_dst.
*/
IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_solve);
} // namespace arith
} // namespace tvm
#endif // TVM_ARITH_INT_SOLVER_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/arith/util.h
* \brief Utils for arithmetic analysis.
*/
#ifndef TVM_ARITH_UTIL_H_
#define TVM_ARITH_UTIL_H_
#include <cstdint>
#include <tuple>
namespace tvm {
/*! \brief namespace of arithmetic analysis. */
namespace arith {
/*!
* \brief Calculate the extended greatest common divisor for two values.
* See https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm.
* \param a an integer number
* \param b an integer number
* \return 3 integers (div, m, n) where div = gcd(a, b) and a*m + b*n = div
*/
std::tuple<int64_t, int64_t, int64_t> xgcd(int64_t a, int64_t b);
} // namespace arith
} // namespace tvm
#endif // TVM_ARITH_UTIL_H_
......@@ -20,3 +20,4 @@ from .int_set import IntSet, IntervalSet
from .analyzer import ModularSet, ConstIntBound, Analyzer
from .bound import deduce_bound
from .pattern import detect_linear_equation, detect_clip_bound
from .int_solver import solve_linear_equations
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""integer constraints data structures and solvers"""
import tvm._ffi
from tvm.runtime import Object
from . import _ffi_api
@tvm._ffi.register_object("arith.IntConstraints")
class IntConstraints(Object):
"""Represent a set of integer constraints including variables, their ranges and
the relations between them (either equations or inequalities)
Parameters
----------
variables : List[tvm.tir.Var]
The variables in the constraints. Must be integers
ranges : Map[tvm.tir.Var, tvm.ir.Range]
The ranges of the variables.
relations : List[tvm.ir.PrimExpr]
The relations between the variables (either equations or inequalities)
"""
def __init__(self, variables, ranges, relations):
self.__init_handle_by_constructor__(
_ffi_api.IntConstraints, variables, ranges, relations)
@tvm._ffi.register_object("arith.IntConstraintsTransform")
class IntConstraintsTransform(Object):
"""We can have different set of variables to represent the same integer constraints.
For example, the following two constrains are equivalent,
{a + b = 0 | a >= 0, b >= 0} and
{m - n = 0 | m >= 0, n <= 0}
This data structure represents the transformation
between two equivalent integer constraints.
In the above example,
src : {a + b = 0 | a >= 0, b >= 0}
dst : {m - n = 0 | m >= 0, n <= 0}
src_to_dst : {a -> m, b -> -n}
dst_to_src : {m -> a, n -> -b}
Parameters
----------
src : arith.IntConstraints
source integer constraints, e.g., {a + b = 0 | a >= 0, b >= 0}
dst : arith.IntConstraints
integer constraints equivalent to the source, e.g., {m - n = 0 | m >= 0, n <= 0}
src_to_dst : Map[tvm.tir.Var, tvm.ir.PrimExpr]
mapping from variables in the src to the variables in the dst,
e.g., {a -> m, b -> -n}
dst_to_src : Map[tvm.tir.Var, tvm.ir.PrimExpr]
mapping from variables in the dst to the variables in the src,
e.g., {m -> a, n -> -b}
"""
def __init__(self, src, dst, src_to_dst, dst_to_src):
self.__init_handle_by_constructor__(
_ffi_api.IntConstraintsTransform, src, dst, src_to_dst, dst_to_src)
def solve_linear_equations(equations, variables=None, ranges=None):
"""Solve linear equations.
Parameters
----------
equations: List[tvm.ir.PrimExpr] or IntConstraints
The equations of the variables
variables : Optional[List[tvm.tir.Var]]
The variables in the system.
ranges : Optional[Map[tvm.tir.Var, tvm.ir.Range]]
The ranges of the variables.
Returns
-------
int_constraints_transform : IntConstraintsTransform
New integer constraints, with less variables (if the problem is NOT of full rank),
or no variable (if the problem is of full rank),
or an empty integer constraints (if the problem is unsolvable).
It also provides the ranges of the variables in the new system,
as well as inequalities inferred from the problem.
You can get the mapping from the original variables to the solution via
int_constraints_transform.src_to_dst.
"""
if isinstance(equations, IntConstraints):
return _ffi_api.SolveLinearEquations(equations)
return _ffi_api.SolveLinearEquations(variables, ranges, equations)
......@@ -58,6 +58,11 @@ void Analyzer::Bind(const Var& var, const Range& range) {
// skip rewrite simplify
}
void Analyzer::Bind(const Map<Var, Range>& variables) {
for (const auto& iter : variables) {
this->Bind(iter.first, iter.second);
}
}
void ConstraintContext::EnterWithScope() {
CHECK(exit_ == nullptr);
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file int_constraints.cc
* \brief The integer constraints data structures.
*/
#include <tvm/arith/int_solver.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/expr_functor.h>
#include <tvm/runtime/registry.h>
#include <utility>
#include <algorithm>
#include <unordered_map>
namespace tvm {
namespace arith {
IntConstraints::IntConstraints(Array<Var> variables,
Map<Var, Range> ranges,
Array<PrimExpr> relations) {
ObjectPtr<IntConstraintsNode> node = make_object<IntConstraintsNode>();
if (!variables.defined()) {
variables = Array<Var>();
}
if (!ranges.defined()) {
ranges = Map<Var, Range>();
}
CHECK(relations.defined());
for (const auto& var : variables) {
CHECK(var.dtype().is_int() || var.dtype().is_uint())
<< "Variables in IntConstraints must be integers";
}
node->variables = std::move(variables);
node->ranges = std::move(ranges);
node->relations = std::move(relations);
data_ = std::move(node);
}
TVM_REGISTER_NODE_TYPE(IntConstraintsNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<IntConstraintsNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const IntConstraintsNode*>(node.get());
p->stream << "IntConstraints("
<< op->variables
<< ", " << op->ranges
<< ", " << op->relations
<< ")";
});
IntConstraintsTransform::IntConstraintsTransform(IntConstraints src,
IntConstraints dst,
Map<Var, PrimExpr> src_to_dst,
Map<Var, PrimExpr> dst_to_src) {
ObjectPtr<IntConstraintsTransformNode> node = make_object<IntConstraintsTransformNode>();
node->src = std::move(src);
node->dst = std::move(dst);
node->src_to_dst = std::move(src_to_dst);
node->dst_to_src = std::move(dst_to_src);
data_ = std::move(node);
}
TVM_REGISTER_NODE_TYPE(IntConstraintsTransformNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<IntConstraintsTransformNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const IntConstraintsTransformNode*>(node.get());
p->stream << "IntConstraintsTransform("
<< "\n\t" << op->src
<< "\n\t" << op->dst
<< "\n\t" << op->src_to_dst
<< "\n\t" << op->dst_to_src
<< "\n)";
});
} // namespace arith
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file util.cc
* \brief The utils for arithmetic analysis.
*/
#include <tvm/arith/util.h>
#include <dmlc/logging.h>
namespace tvm {
namespace arith {
std::tuple<int64_t, int64_t, int64_t> xgcd(int64_t a, int64_t b) {
int64_t s = 0, old_s = 1;
int64_t t = 1, old_t = 0;
int64_t r = b, old_r = a;
while (r != 0) {
int64_t q = old_r / r;
std::swap(r, old_r);
r -= q * old_r;
std::swap(s, old_s);
s -= q * old_s;
std::swap(t, old_t);
t -= q * old_t;
}
CHECK_EQ(a % old_r, 0);
CHECK_EQ(b % old_r, 0);
CHECK(old_r == old_s*a + old_t*b);
return std::make_tuple(old_r, old_s, old_t);
}
} // namespace arith
} // namespace tvm
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import random
import numpy as np
import sys
import pytest
import tvm
from tvm import te, arith, ir, tir
def run_expr(expr, vranges):
""" Evaluate expr for every value of free variables
given by vranges and return the tensor of results.
TODO(yzhliu): move to utils
"""
def _compute_body(*us):
vmap = {v: u + r.min for (v, r), u in zip(vranges.items(), us)}
return tir.ir_pass.Substitute(expr, vmap)
A = te.compute([r.extent.value for v, r in vranges.items()], _compute_body)
args = [tvm.nd.empty(A.shape, A.dtype)]
sch = te.create_schedule(A.op)
mod = tvm.build(sch, [A])
mod(*args)
return args[0].asnumpy()
def check_bruteforce(bool_expr, vranges, cond=None):
""" Check that bool_expr holds given the condition cond
for every value of free variables from vranges.
TODO(yzhliu): move to utils
"""
if cond is not None:
bool_expr = te.any(tir.Not(cond), bool_expr)
res = run_expr(bool_expr, vranges)
if not np.all(res):
indices = list(np.argwhere(res == 0)[0])
counterex = [(str(v), i + r.min) for (v, r), i in zip(vranges.items(), indices)]
counterex = sorted(counterex, key=lambda x: x[0])
counterex = ", ".join([v + " = " + str(i) for v, i in counterex])
raise AssertionError("Expression {}\nis not true on {}\n"
"Counterexample: {}"
.format(tir.ir_pass.CanonicalSimplify(bool_expr), vranges, counterex))
def check_solution(solution, vranges={}):
"""Check that solution is a bijective transformation"""
def _check_forward(constraints1, constraints2, varmap, backvarmap):
all_vranges = vranges.copy()
all_vranges.update({v: r for v, r in constraints1.ranges.items()})
# Check that the transformation is injective
cond_on_vars = tir.const(1, 'bool')
for v in constraints1.variables:
# variable mapping is consistent
v_back = tir.ir_pass.Simplify(tir.ir_pass.Substitute(varmap[v], backvarmap))
cond_on_vars = te.all(cond_on_vars, v == v_back)
# Also we have to check that the new relations are true when old relations are true
cond_subst = tir.ir_pass.Substitute(
te.all(tir.const(1, 'bool'), *constraints2.relations), backvarmap)
# We have to include relations from vranges too
for v in constraints2.variables:
if v in constraints2.ranges:
r = constraints2.ranges[v]
range_cond = te.all(v >= r.min, v < r.min + r.extent)
range_cond = tir.ir_pass.Substitute(range_cond, backvarmap)
cond_subst = te.all(cond_subst, range_cond)
cond_subst = tir.ir_pass.Simplify(cond_subst)
check_bruteforce(te.all(cond_subst, cond_on_vars), all_vranges,
cond=te.all(tir.const(1, 'bool'), *constraints1.relations))
rels = solution.dst.relations
if len(rels) == 1 and ir.structural_equal(rels[0], False):
# not solvable, skip
return
_check_forward(solution.src, solution.dst,
solution.src_to_dst, solution.dst_to_src)
_check_forward(solution.dst, solution.src,
solution.dst_to_src, solution.src_to_dst)
def test_solution_consistency():
seed = random.randrange(sys.maxsize)
print("\nThis test is intentionally non-deterministic, "
"if it fails please report it in github issue together with this seed {}\n".format(seed))
random.seed(seed)
def _check(num_vars, num_formulas, coef=(-5, 5), bounds=(-20, 20)):
variables = [te.var("x" + str(i)) for i in range(num_vars)]
relations = []
for i in range(num_formulas):
s1 = sum([v*random.randint(coef[0], coef[1]) for v in variables])
s1 += random.randint(coef[0], coef[1])
s2 = sum([v*random.randint(coef[0], coef[1]) for v in variables])
s2 += random.randint(coef[0], coef[1])
if random.random() < 0.7:
op = tvm.tir.EQ
else:
# we also make sure it can correctly handle inequalities
op = random.choice([tvm.tir.LE, tvm.tir.LT, tvm.tir.GE, tvm.tir.GT])
relations.append(op(s1, s2))
vranges = {v: tvm.ir.expr.Range(bounds[0], bounds[1] + 1) for v in variables}
solution = arith.solve_linear_equations(relations, variables, vranges)
check_solution(solution)
# leaving some variables as parameters should also be ok
for k in [1, 2]:
if len(variables) > k:
solution = arith.solve_linear_equations(relations, variables[:-k], vranges)
param_ranges = {v: vranges[v] for v in variables[-k:]}
check_solution(solution, param_ranges)
for i in range(2):
_check(num_vars=1, num_formulas=1)
for i in range(2):
_check(num_vars=1, num_formulas=2)
for i in range(2):
_check(num_vars=2, num_formulas=1)
for i in range(2):
_check(num_vars=2, num_formulas=2)
for i in range(2):
_check(num_vars=2, num_formulas=3)
for i in range(3):
_check(num_vars=3, num_formulas=3, coef=(-2, 2))
for i in range(3):
_check(num_vars=3, num_formulas=4, coef=(-2, 2))
for i in range(3):
_check(num_vars=4, num_formulas=3, coef=(-1, 1))
for i in range(3):
_check(num_vars=10, num_formulas=2, coef=(-1, 1), bounds=(0, 4))
for i in range(3):
_check(num_vars=10, num_formulas=3, coef=(0, 1), bounds=(0, 4))
def test_empty_var_to_solve():
x, y = te.var("x"), te.var("y")
equations = [
tvm.tir.EQ(x + y, 20),
tvm.tir.EQ(x - y, 10),
]
solution = arith.solve_linear_equations(equations)
assert len(solution.src_to_dst) == 0
assert len(solution.dst_to_src) == 0
assert len(solution.src.variables) == 0
assert len(solution.src.ranges) == 0
assert ir.structural_equal(solution.src.relations, equations)
assert ir.structural_equal(solution.src, solution.dst)
def test_unique_solution():
x, y = te.var("x"), te.var("y")
solution = arith.solve_linear_equations([
tvm.tir.EQ(x + y, 20),
tvm.tir.EQ(x - y, 10),
], [x, y])
assert list(solution.dst.variables) == []
assert ir.structural_equal(solution.src_to_dst[x], 15)
assert ir.structural_equal(solution.src_to_dst[y], 5)
def test_low_rank():
x, y, z = te.var("x"), te.var("y"), te.var("z")
ranges = {}
solution = arith.solve_linear_equations([
tvm.tir.EQ(x + y + z, 15),
tvm.tir.EQ(x + y, 10),
], [x, y, z], ranges)
[n0] = solution.dst.variables
assert ir.structural_equal(solution.src_to_dst[x], n0 + 10)
assert ir.structural_equal(solution.src_to_dst[y], -n0)
assert ir.structural_equal(solution.src_to_dst[z], 5)
def test_infer_range():
x, y = te.var("x"), te.var("y")
ranges = {
x: tvm.ir.Range.make_by_min_extent(-5, 10),
y: tvm.ir.Range.make_by_min_extent(0, 10),
}
solution = arith.solve_linear_equations([
tvm.tir.EQ(x + y, 0),
], [x, y], ranges)
[n0] = solution.dst.variables
assert ir.structural_equal(solution.src_to_dst[x], n0)
assert ir.structural_equal(solution.src_to_dst[y], -n0)
# inferred from y's range
assert ir.structural_equal(solution.dst.ranges[n0].min, -9)
assert ir.structural_equal(solution.dst.ranges[n0].extent, 10)
# additional inequality is added into the system for x
[ineq] = solution.dst.relations
assert isinstance(ineq, tvm.tir.LE)
assert ir.structural_equal(ineq.a, -5)
assert ir.structural_equal(ineq.b, n0)
def test_ill_formed():
x, y = te.var("x"), te.var("y")
solution = arith.solve_linear_equations([
tvm.tir.EQ(x + y, 0),
tvm.tir.EQ(x - y, 0),
tvm.tir.EQ(x, 5),
], [x, y], {})
assert list(solution.dst.variables) == []
[rel] = solution.dst.relations
assert ir.structural_equal(rel, False)
assert len(solution.src_to_dst) == 0
assert len(solution.dst_to_src) == 0
if __name__ == "__main__":
pytest.main([__file__])
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment