/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 *  Copyright (c) 2017 by Contributors
 * \file bound_deducer.cc
 * \brief Utility to deduce bound of expression
 */
#include <tvm/expr.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/arithmetic.h>
#include <tvm/api_registry.h>

#include <unordered_set>
#include <unordered_map>

namespace tvm {
namespace arith {

using namespace ir;
using HalideIR::Internal::Interval;

// a visitor to find the path to the target variable
// from a expression.
class VariablePathFinder: public IRVisitor {
 public:
  explicit VariablePathFinder(Expr target) : target_(target) {}

  void Visit(const NodeRef& node) final {
    if (visited_.count(node.get()) != 0) return;
    visited_.insert(node.get());

    if (!found_) path_.push_back(node.get());
    if (node.same_as(target_)) found_ = true;
    IRVisitor::Visit(node);
    if (!found_) path_.pop_back();
  }

  std::vector<const Node*> path_;

 private:
  bool found_{false};
  Expr target_;
  std::unordered_set<const Node*> visited_;
};

// get the path to the variable,
// return empty vector to represent failure
std::vector<const Node*> GetPath(Expr target, Expr expr) {
  VariablePathFinder v(target);
  v.Visit(expr);
  return v.path_;
}

class BoundDeduceIntputChecker;

// a visitor to deduce the bound of a variable from a expression
class BoundDeducer: public IRVisitor {
 public:
  friend class BoundDeduceInputChecker;
  friend class Converter;
  BoundDeducer(Expr target, Expr expr,
               const std::unordered_map<const Variable*, IntSet>& hint_map,
               const std::unordered_map<const Variable*, IntSet>& relax_map)
  : target_(target), expr_(expr), hint_map_(hint_map), relax_map_(relax_map) {}

  void Deduce();

  void Visit(const NodeRef& e) final {
    if (!success) return;
    if (e.get() == path_[iter_++]) {
      IRVisitor::Visit(e);
    } else {
      success = false;
      return;
    }
  }

  void Visit_(const LT* op) final {
    LOG(FATAL) << "unable to deduce due to multiple comparison operator";
  }

  void Visit_(const LE* op) final {
    LOG(FATAL) << "unable to deduce due to multiple comparison operator";
  }

  void Visit_(const GT* op) final {
    LOG(FATAL) << "unable to deduce due to multiple comparison operator";
  }

  void Visit_(const GE* op) final {
    LOG(FATAL) << "unable to deduce due to multiple comparison operator";
  }

  void Visit_(const Add* op) final {
    bool left = op->a.get() == path_[iter_];
    result -= left ? op->b : op->a;
    Visit(left ? op->a : op->b);
  }

  void Visit_(const Sub* op) final {
    bool left = op->a.get() == path_[iter_];
    if (left) {
      result += op->b;
    } else {
      result -= op->a;
      result = - result;
      is_greater = !is_greater;
    }
    Visit(left ? op->a : op->b);
  }

  void Visit_(const Mul* op) final {
    bool left = op->a.get() == path_[iter_];
    Expr operand = left ? op->b : op->a;

    SignType sign;
    if (operand.type().is_uint()) {
      sign = kPositive;
    } else {
      sign = expr_map_[operand].sign_type();
    }

    if (sign == SignType::kNegative) {
      is_greater = !is_greater;
    } else if (sign == SignType::kUnknown) {
      // unable to get the sign of operand
      success = false;
      return;
    }

    // always use relax bound
    bool divided = can_prove(result % operand == 0);
    result = result / operand;
    // since system will round down when not divided
    // eg. 2/4 -> 0; -2/4 -> -1
    // no need fix for !is_greater:
    // eg. a <= 2/4 -> a <= 0
    // eg. a <= 0/4 -> a <= 0
    // so just fix for not divided and is_greater
    // eg. a >= 2/4 -> a >= 0 + 1
    // eg. a >= 0/4 -> a >= 0
    if (is_greater && !divided) {
       result += 1;
    }

    Visit(left ? op->a : op->b);
  }

  Expr result;
  bool is_greater{true};
  bool success{true};

 private:
  void Init();
  void Transform();
  void Relax();

  Expr target_;
  Expr expr_;
  const std::unordered_map<const Variable*, IntSet>& hint_map_;
  const std::unordered_map<const Variable*, IntSet>& relax_map_;
  ExprIntSetMap expr_map_;
  std::vector<const Node*> path_;
  size_t iter_{0};
};

class BoundDeduceInputChecker: public IRVisitor {
 public:
  bool Check(BoundDeducer* deducer) {
    deducer_ = deducer;
    Visit(deducer_->expr_);
    return target_count == 1;
  }

  void Visit(const NodeRef& e) final {
    if (e.same_as(deducer_->target_)) ++target_count;
    IRVisitor::Visit(e);
  }

 private:
  BoundDeducer* deducer_;
  size_t target_count{0};
};

void BoundDeducer::Init() {
  BoundDeduceInputChecker checker;
  if (!checker.Check(this)) success = false;
  Transform();
}

void BoundDeducer::Transform() {
  // We will ensure to set expr_ such that it contains target_
  if (const LT* op = expr_.as<LT>()) {
    if (GetPath(target_, op->a).empty()) {
      // a < b -> b >= a + 1
      is_greater = true;
      expr_ = op->b;
      result = op->a + 1;
    } else {
      // a < b -> a <= b - 1
      is_greater = false;
      expr_ = op->a;
      result = op->b - 1;
    }
  } else if (const LE* op = expr_.as<LE>()) {
    if (GetPath(target_, op->a).empty()) {
      // a <= b -> b >= a
      is_greater = true;
      expr_ = op->b;
      result = op->a;
    } else {
      is_greater = false;
      expr_ = op->a;
      result = op->b;
    }
  } else if (const GT* op = expr_.as<GT>()) {
    if (GetPath(target_, op->a).empty()) {
      // a > b -> b <= a - 1
      is_greater = false;
      expr_ = op->b;
      result = op->a - 1;
    } else {
      // a > b -> a >= b + 1
      is_greater = true;
      expr_ = op->a;
      result = op->b + 1;
    }
  } else if (const GE* op = expr_.as<GE>()) {
    if (GetPath(target_, op->a).empty()) {
      // a >= b -> b <= a
      is_greater = false;
      expr_ = op->b;
      result = op->a;
    } else {
      is_greater = true;
      expr_ = op->a;
      result = op->b;
    }
  } else {
    success = false;
  }
}

void BoundDeducer::Deduce() {
  Init();
  if (!success) return;
  Relax();
  if (!success) return;
  // get the path
  path_ = GetPath(target_, expr_);
  if (!path_.size()) {
    success = false;
    return;
  }

  expr_map_ = EvalSetForEachSubExpr(expr_, hint_map_);

  Visit(expr_);
}

void BoundDeducer::Relax() {
  IntSet a = EvalSet(expr_, relax_map_);
  IntSet b = EvalSet(result, relax_map_);
  if (a.is_everything() || b.is_everything()) {
    success = false;
    return;
  }
  expr_  = is_greater ? a.min() : a.max();
  result = is_greater ? b.max() : b.min();
}

IntSet DeduceBound(Expr v, Expr e,
  const std::unordered_map<const Variable*, IntSet>& hint_map,
  const std::unordered_map<const Variable*, IntSet>& relax_map) {
  BoundDeducer d(v, e, hint_map, relax_map);
  d.Deduce();
  if (!d.success) return IntSet::nothing();
  Expr min = Interval::neg_inf, max = Interval::pos_inf;
  if (d.is_greater) {
    min = d.result;
  } else {
    max = d.result;
  }
  return IntSet::interval(min, max);
}

// assuming e >= 0, deduce the bound of variable from it.
// return empty set to represent deduce failure.
IntSet DeduceBound(Expr v, Expr e,
                   const Map<Var, IntSet>& hint_map,
                   const Map<Var, IntSet>& relax_map) {
  std::unordered_map<const Variable*, IntSet> hmap;
  for (auto kv : hint_map) {
    hmap[kv.first.get()] = kv.second;
  }
  std::unordered_map<const Variable*, IntSet> rmap;
  for (auto kv : relax_map) {
    rmap[kv.first.get()] = kv.second;
  }
  return DeduceBound(v, e, hmap, rmap);
}

}  // namespace arith
}  // namespace tvm