import argparse
from pathlib import Path
import random
from codecritic.utils.json import load_jsonl, save_jsonl


def add_mask_and_score(messages):
    for idx, turn in enumerate(messages):
        if idx != 3:
            turn["mask"] = True
        else:
            turn["mask"] = False

        if idx == 5:
            turn["score"] = True
    return messages


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--path", type=str)
    parser.add_argument("--split", action="store_true")
    args = parser.parse_args()

    dataset = load_jsonl(args.path)

    for item in dataset:
        item["messages"] = add_mask_and_score(item["messages"])

    if args.split:
        random.shuffle(dataset)
        split_len = int(len(dataset) * 0.01)
        test = dataset[:split_len]
        train = dataset[split_len:]

        dataset_path = Path(args.path).parent
        save_jsonl(train, dataset_path / "train.jsonl")
        save_jsonl(test, dataset_path / "test.jsonl")
    else:
        save_jsonl(dataset, args.path)
