from utils import read_config

train_yaml = """\
### model
model_name_or_path: {model_path}

### method
stage: rm
do_train: true
finetuning_type: full
deepspeed: {deepspeed_config_path}

### dataset
dataset: {dataset_name}
template: deepseekcoder
cutoff_len: 4096
max_samples: 10000
overwrite_cache: true
preprocessing_num_workers: 16

### output
output_dir: {orm_model_path}
logging_steps: 10
save_steps: 100
plot_loss: true
overwrite_output_dir: true

### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 1.0e-5
num_train_epochs: 1.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000

### eval
val_size: 0.01
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500
"""

test_yaml = """\
model_name_or_path: {orm_model_path}
template: {model_template}
stage: rm
"""

def mk_llamafactory_orm_yaml(cfg):
    orm_dataset = cfg["orm_dataset"]
    orm_cfg = cfg["orm"][orm_dataset]
    data_cfg = cfg["preference_dataset"][orm_dataset]

    with open(orm_cfg["train_yaml_path"], "w") as f:
        train_str = train_yaml.format(
            model_path=cfg["model"],
            dataset_name=data_cfg["dataset_name"],
            orm_model_path=orm_cfg["model_path"],
            deepspeed_config_path=orm_cfg["deepspeed_cfg_path"]
        )
        f.write(train_str)

    orm_cfg = cfg["orm"][orm_dataset]
    with open(orm_cfg["test_yaml_path"], "w") as f:
        test_str = test_yaml.format(
            orm_model_path=orm_cfg["model_path"],
            model_template=cfg["llamafactory_model_template"]
        )
        f.write(test_str)


if __name__ == "__main__":
    cfg = read_config(["orm_dataset"])
    mk_llamafactory_orm_yaml(cfg)