from utils import read_config

train_yaml = """\
### model
model_name_or_path: {model_path}

### method
stage: sft
do_train: true
finetuning_type: full
deepspeed: {deepspeed_config_path}

### dataset
dataset: {dataset_name}
template: deepseekcoder
cutoff_len: 4096
max_samples: 10000
overwrite_cache: true
preprocessing_num_workers: 16
mask_history: true

### output
output_dir: {critic_model_path}
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true

### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 2
learning_rate: 1.0e-5
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000

### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500
"""


def mk_llamafactory_sft_yaml(cfg):
    model_type = cfg["model_type"]
    with open(cfg[model_type]["train"]["train_yaml_path"], "w") as f:
        train_str = train_yaml.format(
            model_path=cfg["model"],
            deepspeed_config_path=cfg[model_type]["train"]["deepspeed_cfg_path"],
            dataset_name=cfg[model_type]["dataset_name"],
            critic_model_path=cfg[model_type]["model_path"],
        )
        f.write(train_str)


if __name__ == "__main__":
    cfg = read_config(["model_type"])
    mk_llamafactory_sft_yaml(cfg)
