/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file tvm/arith/solve_linear_equation.cc
 * \brief Solve linear equations.
 */
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/arith/analyzer.h>
#include <tvm/arith/int_solver.h>
#include <tvm/arith/util.h>
#include <tvm/arith/pattern.h>

#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/runtime/data_type.h>

namespace tvm {
namespace arith {

using namespace tvm::runtime;

void SmithNormalFormDiag(std::vector<std::vector<int64_t> >* S,
                         std::vector<std::vector<int64_t> >* V,
                         std::vector<PrimExpr>* x,
                         std::vector<PrimExpr>* y) {
  if (S->empty() || V->empty()) return;
  size_t m = S->size();
  size_t n = (*S)[0].size();  // n is # of variables
  CHECK_EQ(V->size(), n);
  CHECK_EQ((*V)[0].size(), n);

  for (size_t index = 0; index < std::min(m, n); ++index) {
    // Here A is partially diagonalized, that is A[i, j] is zero for all i, j
    // such that (i < index) or (j < index), unless (i == j).
    // That is, now we are diagonalizing the submatrix with i >= index and j >= index

    // Find a row with a nonzero element in the index-th column
    // (We also prefer rows where this element has minimal abs value)
    size_t best_i = index;
    for (size_t i = best_i; i < m; ++i) {
      int64_t s_old = (*S)[best_i][index];
      int64_t s_new = (*S)[i][index];
      if (s_new != 0) {
        if (s_old == 0 || std::abs(s_new) < std::abs(s_old)) {
          best_i = i;
        }
      }
    }
    // Move the row we found to the index-th position
    std::swap((*S)[index], (*S)[best_i]);
    std::swap((*y)[index], (*y)[best_i]);

    // If the index-th diagonal element is still zero, try to find a column with nonzero index-th
    // element and move it to the index-th position
    if ((*S)[index][index] == 0) {
      for (size_t j = index + 1; j < n; ++j) {
        if ((*S)[index][j] != 0) {
          for (size_t i = index; i < m; ++i) {
            std::swap((*S)[i][index], (*S)[i][j]);
          }
          // swapping columns corresponds to swapping the corresponding x
          std::swap((*x)[index], (*x)[j]);
          for (size_t i = 0; i < n; ++i) {
            std::swap((*V)[i][index], (*V)[i][j]);
          }
          break;
        }
      }
    }

    // If the index-th diagonal element is still zero, then both the index-th row and the index-th
    // column are completely zero, and we don't need to do anything; just go to the next index
    if ((*S)[index][index] == 0) {
      continue;
    }

    // Now the index-th diagonal element is non-zero and we can zero all the index-th column
    // below it by subtracting rows from each other
    for (auto i = index + 1; i < m; ++i) {
      if ((*S)[i][index] != 0) {
        int64_t g, a, b;
        // g = a*matrix[index][index] + b*matrix[i][index]
        if ((*S)[i][index] % (*S)[index][index] != 0) {
          std::tie(g, a, b) = xgcd((*S)[index][index], (*S)[i][index]);
        } else {
          // Explicitly avoid changing the index-th row. This is important to avoid infinite loop.
          g = (*S)[index][index];
          a = 1;
          b = 0;
        }

        // Let m = S[index][index], n = S[i][index], then the following is true:
        //
        // [ a   n/g ][ m/g  n/g ] = [ 1  0 ]
        // [ b  -m/g ][ b    -a  ] = [ 0  1 ]
        //
        // Note that the two matrices are integer (since g = gcd(m, n)).
        // We will essentially multiply our matrix on the left by a dilated and transposed version
        // of the first of these two matrices. The second matrix is not needed here, however we will
        // use it while zeroing the index-th row.

        int64_t m_g = (*S)[index][index] / g;
        int64_t n_g = (*S)[i][index] / g;

        // Note that j is the index of the column, not the row
        for (size_t j = index; j < (*S)[i].size(); ++j) {
          // Multiply index-th row by a and add the i-th row multiplied by b
          // This will make the index-th diagonal element equal to the gcd
          int64_t new_index_j = a*(*S)[index][j] + b*(*S)[i][j];
          // This transformation performs zeroing of matrix[i][index]
          int64_t new_i_j = n_g*(*S)[index][j] - m_g*(*S)[i][j];
          (*S)[index][j] = new_index_j;
          (*S)[i][j] = new_i_j;
        }
        // We have to do the same with rhs
        PrimExpr ea = tir::make_const((*y)[index].dtype(), a);
        PrimExpr eb = tir::make_const((*y)[i].dtype(), b);
        PrimExpr e_m_g = tir::make_const((*y)[i].dtype(), m_g);
        PrimExpr e_n_g = tir::make_const((*y)[index].dtype(), n_g);
        PrimExpr new_index_rhs = ea*(*y)[index] + eb*(*y)[i];
        PrimExpr new_i_rhs = e_n_g*(*y)[index] - e_m_g*(*y)[i];
        (*y)[index] = new_index_rhs;
        (*y)[i] = new_i_rhs;
      }
    }

    bool changed = false;

    // Now we have to zero the elements of the index-th row by manipulating columns.
    // This is more difficult because column manipulation corresponds to variable manipulation,
    // but the algorithm is essentially the same as before.
    for (size_t j = index + 1; j < n; ++j) {
      if ((*S)[index][j] != 0) {
        int64_t g, a, b;
        // g = a*matrix[index][index] + b*matrix[index][j]
        if ((*S)[index][j] % (*S)[index][index] != 0) {
          std::tie(g, a, b) = xgcd((*S)[index][index], (*S)[index][j]);
          // During this phase we may disrupt the zeroness of the index-th column, so we will
          // have to take some action if this might have happened.
          changed = true;
        } else {
          // Explicitly avoid changing the index-th column. This is important to avoid infinite
          // loop. Note that here we don't have to set `changed` to true since we don't change the
          // index-th column.
          g = (*S)[index][index];
          a = 1;
          b = 0;
        }

        // Let m = S[index][index], n = S[index][j], then the following is true:
        //
        // [ a   n/g ][ m/g  n/g ] = [ 1  0 ]
        // [ b  -m/g ][ b    -a  ] = [ 0  1 ]
        //
        // Now we are going to multiply our matrix on the right (to manipulate columns instead of
        // rows), we will also transform the old_to_new matrix the same way, and we will use the
        // second matrix to transform new_to_old.

        int64_t m_g = (*S)[index][index] / g;
        int64_t n_g = (*S)[index][j] / g;

        for (size_t i = index; i < m; ++i) {
          int64_t new_i_index = a*(*S)[i][index] + b*(*S)[i][j];
          int64_t new_i_j = n_g*(*S)[i][index] - m_g*(*S)[i][j];
          (*S)[i][index] = new_i_index;
          (*S)[i][j] = new_i_j;
        }
        // We do exactly the same transformations with V
        for (size_t i = 0; i < n; ++i) {
          int64_t new_i_index = a*(*V)[i][index] + b*(*V)[i][j];
          int64_t new_i_j = n_g*(*V)[i][index] - m_g*(*V)[i][j];
          (*V)[i][index] = new_i_index;
          (*V)[i][j] = new_i_j;
        }
        // And apply reverse transformations to new_to_old.
        PrimExpr ea = tir::make_const((*x)[j].dtype(), a);
        PrimExpr eb = tir::make_const((*x)[index].dtype(), b);
        PrimExpr e_m_g = tir::make_const((*x)[index].dtype(), m_g);
        PrimExpr e_n_g = tir::make_const((*x)[j].dtype(), n_g);
        PrimExpr new_index = e_m_g*(*x)[index] + e_n_g*(*x)[j];
        PrimExpr new_j = eb*(*x)[index] - ea*(*x)[j];
        (*x)[index] = new_index;
        (*x)[j] = new_j;
      }
    }

    if (changed) {
      // We might have changed the first column, so we have to zero it once more
      // (or at least check if it's zero), so just perform this iteration once more.
      index -= 1;
    }
  }
}

Map<Var, Range> InferRange(const Map<Var, PrimExpr>& vars_to_infer,
                           const Array<Var>& ori_vars,
                           const Map<Var, Range>& ori_ranges) {
  // The resulting ranges
  Map<Var, Range> new_ranges;

  std::unordered_set<const VarNode*> ori_vset;
  for (const Var& v : ori_vars) {
    ori_vset.insert(v.get());
  }

  std::unordered_map<const VarNode*, IntSet> var_intsets;
  for (const auto& p : ori_ranges) {
    if (!ori_vset.count(p.first.get())) {
      // First of all, fill the new ranges with outer variable ranges
      new_ranges.Set(p.first, p.second);
    }
    // Convert original ranges to IntSets
    var_intsets[p.first.get()] = IntSet::range(p.second);
  }

  // Infer ranges for the new variables and add them to the resulting ranges
  for (const auto& p : vars_to_infer) {
    const auto& var = p.first;
    const auto& expr = p.second;
    Range range = EvalSet(expr, var_intsets).cover_range(Range());
    if (range.defined()) {
      new_ranges.Set(var, range);
    }
  }
  return new_ranges;
}

// pretty print matrix equation
void DebugPrint(const std::vector<std::vector<int64_t>>& S,
                const std::vector<std::vector<int64_t>>& V,
                const std::vector<PrimExpr>& V_inv_x,
                const std::vector<PrimExpr>& rhs) {
  std::cout << "S:\n";
  for (size_t i = 0; i < S.size(); ++i) {
    for (auto e : S[i]) {
      std::cout << e << "\t";
    }
    std::cout << "\t->\t" << rhs[i];
    std::cout << "\n";
  }
  std::cout << "V:\n";
  for (const auto& r : V) {
    for (auto e : r) {
      std::cout << e << "\t";
    }
    std::cout << "\n";
  }
  std::cout << "V_inv x:\n" << Array<PrimExpr>(V_inv_x);
  std::cout << "\n" << std::endl;
}

IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_solve) {
  // m: # of equations
  // n: # of variables
  // we first construct A_{mxn} x_{nx1} = y_{mx1}
  // then get Smith normal form of matrix A,
  // S_{mxn} = U_{mxm} A_{mxn} V_{nxn}
  // => U^{-1} S V^{-1} x = y
  // S V^{-1} x = U y
  std::vector<PrimExpr> Uy;  // mx1
  std::vector<std::vector<int64_t>> S;  // mxn
  std::vector<std::vector<int64_t>> V;  // nxn
  std::vector<PrimExpr> V_inv_x;  // V^{-1} x, nx1
  // Conditions we don't know what to do with
  std::vector<PrimExpr> rest;

  Analyzer analyzer_problem;
  analyzer_problem.Bind(system_to_solve->ranges);

  size_t num_vars = system_to_solve->variables.size();

  // initialize V_{nxn} with identity matrix,
  // initialize V^{-1} x as x
  for (size_t i = 0; i < num_vars; ++i) {
    V.emplace_back(num_vars);
    V.back()[i] = 1;
    V_inv_x.push_back(system_to_solve->variables[i]);
  }

  // Transform formulas into rows of the matrix
  // S_{mxn} V^{-1}_{nxn} x_{nx1} = U y, in which n is # of variables
  // here we initialize S_{mxn} to be A, U to be identity matrix.
  for (const PrimExpr& equation : system_to_solve->relations) {
    if (const tir::EQNode* eq = equation.as<tir::EQNode>()) {
      // a-b = sum_{i=0}^{n-1} variables[i] * coeff[i] + coeff[n]
      Array<PrimExpr> coeffs = arith::DetectLinearEquation(
          analyzer_problem.Simplify(eq->a - eq->b),
          system_to_solve->variables);
      if (!coeffs.empty()) {
        std::vector<int64_t> row;
        for (size_t j = 0; j < coeffs.size() - 1; ++j) {
          PrimExpr c = coeffs[j];
          if (const IntImmNode* ic = c.as<IntImmNode>()) {
            row.push_back(ic->value);
          } else {
            // elements in matrix S V must be integers
            // ignore equations that we cannot deal with.
            LOG(WARNING) << "Cannot deal with non-integer coefficients, ignore equation "
                         << equation;
            row.clear();
            break;
          }
        }

        if (!row.empty()) {
          // S V^{-1} (a-b) = Uy
          // V is identity for now
          S.push_back(row);
          Uy.push_back(-coeffs[coeffs.size() - 1]);
          continue;
        }
      }
    }

    // otherwise
    rest.push_back(equation);
  }

  // After diagonalizing, we have
  // S_{mxn} is the Smith normal form (diagonal matrix)
  // V_{nxn} is invertible
  // V_inv_x is V^{-1} \times x
  // Uy is U \times y
  SmithNormalFormDiag(&S, &V, &V_inv_x, &Uy);

  Array<Var> new_vars;
  Array<PrimExpr> new_relations;
  Map<Var, PrimExpr> new_to_old_map;
  Map<Var, PrimExpr> old_to_new_map;

  // Simplify right hand sides
  for (PrimExpr r : Uy) {
    r = analyzer_problem.Simplify(r);
  }

  // Create the relations of the existence of a solution
  for (size_t j = 0; j < S.size(); ++j) {
    PrimExpr new_relation;
    if (j >= num_vars || S[j][j] == 0) {
      // The row of matrix is zero. A solution exists only if the Ub[j] is also zero
      new_relation = (Uy[j] == 0);
    } else {
      // The diagonal element is non-zero. A solution exists only if the diagonal element
      // is a divisor of the Ub[j]
      new_relation = (floormod(Uy[j], std::abs(S[j][j])) == 0);
    }
    new_relation = analyzer_problem.Simplify(new_relation);
    if (tir::is_const_int(new_relation, 0)) {
      // unable to solve the system.
      return IntConstraintsTransform(
          system_to_solve,
          IntConstraints(
              /*variables=*/{},
              /*ranges=*/{},
              /*relations=*/{tir::make_zero(DataType::Bool())}),
          {}, {});
    } else if (!tir::is_const_int(new_relation, 1)) {
      new_relations.push_back(new_relation);
    }
  }

  Array<PrimExpr> solution_for_V_inv_x;
  // Now create new variables or directly solve the equations
  // suppose the rank of A is r, aka r = # of non-zeros in S
  // the solution of S_{mxn} V^{-1}_{nxn} x_{nx1} = U b
  // is
  // x = (pseudo-inverse of A) b + K_{(n)x(n-r)} z_{n-r}
  //   = V_{nxn} S^{-1}_{nxm} (Ub)_{mxn} + K_{(n)x(n-r)} z_{n-r}
  // in which K is the right n-r columns of V, z is variable vector
  // thus,
  // V^{-1} x = S^{-1}_{nxm} (Ub)_{mxn} +
  //            [[0, ... 0]_{n-r}, ... [0, ..., 0], diag(1, ..., 1)_{(n-r)x(n-r)}] z_{n-r}
  for (size_t j = 0; j < num_vars; ++j) {
    if (j >= S.size() || S[j][j] == 0) {
      // The j-th variable can take any integer value, create a tvm variable for it
      PrimExpr to_old = analyzer_problem.Simplify(V_inv_x[j]);
      std::string name_hint = "n" + std::to_string(new_vars.size());
      if (const VarNode* v_old = to_old.as<VarNode>()) {
        name_hint += "_" + v_old->name_hint;
      }
      Var v = Var(name_hint, V_inv_x[j].dtype());
      solution_for_V_inv_x.push_back(v);
      new_vars.push_back(v);
      new_to_old_map.Set(v, to_old);
    } else {
      // The j-th variable is just a single value, don't create a tvm variable
      // S^{-1}_{nxm} Uy_{mxn}
      if (S[j][j] >= 0) {
        PrimExpr a = tir::make_const(Uy[j].dtype(), S[j][j]);
        solution_for_V_inv_x.push_back(
            analyzer_problem.Simplify(floordiv(Uy[j], a)));
      } else {
        // This is required because some simplifiers
        // have problems with dividing by negative numbers
        PrimExpr a = tir::make_const(Uy[j].dtype(), -S[j][j]);
        solution_for_V_inv_x.push_back(
            analyzer_problem.Simplify(floordiv(-Uy[j], a)));
      }
    }
  }

  // V V^{-1} x = x
  for (size_t i = 0; i < num_vars; ++i) {
    PrimExpr e = tir::make_zero(system_to_solve->variables[i].dtype());
    for (size_t j = 0; j < num_vars; ++j) {
      e = e + tir::make_const(e.dtype(), V[i][j])*solution_for_V_inv_x[j];
    }
    e = analyzer_problem.Simplify(e);
    old_to_new_map.Set(system_to_solve->variables[i], e);
  }

  // The resulting ranges
  Map<Var, Range> new_ranges = InferRange(
      new_to_old_map, system_to_solve->variables, system_to_solve->ranges);
  Analyzer analyzer_solution;
  analyzer_solution.Bind(new_ranges);

  // We have to transform ranges of the old variables into relations over new variables because
  // new ranges are not enough usually.
  for (const auto& p : system_to_solve->ranges) {
    const Var& old_var = p.first;
    const Range& old_range = p.second;
    if (old_to_new_map.count(old_var)) {
      PrimExpr express_by_new_vars = old_to_new_map[old_var];
      PrimExpr lower_cond = analyzer_solution.Simplify(
          old_range->min <= express_by_new_vars);
      PrimExpr upper_cond = analyzer_solution.Simplify(
          express_by_new_vars < old_range->min + old_range->extent);
      if (!tir::is_const_int(lower_cond, 1)) {
        new_relations.push_back(lower_cond);
      }
      if (!tir::is_const_int(upper_cond, 1)) {
        new_relations.push_back(upper_cond);
      }
    }
  }

  // Add the rest conditions
  for (const PrimExpr& cond : rest) {
    new_relations.push_back(Substitute(cond, old_to_new_map));
  }

  IntConstraints solution(new_vars, new_ranges, new_relations);
  IntConstraintsTransform transform(
      system_to_solve, solution, old_to_new_map, new_to_old_map);

  return transform;
}

TVM_REGISTER_GLOBAL("arith.SolveLinearEquations")
.set_body([](TVMArgs args, TVMRetValue *ret) {
    if (args.size() == 1) {
      *ret = SolveLinearEquations(args[0]);
    } else if (args.size() == 3) {
      IntConstraints problem(args[0], args[1], args[2]);
      *ret = SolveLinearEquations(problem);
    } else {
      LOG(FATAL) << "arith.SolveLinearEquations expects 1 or 3 arguments, gets " << args.size();
    }
  });

}  // namespace arith
}  // namespace tvm