Commit e86afa63 by nanziyuan

fix bug & add a new cli tool to view data

parent 984c7815
from rich.console import Console
from rich.panel import Panel
from rich.text import Text
import json
import sys
def display_messages(messages, console):
for message in messages:
role = message["role"]
content = message["content"]
text = Text(content, style="bold")
panel = Panel(text, title=f"[{role.capitalize()}]")
console.print(panel)
if __name__ == "__main__":
json_string = sys.stdin.read()
console = Console()
messages = json.loads(json_string)
display_messages(messages, console)
import argparse
from functools import partial
import json
from tqdm import tqdm
import requests
......@@ -51,18 +50,17 @@ if __name__ == "__main__":
result_dir.mkdir(exist_ok=True)
# compute score
score_path = result_dir / "scores.jsonl"
test_dataset = load_jsonl(args.test)
server_url = "http://0.0.0.0:5000/get_reward"
tokenizer = AutoTokenizer.from_pretrained(args.model)
fun = partial(test_reward_model, server_url=server_url, tokenizer=tokenizer)
results = [fun(server_url, item) for item in tqdm(test_dataset)]
results = [test_reward_model(server_url, item, tokenizer) for item in tqdm(test_dataset)]
score_path = result_dir / "scores.jsonl"
save_jsonl(results, score_path)
# compute pass@k
eval_result_path = result_dir / "passk.jsonl"
# results = load_jsonl(result_path)
results = load_jsonl(score_path)
groups = group_results(results, args.apps)
eval_results = [score_pass_at_k(groups, k, home_path.stem) for k in range(1, 16)]
eval_result_path = result_dir / "passk.jsonl"
save_jsonl(eval_results, eval_result_path)
pprint.pp(eval_results)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment