Commit 9c8fbf86 by nanziyuan

Merge branch 'main' of http://62.234.201.16/nzy/codecritic

parents ca962f5c 83a77cdd
import argparse
import os
from pathlib import Path
from tqdm import tqdm
import torch
import torch.nn.functional as F
import transformers
import accelerate
from codecritic.utils.json import load_jsonl, save_jsonl
from codecritic.data.verify import get_score_token_id
@torch.inference_mode()
def hf_score(model, tokenizer, prompts):
score_token = get_score_token_id(tokenizer)
with accelerator.split_between_processes(prompts) as partial_prompts:
results = []
for item in tqdm(partial_prompts):
input_ids = tokenizer.apply_chat_template(
item["messages"], add_generation_prompt=True, return_tensors="pt"
).to("cuda")
output = model(**input_ids)
next_token_logits = output.logits[0, -1, :]
score = F.softmax(next_token_logits, dim=0)[score_token].item()
results.append({**item, "score": score})
return accelerate.utils.gather_object(results)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str)
parser.add_argument("--prompts", type=str)
parser.add_argument("--out", type=str)
args = parser.parse_args()
home_path = Path(args.model).parent
result_dir = home_path / "hf_eval"
result_dir.mkdir(exist_ok=True)
prompts = load_jsonl(args.prompts)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
accelerator = accelerate.Accelerator()
tokenizer = transformers.AutoTokenizer.from_pretrained(args.model)
model = transformers.AutoModelForCausalLM.from_pretrained(
args.model, device_map={"": accelerator.process_index}
)
# model.generation_config.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id else tokenizer.eos_token_id
# model.generation_config.eos_token_id = tokenizer.eos_token_id
model.eval()
accelerator.wait_for_everyone()
results = hf_score(model, tokenizer, prompts)
if accelerator.is_main_process:
save_jsonl(results, args.out)
if __name__ == "__main__":
main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment