/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 *  Copyright (c) 2017 by Contributors
 * \file print_graph_ir.cc
 * \brief Print the graph IR in LLVM style human readable format.
 */
#include <nnvm/graph.h>
#include <nnvm/pass.h>
#include <nnvm/tuple.h>
#include <iostream>

namespace nnvm {
namespace pass {

using AttrPrinter = std::function<void(uint32_t index, std::ostream& os)>;  // NOLINT(*)

template<typename T>
AttrPrinter GetVectorPrinter_(const T& vec) {
  return [&vec](uint32_t index, std::ostream& os) {  // NOLINT(*)
    os << vec[index];
  };
}

AttrPrinter GetVectorPrinter(const Graph& graph,
                             const std::string& key) {
  auto it = graph.attrs.find(key);
  CHECK(it != graph.attrs.end())
      << "Cannot find " << key << " in graph attr";
  const any& value = *(it->second);
  if (value.type() == typeid(std::vector<TShape>)) {
    return GetVectorPrinter_(
        nnvm::get<std::vector<TShape> >(value));
  } else if (value.type() == typeid(std::vector<int>)) {
    return GetVectorPrinter_(
        nnvm::get<std::vector<int> >(value));
  } else if (value.type() == typeid(std::vector<std::string>)) {
    return GetVectorPrinter_(
        nnvm::get<std::vector<std::string> >(value));
  } else {
    LOG(FATAL) << "Cannot handle type " << value.type().name();
    return nullptr;
  }
}


// print the graph ir in readable format
void PrintGraphIR_(Graph src,
                   const std::vector<std::string>& join_entry_attrs,
                   const std::vector<std::string>& join_node_attrs,
                   std::ostream& os) { // NOLINT(*)
  const IndexedGraph& idx = src.indexed_graph();
  std::vector<std::function<void(uint32_t, std::ostream&)> > trigger;  // NOLINT(*)

  for (const std::string& key : join_entry_attrs) {
    AttrPrinter fp = GetVectorPrinter(src, key);
    auto fprint = [&idx, key, fp](
        uint32_t nid, std::ostream& os) {  // NOLINT(*)
      const IndexedGraph::Node& inode = idx[nid];
      os << ", " << key << "=";
      if (inode.source->num_outputs() != 1) {
        os << '[';
        for (uint32_t i = 0; i < inode.source->num_outputs(); ++i) {
          if (i != 0) os << ", ";
          fp(idx.entry_id(nid, i), os);
        }
        os << ']';
      } else {
        fp(idx.entry_id(nid, 0), os);
      }
    };
    trigger.push_back(fprint);
  }
  for (const std::string& key : join_node_attrs) {
    AttrPrinter fp = GetVectorPrinter(src, key);
    auto fprint = [&idx, key, fp](
        uint32_t nid, std::ostream& os) {  // NOLINT(*)
      os << ", " << key << "=";
      fp(idx.entry_id(nid, 0), os);
    };
    trigger.push_back(fprint);
  }

  os << "Graph(";
  if (idx.input_nodes().size() < 4) {
    for (size_t i = 0; i < idx.input_nodes().size(); ++i) {
      uint32_t nid = idx.input_nodes()[i];
      if (i != 0)  {
        os << ", ";
      }
      os << '%' << idx[nid].source->attrs.name;
    }
  } else {
    for (size_t i = 0; i < idx.input_nodes().size(); ++i) {
      uint32_t nid = idx.input_nodes()[i];
      if (i != 0)  {
        os << ",\n      ";
      }
      os << '%' << idx[nid].source->attrs.name;
    }
  }
  os << ") {\n";

  auto print_entry = [&](const IndexedGraph::NodeEntry& e) {
    if (idx[e.node_id].source->is_variable()) {
      os << '%' << idx[e.node_id].source->attrs.name;
    } else if (idx[e.node_id].source->num_outputs() == 1) {
      os << '%' << e.node_id;
    } else {
      os << '%' << e.node_id << "." << e.index;
    }
  };

  if (trigger.size() != 0) {
    for (size_t i = 0; i < idx.input_nodes().size(); ++i) {
      uint32_t nid = idx.input_nodes()[i];
      os << "  %" << idx[nid].source->attrs.name;
      for (const auto& fp : trigger) {
        fp(nid, os);
      }
      os << '\n';
    }
  }

  for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
    const auto& inode = idx[nid];
    if (inode.source->is_variable()) continue;
    os << "  " << "%" << nid << " = "
       << inode.source->op()->name << "(";
    bool first = true;
    for (const IndexedGraph::NodeEntry& e : inode.inputs) {
      if (first) {
        first = false;
      } else {
        os << ", ";
      }
      print_entry(e);
    }
    for (const auto& kv : inode.source->attrs.dict) {
      if (first) {
        first = false;
      } else {
        os << ", ";
      }
      os << kv.first << "=\'" << kv.second << "\'";
    }
    os << ")";
    if (inode.control_deps.size() != 0) {
      os << ", control_deps=[";
      for (size_t i = 0; i < inode.control_deps.size(); ++i) {
        if (i != 0) os << ", ";
        uint32_t cid = inode.control_deps[i];
        if (idx[cid].source->is_variable()) {
          os << '%' << idx[cid].source->attrs.name;
        } else {
          os << '%' << cid;
        }
      }
      os << "]";
    }
    // additional attribute trigger
    for (const auto& fp : trigger) {
      fp(nid, os);
    }
    os << "\n";
  }
  os << "  ret ";
  {
    bool first = true;
    for (const IndexedGraph::NodeEntry& e : idx.outputs()) {
      if (first) {
        first = false;
      } else {
        os << ", ";
      }
      print_entry(e);
    }
  }
  os << "\n}";
  if (src.attrs.size() != 0) {
    os << "\ngraph_attr_keys = [";
    bool first = true;
    for (const auto& kv : src.attrs) {
      if (first) {
        first = false;
      } else {
        os << ", ";
      }
      os << kv.first;
    }
    os << "]\n";
  }
}

// save a graph to json
Graph PrintGraphIRPass(Graph src) {
  std::ostringstream os;
  std::vector<std::string> join_entry_attrs, join_node_attrs;
  if (src.attrs.count("join_entry_attrs") != 0) {
    join_entry_attrs = src.MoveCopyAttr<std::vector<std::string> >(
        "join_entry_attrs");
  }
  if (src.attrs.count("join_node_attrs") != 0) {
    join_node_attrs = src.MoveCopyAttr<std::vector<std::string> >(
        "join_node_attrs");
  }
  PrintGraphIR_(src, join_entry_attrs, join_node_attrs, os);
  Graph ret;
  ret.attrs["graphir"] = std::make_shared<any>(os.str());
  return ret;
}

// register pass
NNVM_REGISTER_PASS(PrintGraphIR)
.describe("Return a empty Graph, save ir to ret.attrs[\"graphir\"]")
.set_body(PrintGraphIRPass);

}  // namespace pass
}  // namespace nnvm