Commit 1a8e110a by nzy

additional experiments: Chain of Verification(cov)

TODO: refactor vllm_chatcomplete
parent 64012d3d
# Additional Experiment:
# Is reasoning really work? Let's verify step by step.
import argparse
from itertools import chain
from pathlib import Path
from utils import load_json
from utils_dataset import mk_critic_qa, mk_sft_item, mk_sft_dataset_info, save_dataset
from utils_vllm import vllm_chatcomplete
verify_prompt = "Let's verify step by step."
def mk_cov_prompt(is_correct):
if is_correct:
prompt1 = "This code is correct."
else:
prompt1 = "This code is incorrect."
return [{"role": "user", "content": prompt1 + " " + verify_prompt}]
def convert_preference_to_vot_prompt(item):
message = item["messages"][0]["content"]
chosen = item["chosen"]["content"]
rejected = item["rejected"]["content"]
messages1 = mk_critic_qa(message, chosen) + mk_cov_prompt(True)
messages2 = mk_critic_qa(message, rejected) + mk_cov_prompt(False)
return mk_sft_item(messages1), mk_sft_item(messages2)
def convert_cov_to_cov_dataset(item):
item["messages"][2]["content"] = verify_prompt
return item
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str)
parser.add_argument("--preference_dataset", type=str)
parser.add_argument("--llamafactory", type=str)
parser.add_argument("--dataset_name", type=str)
parser.add_argument("--output_dir", type=str)
args = parser.parse_args()
output_dir = Path(args.output_dir)
preference_dataset = load_json(args.preference_dataset)
cov_prompts = list(chain(*convert_preference_to_vot_prompt(preference_dataset)))
sampling_params = dict(n=1, temperature=0.8, max_tokens=2048)
reason_path = output_dir / "cov.jsonl"
covs = vllm_chatcomplete(args.model, cov_prompts, reason_path, sampling_params)
dataset = list(map(convert_cov_to_cov_dataset, covs))
dataset_info = mk_sft_dataset_info(args.dataset_name)
save_dataset(args.llamafactory, dataset_info, dataset)
......@@ -78,8 +78,11 @@ def score_worker(cuda_device, prompts, model_path, score_token):
return result
def vllm_chatcomplete(model_path, prompt_path, output_path, sampling_params):
prompts = load_jsonl(prompt_path)
def vllm_chatcomplete(model_path, prompts, output_path, sampling_params):
if isinstance(prompts, str):
prompts = load_jsonl(prompts)
else:
assert isinstance(prompts, list)
# Respect the slurm's gpu allocation
cuda_devices = os.environ["CUDA_VISIBLE_DEVICES"].split(',')
......@@ -99,6 +102,7 @@ def vllm_chatcomplete(model_path, prompt_path, output_path, sampling_params):
results = list(chain(*nested_results))
print(f"size of dataset: {len(results)}")
save_jsonl(results, output_path)
return results
def vllm_score(model_path, prompt_path, output_path, score_token):
......@@ -122,3 +126,4 @@ def vllm_score(model_path, prompt_path, output_path, score_token):
results = list(chain(*nested_results))
print(f"size of dataset: {len(results)}")
save_jsonl(results, output_path)
return results
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment