Commit 4fdf6d7d by nanziyuan

step4: test_critic_model fix bugs

parent b78d979f
from utils_vllm import vllm_chatcomplete, vllm_score from utils_vllm_bkp import vllm_chatcomplete, vllm_score
from utils import read_config from utils import read_config, load_jsonl, save_jsonl
from utils_metric import group_results, score_pass_at_k
from utils_preference_dataset import mk_critic_verify
from transformers import AutoTokenizer from transformers import AutoTokenizer
...@@ -7,19 +9,33 @@ if __name__ == "__main__": ...@@ -7,19 +9,33 @@ if __name__ == "__main__":
cfg = read_config() cfg = read_config()
vllm_chatcomplete( vllm_chatcomplete(
cfg["critic"]["model_path"], cfg["critic"]["model_path"],
cfg["dataset"]["minimal_test_path"], cfg["critic"]["test"]["prompt_path"],
cfg["critic"]["test"]["reason_result_path"], cfg["critic"]["test"]["reason_result_path"],
cfg["critic"]["test"]["sampling_params"] cfg["critic"]["test"]["sampling_params"]
) )
tokenizer = AutoTokenizer.from_pretrained(cfg["model"]) tokenizer = AutoTokenizer.from_pretrained(cfg["model"])
score_tokens = tokenizer.encode("Yes") score_tokens = tokenizer.encode("Yes", add_special_tokens=False)
assert len(score_tokens) == 1 assert len(score_tokens) == 1
score_token = score_tokens[0] score_token = score_tokens[0]
reason_results = load_jsonl(cfg["critic"]["test"]["reason_result_path"])
score_prompts = []
for item in reason_results:
item["messages"] += mk_critic_verify()
score_prompts.append(item)
save_jsonl(score_prompts, "test_score_prompt.jsonl")
vllm_score( vllm_score(
cfg["critic"]["model_path"], cfg["critic"]["model_path"],
cfg["critic"]["test"]["reason_result_path"], "test_score_prompt.jsonl",
cfg["critic"]["test"]["score_result_path"], cfg["critic"]["test"]["score_result_path"],
score_token score_token
) )
\ No newline at end of file
results = load_jsonl(cfg["critic"]["test"]["score_result_path"])
groups = group_results(results, cfg["apps"])
eval_results = [score_pass_at_k(groups, k, "critic") for k in range(1, 16)]
save_jsonl(eval_results, cfg["critic"]["test"]["eval_result_path"])
print(eval_results)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment