import subprocess
import re
import os
import argparse


def submit_slurm_job(node_count, gpus_per_node, node_type, partition, qos):
    """
    提交Slurm任务的函数

    :param node_count: 需要的节点个数
    :param gpus_per_node: 每个节点的GPU卡数（4或8）
    :param node_type: 节点名称（r8l40/r8l40s/r8a100）
    :param partition: 分区名称
    :param qos: QOS类型
    """
    # 定义原文件路径和副本名称
    slurm_script = "train-multigpu.slurm"
    slurm_copy = f"{slurm_script}.copy"

    # 复制原文件到副本
    with open(slurm_script, 'r') as src, open(slurm_copy, 'w') as dst:
        dst.write(src.read())

    # 获取 sinfo 输出信息
    try:
        sinfo_output = subprocess.check_output(['sinfo', '-h', '-o', '%t %N']).decode('utf-8')
    except subprocess.CalledProcessError as e:
        print(f"获取 sinfo 信息失败: {e}")
        return

    # 筛选出符合条件的空闲节点
    idle_nodes = []
    for line in sinfo_output.splitlines():
        if 'idle' in line and node_type in line:
            matches = re.findall(rf'({node_type}(-[a-zA-Z])?(\[?[0-9,-]+\]?))', line)
            print(line)
            print(matches)
            for match in matches:
                idle_nodes.append(match[0])

    # 解析节点列表
    def parse_nodes(node_list):
        nodes = []
        if '[' in node_list:
            prefix = re.findall(r'^[a-zA-Z0-9]+-[a-zA-Z]', node_list)[0]
            numbers = re.findall(r'\[([^]]+)\]', node_list)[0].split(',')
            for num in numbers:
                if '-' in num:
                    start, end = map(int, num.split('-'))
                    nodes.extend([f"{prefix}{i:02d}" for i in range(start, end + 1)])
                else:
                    nodes.append(f"{prefix}{num}")
        else:
            nodes.append(node_list)
        return nodes

    # 解析所有空闲节点
    all_nodes = []
    for node_list in idle_nodes:
        all_nodes.extend(parse_nodes(node_list))
    for node in exclude_nodes:
        if node in all_nodes:
            all_nodes.remove(node)

    # 检查是否有足够的可用节点
    if len(all_nodes) >= node_count:
        selected_nodes = ','.join(all_nodes[:node_count])

        # 修改 Slurm 脚本中的参数
        with open(slurm_copy, 'r') as file:
            content = file.read()

        if qos == "gpu-normal":
            target_time = "1-05:59:59"
        elif qos == "gpu-long":
            target_time = "2-23:59:59"
        elif qos == "normal":
            target_time = "6-23:59:59"

        # 修改 --nodelist 参数
        content = re.sub(r'^#SBATCH --constraint=.*$', f'#SBATCH --nodelist={selected_nodes}', content, flags=re.M)
        # 修改 --nodes 参数
        content = re.sub(r'^#SBATCH --nodes=[0-9]+', f'#SBATCH --nodes={node_count}', content, flags=re.M)
        # 修改 --gres=gpu 参数
        content = re.sub(r'^#SBATCH --gres=gpu:[0-9]+', f'#SBATCH --gres=gpu:{gpus_per_node}', content, flags=re.M)
        # 修改 export USER_GPUS_PER_NODE 参数
        content = re.sub(r'^export USER_GPUS_PER_NODE=[0-9]+', f'export USER_GPUS_PER_NODE={gpus_per_node}', content, flags=re.M)
        # 修改 #SBATCH -t 参数
        content = re.sub(r'^(#SBATCH -t) ([0-9]+-)[0-9]{2}:[0-9]{2}:[0-9]{2}', f'\\1 {target_time}', content, flags=re.M)
        # 修改 #SBATCH -p 参数
        content = re.sub(r'^#SBATCH -p [a-zA-Z0-9-]+', f'#SBATCH -p {partition}', content, flags=re.M)
        # 修改 #SBATCH --qos 参数
        content = re.sub(r'^#SBATCH --qos=[a-zA-Z-]+', f'#SBATCH --qos={qos}', content, flags=re.M)

        with open(slurm_copy, 'w') as file:
            file.write(content)

        # 提交修改后的副本
        try:
            sbatch_output = subprocess.check_output(['sbatch', slurm_copy]).decode('utf-8')
            # sbatch_output = None
            print(f"任务提交成功！使用节点: {selected_nodes}，副本: {slurm_copy}")
            print(sbatch_output)
        except subprocess.CalledProcessError as e:
            print(f"任务提交失败: {e}")
    else:
        print(f"没有足够的 {node_type} 空闲节点可用。")

    # 删除副本文件
    # os.remove(slurm_copy)


if __name__ == "__main__":
    # 创建 ArgumentParser 对象
    parser = argparse.ArgumentParser(description="Submit a Slurm job with specified parameters.")

    # 添加命令行参数
    parser.add_argument("--node_count", type=int, default=2, help="Number of nodes required.")
    parser.add_argument("--gpus_per_node", type=int, default=8, help="Number of GPUs per node (4 or 8).")
    parser.add_argument("--node_type", type=str, default="r8l40s", help="Node type (r8l40/r8l40s/r8a100).")
    parser.add_argument("--partition", type=str, default=None, help="Partition name. (r8nv-gpu-dedicated needs to be specified)")
    parser.add_argument("--qos", type=str, default=None, help="QOS type. (gpu-long needs to be specified)")

    # 解析命令行参数
    args = parser.parse_args()

    # 根据参数确定分区
    partition = args.partition
    if partition is None:
        if args.node_count >= 2:
            partition = "r8nv-gpu-dist"
        elif args.node_type in ["r8a100-c", "r8a100-d"]:
            partition = "r8nv-gpu-hw-80g"
        else:
            partition = "r8nv-gpu-hw"

    qos = args.qos
    if partition == "r8nv-gpu-dedicated":
        qos = "normal"
    elif qos is None:
        qos = "gpu-normal"

    exclude_nodes = ['r8l40s-a02', 'r8l40s-a03', 'r8l40s-a04']
    # 调用提交任务的函数
    submit_slurm_job(node_count=args.node_count, gpus_per_node=args.gpus_per_node, node_type=args.node_type, partition=partition, qos=qos)