/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file deivce_annotation.cc
 * \brief Passes to rewrite annotated program and retrieve the device allocation
 * of expression.
 *
 * The following passes are performed:
 *  1. Validate the unnecessary and redundant annotation.
 *  2. Rewrite the annotated program and insert data copy operators.
 *  3. Collect the device allocation of each expression.
 */

#include <tvm/tir/expr.h>
#include <tvm/relay/attrs/device_copy.h>
#include <tvm/relay/attrs/annotation.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>

#include <memory>
#include <unordered_map>
#include <unordered_set>

namespace tvm {
namespace relay {

namespace {

bool IsOnDeviceNode(const ExprNode* node) {
  if (!node->IsInstance<CallNode>()) return false;
  const auto* call_node = static_cast<const CallNode*>(node);
  return call_node->attrs.as<OnDeviceAttrs>();
}

bool IsDeviceCopyNode(const ExprNode* node) {
  if (!node->IsInstance<CallNode>()) return false;
  const auto* call_node = static_cast<const CallNode*>(node);
  return call_node->attrs.as<DeviceCopyAttrs>();
}

}  // namespace

class ValidateAnnotation : private ExprVisitor {
 public:
  static std::unordered_map<const ExprNode*, int> Validate(const Expr& expr) {
    ValidateAnnotation valid;
    valid(expr);
    return valid.annotation_map_;
  }

 private:
  void VisitExpr_(const CallNode* call_node) final {
    ExprVisitor::VisitExpr_(call_node);
    if (IsOnDeviceNode(call_node)) {
      int device_type = GetDeviceId(call_node);
      if (annotation_map_.count(call_node)) {
        CHECK_EQ(annotation_map_.at(call_node), device_type)
            << "An expression node can only be annotated to one device.";
      } else {
        annotation_map_.insert({call_node, GetDeviceId(call_node)});
      }

      CHECK_EQ(call_node->args.size(), 1U);
      const auto* node = call_node->args[0].operator->();
      if (annotation_map_.count(node)) {
        CHECK_EQ(annotation_map_.at(node), device_type)
            << "An expression node can only be annotated to one device.";
      } else {
        annotation_map_.insert({node, GetDeviceId(call_node)});
      }
    }
  }

  void VisitExpr_(const TupleGetItemNode* get_elem) final {
    ExprVisitor::VisitExpr_(get_elem);
    const auto* tn = get_elem->tuple.operator->();
    if (annotation_map_.count(tn)) {
      annotation_map_.insert({get_elem, annotation_map_.at(tn)});
    }
  }

  /*
   * \brief Get the device type of the annotation node.
   * \param call_node The on_device annotation call node.
   * \return The device type.
   */
  int GetDeviceId(const CallNode* call_node) {
    CHECK(IsOnDeviceNode(call_node))
        << "The input call node must be on_device node.";
    const OnDeviceAttrs* on_device_attr = call_node->attrs.as<OnDeviceAttrs>();
    return on_device_attr->device_type;
  }

  std::unordered_map<const ExprNode*, int> annotation_map_;
};

// Replace the use of an expression with the output of a `copy_device` operator
// if the `on_device` operator takes the annotated expr as an input.
//
// This actually replaces annotation ops with device copy ops and connects any
// two dependent expressions with a `device_copy` op when needed. Note that the
// device type of a `device_copy` op is identical to that of the destination op
// since it is where the data should be copied to.
class RewriteAnnotation : public ExprMutator {
 public:
  Expr Rewrite(const Expr& expr, int fallback_device) {
    fallback_device_ = fallback_device;
    annotation_map_ = ValidateAnnotation::Validate(expr);
    return this->VisitExpr(expr);
  }

  Expr VisitExpr_(const LetNode* op) final {
    Expr value = GetDeviceCopyExpr(op->value, op);
    Expr body = GetDeviceCopyExpr(op->body, op);

    if (value.same_as(op->value) && body.same_as(op->body)) {
      return ExprMutator::VisitExpr_(op);
    } else {
      Expr new_let = Let(op->var, value, body);
      UpdateAnnotationMap(op, new_let.operator->());
      return this->VisitExpr(new_let);
    }
  }

  Expr VisitExp_(const TupleNode* op) {
    Array<Expr> fields;
    bool annotated = false;
    for (const auto& field : fields) {
      annotated |= NeedDeviceCopy(field.operator->(), op);
      fields.push_back(GetDeviceCopyExpr(field, op));
    }

    if (annotated) {
      Expr new_tuple = Tuple(fields);
      UpdateAnnotationMap(op, new_tuple.operator->());
      return this->VisitExpr(new_tuple);
    } else {
      return ExprMutator::VisitExpr_(op);
    }
  }

  Expr VisitExpr_(const TupleGetItemNode* op) final {
    Expr tuple = op->tuple;
    if (NeedDeviceCopy(tuple.operator->(), op)) {
      Expr new_expr =
          TupleGetItem(GetDeviceCopyExpr(tuple, op), op->index);
      UpdateAnnotationMap(op, new_expr.operator->());
      return this->VisitExpr(new_expr);
    } else {
      return ExprMutator::VisitExpr_(op);
    }
  }

  Expr VisitExpr_(const IfNode* if_node) final {
    Expr cond = GetDeviceCopyExpr(if_node->cond, if_node);
    Expr true_br = GetDeviceCopyExpr(if_node->true_branch, if_node);
    Expr false_br = GetDeviceCopyExpr(if_node->false_branch, if_node);

    if (if_node->cond.same_as(cond) && if_node->true_branch.same_as(true_br) &&
        if_node->false_branch.same_as(false_br)) {
      return ExprMutator::VisitExpr_(if_node);
    } else {
      Expr new_if = If(cond, true_br, false_br);
      UpdateAnnotationMap(if_node, new_if.operator->());
      return this->VisitExpr(new_if);
    }
  }

  Expr VisitExpr_(const CallNode* call_node) final {
    if (IsOnDeviceNode(call_node)) {
      return this->VisitExpr(call_node->args[0]);
    }

    if (IsDeviceCopyNode(call_node)) {
      return ExprMutator::VisitExpr_(call_node);
    }

    Array<Expr> new_args;
    bool annotated = false;
    for (const auto& arg : call_node->args) {
      annotated |= NeedDeviceCopy(arg.operator->(), call_node);
      new_args.push_back(GetDeviceCopyExpr(arg, call_node));
    }

    if (annotated) {
      Call new_call = Call(call_node->op, new_args, call_node->attrs,
                                     call_node->type_args);

      UpdateAnnotationMap(call_node, new_call.operator->());
      return this->VisitExpr(new_call);
    } else {
      return ExprMutator::VisitExpr_(call_node);
    }
  }

 private:
  void UpdateAnnotationMap(const ExprNode* old_node, const ExprNode* new_node) {
    const auto it = annotation_map_.find(old_node);
    if (it == annotation_map_.end()) {
      annotation_map_.insert({new_node, fallback_device_});
    } else {
      annotation_map_.insert({new_node, it->second});
    }
    this->memo_[GetRef<Expr>(old_node)] = GetRef<Expr>(new_node);
  }

  Expr GetDeviceCopyExpr(const Expr& src, const ExprNode* dst) {
    const auto* src_node = src.operator->();
    if (!NeedDeviceCopy(src_node, dst)) return src;

    const auto sit = annotation_map_.find(src_node);
    if (sit == annotation_map_.end()) {
      const auto dit = annotation_map_.find(dst);
      CHECK(dit != annotation_map_.end())
          << "Device copy op is not required when both src and dst ops are not "
             "annotated.";
      return CreateDeviceCopy(src, fallback_device_, dit->second);
    } else {
      const auto dit = annotation_map_.find(dst);
      int dst_dev_type =
          dit == annotation_map_.end() ? fallback_device_ : dit->second;
      return CreateDeviceCopy(src, sit->second, dst_dev_type);
    }
  }

  // Check if a device copy op is need between two ops.
  bool NeedDeviceCopy(const ExprNode* src, const ExprNode* dst) {
    if (annotation_map_.count(src)) {
      int src_dev_type = annotation_map_.at(src);
      if (annotation_map_.count(dst)) {
        return src_dev_type != annotation_map_.at(dst);
      } else {
        return src_dev_type != fallback_device_;
      }
    } else {
      if (annotation_map_.count(dst)) {
        // Though data copy op could be inserted whenever the `src` and `dst`
        // ops are annotated to different devices, it leads to high overhead.
        //
        // Here we need across device data transferring only when `src` is a
        // CallNode or FunctionNode and the `dst` is annotated with any device
        // id other than fallback_device_.
        if (src->IsInstance<CallNode>() || src->IsInstance<FunctionNode>()) {
          return annotation_map_.at(dst) != fallback_device_;
        } else {
          // There shouldn't be any copy nodes between var/constant and another
          // expression.
          return !(src->IsInstance<VarNode>() || src->IsInstance<ConstantNode>());
        }
      } else {
        return false;
      }
    }
  }

  /*
   * \brief Create an operator to copy data from the source device to the
   * destination device.
   * \param src The source expression that produces data to be copied.
   * \param src_dev_type The device type where the data is copied from.
   * \param dst_dev_type The device type where the data is copied to.
   * \return The created call node.
   */
  Call CreateDeviceCopy(const Expr& src, int src_dev_type, int dst_dev_type) {
    auto attrs = make_object<DeviceCopyAttrs>();
    attrs->src_dev_type = src_dev_type;
    attrs->dst_dev_type = dst_dev_type;
    static const Op& op = Op::Get("device_copy");
    Call device_copy = Call(op, {src}, Attrs(attrs), {});
    annotation_map_.insert({device_copy.operator->(), dst_dev_type});
    return device_copy;
  }

  std::unordered_map<const ExprNode*, int> annotation_map_;
  int fallback_device_;
};

// Get all annotation expressions.
class AnnotatationVisitor : private ExprVisitor {
 public:
  static Map<Expr, Integer> GetAnnotations(const Expr& expr) {
    AnnotatationVisitor visitor;
    visitor(expr);
    return visitor.annotations_;
  }
 private:
  void VisitExpr_(const CallNode* call_node) {
    if (IsOnDeviceNode(call_node)) {
      const auto* attr = call_node->attrs.as<OnDeviceAttrs>();
      annotations_.Set(GetRef<Expr>(call_node), attr->device_type);
    }
    ExprVisitor::VisitExpr_(call_node);
  }
  Map<Expr, Integer> annotations_;
};

/*
 * \brief Return device allocation map based on the post order traversed graph.
 * For the following program:
 * .. code-block:: python
 *     x = relay.var("x")
 *     y = relay.var("y")
 *     add = relay.add(x, y)
 *     sqrt = relay.sqrt(add)
 *     log = relay.log(add)
 *     subtract = relay.subtract(sqrt, log)
 *     exp = relay.exp(subtract)
 *
 * Suppose we have annotated add, sqrt, and log with device 1, 2, and 3,
 * respectively. The fallback/default device is 4. After Rewriting the
 * program, we can have the following graph, where each copy op has both
 * source and destination device type denoting which device the data should be
 * copied from and to.
 *
 *         x     y
 *          \   /
 *          add/1
 *          /   \
 *       copy1  copy2
 *         |     |
 *      sqrt/2 log/3
 *         |     |
 *       copy3 copy4
 *          \   /
 *        subtract
 *            |
 *           exp
 *
 * To Get the device mapping of each expression, we need to propagate the
 * device information from the copy ops. This can be done in two passes.
 *  -Pass 1: Propagating the source device type to ops in a bottom-up way to the
 *           ancestors until encountering another copy op. For example, this way
 *           provides add, x, and y device types from the copy operator, `copy1`.
 *  -Pass 2: Propagating the destination device type of "the last" copy op to the
 *           remain nodes. For instance, this offers `subtract` and `exp` the
 *           same device type as `copy3`.
 */

class DeviceInfo {
 public:
  static Map<Expr, Integer> GetDeviceMap(const Expr& expr) {
    DeviceInfo device_info;
    device_info.post_visitor_ = PostDfsOrderVisitor();
    device_info.post_visitor_.Visit(expr);
    if (device_info.post_visitor_.num_device_copy_ops_ > 0) {
      device_info.PropagateDeviceId();
      return device_info.device_map_;
    } else {
      return Map<Expr, Integer>();
    }
  }

 private:
  class PostDfsOrderVisitor : private ExprVisitor {
   public:
    void Visit(const Expr& expr) {
      if (const auto* fn = expr.as<FunctionNode>()) {
        for (const auto& param : fn->params) {
          this->VisitExpr(param);
        }
        this->VisitExpr(fn->body);
      } else {
        this->VisitExpr(expr);
      }
    }

   private:
    // Post order traversal.
    void VisitExpr_(const FunctionNode* fn) final {
      // TODO(zhiics) Skip annotation of function node for now.
    }

    void VisitExpr_(const ConstantNode* cn) final {
      post_dfs_order_.push_back(std::make_pair(cn, has_copy_));
    }

    void VisitExpr_(const CallNode* call) final {
      // Skip annotation nodes.
      if (!IsOnDeviceNode(call)) {
        if (GetDeviceCopyNode(call)) {
          num_device_copy_ops_++;
          bool has_copy_prev = has_copy_;
          has_copy_ = true;
          ExprVisitor::VisitExpr_(call);
          post_dfs_order_.push_back(std::make_pair(call, has_copy_));
          has_copy_ = has_copy_prev;
        } else {
          ExprVisitor::VisitExpr_(call);
          post_dfs_order_.push_back(std::make_pair(call, has_copy_));
        }
      }
    }

    void VisitExpr_(const TupleNode* tn) final {
      ExprVisitor::VisitExpr_(tn);
      // TODO(zhiics) Skip annotation of tuple node for now.
    }

    void VisitExpr_(const TupleGetItemNode* op) final {
      ExprVisitor::VisitExpr_(op);
    }

    void VisitExpr_(const VarNode* vn) final {
      post_dfs_order_.push_back(std::make_pair(vn, has_copy_));
    }

    void VisitExpr_(const LetNode* ln) final {
      ExprVisitor::VisitExpr_(ln);
      post_dfs_order_.push_back(std::make_pair(ln, has_copy_));
    }

    void VisitExpr_(const IfNode* in) final {
      ExprVisitor::VisitExpr_(in);
      post_dfs_order_.push_back(std::make_pair(in, has_copy_));
    }


    int num_device_copy_ops_{0};
    bool has_copy_ = false;
    std::vector<std::pair<const ExprNode*, bool>> post_dfs_order_;
    friend DeviceInfo;
  };

  /*
   * \brief Returns a device copy node based on the current expr node. It
   * returns a device copy node either the current expr node is a device copy
   * node or the current expr node is a function node whose body is a device
   * copy node (i.e. the fused function of a device copy call node).
   */
  static const ExprNode* GetDeviceCopyNode(const ExprNode* node) {
    if (IsDeviceCopyNode(node)) {
      return node;
    } else if (node->IsInstance<CallNode>()) {
      const auto* call_node = static_cast<const CallNode*>(node);
      if (const auto* fn = call_node->op.as<FunctionNode>()) {
        const ExprNode* body = fn->body.operator->();
        if (IsDeviceCopyNode(body)) {
          return body;
        }
      }
    }
    return nullptr;
  }

  void PropagateDeviceId() {
    // Bottom-up propagation.
    int out_dev_type = BottomUpPropagation();
    // propagation for remained nodes.
    FillPropagation(out_dev_type);
  }

  int BottomUpPropagation() {
    const CallNode* last_copy_node = nullptr;
    int cur_dev_type = -1;
    int out_dev_type = -1;
    for (auto it = post_visitor_.post_dfs_order_.crbegin();
         it != post_visitor_.post_dfs_order_.crend(); ++it) {
      if (const auto* node = GetDeviceCopyNode(it->first)) {
        CHECK(node->IsInstance<CallNode>());
        last_copy_node = static_cast<const CallNode*>(node);
        const auto* attrs = last_copy_node->attrs.as<DeviceCopyAttrs>();
        cur_dev_type = attrs->src_dev_type;
        if (out_dev_type == -1) out_dev_type = attrs->dst_dev_type;
        if (it->second) device_map_.Set(GetRef<Expr>(it->first),
                                        attrs->dst_dev_type);
      } else if (last_copy_node) {
        Expr expr = GetRef<Expr>(it->first);
        CHECK_EQ(device_map_.count(expr), 0U);
        if (it->second) device_map_.Set(expr, cur_dev_type);
      }
    }
      return out_dev_type;
  }

  void FillPropagation(int out_dev_type) {
    for (const auto& it : post_visitor_.post_dfs_order_) {
        Expr expr = GetRef<Expr>(it.first);
        if (!it.second) device_map_.Set(expr, out_dev_type);
    }
  }


  PostDfsOrderVisitor post_visitor_;
  Map<Expr, Integer> device_map_;
};

Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device) {
  RewriteAnnotation rewrote = RewriteAnnotation();
  Expr new_expr = rewrote.Rewrite(expr, fallback_device);

  // Remove OnDevice operators. Note that these operators are only present at the
  // leaves after annotation. Therefore, we can simply reconstruct the
  // Function/Expr by removing them directly.
  if (const FunctionNode* fn = new_expr.as<FunctionNode>()) {
    auto params = fn->params;
    auto body = fn->body;
    std::vector<Expr> new_body;
    if (const TupleNode* tuple = body.as<TupleNode>()) {
      for (const auto& field : tuple->fields) {
        if (!IsOnDeviceNode(field.operator->())) {
          new_body.push_back(field);
        }
      }
      CHECK_GT(new_body.size(), 0U);
      if (new_body.size() == 1) {
        return Function(params, new_body[0], Type(nullptr),
                                  fn->type_params, fn->attrs);
      } else if (tuple->fields.size() == new_body.size()) {
          return new_expr;
      } else {
        Tuple tuple_body = Tuple(new_body);
        return Function(params, tuple_body, Type(nullptr),
                                  fn->type_params, fn->attrs);
      }
    } else {
      return new_expr;
    }
  } else if (const TupleNode* tuple = new_expr.as<TupleNode>()) {
    std::vector<Expr> new_fields;
    for (const auto& field : tuple->fields) {
      if (!IsOnDeviceNode(field.operator->())) {
        new_fields.push_back(field);
      }
    }
    CHECK_GT(new_fields.size(), 0U);
    if (tuple->fields.size() == new_fields.size()) {
      return new_fields.size() == 1 ? new_fields[0] : new_expr;
    } else {
      return new_fields.size() == 1 ? new_fields[0]
                                    : Tuple(new_fields);
    }
  } else {
    return new_expr;
  }
}

Map<Expr, Integer> CollectDeviceInfo(const Expr& expr) {
  return DeviceInfo::GetDeviceMap(expr);
}

Map<Expr, Integer> CollectDeviceAnnotationOps(const Expr& expr) {
  return AnnotatationVisitor::GetAnnotations(expr);
}

TVM_REGISTER_GLOBAL("relay.analysis.CollectDeviceInfo")
.set_body_typed(CollectDeviceInfo);

TVM_REGISTER_GLOBAL("relay.analysis.CollectDeviceAnnotationOps")
.set_body_typed(CollectDeviceAnnotationOps);

namespace transform {

Pass RewriteAnnotatedOps(int fallback_device) {
  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
    [=](Function f, IRModule m, PassContext pc) {
    return Downcast<Function>(relay::RewriteAnnotatedOps(f, fallback_device));
  };
  return CreateFunctionPass(pass_func, 1, "RewriteAnnotatedOps",
                            {tir::StringImmNode::make("InferType")});
}

TVM_REGISTER_GLOBAL("relay._transform.RewriteDeviceAnnotation")
.set_body_typed(RewriteAnnotatedOps);

}  // namespace transform

}  // namespace relay
}  // namespace tvm