import argparse
from pathlib import Path
import re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


def process_log(folder):
    folder = Path(folder)
    # 获取文件夹内所有的 .out 文件
    log_files = list(folder.glob("*.out"))

    key_value_pattern = re.compile(r"([\w/%(),[\]]+):\s*(-?\d+\.?\d*)")
    all_data = {}  # 用于存储每个 step 对应的最新数据

    for log_file in log_files:
        with open(log_file, 'r') as f:
            content = f.read()

        step_lines = re.findall(r'0m step:\d+.*$', content, re.MULTILINE)

        for line in step_lines:
            matches = key_value_pattern.findall(line)
            data = {key: float(value) for key, value in matches}
            step = int(data.get('step', -1))  # 获取当前行的 step 值
            if step != -1:
                # 如果当前 step 还没有记录，或者当前记录的 step 比之前的更新，则更新数据
                if step not in all_data or line > all_data[step]['line']:
                    all_data[step] = {'data': data, 'line': line}

    # 提取所有数据
    final_data = [entry['data'] for entry in all_data.values()]

    df = pd.DataFrame(final_data)
    csv_file_path = folder / "stats.csv"
    df.to_csv(csv_file_path, index=False)
    return csv_file_path


def smooth_data(data, window_size='auto'):
    if window_size == 'auto':
        window_size = data.shape[0] // 20
    return data.rolling(window=window_size, min_periods=1).mean()


def plot_accuracy_vs_length(folder):
    csv_file_path = Path(folder) / 'stats.csv'
    ve2_result_path = Path(folder) / 'verilog-eval-v2.csv'

    # 检查 ve2_result_path 文件是否存在
    if ve2_result_path.exists():
        ve2_df = pd.read_csv(ve2_result_path)
    else:
        ve2_df = None

    columns = ['step', 'reward/all_correct_ratio', 'response_length/mean']

    df = pd.read_csv(csv_file_path)

    # 除 step 列外，对所有列进行平滑处理
    for col in columns:
        if col != 'step':
            df[f'{col}_smoothed'] = smooth_data(df[col])

    plt.rcParams['figure.dpi'] = 300
    plt.rcParams['font.size'] = 14  # 增大全局字体大小

    # 第一个图：双 Y 轴图
    fig, ax1 = plt.subplots(figsize=(12, 7))  # 增大图表尺寸

    # 绘制平滑后的 reward/all_correct_ratio 到左 Y 轴
    color = 'tab:red'
    ax1.set_xlabel('Step', fontsize=16)  # 增大坐标轴标签字体大小
    ax1.set_ylabel('reward/all_correct_ratio', color=color, fontsize=16)
    ax1.plot(df['step'], df['reward/all_correct_ratio_smoothed'], color=color, label='Correct Ratio (smooth)',
             linewidth=2, alpha=0.8)
    ax1.tick_params(axis='both', which='major', labelsize=14)  # 增大刻度标签字体大小
    # ax1.set_ylim(0.05, 0.95)

    # 如果 ve2_df 存在，绘制 Pass@1, Pass@5, Pass@20 数据到左 Y 轴
    if ve2_df is not None:
        pass_colors = ['orange', 'green', 'purple']
        pass_markers = ['o', 's', '^']
        pass_labels = ['Pass@1 (ve2)', 'Pass@5 (ve2)', 'Pass@20 (ve2)']
        pass_columns = ['Pass@1', 'Pass@5', 'Pass@20']
        for col, color, marker, label in zip(pass_columns, pass_colors, pass_markers, pass_labels):
            ax1.plot(ve2_df['Step'], ve2_df[col], color=color, linestyle='--', label=label, linewidth=3,
                     marker=marker, markersize=10, alpha=0.8)

    # 创建第二个 Y 轴
    ax2 = ax1.twinx()

    # 绘制平滑后的 response_length/mean 到右 Y 轴
    color = 'tab:blue'
    ax2.set_ylabel('response_length/mean', color=color, fontsize=16)
    ax2.plot(df['step'], df['response_length/mean_smoothed'], color=color, label='Response Length (smooth)',
             linewidth=2, alpha=0.8)
    ax2.tick_params(axis='both', which='major', labelsize=14)
    # ax2.set_ylim(200, 550)

    # 合并两个图例
    lines, labels = ax1.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    ax2.legend(handles=lines + lines2, labels=labels + labels2, loc='upper center', bbox_to_anchor=(0.5, 1.15),
               ncol=5,  # 调整图例位置
               edgecolor='black', facecolor='whitesmoke', fontsize=12)  # 调整图例字体大小

    plt.title('Smoothed Data Visualization of Model Performance Metrics', fontsize=18)  # 增大标题字体大小
    ax1.set_facecolor('whitesmoke')
    plt.grid(True, linestyle='--', alpha=0.7)

    # 优化注释
    # ax1.annotate('此处有变化', xy=(500, 0.5), xytext=(550, 0.55),  # 调整注释文本
    #              arrowprops=dict(facecolor='black', shrink=0.05, width=1.5, headwidth=6))  # 调整箭头样式

    plot_file_path = csv_file_path.parent / "accuracy_vs_length.png"
    plt.savefig(plot_file_path)
    plt.close()


def plot_different_accuracy_ratio(folder):
    csv_file_path = Path(folder) / 'stats.csv'
    ratio_columns = ['reward/correct_0%_ratio', 'reward/correct_(0%,50%)_ratio',
                     'reward/correct_[50%,100%)_ratio', 'reward/correct_100%_ratio']

    df = pd.read_csv(csv_file_path)

    # 除 step 列外，对所有列进行平滑处理
    for col in ratio_columns:
        df[f'{col}_smoothed'] = smooth_data(df[col])

    # 第二个图：绘制比例折线图
    fig2, ax3 = plt.subplots(figsize=(10, 6))
    for col in ratio_columns:
        ax3.plot(df['step'], df[f'{col}_smoothed'], label=f'{col} (smoothed)', linewidth=2)

    ax3.set_xlabel('Step', fontsize=14)
    ax3.set_ylabel('Ratio', fontsize=14)
    ax3.set_title('Reward Correct Ratio Visualization', fontsize=16)
    ax3.legend(fontsize=10)
    ax3.grid(True, linestyle='--', alpha=0.7)

    plot_file_path = csv_file_path.parent / "different_accuracy_ratio.png"
    plt.savefig(plot_file_path)
    plt.close()


def plot_metrics(folder, columns, colors, title, plot_file_name, log_columns=None, smooth=True):
    csv_file_path = Path(folder) / 'stats.csv'
    df = pd.read_csv(csv_file_path)

    # 对需要绘制的列进行插值处理
    for col in columns:
        if col in df.columns:
            df[col] = df[col].interpolate(method='linear')

    valid_columns = []
    valid_colors = []
    for col, color in zip(columns, colors):
        if col in df.columns:
            valid_columns.append(col)
            valid_colors.append(color)
            if smooth:
                df[f'{col}_smoothed'] = smooth_data(df[col])
            else:
                df[f'{col}_smoothed'] = df[col]

    # 处理对数列
    if log_columns:
        for col in log_columns:
            if col in df.columns:
                df[f'{col}_log'] = np.log(df[col] + 1e-10)
                if smooth:
                    df[f'{col}_log_smoothed'] = smooth_data(df[f'{col}_log'])
                else:
                    df[f'{col}_log_smoothed'] = df[f'{col}_log']

    # 设置图形分辨率和字体大小
    plt.rcParams['figure.dpi'] = 300
    plt.rcParams['font.size'] = 14

    # 创建图形和主坐标轴
    fig, ax_main = plt.subplots(figsize=(12, 7))

    ax_main.set_xlabel('Step')
    ax_main.set_ylabel('Value')  # 统一的 Y 轴标签

    # 绘制相关指标
    for i, col in enumerate(valid_columns):
        if log_columns and col in log_columns:
            data_col = f'{col}_log_smoothed'
            label_col = f'{col} (log, smoothed)'
        else:
            data_col = f'{col}_smoothed'
            label_col = f'{col} (smoothed)'

        ax_main.plot(df['step'], df[data_col], label=label_col, color=valid_colors[i], linestyle='-')

    # 合并所有图例
    lines, labels = ax_main.get_legend_handles_labels()
    ax_main.legend(lines, labels, loc='upper center', bbox_to_anchor=(0.5, 1.15),
                   ncol=len(valid_columns), edgecolor='black', facecolor='whitesmoke', fontsize=12)

    plt.title(title)
    ax_main.set_facecolor('whitesmoke')
    plt.grid(True, linestyle='--', alpha=0.7)

    # 保存图形
    plot_file_path = csv_file_path.parent / plot_file_name
    plt.savefig(plot_file_path)
    plt.close()


def plot_actor_loss_metrics(folder):
    actor_loss_columns = [
        'actor/kl_loss', 'actor/entropy_loss', 'actor/pg_loss'
    ]
    colors = [
        'red', 'green', 'blue'
    ]
    title = 'Visualization of Smoothed Actor Loss Metrics'
    plot_file_name = "actor_loss_metrics.png"
    plot_metrics(folder, actor_loss_columns, colors, title, plot_file_name)


def plot_actor_other_metrics(folder):
    actor_other_columns = [
        'actor/pg_clipfrac', 'actor/ppo_kl', 'actor/grad_norm'
    ]
    colors = [
        'orange', 'purple', 'brown'
    ]
    title = 'Visualization of Smoothed Actor Other Metrics'
    plot_file_name = "actor_other_metrics.png"
    log_columns = ['actor/grad_norm']
    plot_metrics(folder, actor_other_columns, colors, title, plot_file_name, log_columns)


def plot_critic_metrics(folder):
    critic_columns = [
        'critic/advantages/max', 'critic/advantages/min', 'critic/values/mean', 'critic/vf_explained_var'
    ]
    colors = [
        'cyan', 'magenta', 'orange', 'brown'
    ]
    title = 'Visualization of Smoothed Critic Metrics'
    plot_file_name = "critic_metrics.png"
    plot_metrics(folder, critic_columns, colors, title, plot_file_name)


def plot_val_metrics(folder):
    val_columns = [
        'val/test_score/', 'val/test_score/deepscaler', 'val/test_score/codev'
    ]
    colors = [
        'blue', 'green', 'red'  # 为每个指标分配不同颜色
    ]
    title = 'Visualization of Validation Test Scores'
    plot_file_name = "val_test_score.png"
    plot_metrics(folder, val_columns, colors, title, plot_file_name, smooth=False)


def plot_clip_ratio(folder):
    val_columns = [
        'response_length/clip_ratio'
    ]
    colors = [
        'blue'  # 为每个指标分配不同颜色
    ]
    title = 'Visualization of Response Length Clip Ratio'
    plot_file_name = "clip_ratio.png"
    plot_metrics(folder, val_columns, colors, title, plot_file_name, smooth=True)


def plot_data(folder, no_ratio=False):
    try:
        plot_accuracy_vs_length(folder)
    except:
        pass
    if not no_ratio:
        plot_different_accuracy_ratio(folder)
    plot_actor_loss_metrics(folder)
    plot_actor_other_metrics(folder)
    plot_critic_metrics(folder)
    plot_clip_ratio(folder)
    plot_val_metrics(folder)
    plt.show()


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Process log files and generate CSV.')
    parser.add_argument('--folder', type=str, help='Path to the folder containing the log file.',
                        default='ckpt/codev_distill_16k')
    parser.add_argument('--no_ratio', action="store_true", help='Not to plot correct_0\%-50\%-100\%_ratios.')
    args = parser.parse_args()
    process_log(args.folder)
    plot_data(args.folder, no_ratio=args.no_ratio)