import os
import time
import shutil
import sys
sys.path.append('./utils')
sys.path.append('../../FormatTranslators/src')

from math import sqrt
from FormatTranslators import Port
from FormatTranslators import Macro
from FormatTranslators import MacroPin
import collections
import sortedcontainers

class Vertex:
    def __init__(self, vertex_id, name, type, x, y, width, height):
        self.vertex_id = vertex_id
        self.name = name
        self.type = type
        self.x = x
        self.y = y
        self.width = width
        self.height = height
        self.orient = "N"
        self.group_id = -1
        self.fanouts = set()
        self.macro_pins = set()  # only for macros
        self.group_fanouts = sortedcontainers.SortedSet()

class Clustering:
    def __init__(self, design, src_dir, fixed_file, n_cols = 27, n_rows = 23,
                 global_net_threshold = 500,
                 Nparts = 500, setup_file = "setup.tcl",
                 openroad_exe = "openroad"):
        """
        parameter: design,  help="design_name: ariane, MegaBoom_x2 ", type = str
        parameter: src_dir, help="directory for source codes", type = str
        parameter: fixed_file, help="fixed file generated by grouping", type = str
        parameter: global_net_threshold, help="large net threshold", type = int
        parameter: n_cols, n_rows from Gridding codes
        parameter: Nparts,  help = "number of clusters (only for hmetis, default  = 500)", type = int
        parameter: setup_file, help = "setup file for openroad (default = setup.tcl)", type = str
        """
        # initialize parameters
        self.design = design
        self.src_dir = src_dir
        self.fixed_file = fixed_file
        self.merge_threshold = 0.0
        self.breakup_threshold = 0.0
        self.closeness = 0.0
        self.global_net_threshold = global_net_threshold
        self.Nparts = Nparts
        self.setup_file = setup_file
        self.n_cols = n_cols
        self.n_rows = n_rows
        self.soft_macro_width = 0.0
        self.floorplan_width = 0.0
        self.floorplan_height = 0.0

        # Specify the location of hmetis exe and openroad exe and other utilities
        self.hmetis_exe = src_dir + "/utils/hmetis"
        self.openroad_exe = openroad_exe
        self.extract_hypergraph_file  = src_dir + "/utils/extract_hypergraph.tcl"

        # set up temp report directory
        rpt_dir = os.getcwd() + "/rtl_mp"
        self.hypergraph_file = rpt_dir + "/" + design + ".hgr"
        self.instance_name_file = self.hypergraph_file + ".vertex"
        self.outline_file = self.hypergraph_file + ".outline"
        self.pb_netlist_file = os.getcwd() + "/" + design + ".pb.txt"
        self.plc_file = os.getcwd() + "/" + design + ".plc"

        self.vertices = []
        self.hyperedges = []

        self.fixed_group_id = 0
        self.group_id_list = []
        self.group_id = 0

        # read the netlist information
        self.GenerateHypergraph() # Extract netlist information from lef/def/v
        self.hMetisPartitioner()  # Partition the hypergraph
        self.BreakClusters()  # Break clusters spreading apart
        while not self.MergeClusters():  # Merge clusters
            pass
        print("[INFO] After merge :  num_clusters = ", len(self.group_id_list))
        self.GenerateProtocolBufferNetlist()
        self.WritePlacementFile()

    def AddMacroPinToMacro(self, macro_pin_vertex_id):
        if (self.vertices[macro_pin_vertex_id].type != "macro_pin"):
            return

        macro_name = ""
        items = self.vertices[macro_pin_vertex_id].name.split('/')
        for i in range(len(items) - 1):
            macro_name += items[i] + '/'
        macro_name = macro_name[:-1]

        for vertex in self.vertices:
            if (vertex.name == macro_name):
                vertex.fanouts.add(macro_pin_vertex_id)
                return

    def GenerateHypergraph(self):
        # Extract hypergraph from netlist
        temp_file = os.getcwd() + "/extract_hypergraph.tcl"
        cmd = "cp " + self.setup_file + " " + temp_file
        os.system(cmd)

        with open(self.extract_hypergraph_file) as f:
            content = f.read().splitlines()
        f.close()

        f = open(temp_file, "a")
        f.write("\n")
        for line in content:
            f.write(line + "\n")
        f.close()

        cmd = self.openroad_exe + " " + temp_file
        os.system(cmd)

        cmd = "rm " + temp_file
        os.system(cmd)

        # read vertices
        with open(self.instance_name_file) as f:
            content = f.read().splitlines()
        f.close()

        for i in range(len(content)):
            items = content[i].split()
            self.vertices.append(Vertex(i, items[0], items[1], float(items[2]), float(items[3]), float(items[4]), float(items[5])))
            if (self.vertices[-1].type == "macro"):
                self.vertices[-1].orient = items[6]

        # read hyperedges
        with open(self.hypergraph_file) as f:
            content = f.read().splitlines()
        f.close()

        items = content[0].split()
        num_hyperedges = int(items[0])
        num_vertices   = int(items[1])

        for i in range(num_hyperedges):
            items = content[i+1].split()
            # ignore all the global nets
            if (len(items) > self.global_net_threshold or len(items) == 1):
                continue
            hyperedge = [int(item) - 1 for item in items]
            self.hyperedges.append(hyperedge)
            self.AddMacroPinToMacro(hyperedge[0])
            for i in range(1, len(hyperedge)):
                self.vertices[hyperedge[0]].fanouts.add(hyperedge[i])

        f = open(self.hypergraph_file, "w")
        line = str(len(self.hyperedges)) + " " + str(len(self.vertices)) + "\n"
        f.write(line)
        for hyperedge in self.hyperedges:
            line = ""
            for vertex in hyperedge:
                line += str(vertex + 1) + " "
            f.write(line + "\n")
        f.close()

        # read boundary information
        with open(self.outline_file) as f:
            content = f.read().splitlines()
        f.close()

        items = content[0].split()
        floorplan_lx = float(items[0])
        floorplan_ly = float(items[1])
        floorplan_ux = float(items[2])
        floorplan_uy = float(items[3])

        canvas_width = floorplan_ux - floorplan_lx
        canvas_height = floorplan_uy - floorplan_ly
        self.floorplan_width = canvas_width
        self.floorplan_height = canvas_height

        self.breakup_threshold = sqrt(canvas_width * canvas_height / 16.0)
        self.closeness = self.breakup_threshold / 2.0
        self.soft_macro_width = canvas_width / self.n_cols

        self.fixed_group_id = -1
        with open(self.fixed_file) as f:
            content = f.read().splitlines()
        f.close()

        for i in range(len(content)):
            group_id = int(content[i])
            self.fixed_group_id = max(self.fixed_group_id, group_id)
            self.vertices[i].group_id = group_id

        self.fixed_group_id += 1
        self.Nparts += self.fixed_group_id

        print("[INFO] Breakup threshold : ", self.breakup_threshold)
        print("[INFO] Merge closeness : ", self.closeness)

    def hMetisPartitioner(self):
        # Partitioning the hypergraph using hmetis
        # The parameter configuration is the same as Google Brain paper
        # UBfactor = 5
        # Nruns    = 10
        # CType    = 5
        # RType    = 3
        # Vcycle   = 3
        # The random seed is 0 by default (in our implementation)
        # We use the hMetis C++ API to implement hMetis
        cmd = self.hmetis_exe + " " + self.hypergraph_file + " "
        cmd += self.fixed_file + " "
        cmd += str(self.Nparts) + " 5 10 5 3 3 0 0"
        os.system(cmd)

        solution_file = self.hypergraph_file + ".part." + str(self.Nparts)
        # read solution vector
        with open(solution_file) as f:
            content = f.read().splitlines()
        f.close()

        for i in range(len(self.vertices)):
            if (self.vertices[i].group_id > -1):
                continue
            group_id = int(content[i]) + self.fixed_group_id
            self.vertices[i].group_id = group_id

        for vertex in self.vertices:
            group_id = vertex.group_id
            self.group_id = max(self.group_id, group_id)
            if group_id not in self.group_id_list:
                self.group_id_list.append(group_id)

        self.merge_threshold = len(self.vertices) // len(self.group_id_list) // 4
        print("[INFO] Merge threshold : ", self.merge_threshold)


    def GetBoundingBox(self, group_id):
        lx = 1e9
        ly = 1e9
        ux = -1e9
        uy = -1e9

        for vertex in self.vertices:
            if (vertex.group_id == group_id):
                lx = min(lx, vertex.x)
                ly = min(ly, vertex.y)
                ux = max(ux, vertex.x)
                uy = max(uy, vertex.y)

        return [lx, ly, ux, uy]


    def GetWeightedCenter(self, group_id):
        x_weighted_sum = 0.0
        y_weighted_sum = 0.0
        divisor = 0.0
        for vertex in self.vertices:
            if (vertex.group_id == group_id):
                if (vertex.type != "stdcell"):
                    continue

                area = vertex.width * vertex.height
                x_weighted_sum += vertex.x * area
                y_weighted_sum += vertex.y * area
                divisor += area

        if divisor == 0.0:
            return 0.0, 0.0

        return x_weighted_sum / divisor, y_weighted_sum / divisor

    def Bucket(self, x, threshold, center, min_x, max_x):
        if (max_x - min_x < threshold):
            return 0

        if x > center:
            return int(0.5 + (x - center) / threshold)
        else:
            return int(-0.5 + (x - center) / threshold)


    def BreakClusters(self):
        # In this step, we break clusters which spread around the canvas.
        self.group_vertices = {  }
        for i in range(len(self.vertices)):
            group_id = self.vertices[i].group_id
            if group_id not in self.group_vertices:
                self.group_vertices[group_id] = [i]
            else:
                self.group_vertices[group_id].append(i)

        self.breakup_threshold = 400.006
        self.group_id_list.sort()
        for group_id in self.group_id_list:
            gcell_vs_group_id = {  }
            group_box = self.GetBoundingBox(group_id)
            if (group_box[2] - group_box[0] > self.breakup_threshold or
                    group_box[3] - group_box[1] > self.breakup_threshold):
                group_x, group_y = self.GetWeightedCenter(group_id)
                for vertex_id in self.group_vertices[group_id]:
                    vertex = self.vertices[vertex_id]
                    xb = self.Bucket(vertex.x, self.breakup_threshold, group_x, group_box[0], group_box[2])
                    yb = self.Bucket(vertex.y, self.breakup_threshold, group_y, group_box[1], group_box[3])

                    if xb == 0 and yb == 0:
                        continue

                    if (xb, yb) not in gcell_vs_group_id:
                        self.group_id += 1
                        gcell_vs_group_id[(xb, yb)] = self.group_id

                    self.vertices[vertex_id].group_id = gcell_vs_group_id[(xb, yb)]


        self.group_id_list = []
        for vertex in self.vertices:
            if vertex.group_id not in self.group_id_list:
                self.group_id_list.append(vertex.group_id)

        print("After break up clusters : ", len(self.group_id_list))


    def IsClose(self, group_loc_a, group_loc_b):
        dist = abs(group_loc_a[0] - group_loc_b[0])
        dist += abs(group_loc_a[1] - group_loc_b[1])
        return dist <= self.closeness



    def MergeClusters(self):
        # In this step, we merge small groups to the most adjacent group if they are
        # with a certain distance
        # self.merge_threshold is the minimum number of vertices in a group
        # self.closeness is the distance used to determine if two vertices are close or not
        adj_matrix = [0] * (self.group_id + 1) * (self.group_id + 1)
        for vertex in self.vertices:
            groups = set()
            groups.add(vertex.group_id)
            for vertex_id in vertex.fanouts:
                groups.add(self.vertices[vertex_id].group_id)
            for i in groups:
                for j in groups:
                    if i == j:
                        continue
                    adj_matrix[i * (self.group_id + 1) + j] += 1

        group_locs = [None] * (self.group_id + 1)
        for group_id in self.group_id_list:
            group_x, group_y = self.GetWeightedCenter(group_id)
            group_locs[group_id] = [group_x, group_y]

        group_size = [0] * (self.group_id + 1)
        for vertex in self.vertices:
            group_size[vertex.group_id] += 1


        group_vertices = [ [] for i in range(self.group_id + 1) ]
        for i in range(len(self.vertices)):
            group_vertices[self.vertices[i].group_id].append(i)

        finished = True
        # Going through the small clusters, find the highest adjacency group
        # within the self.closeness
        self.group_id_list.sort()
        for group_id in self.group_id_list:
            if (group_size[group_id] > self.merge_threshold):
                continue

            max_adj_grp = -1
            max_adj = 0
            for i in self.group_id_list:
                if (group_id == i):
                    continue
                # get number of connections
                adj = adj_matrix[i * (self.group_id + 1) + group_id]
                # check if connected
                if (adj == 0):
                    continue

                # check if close to each other
                if (self.IsClose(group_locs[group_id], group_locs[i]) == True and max_adj < adj):
                    max_adj = adj
                    max_adj_grp = i

            if (max_adj_grp > -1):
                for vertex_id in group_vertices[group_id]:
                    self.vertices[vertex_id].group_id = max_adj_grp
                    group_vertices[max_adj_grp].append(vertex_id)
                    group_size[max_adj_grp] += 1

                group_size[group_id] = 0
                group_vertices[group_id] = [  ]
                if (group_size[max_adj_grp] <= self.merge_threshold):
                    finished = False

        self.group_id_list = []
        for vertex in self.vertices:
            if vertex.group_id not in self.group_id_list:
                self.group_id_list.append(vertex.group_id)


        return finished




    def AddMacroPins(self, macro_pin_vertex_id):
        if (self.vertices[macro_pin_vertex_id].type != "macro_pin"):
            return

        macro_name = ""
        items = self.vertices[macro_pin_vertex_id].name.split('/')
        for i in range(len(items) - 1):
            macro_name += items[i] + '/'
        macro_name = macro_name[:-1]

        for vertex in self.vertices:
            if (vertex.name == macro_name):
                vertex.macro_pins.add(macro_pin_vertex_id)
                return

    def GenerateProtocolBufferNetlist(self):
        self.soft_macro_area_bloating_ratio = 2.0
        # Get the bounding box first
        die_lx = 1e9
        die_ly = 1e9
        die_ux = 0.0
        die_uy = 0.0

        for vertex in self.vertices:
            die_lx = min(die_lx, vertex.x)
            die_ly = min(die_ly, vertex.y)
            die_ux = max(die_ux, vertex.x)
            die_uy = max(die_uy, vertex.y)
            self.AddMacroPins(vertex.vertex_id)

        # check all the ports
        self.ports = []
        self.macros = []
        for vertex in self.vertices:
            if vertex.type != "port":
                continue

            side = "NONE"
            if (vertex.x == die_lx):
                side = "LEFT"
            elif (vertex.x == die_ux):
                side = "RIGHT"
            elif (vertex.y == die_ly):
                side = "BOTTOM"
            else:
                side = "TOP"

            sinks = set()
            for vertex_id in vertex.fanouts:
                if (self.vertices[vertex_id].type != "stdcell"):
                    continue
                sinks.add("Grp_" + str(self.vertices[vertex_id].group_id) + "/Pinput")

            self.ports.append(Port(vertex.name, vertex.x, vertex.y, side))
            self.ports[-1].AddSinks(sinks)

        # check each hard macro and its macro pins
        for vertex in self.vertices:
            if vertex.type != "macro":
                continue

            macro_x = vertex.x
            macro_y = vertex.y
            macro_name = vertex.name
            self.macros.append(Macro(vertex.name, vertex.width, vertex.height, macro_x, macro_y, vertex.orient))

            # check all the macro pins
            for macro_pin_id in vertex.macro_pins:
                pin_vertex = self.vertices[macro_pin_id]
                macro_pin = MacroPin(pin_vertex.name, macro_name, pin_vertex.x - vertex.x, pin_vertex.y - vertex.y, "MACRO", pin_vertex.x, pin_vertex.y)
                sinks = set()
                for vertex_id in pin_vertex.fanouts:
                    if (self.vertices[vertex_id].type != "stdcell"):
                        continue
                    sinks.add("Grp_" + str(self.vertices[vertex_id].group_id) + "/Pinput")
                macro_pin.AddSinks(sinks)
                self.macros[-1].AddOutputPin(macro_pin)


        self.soft_macros = []
        # check all the soft macros
        std_cell_groups = {  }
        for vertex in self.vertices:
            if vertex.type != "stdcell":
                continue
            if vertex.group_id not in std_cell_groups:
                std_cell_groups[vertex.group_id] = [vertex.vertex_id]
            else:
                std_cell_groups[vertex.group_id].append(vertex.vertex_id)

        std_cell_groups = dict(sorted(std_cell_groups.items()))

        # convert the fanouts into group fanouts
        for vertex in self.vertices:
            if (vertex.type != "stdcell"):
                continue
            for vertex_id in vertex.fanouts:
                if (self.vertices[vertex_id].type != "stdcell"):
                    vertex.group_fanouts.add(self.vertices[vertex_id].name)
                elif (self.vertices[vertex_id].group_id != vertex.group_id):
                    vertex.group_fanouts.add("Grp_" + str(self.vertices[vertex_id].group_id) + "/Pinput")

        for group_id, vertices_list in std_cell_groups.items():
            group_x, group_y = self.GetWeightedCenter(group_id)
            macro_name = 'Grp_' + str(group_id)
            area = 0.0
            for vertex_id in vertices_list:
                area += self.vertices[vertex_id].width * self.vertices[vertex_id].height
            area *= self.soft_macro_area_bloating_ratio
            macro_height = area / self.soft_macro_width
            self.soft_macros.append(Macro(macro_name, self.soft_macro_width, macro_height, group_x, group_y))
            self.soft_macros[-1].IsSoft()  # set to soft macro
            # add the input pins
            self.soft_macros[-1].AddInputPin(MacroPin(macro_name + "/Pinput", macro_name, 0.0, 0.0, "macro", group_x, group_y))
            # add output pins based on weight
            single_sink = sortedcontainers.SortedDict()
            p_index = 0
            for vertex_id in vertices_list:
                group_fanouts = self.vertices[vertex_id].group_fanouts
                if (len(group_fanouts) == 0):
                    continue
                elif (len(group_fanouts) == 1):
                    fanout = list(group_fanouts)[0]
                    if fanout not in single_sink:
                        single_sink[fanout] = 1
                    else:
                        single_sink[fanout] += 1
                else:
                    pin_name = "Grp_" + str(group_id) + "/Poutput_multi_" + str(p_index)
                    p_index += 1
                    macro_pin = MacroPin(pin_name, macro_name, 0.0, 0.0, "macro", group_x, group_y)
                    macro_pin.AddSinks(group_fanouts)
                    self.soft_macros[-1].AddOutputPin(macro_pin)

            p_index = 0
            single_sinks = dict(sorted(single_sink.items(), reverse = True))
            for pin, weight in single_sink.items():
                pin_name = "Grp_" + str(group_id) + "/Poutput_single_" + str(p_index)
                p_index += 1
                macro_pin = MacroPin(pin_name, macro_name, 0.0, 0.0, "macro", group_x, group_y)
                macro_pin.AddSink(pin)
                macro_pin.SpecifyWeight(weight)
                self.soft_macros[-1].AddOutputPin(macro_pin)

        self.WriteProtocolBufferNetlist()

    def WriteProtocolBufferNetlist(self):
        f = open(self.pb_netlist_file, "w")
        line  = "node {\n"
        line += '  name: "__metadata__"\n'
        line += '  attr {\n'
        line += '    key: "soft_macro_area_bloating_ratio"\n'
        line += '    value {\n'
        line += '      f: ' + str(self.soft_macro_area_bloating_ratio) + '\n'
        line += '    }\n'
        line += '  }\n'
        line += '}\n'
        f.write(line)
        for port in self.ports:
            f.write(str(port))

        for macro in self.macros:
            f.write(str(macro))
            for macro_pin in macro.GetPins():
                f.write(str(macro_pin))

        for macro in self.soft_macros:
            f.write(str(macro))
            for macro_pin in macro.GetOutputPins():
                f.write(str(macro_pin))
            for macro_pin in macro.GetInputPins():
                f.write(str(macro_pin))

        f.close()

    def WritePlacementFile(self):
        f = open(self.plc_file, "w")
        line = "# Columns : " + str(self.n_cols) + "  Rows : " + str(self.n_rows)  + "\n"
        line += "# Width : " + str(round(self.floorplan_width, 2)) + "  Height : " + str(round(self.floorplan_height, 2)) + "\n"
        line += "# Area : " + str(round(self.floorplan_width * self.floorplan_height, 2)) + "\n"
        f.write(line)


        node_index = 0
        line = "# node_index x y orientation fixed\n"
        for port in self.ports:
            line += str(node_index) + " " + str(round(port.x, 2)) + " " + str(round(port.y, 2)) + " - 0\n"
            node_index += 1
        num_hard_macro_pins = 0
        for macro in self.macros:
            line += str(node_index) + " " + str(round(macro.x, 2)) + " " + str(round(macro.y, 2)) + " " + macro.orientation + " 0\n"
            node_index += 1
            for pin in macro.GetPins():
                node_index += 1
                num_hard_macro_pins += 1

        num_soft_macro_pins = 0
        for macro in self.soft_macros:
            line += str(node_index) + " " + str(round(macro.x, 2)) + " " + str(round(macro.y, 2)) + " " + macro.orientation + " 0\n"
            node_index += 1
            for pin in macro.GetPins():
                node_index += 1
                num_soft_macro_pins += 1
        f.write("# Counts of node types:\n")
        f.write("# HARD_MACROs     :       " + str(len(self.macros)) + "\n")
        f.write("# HARD_MACRO_PINs :       " + str(num_hard_macro_pins) + "\n")
        f.write("# MACROs          :       " + str(len(self.macros) + len(self.soft_macros)) + "\n")
        f.write("# MACRO_PINs      :       " + str(num_hard_macro_pins + num_soft_macro_pins) + "\n")
        f.write("# PORTs           :       " + str(len(self.ports)) + "\n")
        f.write("# SOFT_MACROs     :       " + str(len(self.soft_macros)) + "\n")
        f.write("# SOFT_MACRO_PINs :       " + str(num_soft_macro_pins)  + "\n")
        f.write("# STDCELLs        :        0\n")
        f.write(line)
        f.close()