/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file tvm/arith/solve_linear_equation.cc * \brief Solve linear equations. */ #include <tvm/runtime/registry.h> #include <tvm/tir/expr.h> #include <tvm/arith/analyzer.h> #include <tvm/arith/int_solver.h> #include <tvm/arith/util.h> #include <tvm/arith/pattern.h> #include <tvm/tir/op.h> #include <tvm/tir/stmt_functor.h> #include <tvm/runtime/data_type.h> namespace tvm { namespace arith { using namespace tvm::runtime; void SmithNormalFormDiag(std::vector<std::vector<int64_t> >* S, std::vector<std::vector<int64_t> >* V, std::vector<PrimExpr>* x, std::vector<PrimExpr>* y) { if (S->empty() || V->empty()) return; size_t m = S->size(); size_t n = (*S)[0].size(); // n is # of variables CHECK_EQ(V->size(), n); CHECK_EQ((*V)[0].size(), n); for (size_t index = 0; index < std::min(m, n); ++index) { // Here A is partially diagonalized, that is A[i, j] is zero for all i, j // such that (i < index) or (j < index), unless (i == j). // That is, now we are diagonalizing the submatrix with i >= index and j >= index // Find a row with a nonzero element in the index-th column // (We also prefer rows where this element has minimal abs value) size_t best_i = index; for (size_t i = best_i; i < m; ++i) { int64_t s_old = (*S)[best_i][index]; int64_t s_new = (*S)[i][index]; if (s_new != 0) { if (s_old == 0 || std::abs(s_new) < std::abs(s_old)) { best_i = i; } } } // Move the row we found to the index-th position std::swap((*S)[index], (*S)[best_i]); std::swap((*y)[index], (*y)[best_i]); // If the index-th diagonal element is still zero, try to find a column with nonzero index-th // element and move it to the index-th position if ((*S)[index][index] == 0) { for (size_t j = index + 1; j < n; ++j) { if ((*S)[index][j] != 0) { for (size_t i = index; i < m; ++i) { std::swap((*S)[i][index], (*S)[i][j]); } // swapping columns corresponds to swapping the corresponding x std::swap((*x)[index], (*x)[j]); for (size_t i = 0; i < n; ++i) { std::swap((*V)[i][index], (*V)[i][j]); } break; } } } // If the index-th diagonal element is still zero, then both the index-th row and the index-th // column are completely zero, and we don't need to do anything; just go to the next index if ((*S)[index][index] == 0) { continue; } // Now the index-th diagonal element is non-zero and we can zero all the index-th column // below it by subtracting rows from each other for (auto i = index + 1; i < m; ++i) { if ((*S)[i][index] != 0) { int64_t g, a, b; // g = a*matrix[index][index] + b*matrix[i][index] if ((*S)[i][index] % (*S)[index][index] != 0) { std::tie(g, a, b) = xgcd((*S)[index][index], (*S)[i][index]); } else { // Explicitly avoid changing the index-th row. This is important to avoid infinite loop. g = (*S)[index][index]; a = 1; b = 0; } // Let m = S[index][index], n = S[i][index], then the following is true: // // [ a n/g ][ m/g n/g ] = [ 1 0 ] // [ b -m/g ][ b -a ] = [ 0 1 ] // // Note that the two matrices are integer (since g = gcd(m, n)). // We will essentially multiply our matrix on the left by a dilated and transposed version // of the first of these two matrices. The second matrix is not needed here, however we will // use it while zeroing the index-th row. int64_t m_g = (*S)[index][index] / g; int64_t n_g = (*S)[i][index] / g; // Note that j is the index of the column, not the row for (size_t j = index; j < (*S)[i].size(); ++j) { // Multiply index-th row by a and add the i-th row multiplied by b // This will make the index-th diagonal element equal to the gcd int64_t new_index_j = a*(*S)[index][j] + b*(*S)[i][j]; // This transformation performs zeroing of matrix[i][index] int64_t new_i_j = n_g*(*S)[index][j] - m_g*(*S)[i][j]; (*S)[index][j] = new_index_j; (*S)[i][j] = new_i_j; } // We have to do the same with rhs PrimExpr ea = tir::make_const((*y)[index].dtype(), a); PrimExpr eb = tir::make_const((*y)[i].dtype(), b); PrimExpr e_m_g = tir::make_const((*y)[i].dtype(), m_g); PrimExpr e_n_g = tir::make_const((*y)[index].dtype(), n_g); PrimExpr new_index_rhs = ea*(*y)[index] + eb*(*y)[i]; PrimExpr new_i_rhs = e_n_g*(*y)[index] - e_m_g*(*y)[i]; (*y)[index] = new_index_rhs; (*y)[i] = new_i_rhs; } } bool changed = false; // Now we have to zero the elements of the index-th row by manipulating columns. // This is more difficult because column manipulation corresponds to variable manipulation, // but the algorithm is essentially the same as before. for (size_t j = index + 1; j < n; ++j) { if ((*S)[index][j] != 0) { int64_t g, a, b; // g = a*matrix[index][index] + b*matrix[index][j] if ((*S)[index][j] % (*S)[index][index] != 0) { std::tie(g, a, b) = xgcd((*S)[index][index], (*S)[index][j]); // During this phase we may disrupt the zeroness of the index-th column, so we will // have to take some action if this might have happened. changed = true; } else { // Explicitly avoid changing the index-th column. This is important to avoid infinite // loop. Note that here we don't have to set `changed` to true since we don't change the // index-th column. g = (*S)[index][index]; a = 1; b = 0; } // Let m = S[index][index], n = S[index][j], then the following is true: // // [ a n/g ][ m/g n/g ] = [ 1 0 ] // [ b -m/g ][ b -a ] = [ 0 1 ] // // Now we are going to multiply our matrix on the right (to manipulate columns instead of // rows), we will also transform the old_to_new matrix the same way, and we will use the // second matrix to transform new_to_old. int64_t m_g = (*S)[index][index] / g; int64_t n_g = (*S)[index][j] / g; for (size_t i = index; i < m; ++i) { int64_t new_i_index = a*(*S)[i][index] + b*(*S)[i][j]; int64_t new_i_j = n_g*(*S)[i][index] - m_g*(*S)[i][j]; (*S)[i][index] = new_i_index; (*S)[i][j] = new_i_j; } // We do exactly the same transformations with V for (size_t i = 0; i < n; ++i) { int64_t new_i_index = a*(*V)[i][index] + b*(*V)[i][j]; int64_t new_i_j = n_g*(*V)[i][index] - m_g*(*V)[i][j]; (*V)[i][index] = new_i_index; (*V)[i][j] = new_i_j; } // And apply reverse transformations to new_to_old. PrimExpr ea = tir::make_const((*x)[j].dtype(), a); PrimExpr eb = tir::make_const((*x)[index].dtype(), b); PrimExpr e_m_g = tir::make_const((*x)[index].dtype(), m_g); PrimExpr e_n_g = tir::make_const((*x)[j].dtype(), n_g); PrimExpr new_index = e_m_g*(*x)[index] + e_n_g*(*x)[j]; PrimExpr new_j = eb*(*x)[index] - ea*(*x)[j]; (*x)[index] = new_index; (*x)[j] = new_j; } } if (changed) { // We might have changed the first column, so we have to zero it once more // (or at least check if it's zero), so just perform this iteration once more. index -= 1; } } } Map<Var, Range> InferRange(const Map<Var, PrimExpr>& vars_to_infer, const Array<Var>& ori_vars, const Map<Var, Range>& ori_ranges) { // The resulting ranges Map<Var, Range> new_ranges; std::unordered_set<const VarNode*> ori_vset; for (const Var& v : ori_vars) { ori_vset.insert(v.get()); } std::unordered_map<const VarNode*, IntSet> var_intsets; for (const auto& p : ori_ranges) { if (!ori_vset.count(p.first.get())) { // First of all, fill the new ranges with outer variable ranges new_ranges.Set(p.first, p.second); } // Convert original ranges to IntSets var_intsets[p.first.get()] = IntSet::range(p.second); } // Infer ranges for the new variables and add them to the resulting ranges for (const auto& p : vars_to_infer) { const auto& var = p.first; const auto& expr = p.second; Range range = EvalSet(expr, var_intsets).cover_range(Range()); if (range.defined()) { new_ranges.Set(var, range); } } return new_ranges; } // pretty print matrix equation void DebugPrint(const std::vector<std::vector<int64_t>>& S, const std::vector<std::vector<int64_t>>& V, const std::vector<PrimExpr>& V_inv_x, const std::vector<PrimExpr>& rhs) { std::cout << "S:\n"; for (size_t i = 0; i < S.size(); ++i) { for (auto e : S[i]) { std::cout << e << "\t"; } std::cout << "\t->\t" << rhs[i]; std::cout << "\n"; } std::cout << "V:\n"; for (const auto& r : V) { for (auto e : r) { std::cout << e << "\t"; } std::cout << "\n"; } std::cout << "V_inv x:\n" << Array<PrimExpr>(V_inv_x); std::cout << "\n" << std::endl; } IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_solve) { // m: # of equations // n: # of variables // we first construct A_{mxn} x_{nx1} = y_{mx1} // then get Smith normal form of matrix A, // S_{mxn} = U_{mxm} A_{mxn} V_{nxn} // => U^{-1} S V^{-1} x = y // S V^{-1} x = U y std::vector<PrimExpr> Uy; // mx1 std::vector<std::vector<int64_t>> S; // mxn std::vector<std::vector<int64_t>> V; // nxn std::vector<PrimExpr> V_inv_x; // V^{-1} x, nx1 // Conditions we don't know what to do with std::vector<PrimExpr> rest; Analyzer analyzer_problem; analyzer_problem.Bind(system_to_solve->ranges); size_t num_vars = system_to_solve->variables.size(); // initialize V_{nxn} with identity matrix, // initialize V^{-1} x as x for (size_t i = 0; i < num_vars; ++i) { V.emplace_back(num_vars); V.back()[i] = 1; V_inv_x.push_back(system_to_solve->variables[i]); } // Transform formulas into rows of the matrix // S_{mxn} V^{-1}_{nxn} x_{nx1} = U y, in which n is # of variables // here we initialize S_{mxn} to be A, U to be identity matrix. for (const PrimExpr& equation : system_to_solve->relations) { if (const tir::EQNode* eq = equation.as<tir::EQNode>()) { // a-b = sum_{i=0}^{n-1} variables[i] * coeff[i] + coeff[n] Array<PrimExpr> coeffs = arith::DetectLinearEquation( analyzer_problem.Simplify(eq->a - eq->b), system_to_solve->variables); if (!coeffs.empty()) { std::vector<int64_t> row; for (size_t j = 0; j < coeffs.size() - 1; ++j) { PrimExpr c = coeffs[j]; if (const IntImmNode* ic = c.as<IntImmNode>()) { row.push_back(ic->value); } else { // elements in matrix S V must be integers // ignore equations that we cannot deal with. LOG(WARNING) << "Cannot deal with non-integer coefficients, ignore equation " << equation; row.clear(); break; } } if (!row.empty()) { // S V^{-1} (a-b) = Uy // V is identity for now S.push_back(row); Uy.push_back(-coeffs[coeffs.size() - 1]); continue; } } } // otherwise rest.push_back(equation); } // After diagonalizing, we have // S_{mxn} is the Smith normal form (diagonal matrix) // V_{nxn} is invertible // V_inv_x is V^{-1} \times x // Uy is U \times y SmithNormalFormDiag(&S, &V, &V_inv_x, &Uy); Array<Var> new_vars; Array<PrimExpr> new_relations; Map<Var, PrimExpr> new_to_old_map; Map<Var, PrimExpr> old_to_new_map; // Simplify right hand sides for (PrimExpr r : Uy) { r = analyzer_problem.Simplify(r); } // Create the relations of the existence of a solution for (size_t j = 0; j < S.size(); ++j) { PrimExpr new_relation; if (j >= num_vars || S[j][j] == 0) { // The row of matrix is zero. A solution exists only if the Ub[j] is also zero new_relation = (Uy[j] == 0); } else { // The diagonal element is non-zero. A solution exists only if the diagonal element // is a divisor of the Ub[j] new_relation = (floormod(Uy[j], std::abs(S[j][j])) == 0); } new_relation = analyzer_problem.Simplify(new_relation); if (tir::is_const_int(new_relation, 0)) { // unable to solve the system. return IntConstraintsTransform( system_to_solve, IntConstraints( /*variables=*/{}, /*ranges=*/{}, /*relations=*/{tir::make_zero(DataType::Bool())}), {}, {}); } else if (!tir::is_const_int(new_relation, 1)) { new_relations.push_back(new_relation); } } Array<PrimExpr> solution_for_V_inv_x; // Now create new variables or directly solve the equations // suppose the rank of A is r, aka r = # of non-zeros in S // the solution of S_{mxn} V^{-1}_{nxn} x_{nx1} = U b // is // x = (pseudo-inverse of A) b + K_{(n)x(n-r)} z_{n-r} // = V_{nxn} S^{-1}_{nxm} (Ub)_{mxn} + K_{(n)x(n-r)} z_{n-r} // in which K is the right n-r columns of V, z is variable vector // thus, // V^{-1} x = S^{-1}_{nxm} (Ub)_{mxn} + // [[0, ... 0]_{n-r}, ... [0, ..., 0], diag(1, ..., 1)_{(n-r)x(n-r)}] z_{n-r} for (size_t j = 0; j < num_vars; ++j) { if (j >= S.size() || S[j][j] == 0) { // The j-th variable can take any integer value, create a tvm variable for it PrimExpr to_old = analyzer_problem.Simplify(V_inv_x[j]); std::string name_hint = "n" + std::to_string(new_vars.size()); if (const VarNode* v_old = to_old.as<VarNode>()) { name_hint += "_" + v_old->name_hint; } Var v = Var(name_hint, V_inv_x[j].dtype()); solution_for_V_inv_x.push_back(v); new_vars.push_back(v); new_to_old_map.Set(v, to_old); } else { // The j-th variable is just a single value, don't create a tvm variable // S^{-1}_{nxm} Uy_{mxn} if (S[j][j] >= 0) { PrimExpr a = tir::make_const(Uy[j].dtype(), S[j][j]); solution_for_V_inv_x.push_back( analyzer_problem.Simplify(floordiv(Uy[j], a))); } else { // This is required because some simplifiers // have problems with dividing by negative numbers PrimExpr a = tir::make_const(Uy[j].dtype(), -S[j][j]); solution_for_V_inv_x.push_back( analyzer_problem.Simplify(floordiv(-Uy[j], a))); } } } // V V^{-1} x = x for (size_t i = 0; i < num_vars; ++i) { PrimExpr e = tir::make_zero(system_to_solve->variables[i].dtype()); for (size_t j = 0; j < num_vars; ++j) { e = e + tir::make_const(e.dtype(), V[i][j])*solution_for_V_inv_x[j]; } e = analyzer_problem.Simplify(e); old_to_new_map.Set(system_to_solve->variables[i], e); } // The resulting ranges Map<Var, Range> new_ranges = InferRange( new_to_old_map, system_to_solve->variables, system_to_solve->ranges); Analyzer analyzer_solution; analyzer_solution.Bind(new_ranges); // We have to transform ranges of the old variables into relations over new variables because // new ranges are not enough usually. for (const auto& p : system_to_solve->ranges) { const Var& old_var = p.first; const Range& old_range = p.second; if (old_to_new_map.count(old_var)) { PrimExpr express_by_new_vars = old_to_new_map[old_var]; PrimExpr lower_cond = analyzer_solution.Simplify( old_range->min <= express_by_new_vars); PrimExpr upper_cond = analyzer_solution.Simplify( express_by_new_vars < old_range->min + old_range->extent); if (!tir::is_const_int(lower_cond, 1)) { new_relations.push_back(lower_cond); } if (!tir::is_const_int(upper_cond, 1)) { new_relations.push_back(upper_cond); } } } // Add the rest conditions for (const PrimExpr& cond : rest) { new_relations.push_back(Substitute(cond, old_to_new_map)); } IntConstraints solution(new_vars, new_ranges, new_relations); IntConstraintsTransform transform( system_to_solve, solution, old_to_new_map, new_to_old_map); return transform; } TVM_REGISTER_GLOBAL("arith.SolveLinearEquations") .set_body([](TVMArgs args, TVMRetValue *ret) { if (args.size() == 1) { *ret = SolveLinearEquations(args[0]); } else if (args.size() == 3) { IntConstraints problem(args[0], args[1], args[2]); *ret = SolveLinearEquations(problem); } else { LOG(FATAL) << "arith.SolveLinearEquations expects 1 or 3 arguments, gets " << args.size(); } }); } // namespace arith } // namespace tvm