/*!
 *  Copyright (c) 2016 by Contributors
 * \file graph.cc
 * \brief Utilities to get information about schedule graph.
 */
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/operation.h>
#include <unordered_set>
#include <unordered_map>
#include "./graph.h"

namespace tvm {
namespace schedule {
// key to specific tensor dimension.
struct TensorDimKey {
  FunctionRef f;
  int value_index;
  int dim;
  TensorDimKey() {}
  TensorDimKey(const ir::Call* op, int dim)
      : f(op->func), value_index(op->value_index), dim(dim) {
  }
  TensorDimKey(const Tensor& t, int dim)
      : f(t->op), value_index(t->value_index), dim(dim) {
  }
  TensorDimKey(const Tensor& t, size_t dim)
      : f(t->op), value_index(t->value_index), dim(static_cast<int>(dim)) {
  }
  inline bool operator==(const TensorDimKey& other) const {
    return f == other.f &&
        value_index == other.value_index &&
        dim == other.dim;
  }
  inline bool operator!=(const TensorDimKey& other) const {
    return !operator==(other);
  }
};
}  // namespace schedule
}  // namespace tvm

namespace std {
template <>
struct hash<::tvm::schedule::TensorDimKey> {
  std::size_t operator()(const ::tvm::schedule::TensorDimKey& k) const {
    size_t lhs = k.f.hash();
    size_t rhs = static_cast<size_t>(k.value_index) << 16UL |
        static_cast<size_t>(k.dim);
    lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
    return lhs;
  }
};
}  // namespace std


namespace tvm {
namespace schedule {

// construct a read graph that gives readers of each operation
// that the root depend on
ReadGraph CreateReadGraph(const Array<Operation>& roots) {
  ReadGraph rmap;
  std::vector<Operation> stack;
  std::unordered_set<const Node*> visited;
  // initialize the roots
  for (Operation op : roots) {
    stack.push_back(op);
    visited.insert(op.get());
  }

  while (!stack.empty()) {
    Operation op = stack.back();
    stack.pop_back();
    Array<Tensor> deps = op->InputTensors();
    rmap.Set(op, deps);
    for (Tensor t : deps) {
      if (t->op.defined() && visited.count(t->op.get()) == 0) {
        visited.insert(t->op.get());
        stack.push_back(t->op);
      }
    }
  }
  return rmap;
}

// Do DFS visit to get the subgraph.
// Return if op is inside the subgraph.
bool GetSubGraphByPostDFS_(
    const Operation& op,
    const std::unordered_set<const Node*>& boundary,
    bool include_bounary,
    std::unordered_map<const Node*, bool>* visited,
    Array<Operation>* result) {
  if (visited->count(op.get())) {
    return visited->at(op.get());
  }
  if (boundary.count(op.get())) {
    (*visited)[op.get()] = true;
    if (include_bounary) {
      result->push_back(op);
    }
    return true;
  }
  // mark to avoid loop
  // Not necessary for DAG.
  (*visited)[op.get()] = false;
  // check if we can reach boundary.
  bool reach_boundary = false;
  for (Tensor t : op->InputTensors()) {
    if (GetSubGraphByPostDFS_(t->op, boundary,
                              include_bounary,
                              visited, result)) {
      reach_boundary = true;
    }
  }
  (*visited)[op.get()] = reach_boundary;
  if (reach_boundary) {
    result->push_back(op);
  }
  return reach_boundary;
}

Array<Operation> GetSubGraph(const Array<Tensor>& outputs,
                             const Array<Tensor>& inputs,
                             bool include_inputs) {
  Array<Operation> result;
  std::unordered_set<const Node*> boundary;
  for (Tensor t : inputs) {
    boundary.insert(t->op.get());
  }
  std::unordered_map<const Node*, bool> visited;
  for (Tensor t : outputs) {
    GetSubGraphByPostDFS_(t->op, boundary, include_inputs,
                          &visited, &result);
  }
  return result;
}


void PostDFSOrder(const Operation& op,
                  const ReadGraph& g,
                  std::unordered_set<Operation>* visited,
                  Array<Operation>* post_order) {
  if (visited->count(op)) return;
  visited->insert(op);
  for (const auto& t : g.at(op)) {
    PostDFSOrder(t->op, g, visited, post_order);
  }
  post_order->push_back(op);
}

Array<Operation> PostDFSOrder(
    const Array<Operation>& roots,
    const ReadGraph& g) {
  std::unordered_set<Operation> visited;
  Array<Operation> post_order;
  for (Operation op : roots) {
    PostDFSOrder(op, g, &visited, &post_order);
  }
  return post_order;
}

FeedGraph CreateFeedGraph(const ReadGraph& g) {
  FeedGraph fg;
  for (auto kv : g) {
    for (Tensor t : kv.second) {
      fg[t].push_back(kv.first);
    }
  }
  return fg;
}

AttachPath CreateAttachPath(Schedule sch) {
  AttachPath ret;
  for (Stage stage : sch->stages) {
    std::unordered_set<const Node*> visited;
    Array<IterVar> path;
    for (Stage s = stage; s.defined();) {
      CHECK(!visited.count(s.get()))
          << "Find loop in compute_at attach group";
      visited.insert(s.get());
      Stage spec = s.GetAttachSpec();
      bool start_attach;
      IterVar attach_ivar;
      if (spec->attach_type == kScope) {
        attach_ivar = spec->attach_ivar;
        s = spec->attach_stage;
        start_attach = false;
        CHECK(attach_ivar.defined());
      } else if (spec->attach_type == kScanUpdate) {
        s = spec->attach_stage;
        start_attach = true;
      } else {
        break;
      }
      CHECK(s.defined());
      for (size_t i = s->leaf_iter_vars.size(); i != 0; --i) {
        IterVar iv = s->leaf_iter_vars[i - 1];
        if (!start_attach && iv.same_as(attach_ivar)) {
          start_attach = true;
        }
        if (start_attach) path.push_back(iv);
      }
      CHECK(start_attach)
          << "Invalid Schedule: cannot find attach point " << attach_ivar
          << " in the schedule of " << s->op;
    }
    if (!ret.count(stage->op)) {
      ret.Set(stage->op, path);
    }
  }
  return ret;
}

// graph of push reach relation of tensor dimensions
using ReachGraph = std::unordered_map<TensorDimKey, std::vector<TensorDimKey> >;

ReachGraph GetReachGraph(const Array<Operation>& ops) {
  ReachGraph reach;
  std::unordered_set<const Node*> bset;
  for (size_t i = 0; i < ops.size(); ++i) {
    bset.insert(ops[i].get());
  }

  for (Operation op : ops) {
    if (op.as<ScanOpNode>()) {
      const auto& update = op.as<ScanOpNode>()->update;
      const auto& init = op.as<ScanOpNode>()->init;
      for (size_t i = 0; i < update.size(); ++i) {
        Tensor t = op.output(i);
        for (int k = 1; k < static_cast<int>(update[i]->shape.size()); ++k) {
          reach[TensorDimKey(t, k)].emplace_back(
              TensorDimKey(update[i], k));
          reach[TensorDimKey(t, k)].emplace_back(
              TensorDimKey(init[i], k));
        }
      }
    } else if (op.as<ComputeOpNode>()) {
      std::unordered_map<const Node*, TensorDimKey> vmap;
      const auto& axis = op.as<ComputeOpNode>()->axis;
      Tensor t = op.output(0);
      for (size_t i = 0; i < axis.size(); ++i) {
        vmap[axis[i]->var.get()] = TensorDimKey(t, i);
        reach[TensorDimKey(t, i)] = {};
      }
      auto fvisit = [&vmap, &reach, &bset](const NodeRef& n) {
        const ir::Call *call = n.as<ir::Call>();
        if (call != nullptr && call->func.defined()) {
          if (!bset.count(call->func.get())) return;
          for (size_t i = 0; i < call->args.size(); ++i) {
            TensorDimKey dkey(call, static_cast<int>(i));
            auto fpush = [&dkey, &vmap, &reach](const NodeRef& node) {
              const Variable *v = node.as<Variable>();
              auto it = vmap.find(v);
              if (it != vmap.end()) {
                reach[it->second].push_back(dkey);
              }
            };
            ir::PostOrderVisit(call->args[i], fpush);
          }
        }
      };
      for (auto& e : op.as<ComputeOpNode>()->body) {
        ir::PostOrderVisit(e, fvisit);
      }
    }
  }
  return reach;
}

Array<Operation> ScanGetBody(const Operation& scan_op) {
  const ScanOpNode* scan = scan_op.as<ScanOpNode>();
  // Get the body.
  Array<Tensor> inputs;
  for (Tensor t : scan->state_placeholder) {
    inputs.push_back(t);
  }
  for (Tensor t : scan->inputs) {
    inputs.push_back(t);
  }
  return GetSubGraph(scan->update, inputs, false);
}

Map<IterVar, Expr> ScanFixPointAnalysis(const Operation& scan_op) {
  const ScanOpNode* scan = scan_op.as<ScanOpNode>();
  Array<Operation> body = ScanGetBody(scan_op);

  std::unordered_map<TensorDimKey, const Node*> exact_reach;
  std::unordered_set<const Node*> fail_set;

  for (size_t i = 0, sp_idx = 0; i < scan->update.size(); ++i) {
    for (size_t k = 1; k < scan->update[i]->shape.size(); ++k, ++sp_idx) {
      TensorDimKey key(scan->state_placeholder[i], k);
      exact_reach[key] = scan->spatial_axis_[sp_idx].get();
    }
  }
  // merge exact reach
  auto f_merge_key = [&exact_reach, &fail_set](
      const TensorDimKey& dst, const TensorDimKey& src) {
    auto sit = exact_reach.find(src);
    if (sit == exact_reach.end()) return;
    auto dit = exact_reach.find(dst);
    if (dit == exact_reach.end()) {
      exact_reach[dst] = sit->second;
    } else {
      if (dit->second != sit->second) {
        fail_set.insert(dit->second);
        fail_set.insert(sit->second);
      }
    }
  };
  // prop exact reach back.
  for (size_t i = 0; i < body.size(); ++i) {
    const Operation& op = body[i];
    if (op.as<ScanOpNode>()) {
      const auto& update = op.as<ScanOpNode>()->update;
      const auto& init = op.as<ScanOpNode>()->init;
      for (size_t i = 0; i < update.size(); ++i) {
        Tensor t = op.output(i);
        for (size_t k = 1; i < update[i]->shape.size(); ++k) {
          f_merge_key(TensorDimKey(t, k), TensorDimKey(update[i], k));
          f_merge_key(TensorDimKey(t, k), TensorDimKey(init[i], k));
        }
      }
    } else if (op.as<ComputeOpNode>()) {
      std::unordered_map<const Node*, std::vector<TensorDimKey> > vmap;
      const auto& axis = op.as<ComputeOpNode>()->axis;
      for (size_t i = 0; i < axis.size(); ++i) {
        std::vector<TensorDimKey> keys;
        for (int j = 0; j < op->num_outputs(); ++j) {
          keys.emplace_back(op.output(j), i);
        }
        vmap[axis[i]->var.get()] = std::move(keys);
      }
      auto fvisit = [&vmap, &f_merge_key, &exact_reach, &fail_set](
          const NodeRef& n) {
        const ir::Call *call = n.as<ir::Call>();
        if (call != nullptr && call->func.defined()) {
          for (size_t i = 0; i < call->args.size(); ++i) {
            auto it = vmap.find(call->args[i].get());
            TensorDimKey src(call, static_cast<int>(i));
            if (it != vmap.end()) {
              const std::vector<TensorDimKey>& keys = it->second;
              for (const auto& key : keys) {
                f_merge_key(key, src);
              }
            } else {
              if (exact_reach.count(src)) {
                fail_set.insert(exact_reach.at(src));
              }
            }
          }
        }
      };
      for (auto& e : op.as<ComputeOpNode>()->body) {
        ir::PostOrderVisit(e, fvisit);
      }
    }
  }
  ReachGraph reach;
  Map<IterVar, Expr> ret;
  std::unordered_set<TensorDimKey> place_holder_ref;
  for (size_t i = 0; i < scan->state_placeholder.size(); ++i) {
    for (size_t k = 0; k < scan->state_placeholder[i]->shape.size(); ++k) {
      place_holder_ref.insert(TensorDimKey(scan->state_placeholder[i], k));
    }
  }

  for (size_t i = 0, sp_idx = 0; i < scan->update.size(); ++i) {
    for (size_t k = 1; k < scan->update[i]->shape.size(); ++k, ++sp_idx) {
      TensorDimKey key(scan->update[i], k);
      TensorDimKey target(scan->state_placeholder[i], k);
      IterVar sp_iv = scan->spatial_axis_[sp_idx];
      if (fail_set.count(sp_iv.get()) ||
          !exact_reach.count(key) ||
          exact_reach.at(key) != sp_iv.get()) {
        ret.Set(sp_iv, make_const(Int(32), 0));
      } else {
        // now we proved exact match, need to prove no interference with other graph.
        if (reach.size() == 0) reach = GetReachGraph(body);
        // do a DFS
        std::unordered_set<TensorDimKey> visited;
        std::vector<TensorDimKey> stack{key};
        visited.insert(key);
        while (!stack.empty()) {
          TensorDimKey k = stack.back();
          if (k != target && place_holder_ref.count(k)) break;
          stack.pop_back();
          if (!reach.count(k)) {
            LOG(FATAL) << "cannot find reach of " << k.f << "-" << k.dim;
          }

          for (TensorDimKey kk : reach.at(k)) {
            if (visited.count(kk)) {
              continue;
            }
            visited.insert(kk);
            stack.push_back(kk);
          }
        }
        if (!stack.empty()) {
          // failed the prove.
          ret.Set(sp_iv, make_const(Int(32), 0));
        } else {
          ret.Set(sp_iv, make_const(Int(32), 1));
        }
      }
    }
  }
  return ret;
}

}  // namespace schedule
}  // namespace tvm