Commit 2a43e44e by nanziyuan

fix apps bugs & filter rationale

parent 97085438
import argparse import argparse
from collections import defaultdict from collections import defaultdict
from functools import partial from functools import partial
import pprint
import random
from vllm import SamplingParams from vllm import SamplingParams
from datasets import load_dataset from datasets import load_dataset
...@@ -35,28 +37,30 @@ if __name__ == "__main__": ...@@ -35,28 +37,30 @@ if __name__ == "__main__":
ds[item["task_id"]][item["solution_id"]] = item ds[item["task_id"]][item["solution_id"]] = item
# Step1 Generate hints # Step1 Generate hints
# hint_prompts = [] hint_prompts = []
# for pair in pairinfo: for pair in pairinfo:
# task_id, chosen_id, rejected_id = pair["task_id"], pair["chosen"], pair["rejected"] task_id, chosen_id, rejected_id = pair["task_id"], pair["chosen"], pair["rejected"]
# chosen, rejected = ds[task_id][chosen_id], ds[task_id][rejected_id] chosen, rejected = ds[task_id][chosen_id], ds[task_id][rejected_id]
# prompt = promptlib.process_to_hint_prompt(chosen, rejected, args.level) prompt = promptlib.process_to_hint_prompt(chosen, rejected, args.level)
# hint_prompts.append(prompt) hint_prompts.append(prompt)
# sampling_params = SamplingParams( sampling_params = SamplingParams(
# n=1, n=1,
# temperature=0, temperature=0,
# top_p=0.95, top_p=0.95,
# max_tokens=2048, max_tokens=2048,
# ) )
# worker = partial(generate_worker, model_path=args.model, sampling_params=sampling_params) worker = partial(generate_worker, model_path=args.model, sampling_params=sampling_params)
# hint_responses = model_map(worker, hint_prompts, args.tp) hint_responses = model_map(worker, hint_prompts, args.tp)
pprint.pp(hint_responses[0])
# hints = [promptlib.postprocess_to_hint(x) for x in hint_responses] hints = [promptlib.postprocess_to_hint(x) for x in hint_responses]
# # hints: {"dataset"..., "task_id": ..., "solution_id": ..., "hints": ...} # hints: {"dataset"..., "task_id": ..., "solution_id": ..., "hints": ...}
# save_jsonl(hints, args.output + ".hints") # save_jsonl(hint_responses, args.output + ".hint_responses")
hints = load_jsonl(args.output + ".hints") save_jsonl(hints, args.output + ".hints")
# hints = load_jsonl(args.output + ".hints")
hints_dict = defaultdict(dict) hints_dict = defaultdict(dict)
for item in hints: for item in hints:
...@@ -87,6 +91,7 @@ if __name__ == "__main__": ...@@ -87,6 +91,7 @@ if __name__ == "__main__":
worker = partial(generate_worker, model_path=args.model, sampling_params=sampling_params) worker = partial(generate_worker, model_path=args.model, sampling_params=sampling_params)
reason_responses = model_map(worker, reason_prompts, args.tp) reason_responses = model_map(worker, reason_prompts, args.tp)
pprint.pp(reason_responses[0])
save_jsonl(reason_responses, args.output + ".reason") save_jsonl(reason_responses, args.output + ".reason")
# Step3 Verify reasoning results # Step3 Verify reasoning results
...@@ -116,6 +121,7 @@ if __name__ == "__main__": ...@@ -116,6 +121,7 @@ if __name__ == "__main__":
worker = partial(generate_worker, model_path=args.model, sampling_params=sampling_params) worker = partial(generate_worker, model_path=args.model, sampling_params=sampling_params)
verify_responses = model_map(worker, reason_responses, args.tp) verify_responses = model_map(worker, reason_responses, args.tp)
pprint.pp(verify_responses[0])
print("verify response size: {}".format(len(verify_responses))) print("verify response size: {}".format(len(verify_responses)))
# postprocess verify_response. # postprocess verify_response.
...@@ -150,8 +156,7 @@ if __name__ == "__main__": ...@@ -150,8 +156,7 @@ if __name__ == "__main__":
print("Corrects (execution consistent) size: {}".format(len(corrects))) print("Corrects (execution consistent) size: {}".format(len(corrects)))
print("Incorrects (execution consistent) size: {}".format(len(incorrects))) print("Incorrects (execution consistent) size: {}".format(len(incorrects)))
# Step4 Remove hints and Reformat to a SFT dataset # Step4 Reformat to a SFT dataset
# extract reasoning sets
sft = [] sft = []
for item in verify_passed: for item in verify_passed:
...@@ -171,4 +176,26 @@ if __name__ == "__main__": ...@@ -171,4 +176,26 @@ if __name__ == "__main__":
sft.append(line) sft.append(line)
print("Size of sft dataset: {}".format(len(sft))) print("Size of sft dataset: {}".format(len(sft)))
pprint.pp(sft[0])
save_jsonl(sft, args.output) save_jsonl(sft, args.output)
# Step5 keep 1 rationale for 1 solution
task_solution_map = defaultdict(lambda: defaultdict(list))
for entry in sft:
task_id = entry["task_id"]
solution_id = entry["solution_id"]
task_solution_map[task_id][solution_id.split("_")[0]].append(entry)
# Step 2: Keep only one reasoning for each solution
processed_dataset = []
for task_id, solution_map in task_solution_map.items():
for solution, reasoning_list in solution_map.items():
if len(reasoning_list) > 1:
selected_index = random.choice(range(1, len(reasoning_list)))
processed_dataset.append(reasoning_list[selected_index])
else:
processed_dataset.append(reasoning_list[0])
save_jsonl(processed_dataset, args.output.split('.')[0] + "-filtered.jsonl")
...@@ -84,7 +84,7 @@ def evaluate_code_samples(code_samples, apps): ...@@ -84,7 +84,7 @@ def evaluate_code_samples(code_samples, apps):
cpu_num = multiprocessing.cpu_count() // 2 cpu_num = multiprocessing.cpu_count() // 2
chunksize = max(len(code_samples) // (cpu_num * 10), 1) chunksize = max(len(code_samples) // (cpu_num * 10), 1)
results = process_map( results = process_map(
test_generation, args, max_workers=cpu_num, chunksize=chunksize test_generation, args, max_workers=cpu_num, chunksize=1
) )
return results return results
...@@ -100,7 +100,7 @@ def evaluate(code_samples, apps): ...@@ -100,7 +100,7 @@ def evaluate(code_samples, apps):
The 'loop_num' parameter controls the number of times the function will be retried until the test framework obtains a consistent result. The 'loop_num' parameter controls the number of times the function will be retried until the test framework obtains a consistent result.
""" """
all_results = [] all_results = []
for _ in range(2): for _ in range(1):
results = evaluate_code_samples(code_samples, apps) results = evaluate_code_samples(code_samples, apps)
all_results.append(results) all_results.append(results)
......
...@@ -350,9 +350,9 @@ def run_test(sample, test=None, debug=False): ...@@ -350,9 +350,9 @@ def run_test(sample, test=None, debug=False):
# try by converting the stuff into split up list # try by converting the stuff into split up list
if isinstance(in_outs["outputs"][index], list): if isinstance(in_outs["outputs"][index], list):
for tmp_index, i in enumerate(in_outs["outputs"][index]): for tmp_index, i in enumerate(in_outs["outputs"][index]):
in_outs["outputs"][index][tmp_index] = set(i.split()) in_outs["outputs"][index][tmp_index] = list(i.split())
else: else:
in_outs["outputs"][index] = set(in_outs["outputs"][index].split()) in_outs["outputs"][index] = list(in_outs["outputs"][index].split())
try: try:
tmp_result = (output == in_outs["outputs"][index]) tmp_result = (output == in_outs["outputs"][index])
...@@ -371,14 +371,14 @@ def run_test(sample, test=None, debug=False): ...@@ -371,14 +371,14 @@ def run_test(sample, test=None, debug=False):
output[tmp_index] = i.split() output[tmp_index] = i.split()
output = list(filter(len, output)) output = list(filter(len, output))
for tmp_index, i in enumerate(output): for tmp_index, i in enumerate(output):
output[tmp_index] = set(i) output[tmp_index] = list(i)
else: else:
output = output.split() output = output.split()
output = list(filter(len, output)) output = list(filter(len, output))
output = set(output) output = list(output)
try: try:
tmp_result = (set(frozenset(s) for s in output) == set(frozenset(s) for s in in_outs["outputs"][index])) tmp_result = (list(list(s) for s in output) == list(list(s) for s in in_outs["outputs"][index]))
except Exception as e: except Exception as e:
if debug: if debug:
print(f"Failed check5 exception = {e}") print(f"Failed check5 exception = {e}")
...@@ -386,8 +386,8 @@ def run_test(sample, test=None, debug=False): ...@@ -386,8 +386,8 @@ def run_test(sample, test=None, debug=False):
# if they are all numbers, round so that similar numbers are treated as identical # if they are all numbers, round so that similar numbers are treated as identical
try: try:
tmp_result = tmp_result or (set(frozenset(round(float(t),3) for t in s) for s in output) ==\ tmp_result = tmp_result or (list(list(round(float(t),3) for t in s) for s in output) ==\
set(frozenset(round(float(t),3) for t in s) for s in in_outs["outputs"][index])) list(list(round(float(t),3) for t in s) for s in in_outs["outputs"][index]))
except Exception as e: except Exception as e:
if debug: if debug:
print(f"Failed check6 exception = {e}") print(f"Failed check6 exception = {e}")
......
...@@ -3,55 +3,56 @@ set -xe ...@@ -3,55 +3,56 @@ set -xe
model="/lustre/S/huangdi/open_for_out/models/Qwen2.5-Coder-7B-Instruct/" model="/lustre/S/huangdi/open_for_out/models/Qwen2.5-Coder-7B-Instruct/"
project="/lustre/S/nanziyuan/projects/ccc" project="/lustre/S/nanziyuan/projects/ccc"
modelname="qwen25_coder_inst" modelname="qwen25_coder_inst"
data="${project}/data"
trainset="${project}/data/train/${modelname}-apps-train.jsonl" trainset="${data}/train/${modelname}-apps-train.jsonl"
testset="${project}/data/test/${modelname}-apps-test.jsonl" testset="${data}/test/${modelname}-apps-test.jsonl"
train_selected_pairs="${project}/data/train/${modelname}-apps-train-selected_pairs.jsonl" train_selected_pairs="${data}/train/${modelname}-apps-train-selected_pairs.jsonl"
apps="/lustre/S/nanziyuan/datasets/apps/" apps="/lustre/S/nanziyuan/datasets/apps/"
sft="${project}/data/train/${modelname}-sft.jsonl" sft="${data}/train/${modelname}-sft.jsonl"
ftmodel="${project}/model/qwen25_coder_inst_7b-algolr" ftmodel="${project}/model/qwen25_coder_inst_7b-algolr"
testset="/lustre/S/nanziyuan/projects/ccc/data/test/qwen25_coder_inst-apps-test.jsonl" testset="${data}/test/qwen25_coder_inst-apps-test.jsonl"
evalresults="/lustre/S/nanziyuan/projects/ccc/data/eval/qwen25_code_inst-apps-test-algolr-score.jsonl" evalresults="${data}/eval/qwen25_code_inst-apps-test-algolr-score.jsonl"
# export CUDA_VISIBLE_DEVICES=0,1,2,3 # export CUDA_VISIBLE_DEVICES=0,1,2,3
# python -m codecritic.cli.algolr \ python -m codecritic.cli.algolr \
# --model ${model} \ --model ${model} \
# --dataset ${trainset} \ --dataset ${trainset} \
# --pairinfo ${train_selected_pairs} \ --pairinfo ${train_selected_pairs} \
# --apps ${apps} \ --apps ${apps} \
# --output ${sft} \ --output ${sft} \
# --level beginner \ --level beginner \
# --tp 1 --tp 1
# deepspeed --module \ deepspeed --module \
# openrlhf.cli.train_sft \ openrlhf.cli.train_sft \
# --max_len 4096 \ --max_len 4096 \
# --dataset ${sft} \ --dataset ${sft} \
# --input_key question \ --input_key question \
# --output_key response \ --output_key response \
# --apply_chat_template \ --apply_chat_template \
# --train_batch_size 256 \ --train_batch_size 256 \
# --micro_train_batch_size 2 \ --micro_train_batch_size 2 \
# --max_samples 500000 \ --max_samples 500000 \
# --pretrain ${model} \ --pretrain ${model} \
# --save_path ${ftmodel} \ --save_path ${ftmodel} \
# --save_steps -1 \ --save_steps -1 \
# --logging_steps 1 \ --logging_steps 1 \
# --eval_steps -1 \ --eval_steps -1 \
# --zero_stage 2 \ --zero_stage 2 \
# --max_epochs 1 \ --max_epochs 1 \
# --bf16 \ --bf16 \
# --flash_attn \ --flash_attn \
# --learning_rate 5e-6 \ --learning_rate 5e-6 \
# --load_checkpoint \ --load_checkpoint \
# --gradient_checkpointing \ --gradient_checkpointing \
# --use_tensorboard "${ftmodel}_log" --use_tensorboard "${ftmodel}_log"
python -m codecritic.cli.test_genrm \ python -m codecritic.cli.test_genrm \
......
...@@ -12,7 +12,7 @@ train_selected_pairs="${project}/data/train/${modelname}-apps-train-selected_pai ...@@ -12,7 +12,7 @@ train_selected_pairs="${project}/data/train/${modelname}-apps-train-selected_pai
reward_ds="${project}/data/train/${modelname}-apps-train-reward_dataset.jsonl" reward_ds="${project}/data/train/${modelname}-apps-train-reward_dataset.jsonl"
export CUDA_VISIBLE_DEVICES=0,1,2,3 # export CUDA_VISIBLE_DEVICES=0,1,2,3
## Sampling ## Sampling
## APPS ## APPS
......
...@@ -9,29 +9,29 @@ ftmodel="${project}/model/qwen25_coder_inst_7b-orm" ...@@ -9,29 +9,29 @@ ftmodel="${project}/model/qwen25_coder_inst_7b-orm"
testset="/lustre/S/nanziyuan/projects/ccc/data/test/qwen25_coder_inst-apps-test.jsonl" testset="/lustre/S/nanziyuan/projects/ccc/data/test/qwen25_coder_inst-apps-test.jsonl"
evalresults="/lustre/S/nanziyuan/projects/ccc/data/eval/qwen25_code_inst-apps-test-orm-score.jsonl" evalresults="/lustre/S/nanziyuan/projects/ccc/data/eval/qwen25_code_inst-apps-test-orm-score.jsonl"
# deepspeed --module \ deepspeed --module \
# openrlhf.cli.train_rm \ openrlhf.cli.train_rm \
# --save_path ${ftmodel} \ --save_path ${ftmodel} \
# --save_steps -1 \ --save_steps -1 \
# --logging_steps 1 \ --logging_steps 1 \
# --eval_steps -1 \ --eval_steps -1 \
# --train_batch_size 256 \ --train_batch_size 256 \
# --micro_train_batch_size 1 \ --micro_train_batch_size 1 \
# --pretrain ${model} \ --pretrain ${model} \
# --bf16 \ --bf16 \
# --max_epochs 1 \ --max_epochs 1 \
# --max_len 8192 \ --max_len 8192 \
# --zero_stage 3 \ --zero_stage 3 \
# --learning_rate 9e-6 \ --learning_rate 9e-6 \
# --dataset ${dataset} \ --dataset ${dataset} \
# --apply_chat_template \ --apply_chat_template \
# --prompt_key messages \ --prompt_key messages \
# --chosen_key chosen \ --chosen_key chosen \
# --rejected_key rejected \ --rejected_key rejected \
# --flash_attn \ --flash_attn \
# --load_checkpoint \ --load_checkpoint \
# --gradient_checkpointing \ --gradient_checkpointing \
# --use_tensorboard "${ftmodel}_log" --use_tensorboard "${ftmodel}_log"
start_server() { start_server() {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment