Commit 96e8a313 by nanziyuan

step 2-4: fix sftorm bugs

parent 1e69b079
......@@ -4,8 +4,8 @@
# 2. Using SFT (Supervised Fine-Tuning) directly
# This experiment aims to fairly compare these two approaches.
from utils import load_json, save_json
from utils_preference_dataset import mk_critic_qa, mk_critic_verify, mk_sft_item
from utils import load_json, save_json, read_config
from utils_preference_dataset import mk_critic_qa, mk_critic_verify, mk_sft_item, mk_sft_dataset_info
def convert_preference_to_sft(item):
......@@ -19,10 +19,15 @@ def convert_preference_to_sft(item):
if __name__ == "__main__":
cfg = read_config()
preference_path = None
preference_path = cfg["preference_dataset"]["min_edit_distance"]["preference_dataset_path"]
preference_dataset = load_json(preference_path)
sft_dataset = []
for item in preference_dataset:
sft_dataset.extend(convert_preference_to_sft(item))
\ No newline at end of file
sft_dataset.extend(convert_preference_to_sft(item))
dataset_info = mk_sft_dataset_info(cfg["sftorm"]["dataset_name"])
save_json(sft_dataset, cfg["sftorm"]["dataset_path"])
save_json(dataset_info, cfg["sftorm"]["dataset_info_path"])
......@@ -17,6 +17,7 @@ cutoff_len: 4096
max_samples: 10000
overwrite_cache: true
preprocessing_num_workers: 16
mask_history: true
### output
output_dir: {critic_model_path}
......@@ -49,12 +50,12 @@ def mk_llamafactory_sft_yaml(cfg):
train_str = train_yaml.format(
model_path=cfg["model"],
deepspeed_config_path=cfg[model_type]["train"]["deepspeed_cfg_path"],
dataset_name=cfg[model_type]["train"]["dataset_name"],
critic_model_path=cfg[model_type]["model_path"],
dataset_name=cfg[model_type]["dataset_name"],
critic_model_path=cfg[model_type]["model_path"],
)
f.write(train_str)
if __name__ == "__main__":
cfg = read_config(["model_type"])
mk_llamafactory_sft_yaml(cfg)
\ No newline at end of file
mk_llamafactory_sft_yaml(cfg)
......@@ -21,7 +21,7 @@ if __name__ == "__main__":
save_jsonl(test_dataset, cfg["sftorm"]["test"]["prompt_path"])
tokenizer = AutoTokenizer.from_pretrained(cfg["sftorm"]["model_path"])
score_tokens = tokenizer.encode("Yes")
score_tokens = tokenizer.encode("Yes", add_special_tokens=False)
assert len(score_tokens) == 1
score_token = score_tokens[0]
......@@ -36,4 +36,4 @@ if __name__ == "__main__":
groups = group_results(results, cfg["apps"])
eval_results = [score_pass_at_k(groups, k, "sft-orm") for k in range(1, 16)]
save_jsonl(eval_results, cfg["sftorm"]["test"]["eval_result_path"])
print(eval_results)
\ No newline at end of file
print(eval_results)
......@@ -211,4 +211,5 @@ def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"]))
print("labels:\n{}".format(tokenizer.decode(valid_labels, skip_special_tokens=False)))
\ No newline at end of file
print("labels:\n{}".format(tokenizer.decode(valid_labels, skip_special_tokens=False)))
......@@ -68,7 +68,8 @@ def mk_critic_qa(instruction, code):
]
def mk_critic_verify(answer: bool | None = None):
def mk_critic_verify(answer=None):
# answer: bool or none
message = [{"role": "user", "content": "Is the code correct (Yes/No)?"}]
if answer is not None:
response = "Yes" if answer else "No"
......
......@@ -6,7 +6,7 @@ from itertools import chain
from functools import partial
from utils import load_jsonl, save_jsonl
import numpy as np
def generate_worker(cuda_device, prompts, model_path, sampling_params):
os.environ["CUDA_VISIBLE_DEVICES"] = cuda_device
......@@ -53,8 +53,8 @@ def score_worker(cuda_device, prompts, model_path, score_token):
vllm_sampling_params = SamplingParams(
n=1,
temperature=0,
max_tokens=1,
logprobs=1000
max_tokens=5,
logprobs=20
)
text_prompts = [tokenizer.apply_chat_template(item["messages"], tokenize=False, add_generation_prompt=True) for item in prompts]
......@@ -65,12 +65,15 @@ def score_worker(cuda_device, prompts, model_path, score_token):
for response in output.outputs:
# response.logprobs: list[dict[int, Logprob]] https://github.com/vllm-project/vllm/blob/main/vllm/sequence.py
sample_logprobs = response.logprobs
logprob = sample_logprobs[0].get(score_token)
logprob = sample_logprobs[1].get(score_token)
newitem = item.copy()
# model always return 4 tokens, ['\n', 'Yes'/'No', '\n', <EOT>]
if logprob:
newitem["score"] = logprob.logprob
newitem["score"] = np.exp(logprob.logprob)
newitem["critic_text"] = response.text
else:
newitem["score"] = 0
newitem["score"] = 0
newitem["critic_text"] = response.text
result.append(newitem)
return result
......@@ -118,4 +121,4 @@ def vllm_score(model_path, prompt_path, output_path, score_token):
results = list(chain(*nested_results))
print(f"size of dataset: {len(results)}")
save_jsonl(results, output_path)
\ No newline at end of file
save_jsonl(results, output_path)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment