Unverified Commit e21f2682 by Yizhi Liu Committed by GitHub

[Arith] linear system and equation solver (#5171)

* [arith] linear system and equation solver

Co-authored-by: Sergei Grechanik <sergei.grechanik+h@gmail.com>

* avoid constructing analyzer every time

* generate random test cases and address comments

Co-authored-by: Sergei Grechanik <sergei.grechanik@gmail.com>

* rename linear_system to int_constraints

* add comments and use random seed

* message for reporting failure with seed

* add SEqualReduce to IntConstraints; allow variables & ranges to be None

Co-authored-by: Sergei Grechanik <sergei.grechanik+h@gmail.com>
Co-authored-by: Sergei Grechanik <sergei.grechanik@gmail.com>
parent b236565e
...@@ -424,6 +424,12 @@ class Analyzer { ...@@ -424,6 +424,12 @@ class Analyzer {
*/ */
void Bind(const Var& var, const Range& range); void Bind(const Var& var, const Range& range);
/*! /*!
* \brief Bind all the vars in the Map
*
* \param variables The {variable -> range} map.
*/
void Bind(const Map<Var, Range>& variables);
/*!
* \brief Whether can we prove expr >= val. * \brief Whether can we prove expr >= val.
* Non-negative proof is very useful in integer analysis * Non-negative proof is very useful in integer analysis
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/arith/int_solver.h
* \brief integer constraints data structures and solvers
*/
#ifndef TVM_ARITH_INT_SOLVER_H_
#define TVM_ARITH_INT_SOLVER_H_
#include <tvm/ir/expr.h>
#include <tvm/tir/expr.h>
#include <unordered_map>
#include <vector>
namespace tvm {
namespace arith {
using tir::Var;
using tir::VarNode;
using tir::IterVar;
/*!
* \brief Represent integer constrains including (integer) variables, their ranges and
* the relations between them (either equations or inequalities).
* \sa LinearSystem
*/
class IntConstraintsNode : public Object {
public:
// e.g., \alpha, \beta, must be integers
Array<Var> variables;
// e.g., 1 <= \alpha <= N, etc.
// it is absolutely ok to include ranges for parameters
// (variables that are not in this->variables) in this map
Map<Var, Range> ranges;
// linear equalities or inequalities
// e.g., A \alpha = \beta or A \alpha <= \beta
Array<PrimExpr> relations;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("variables", &variables);
v->Visit("ranges", &ranges);
v->Visit("relations", &relations);
}
bool SEqualReduce(const IntConstraintsNode* other, SEqualReducer equal) const {
return
equal(variables, other->variables) &&
equal(ranges, other->ranges) &&
equal(relations, other->relations);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(variables);
hash_reduce(ranges);
hash_reduce(relations);
}
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const char* _type_key = "arith.IntConstraints";
TVM_DECLARE_FINAL_OBJECT_INFO(IntConstraintsNode, Object);
};
/*!
* \brief Managed reference to IntConstraintsNode.
* \sa IntConstraintsNode
*/
class IntConstraints : public ObjectRef {
public:
/*!
* \brief Constructor by fields
* \param variables The variables in the constraints, must be integers.
* \param ranges The ranges of the variables.
* \param relations The linear relations between the variables
* (either equations or inequalities)
*/
TVM_DLL IntConstraints(Array<Var> variables,
Map<Var, Range> ranges,
Array<PrimExpr> relations);
TVM_DEFINE_OBJECT_REF_METHODS(IntConstraints, ObjectRef, IntConstraintsNode);
};
/*!
* \brief We can have different set of variables to represent the same constraints.
* For example, the following two systems are equivalent,
* {a + b = 0 | a >= 0, b >= 0} and
* {m - n = 0 | m >= 0, n <= 0}
* This data structure represents the transformation
* between two equivalent linear systems.
* In the above example,
* src : {a + b = 0 | a >= 0, b >= 0}
* dst : {m - n = 0 | m >= 0, n <= 0}
* src_to_dst : {a -> m, b -> -n}
* dst_to_src : {m -> a, n -> -b}
* \sa IntConstraintsTransform
*/
class IntConstraintsTransformNode : public Object {
public:
IntConstraints src;
IntConstraints dst;
Map<Var, PrimExpr> src_to_dst;
Map<Var, PrimExpr> dst_to_src;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("src", &src);
v->Visit("dst", &dst);
v->Visit("src_to_dst", &src_to_dst);
v->Visit("dst_to_src", &dst_to_src);
}
bool SEqualReduce(const IntConstraintsTransformNode* other, SEqualReducer equal) const {
return
equal(src, other->src) &&
equal(dst, other->dst) &&
equal(src_to_dst, other->src_to_dst) &&
equal(dst_to_src, other->dst_to_src);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(src);
hash_reduce(dst);
hash_reduce(src_to_dst);
hash_reduce(dst_to_src);
}
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const char* _type_key = "arith.IntConstraintsTransform";
TVM_DECLARE_FINAL_OBJECT_INFO(IntConstraintsTransformNode, Object);
};
/*!
* \brief Managed reference to IntConstraintsTransformNode.
* \sa IntConstraintsTransformNode
*/
class IntConstraintsTransform : public ObjectRef {
public:
/*!
* \brief Constructor by fields
* \param src source integer constraints, e.g., {a + b = 0 | a >= 0, b >= 0}
* \param dst integer constraints equivalent to the source,
* e.g., {m - n = 0 | m >= 0, n <= 0}
* \param src_to_dst mapping from variables in the \p src to the variables in the \p dst,
* e.g., {a -> m, b -> -n}
* \param dst_to_src mapping from variables in the \p dst to the variables in the \p src,
* e.g., {m -> a, n -> -b}
*/
TVM_DLL IntConstraintsTransform(IntConstraints src,
IntConstraints dst,
Map<Var, PrimExpr> src_to_dst,
Map<Var, PrimExpr> dst_to_src);
TVM_DEFINE_OBJECT_REF_METHODS(IntConstraintsTransform, ObjectRef, IntConstraintsTransformNode);
};
/*!
* \brief Obtain Smith Normal Form of linear equation A x = y.
* Smith Normal Form of matrix A_{mxn} is S_{mxn} = U_{mxm} A_{mxn} V_{nxn},
* in which S_{mxn} is diag(s1, s2, ..., sr, 0, ..., 0) and r is the rank of A.
* NOTE: Although in standard Smith Normal Form the diagonal elements satisfy
* s_i | s_{i+1} (| means divides), the implement here does not guarantee it.
* TODO(yzhliu): From sergei-grechanik:
* computing the proper Smith normal form may improve stability of automatic differentiation
* (generating the same gradient code for slightly different but equivalent input code
* U_{mxm} and V_{nxn} are invertible matrices.
* This function modifies \p S to be S_{mxn}, \p V to be V_{nxn},
* \p y to be U_{mxm} y_{mx1} and \p x to be V^{-1} x.
* \param S the original A_{mxn}, it will be modified to S_{mxn}
* \param V an identity matrix, it will be modified to V_{nxn}
* \param x the x in A x = y. it will be modified to V^{-1}_{nxn} x_{nx1}
* \param y the y in A x = y. it will be modified to U_{mxm} y_{mx1}
*/
void SmithNormalFormDiag(std::vector<std::vector<int64_t>> *S,
std::vector<std::vector<int64_t>> *V,
std::vector<PrimExpr>* x,
std::vector<PrimExpr> *y);
/*!
* \brief Solve linear equations.
* \param system_to_solve the variables to solve, their ranges, and a list of equations.
* \return A new linear system, with less variables (if \p system_to_solve is NOT of full rank),
* or no variable (if \p system_to_solve is of full rank),
* or an empty linear system (if \p system_to_solve is unsolvable).
* It also provides the ranges of the variables in the new system,
* as well as inequalities inferred from the \p system_to_solve.
* You can get the mapping from the original variables to the solution via ret->src_to_dst.
*/
IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_solve);
} // namespace arith
} // namespace tvm
#endif // TVM_ARITH_INT_SOLVER_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/arith/util.h
* \brief Utils for arithmetic analysis.
*/
#ifndef TVM_ARITH_UTIL_H_
#define TVM_ARITH_UTIL_H_
#include <cstdint>
#include <tuple>
namespace tvm {
/*! \brief namespace of arithmetic analysis. */
namespace arith {
/*!
* \brief Calculate the extended greatest common divisor for two values.
* See https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm.
* \param a an integer number
* \param b an integer number
* \return 3 integers (div, m, n) where div = gcd(a, b) and a*m + b*n = div
*/
std::tuple<int64_t, int64_t, int64_t> xgcd(int64_t a, int64_t b);
} // namespace arith
} // namespace tvm
#endif // TVM_ARITH_UTIL_H_
...@@ -20,3 +20,4 @@ from .int_set import IntSet, IntervalSet ...@@ -20,3 +20,4 @@ from .int_set import IntSet, IntervalSet
from .analyzer import ModularSet, ConstIntBound, Analyzer from .analyzer import ModularSet, ConstIntBound, Analyzer
from .bound import deduce_bound from .bound import deduce_bound
from .pattern import detect_linear_equation, detect_clip_bound from .pattern import detect_linear_equation, detect_clip_bound
from .int_solver import solve_linear_equations
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""integer constraints data structures and solvers"""
import tvm._ffi
from tvm.runtime import Object
from . import _ffi_api
@tvm._ffi.register_object("arith.IntConstraints")
class IntConstraints(Object):
"""Represent a set of integer constraints including variables, their ranges and
the relations between them (either equations or inequalities)
Parameters
----------
variables : List[tvm.tir.Var]
The variables in the constraints. Must be integers
ranges : Map[tvm.tir.Var, tvm.ir.Range]
The ranges of the variables.
relations : List[tvm.ir.PrimExpr]
The relations between the variables (either equations or inequalities)
"""
def __init__(self, variables, ranges, relations):
self.__init_handle_by_constructor__(
_ffi_api.IntConstraints, variables, ranges, relations)
@tvm._ffi.register_object("arith.IntConstraintsTransform")
class IntConstraintsTransform(Object):
"""We can have different set of variables to represent the same integer constraints.
For example, the following two constrains are equivalent,
{a + b = 0 | a >= 0, b >= 0} and
{m - n = 0 | m >= 0, n <= 0}
This data structure represents the transformation
between two equivalent integer constraints.
In the above example,
src : {a + b = 0 | a >= 0, b >= 0}
dst : {m - n = 0 | m >= 0, n <= 0}
src_to_dst : {a -> m, b -> -n}
dst_to_src : {m -> a, n -> -b}
Parameters
----------
src : arith.IntConstraints
source integer constraints, e.g., {a + b = 0 | a >= 0, b >= 0}
dst : arith.IntConstraints
integer constraints equivalent to the source, e.g., {m - n = 0 | m >= 0, n <= 0}
src_to_dst : Map[tvm.tir.Var, tvm.ir.PrimExpr]
mapping from variables in the src to the variables in the dst,
e.g., {a -> m, b -> -n}
dst_to_src : Map[tvm.tir.Var, tvm.ir.PrimExpr]
mapping from variables in the dst to the variables in the src,
e.g., {m -> a, n -> -b}
"""
def __init__(self, src, dst, src_to_dst, dst_to_src):
self.__init_handle_by_constructor__(
_ffi_api.IntConstraintsTransform, src, dst, src_to_dst, dst_to_src)
def solve_linear_equations(equations, variables=None, ranges=None):
"""Solve linear equations.
Parameters
----------
equations: List[tvm.ir.PrimExpr] or IntConstraints
The equations of the variables
variables : Optional[List[tvm.tir.Var]]
The variables in the system.
ranges : Optional[Map[tvm.tir.Var, tvm.ir.Range]]
The ranges of the variables.
Returns
-------
int_constraints_transform : IntConstraintsTransform
New integer constraints, with less variables (if the problem is NOT of full rank),
or no variable (if the problem is of full rank),
or an empty integer constraints (if the problem is unsolvable).
It also provides the ranges of the variables in the new system,
as well as inequalities inferred from the problem.
You can get the mapping from the original variables to the solution via
int_constraints_transform.src_to_dst.
"""
if isinstance(equations, IntConstraints):
return _ffi_api.SolveLinearEquations(equations)
return _ffi_api.SolveLinearEquations(variables, ranges, equations)
...@@ -58,6 +58,11 @@ void Analyzer::Bind(const Var& var, const Range& range) { ...@@ -58,6 +58,11 @@ void Analyzer::Bind(const Var& var, const Range& range) {
// skip rewrite simplify // skip rewrite simplify
} }
void Analyzer::Bind(const Map<Var, Range>& variables) {
for (const auto& iter : variables) {
this->Bind(iter.first, iter.second);
}
}
void ConstraintContext::EnterWithScope() { void ConstraintContext::EnterWithScope() {
CHECK(exit_ == nullptr); CHECK(exit_ == nullptr);
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file int_constraints.cc
* \brief The integer constraints data structures.
*/
#include <tvm/arith/int_solver.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/expr_functor.h>
#include <tvm/runtime/registry.h>
#include <utility>
#include <algorithm>
#include <unordered_map>
namespace tvm {
namespace arith {
IntConstraints::IntConstraints(Array<Var> variables,
Map<Var, Range> ranges,
Array<PrimExpr> relations) {
ObjectPtr<IntConstraintsNode> node = make_object<IntConstraintsNode>();
if (!variables.defined()) {
variables = Array<Var>();
}
if (!ranges.defined()) {
ranges = Map<Var, Range>();
}
CHECK(relations.defined());
for (const auto& var : variables) {
CHECK(var.dtype().is_int() || var.dtype().is_uint())
<< "Variables in IntConstraints must be integers";
}
node->variables = std::move(variables);
node->ranges = std::move(ranges);
node->relations = std::move(relations);
data_ = std::move(node);
}
TVM_REGISTER_NODE_TYPE(IntConstraintsNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<IntConstraintsNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const IntConstraintsNode*>(node.get());
p->stream << "IntConstraints("
<< op->variables
<< ", " << op->ranges
<< ", " << op->relations
<< ")";
});
IntConstraintsTransform::IntConstraintsTransform(IntConstraints src,
IntConstraints dst,
Map<Var, PrimExpr> src_to_dst,
Map<Var, PrimExpr> dst_to_src) {
ObjectPtr<IntConstraintsTransformNode> node = make_object<IntConstraintsTransformNode>();
node->src = std::move(src);
node->dst = std::move(dst);
node->src_to_dst = std::move(src_to_dst);
node->dst_to_src = std::move(dst_to_src);
data_ = std::move(node);
}
TVM_REGISTER_NODE_TYPE(IntConstraintsTransformNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<IntConstraintsTransformNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const IntConstraintsTransformNode*>(node.get());
p->stream << "IntConstraintsTransform("
<< "\n\t" << op->src
<< "\n\t" << op->dst
<< "\n\t" << op->src_to_dst
<< "\n\t" << op->dst_to_src
<< "\n)";
});
} // namespace arith
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/arith/solve_linear_equation.cc
* \brief Solve linear equations.
*/
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/arith/analyzer.h>
#include <tvm/arith/int_solver.h>
#include <tvm/arith/util.h>
#include <tvm/tir/op.h>
#include <tvm/arith/pattern.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/runtime/data_type.h>
namespace tvm {
namespace arith {
using namespace tvm::runtime;
void SmithNormalFormDiag(std::vector<std::vector<int64_t> >* S,
std::vector<std::vector<int64_t> >* V,
std::vector<PrimExpr>* x,
std::vector<PrimExpr>* y) {
if (S->empty() || V->empty()) return;
size_t m = S->size();
size_t n = (*S)[0].size(); // n is # of variables
CHECK_EQ(V->size(), n);
CHECK_EQ((*V)[0].size(), n);
for (size_t index = 0; index < std::min(m, n); ++index) {
// Here A is partially diagonalized, that is A[i, j] is zero for all i, j
// such that (i < index) or (j < index), unless (i == j).
// That is, now we are diagonalizing the submatrix with i >= index and j >= index
// Find a row with a nonzero element in the index-th column
// (We also prefer rows where this element has minimal abs value)
size_t best_i = index;
for (size_t i = best_i; i < m; ++i) {
int64_t s_old = (*S)[best_i][index];
int64_t s_new = (*S)[i][index];
if (s_new != 0) {
if (s_old == 0 || std::abs(s_new) < std::abs(s_old)) {
best_i = i;
}
}
}
// Move the row we found to the index-th position
std::swap((*S)[index], (*S)[best_i]);
std::swap((*y)[index], (*y)[best_i]);
// If the index-th diagonal element is still zero, try to find a column with nonzero index-th
// element and move it to the index-th position
if ((*S)[index][index] == 0) {
for (size_t j = index + 1; j < n; ++j) {
if ((*S)[index][j] != 0) {
for (size_t i = index; i < m; ++i) {
std::swap((*S)[i][index], (*S)[i][j]);
}
// swapping columns corresponds to swapping the corresponding x
std::swap((*x)[index], (*x)[j]);
for (size_t i = 0; i < n; ++i) {
std::swap((*V)[i][index], (*V)[i][j]);
}
break;
}
}
}
// If the index-th diagonal element is still zero, then both the index-th row and the index-th
// column are completely zero, and we don't need to do anything; just go to the next index
if ((*S)[index][index] == 0) {
continue;
}
// Now the index-th diagonal element is non-zero and we can zero all the index-th column
// below it by subtracting rows from each other
for (auto i = index + 1; i < m; ++i) {
if ((*S)[i][index] != 0) {
int64_t g, a, b;
// g = a*matrix[index][index] + b*matrix[i][index]
if ((*S)[i][index] % (*S)[index][index] != 0) {
std::tie(g, a, b) = xgcd((*S)[index][index], (*S)[i][index]);
} else {
// Explicitly avoid changing the index-th row. This is important to avoid infinite loop.
g = (*S)[index][index];
a = 1;
b = 0;
}
// Let m = S[index][index], n = S[i][index], then the following is true:
//
// [ a n/g ][ m/g n/g ] = [ 1 0 ]
// [ b -m/g ][ b -a ] = [ 0 1 ]
//
// Note that the two matrices are integer (since g = gcd(m, n)).
// We will essentially multiply our matrix on the left by a dilated and transposed version
// of the first of these two matrices. The second matrix is not needed here, however we will
// use it while zeroing the index-th row.
int64_t m_g = (*S)[index][index] / g;
int64_t n_g = (*S)[i][index] / g;
// Note that j is the index of the column, not the row
for (size_t j = index; j < (*S)[i].size(); ++j) {
// Multiply index-th row by a and add the i-th row multiplied by b
// This will make the index-th diagonal element equal to the gcd
int64_t new_index_j = a*(*S)[index][j] + b*(*S)[i][j];
// This transformation performs zeroing of matrix[i][index]
int64_t new_i_j = n_g*(*S)[index][j] - m_g*(*S)[i][j];
(*S)[index][j] = new_index_j;
(*S)[i][j] = new_i_j;
}
// We have to do the same with rhs
PrimExpr ea = te::make_const((*y)[index].dtype(), a);
PrimExpr eb = te::make_const((*y)[i].dtype(), b);
PrimExpr e_m_g = te::make_const((*y)[i].dtype(), m_g);
PrimExpr e_n_g = te::make_const((*y)[index].dtype(), n_g);
PrimExpr new_index_rhs = ea*(*y)[index] + eb*(*y)[i];
PrimExpr new_i_rhs = e_n_g*(*y)[index] - e_m_g*(*y)[i];
(*y)[index] = new_index_rhs;
(*y)[i] = new_i_rhs;
}
}
bool changed = false;
// Now we have to zero the elements of the index-th row by manipulating columns.
// This is more difficult because column manipulation corresponds to variable manipulation,
// but the algorithm is essentially the same as before.
for (size_t j = index + 1; j < n; ++j) {
if ((*S)[index][j] != 0) {
int64_t g, a, b;
// g = a*matrix[index][index] + b*matrix[index][j]
if ((*S)[index][j] % (*S)[index][index] != 0) {
std::tie(g, a, b) = xgcd((*S)[index][index], (*S)[index][j]);
// During this phase we may disrupt the zeroness of the index-th column, so we will
// have to take some action if this might have happened.
changed = true;
} else {
// Explicitly avoid changing the index-th column. This is important to avoid infinite
// loop. Note that here we don't have to set `changed` to true since we don't change the
// index-th column.
g = (*S)[index][index];
a = 1;
b = 0;
}
// Let m = S[index][index], n = S[index][j], then the following is true:
//
// [ a n/g ][ m/g n/g ] = [ 1 0 ]
// [ b -m/g ][ b -a ] = [ 0 1 ]
//
// Now we are going to multiply our matrix on the right (to manipulate columns instead of
// rows), we will also transform the old_to_new matrix the same way, and we will use the
// second matrix to transform new_to_old.
int64_t m_g = (*S)[index][index] / g;
int64_t n_g = (*S)[index][j] / g;
for (size_t i = index; i < m; ++i) {
int64_t new_i_index = a*(*S)[i][index] + b*(*S)[i][j];
int64_t new_i_j = n_g*(*S)[i][index] - m_g*(*S)[i][j];
(*S)[i][index] = new_i_index;
(*S)[i][j] = new_i_j;
}
// We do exactly the same transformations with V
for (size_t i = 0; i < n; ++i) {
int64_t new_i_index = a*(*V)[i][index] + b*(*V)[i][j];
int64_t new_i_j = n_g*(*V)[i][index] - m_g*(*V)[i][j];
(*V)[i][index] = new_i_index;
(*V)[i][j] = new_i_j;
}
// And apply reverse transformations to new_to_old.
PrimExpr ea = te::make_const((*x)[j].dtype(), a);
PrimExpr eb = te::make_const((*x)[index].dtype(), b);
PrimExpr e_m_g = te::make_const((*x)[index].dtype(), m_g);
PrimExpr e_n_g = te::make_const((*x)[j].dtype(), n_g);
PrimExpr new_index = e_m_g*(*x)[index] + e_n_g*(*x)[j];
PrimExpr new_j = eb*(*x)[index] - ea*(*x)[j];
(*x)[index] = new_index;
(*x)[j] = new_j;
}
}
if (changed) {
// We might have changed the first column, so we have to zero it once more
// (or at least check if it's zero), so just perform this iteration once more.
index -= 1;
}
}
}
Map<Var, Range> InferRange(const Map<Var, PrimExpr>& vars_to_infer,
const Array<Var>& ori_vars,
const Map<Var, Range>& ori_ranges) {
// The resulting ranges
Map<Var, Range> new_ranges;
std::unordered_set<const VarNode*> ori_vset;
for (const Var& v : ori_vars) {
ori_vset.insert(v.get());
}
std::unordered_map<const VarNode*, IntSet> var_intsets;
for (const auto& p : ori_ranges) {
if (!ori_vset.count(p.first.get())) {
// First of all, fill the new ranges with outer variable ranges
new_ranges.Set(p.first, p.second);
}
// Convert original ranges to IntSets
var_intsets[p.first.get()] = IntSet::range(p.second);
}
// Infer ranges for the new variables and add them to the resulting ranges
for (const auto& p : vars_to_infer) {
const auto& var = p.first;
const auto& expr = p.second;
Range range = EvalSet(expr, var_intsets).cover_range(Range());
if (range.defined()) {
new_ranges.Set(var, range);
}
}
return new_ranges;
}
// pretty print matrix equation
void DebugPrint(const std::vector<std::vector<int64_t>>& S,
const std::vector<std::vector<int64_t>>& V,
const std::vector<PrimExpr>& V_inv_x,
const std::vector<PrimExpr>& rhs) {
std::cout << "S:\n";
for (size_t i = 0; i < S.size(); ++i) {
for (auto e : S[i]) {
std::cout << e << "\t";
}
std::cout << "\t->\t" << rhs[i];
std::cout << "\n";
}
std::cout << "V:\n";
for (const auto& r : V) {
for (auto e : r) {
std::cout << e << "\t";
}
std::cout << "\n";
}
std::cout << "V_inv x:\n" << Array<PrimExpr>(V_inv_x);
std::cout << "\n" << std::endl;
}
IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_solve) {
// m: # of equations
// n: # of variables
// we first construct A_{mxn} x_{nx1} = y_{mx1}
// then get Smith normal form of matrix A,
// S_{mxn} = U_{mxm} A_{mxn} V_{nxn}
// => U^{-1} S V^{-1} x = y
// S V^{-1} x = U y
std::vector<PrimExpr> Uy; // mx1
std::vector<std::vector<int64_t>> S; // mxn
std::vector<std::vector<int64_t>> V; // nxn
std::vector<PrimExpr> V_inv_x; // V^{-1} x, nx1
// Conditions we don't know what to do with
std::vector<PrimExpr> rest;
Analyzer analyzer_problem;
analyzer_problem.Bind(system_to_solve->ranges);
size_t num_vars = system_to_solve->variables.size();
// initialize V_{nxn} with identity matrix,
// initialize V^{-1} x as x
for (size_t i = 0; i < num_vars; ++i) {
V.emplace_back(num_vars);
V.back()[i] = 1;
V_inv_x.push_back(system_to_solve->variables[i]);
}
// Transform formulas into rows of the matrix
// S_{mxn} V^{-1}_{nxn} x_{nx1} = U y, in which n is # of variables
// here we initialize S_{mxn} to be A, U to be identity matrix.
for (const PrimExpr& equation : system_to_solve->relations) {
if (const tir::EQNode* eq = equation.as<tir::EQNode>()) {
// a-b = sum_{i=0}^{n-1} variables[i] * coeff[i] + coeff[n]
Array<PrimExpr> coeffs = arith::DetectLinearEquation(
analyzer_problem.Simplify(eq->a - eq->b),
system_to_solve->variables);
if (!coeffs.empty()) {
std::vector<int64_t> row;
for (size_t j = 0; j < coeffs.size() - 1; ++j) {
PrimExpr c = coeffs[j];
if (const IntImmNode* ic = c.as<IntImmNode>()) {
row.push_back(ic->value);
} else {
// elements in matrix S V must be integers
// ignore equations that we cannot deal with.
LOG(WARNING) << "Cannot deal with non-integer coefficients, ignore equation "
<< equation;
row.clear();
break;
}
}
if (!row.empty()) {
// S V^{-1} (a-b) = Uy
// V is identity for now
S.push_back(row);
Uy.push_back(-coeffs[coeffs.size() - 1]);
continue;
}
}
}
// otherwise
rest.push_back(equation);
}
// After diagonalizing, we have
// S_{mxn} is the Smith normal form (diagonal matrix)
// V_{nxn} is invertible
// V_inv_x is V^{-1} \times x
// Uy is U \times y
SmithNormalFormDiag(&S, &V, &V_inv_x, &Uy);
Array<Var> new_vars;
Array<PrimExpr> new_relations;
Map<Var, PrimExpr> new_to_old_map;
Map<Var, PrimExpr> old_to_new_map;
// Simplify right hand sides
for (PrimExpr r : Uy) {
r = analyzer_problem.Simplify(r);
}
// Create the relations of the existence of a solution
for (size_t j = 0; j < S.size(); ++j) {
PrimExpr new_relation;
if (j >= num_vars || S[j][j] == 0) {
// The row of matrix is zero. A solution exists only if the Ub[j] is also zero
new_relation = (Uy[j] == 0);
} else {
// The diagonal element is non-zero. A solution exists only if the diagonal element
// is a divisor of the Ub[j]
new_relation = (floormod(Uy[j], std::abs(S[j][j])) == 0);
}
new_relation = analyzer_problem.Simplify(new_relation);
if (tir::is_const_int(new_relation, 0)) {
// unable to solve the system.
return IntConstraintsTransform(
system_to_solve,
IntConstraints(
/*variables=*/{},
/*ranges=*/{},
/*relations=*/{te::make_zero(DataType::Bool())}),
{}, {});
} else if (!tir::is_const_int(new_relation, 1)) {
new_relations.push_back(new_relation);
}
}
Array<PrimExpr> solution_for_V_inv_x;
// Now create new variables or directly solve the equations
// suppose the rank of A is r, aka r = # of non-zeros in S
// the solution of S_{mxn} V^{-1}_{nxn} x_{nx1} = U b
// is
// x = (pseudo-inverse of A) b + K_{(n)x(n-r)} z_{n-r}
// = V_{nxn} S^{-1}_{nxm} (Ub)_{mxn} + K_{(n)x(n-r)} z_{n-r}
// in which K is the right n-r columns of V, z is variable vector
// thus,
// V^{-1} x = S^{-1}_{nxm} (Ub)_{mxn} +
// [[0, ... 0]_{n-r}, ... [0, ..., 0], diag(1, ..., 1)_{(n-r)x(n-r)}] z_{n-r}
for (size_t j = 0; j < num_vars; ++j) {
if (j >= S.size() || S[j][j] == 0) {
// The j-th variable can take any integer value, create a tvm variable for it
PrimExpr to_old = analyzer_problem.Simplify(V_inv_x[j]);
std::string name_hint = "n" + std::to_string(new_vars.size());
if (const VarNode* v_old = to_old.as<VarNode>()) {
name_hint += "_" + v_old->name_hint;
}
Var v = Var(name_hint, V_inv_x[j].dtype());
solution_for_V_inv_x.push_back(v);
new_vars.push_back(v);
new_to_old_map.Set(v, to_old);
} else {
// The j-th variable is just a single value, don't create a tvm variable
// S^{-1}_{nxm} Uy_{mxn}
if (S[j][j] >= 0) {
PrimExpr a = te::make_const(Uy[j].dtype(), S[j][j]);
solution_for_V_inv_x.push_back(
analyzer_problem.Simplify(floordiv(Uy[j], a)));
} else {
// This is required because some simplifiers
// have problems with dividing by negative numbers
PrimExpr a = te::make_const(Uy[j].dtype(), -S[j][j]);
solution_for_V_inv_x.push_back(
analyzer_problem.Simplify(floordiv(-Uy[j], a)));
}
}
}
// V V^{-1} x = x
for (size_t i = 0; i < num_vars; ++i) {
PrimExpr e = te::make_zero(system_to_solve->variables[i].dtype());
for (size_t j = 0; j < num_vars; ++j) {
e = e + te::make_const(e.dtype(), V[i][j])*solution_for_V_inv_x[j];
}
e = analyzer_problem.Simplify(e);
old_to_new_map.Set(system_to_solve->variables[i], e);
}
// The resulting ranges
Map<Var, Range> new_ranges = InferRange(
new_to_old_map, system_to_solve->variables, system_to_solve->ranges);
Analyzer analyzer_solution;
analyzer_solution.Bind(new_ranges);
// We have to transform ranges of the old variables into relations over new variables because
// new ranges are not enough usually.
for (const auto& p : system_to_solve->ranges) {
const Var& old_var = p.first;
const Range& old_range = p.second;
if (old_to_new_map.count(old_var)) {
PrimExpr express_by_new_vars = old_to_new_map[old_var];
PrimExpr lower_cond = analyzer_solution.Simplify(
old_range->min <= express_by_new_vars);
PrimExpr upper_cond = analyzer_solution.Simplify(
express_by_new_vars < old_range->min + old_range->extent);
if (!tir::is_const_int(lower_cond, 1)) {
new_relations.push_back(lower_cond);
}
if (!tir::is_const_int(upper_cond, 1)) {
new_relations.push_back(upper_cond);
}
}
}
// Add the rest conditions
for (const PrimExpr& cond : rest) {
new_relations.push_back(Substitute(cond, old_to_new_map));
}
IntConstraints solution(new_vars, new_ranges, new_relations);
IntConstraintsTransform transform(
system_to_solve, solution, old_to_new_map, new_to_old_map);
return transform;
}
TVM_REGISTER_GLOBAL("arith.SolveLinearEquations")
.set_body([](TVMArgs args, TVMRetValue *ret) {
if (args.size() == 1) {
*ret = SolveLinearEquations(args[0]);
} else if (args.size() == 3) {
IntConstraints problem(args[0], args[1], args[2]);
*ret = SolveLinearEquations(problem);
} else {
LOG(FATAL) << "arith.SolveLinearEquations expects 1 or 3 arguments, gets " << args.size();
}
});
} // namespace arith
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file util.cc
* \brief The utils for arithmetic analysis.
*/
#include <tvm/arith/util.h>
#include <dmlc/logging.h>
namespace tvm {
namespace arith {
std::tuple<int64_t, int64_t, int64_t> xgcd(int64_t a, int64_t b) {
int64_t s = 0, old_s = 1;
int64_t t = 1, old_t = 0;
int64_t r = b, old_r = a;
while (r != 0) {
int64_t q = old_r / r;
std::swap(r, old_r);
r -= q * old_r;
std::swap(s, old_s);
s -= q * old_s;
std::swap(t, old_t);
t -= q * old_t;
}
CHECK_EQ(a % old_r, 0);
CHECK_EQ(b % old_r, 0);
CHECK(old_r == old_s*a + old_t*b);
return std::make_tuple(old_r, old_s, old_t);
}
} // namespace arith
} // namespace tvm
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import random
import numpy as np
import sys
import pytest
import tvm
from tvm import te, arith, ir, tir
def run_expr(expr, vranges):
""" Evaluate expr for every value of free variables
given by vranges and return the tensor of results.
TODO(yzhliu): move to utils
"""
def _compute_body(*us):
vmap = {v: u + r.min for (v, r), u in zip(vranges.items(), us)}
return tir.ir_pass.Substitute(expr, vmap)
A = te.compute([r.extent.value for v, r in vranges.items()], _compute_body)
args = [tvm.nd.empty(A.shape, A.dtype)]
sch = te.create_schedule(A.op)
mod = tvm.build(sch, [A])
mod(*args)
return args[0].asnumpy()
def check_bruteforce(bool_expr, vranges, cond=None):
""" Check that bool_expr holds given the condition cond
for every value of free variables from vranges.
TODO(yzhliu): move to utils
"""
if cond is not None:
bool_expr = te.any(tir.Not(cond), bool_expr)
res = run_expr(bool_expr, vranges)
if not np.all(res):
indices = list(np.argwhere(res == 0)[0])
counterex = [(str(v), i + r.min) for (v, r), i in zip(vranges.items(), indices)]
counterex = sorted(counterex, key=lambda x: x[0])
counterex = ", ".join([v + " = " + str(i) for v, i in counterex])
raise AssertionError("Expression {}\nis not true on {}\n"
"Counterexample: {}"
.format(tir.ir_pass.CanonicalSimplify(bool_expr), vranges, counterex))
def check_solution(solution, vranges={}):
"""Check that solution is a bijective transformation"""
def _check_forward(constraints1, constraints2, varmap, backvarmap):
all_vranges = vranges.copy()
all_vranges.update({v: r for v, r in constraints1.ranges.items()})
# Check that the transformation is injective
cond_on_vars = tir.const(1, 'bool')
for v in constraints1.variables:
# variable mapping is consistent
v_back = tir.ir_pass.Simplify(tir.ir_pass.Substitute(varmap[v], backvarmap))
cond_on_vars = te.all(cond_on_vars, v == v_back)
# Also we have to check that the new relations are true when old relations are true
cond_subst = tir.ir_pass.Substitute(
te.all(tir.const(1, 'bool'), *constraints2.relations), backvarmap)
# We have to include relations from vranges too
for v in constraints2.variables:
if v in constraints2.ranges:
r = constraints2.ranges[v]
range_cond = te.all(v >= r.min, v < r.min + r.extent)
range_cond = tir.ir_pass.Substitute(range_cond, backvarmap)
cond_subst = te.all(cond_subst, range_cond)
cond_subst = tir.ir_pass.Simplify(cond_subst)
check_bruteforce(te.all(cond_subst, cond_on_vars), all_vranges,
cond=te.all(tir.const(1, 'bool'), *constraints1.relations))
rels = solution.dst.relations
if len(rels) == 1 and ir.structural_equal(rels[0], False):
# not solvable, skip
return
_check_forward(solution.src, solution.dst,
solution.src_to_dst, solution.dst_to_src)
_check_forward(solution.dst, solution.src,
solution.dst_to_src, solution.src_to_dst)
def test_solution_consistency():
seed = random.randrange(sys.maxsize)
print("\nThis test is intentionally non-deterministic, "
"if it fails please report it in github issue together with this seed {}\n".format(seed))
random.seed(seed)
def _check(num_vars, num_formulas, coef=(-5, 5), bounds=(-20, 20)):
variables = [te.var("x" + str(i)) for i in range(num_vars)]
relations = []
for i in range(num_formulas):
s1 = sum([v*random.randint(coef[0], coef[1]) for v in variables])
s1 += random.randint(coef[0], coef[1])
s2 = sum([v*random.randint(coef[0], coef[1]) for v in variables])
s2 += random.randint(coef[0], coef[1])
if random.random() < 0.7:
op = tvm.tir.EQ
else:
# we also make sure it can correctly handle inequalities
op = random.choice([tvm.tir.LE, tvm.tir.LT, tvm.tir.GE, tvm.tir.GT])
relations.append(op(s1, s2))
vranges = {v: tvm.ir.expr.Range(bounds[0], bounds[1] + 1) for v in variables}
solution = arith.solve_linear_equations(relations, variables, vranges)
check_solution(solution)
# leaving some variables as parameters should also be ok
for k in [1, 2]:
if len(variables) > k:
solution = arith.solve_linear_equations(relations, variables[:-k], vranges)
param_ranges = {v: vranges[v] for v in variables[-k:]}
check_solution(solution, param_ranges)
for i in range(2):
_check(num_vars=1, num_formulas=1)
for i in range(2):
_check(num_vars=1, num_formulas=2)
for i in range(2):
_check(num_vars=2, num_formulas=1)
for i in range(2):
_check(num_vars=2, num_formulas=2)
for i in range(2):
_check(num_vars=2, num_formulas=3)
for i in range(3):
_check(num_vars=3, num_formulas=3, coef=(-2, 2))
for i in range(3):
_check(num_vars=3, num_formulas=4, coef=(-2, 2))
for i in range(3):
_check(num_vars=4, num_formulas=3, coef=(-1, 1))
for i in range(3):
_check(num_vars=10, num_formulas=2, coef=(-1, 1), bounds=(0, 4))
for i in range(3):
_check(num_vars=10, num_formulas=3, coef=(0, 1), bounds=(0, 4))
def test_empty_var_to_solve():
x, y = te.var("x"), te.var("y")
equations = [
tvm.tir.EQ(x + y, 20),
tvm.tir.EQ(x - y, 10),
]
solution = arith.solve_linear_equations(equations)
assert len(solution.src_to_dst) == 0
assert len(solution.dst_to_src) == 0
assert len(solution.src.variables) == 0
assert len(solution.src.ranges) == 0
assert ir.structural_equal(solution.src.relations, equations)
assert ir.structural_equal(solution.src, solution.dst)
def test_unique_solution():
x, y = te.var("x"), te.var("y")
solution = arith.solve_linear_equations([
tvm.tir.EQ(x + y, 20),
tvm.tir.EQ(x - y, 10),
], [x, y])
assert list(solution.dst.variables) == []
assert ir.structural_equal(solution.src_to_dst[x], 15)
assert ir.structural_equal(solution.src_to_dst[y], 5)
def test_low_rank():
x, y, z = te.var("x"), te.var("y"), te.var("z")
ranges = {}
solution = arith.solve_linear_equations([
tvm.tir.EQ(x + y + z, 15),
tvm.tir.EQ(x + y, 10),
], [x, y, z], ranges)
[n0] = solution.dst.variables
assert ir.structural_equal(solution.src_to_dst[x], n0 + 10)
assert ir.structural_equal(solution.src_to_dst[y], -n0)
assert ir.structural_equal(solution.src_to_dst[z], 5)
def test_infer_range():
x, y = te.var("x"), te.var("y")
ranges = {
x: tvm.ir.Range.make_by_min_extent(-5, 10),
y: tvm.ir.Range.make_by_min_extent(0, 10),
}
solution = arith.solve_linear_equations([
tvm.tir.EQ(x + y, 0),
], [x, y], ranges)
[n0] = solution.dst.variables
assert ir.structural_equal(solution.src_to_dst[x], n0)
assert ir.structural_equal(solution.src_to_dst[y], -n0)
# inferred from y's range
assert ir.structural_equal(solution.dst.ranges[n0].min, -9)
assert ir.structural_equal(solution.dst.ranges[n0].extent, 10)
# additional inequality is added into the system for x
[ineq] = solution.dst.relations
assert isinstance(ineq, tvm.tir.LE)
assert ir.structural_equal(ineq.a, -5)
assert ir.structural_equal(ineq.b, n0)
def test_ill_formed():
x, y = te.var("x"), te.var("y")
solution = arith.solve_linear_equations([
tvm.tir.EQ(x + y, 0),
tvm.tir.EQ(x - y, 0),
tvm.tir.EQ(x, 5),
], [x, y], {})
assert list(solution.dst.variables) == []
[rel] = solution.dst.relations
assert ir.structural_equal(rel, False)
assert len(solution.src_to_dst) == 0
assert len(solution.dst_to_src) == 0
if __name__ == "__main__":
pytest.main([__file__])
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment