/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 *  Copyright (c) 2016 by Contributors
 * \file place_device.cc
 * \brief Inference the device of each operator given known information.
 *  Insert a copy node automatically when there is a cross device.
 */
#include <nnvm/pass.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/graph_attr_types.h>

namespace nnvm {
namespace pass {
namespace {

// simply logic to place device according to device_group hint
// insert copy node when there is
Graph PlaceDevice(Graph src) {
  CHECK(src.attrs.count("device_group_attr_key"))
      << "Need graph attribute \"device_group_attr_key\" in PlaceDevice";
  CHECK(src.attrs.count("device_assign_map"))
      << "Need graph attribute \"device_assign_map\" in PlaceDevice";
  CHECK(src.attrs.count("device_copy_op"))
      << "Need graph attribute \"device_copy_op\" in PlaceDevice";
  std::string device_group_attr_key = src.GetAttr<std::string>("device_group_attr_key");
  const Op* copy_op = Op::Get(src.GetAttr<std::string>("device_copy_op"));
  auto& device_assign_map = src.GetAttr<DeviceAssignMap>("device_assign_map");
  const IndexedGraph& idx = src.indexed_graph();
  static auto& is_backward =
      Op::GetAttr<TIsBackward>("TIsBackward");
  DeviceVector device;
  // copy on write semanatics
  if (src.attrs.count("device") != 0) {
    device = src.MoveCopyAttr<DeviceVector>("device");
    CHECK_EQ(device.size(), idx.num_nodes());
  } else {
    device.resize(idx.num_nodes(), -1);
  }

  // forward pass
  for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
    const auto& inode = idx[nid];
    auto it = inode.source->attrs.dict.find(device_group_attr_key);
    if (it != inode.source->attrs.dict.end()) {
      const std::string& device_group = it->second;
      auto dit = device_assign_map.find(device_group);
      CHECK(dit != device_assign_map.end())
          << "The device assignment not found for group " << device_group;
      device[nid] = dit->second;
    } else {
      if (!inode.source->is_variable() &&
          is_backward.get(inode.source->op(), false)) {
        if (device[inode.control_deps[0]] != -1) {
          device[nid] = device[inode.control_deps[0]];
        }
      } else {
        for (const IndexedGraph::NodeEntry& e : inode.inputs) {
          if (device[e.node_id] != -1) {
            device[nid] = device[e.node_id]; break;
          }
        }
      }
    }
  }
  // backward pass
  for (uint32_t i = idx.num_nodes(); i != 0; --i) {
    uint32_t nid = i - 1;
    const auto& inode = idx[nid];
    if (device[nid] == -1) continue;
    for (const IndexedGraph::NodeEntry& e : inode.inputs) {
      if (device[e.node_id] == -1) device[e.node_id] = device[nid];
    }
  }

  int num_dev = 1, other_dev_id = -1;
  for (int& dev : device) {
    if (dev == -1) dev = 0;
    if (dev != other_dev_id) {
      if (other_dev_id != -1) ++num_dev;
      other_dev_id = dev;
    }
  }

  if (num_dev == 1) {
    src.attrs.erase("device_group_attr_key");
    src.attrs.erase("device_assign_map");
    src.attrs.erase("device_copy_op");
    src.attrs["device"] = std::make_shared<any>(std::move(device));
    return src;
  }
  std::map<std::tuple<uint32_t, uint32_t, int>, NodePtr> copy_map;
  std::vector<NodePtr> new_node_map(idx.num_nodes(), nullptr);
  std::unordered_map<const Node*, int> new_device_map;
  static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs");

  // insert copy node
  for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
    int dev_id = device[nid];
    const auto& inode = idx[nid];
    // check if mutation is needed
    bool need_mutate = false;
    if (!inode.source->is_variable() && fmutate_inputs.count(inode.source->op())) {
      for (uint32_t index : fmutate_inputs[inode.source->op()](inode.source->attrs)) {
        auto e = inode.inputs[index];
        if (new_node_map[e.node_id] != nullptr || dev_id != device[e.node_id]) {
          LOG(FATAL) << " mutable state cannot go across device"
                     << " op=" << inode.source->op()->name
                     << " input_state_index=" << index;
        }
      }
    }
    for (const IndexedGraph::NodeEntry& e : inode.inputs) {
      if (new_node_map[e.node_id] != nullptr || dev_id != device[e.node_id]) {
        need_mutate = true; break;
      }
    }
    if (!need_mutate) {
      for (const uint32_t cid : inode.control_deps) {
        if (new_node_map[cid] != nullptr)  {
          need_mutate = true; break;
        }
      }
    }
    if (inode.source->is_variable()) {
      CHECK(!need_mutate) << "consistency check";
    }
    if (need_mutate) {
      NodePtr new_node = Node::Create();
      new_node->attrs = inode.source->attrs;
      new_node->inputs.reserve(inode.inputs.size());
      for (size_t i = 0; i < inode.inputs.size(); ++i) {
        const IndexedGraph::NodeEntry& e = inode.inputs[i];
        if (dev_id != device[e.node_id]) {
          auto copy_key = std::make_tuple(e.node_id, e.index, dev_id);
          auto it = copy_map.find(copy_key);
          if (it != copy_map.end() && it->first == copy_key) {
            new_node->inputs.emplace_back(
                NodeEntry{it->second, 0, 0});
          } else {
            NodePtr copy_node = Node::Create();
            std::ostringstream os;
            os << inode.source->inputs[i].node->attrs.name << "_" << e.index <<"_copy";
            copy_node->attrs.op = copy_op;
            copy_node->attrs.name = os.str();
            if (new_node_map[e.node_id] != nullptr) {
              copy_node->inputs.emplace_back(
                NodeEntry{new_node_map[e.node_id], e.index, 0});
            } else {
              copy_node->inputs.push_back(inode.source->inputs[i]);
            }
            if (copy_node->attrs.op->attr_parser != nullptr) {
              copy_node->attrs.op->attr_parser(&(copy_node->attrs));
            }
            copy_map[copy_key] = copy_node;
            new_device_map[copy_node.get()] = dev_id;
            new_node->inputs.emplace_back(
                NodeEntry{std::move(copy_node), 0, 0});
          }
        } else {
          if (new_node_map[e.node_id] != nullptr) {
            new_node->inputs.emplace_back(
                NodeEntry{new_node_map[e.node_id], e.index, 0});
          } else {
            new_node->inputs.push_back(inode.source->inputs[i]);
          }
        }
      }
      new_node->control_deps.reserve(inode.control_deps.size());
      for (size_t i = 0; i < inode.control_deps.size(); ++i) {
        uint32_t cid = inode.control_deps[i];
        if (new_node_map[cid] != nullptr) {
          new_node->control_deps.push_back(new_node_map[cid]);
        } else {
          new_node->control_deps.push_back(inode.source->control_deps[i]);
        }
      }
      new_device_map[new_node.get()] = dev_id;
      new_node_map[nid] = std::move(new_node);
    } else {
      new_device_map[inode.source] = dev_id;
    }
  }
  // make the new graph
  Graph ret;
  for (const NodeEntry& e : src.outputs) {
    if (new_node_map[idx.node_id(e.node.get())] != nullptr) {
      ret.outputs.emplace_back(
          NodeEntry{new_node_map[idx.node_id(e.node.get())], e.index, e.version});
    } else {
      ret.outputs.emplace_back(e);
    }
  }
  DeviceVector new_device_vec(ret.indexed_graph().num_nodes());
  for (uint32_t nid = 0; nid < ret.indexed_graph().num_nodes(); ++nid) {
    auto source = ret.indexed_graph()[nid].source;
    if (new_device_map.count(source) == 0) {
      LOG(FATAL) << "canot find " << source;
    }
    new_device_vec[nid] = new_device_map.at(source);
  }
  ret.attrs["device"] = std::make_shared<any>(std::move(new_device_vec));
  return ret;
}

NNVM_REGISTER_PASS(PlaceDevice)
.describe("Infer the device type of each operator."\
          "Insert a copy node when there is cross device copy")
.set_body(PlaceDevice)
.set_change_graph(true)
.provide_graph_attr("device")
.depend_graph_attr("device_group_attr_key")
.depend_graph_attr("device_assign_map")
.depend_graph_attr("device_copy_op");

DMLC_JSON_ENABLE_ANY(DeviceAssignMap, dict_str_int);

}  // namespace
}  // namespace pass
}  // namespace nnvm