Commit 2a43e44e by nanziyuan

fix apps bugs & filter rationale

parent 97085438
import argparse import argparse
from collections import defaultdict from collections import defaultdict
from functools import partial from functools import partial
import pprint
import random
from vllm import SamplingParams from vllm import SamplingParams
from datasets import load_dataset from datasets import load_dataset
...@@ -35,28 +37,30 @@ if __name__ == "__main__": ...@@ -35,28 +37,30 @@ if __name__ == "__main__":
ds[item["task_id"]][item["solution_id"]] = item ds[item["task_id"]][item["solution_id"]] = item
# Step1 Generate hints # Step1 Generate hints
# hint_prompts = [] hint_prompts = []
# for pair in pairinfo: for pair in pairinfo:
# task_id, chosen_id, rejected_id = pair["task_id"], pair["chosen"], pair["rejected"] task_id, chosen_id, rejected_id = pair["task_id"], pair["chosen"], pair["rejected"]
# chosen, rejected = ds[task_id][chosen_id], ds[task_id][rejected_id] chosen, rejected = ds[task_id][chosen_id], ds[task_id][rejected_id]
# prompt = promptlib.process_to_hint_prompt(chosen, rejected, args.level) prompt = promptlib.process_to_hint_prompt(chosen, rejected, args.level)
# hint_prompts.append(prompt) hint_prompts.append(prompt)
# sampling_params = SamplingParams( sampling_params = SamplingParams(
# n=1, n=1,
# temperature=0, temperature=0,
# top_p=0.95, top_p=0.95,
# max_tokens=2048, max_tokens=2048,
# ) )
# worker = partial(generate_worker, model_path=args.model, sampling_params=sampling_params) worker = partial(generate_worker, model_path=args.model, sampling_params=sampling_params)
# hint_responses = model_map(worker, hint_prompts, args.tp) hint_responses = model_map(worker, hint_prompts, args.tp)
pprint.pp(hint_responses[0])
# hints = [promptlib.postprocess_to_hint(x) for x in hint_responses] hints = [promptlib.postprocess_to_hint(x) for x in hint_responses]
# # hints: {"dataset"..., "task_id": ..., "solution_id": ..., "hints": ...} # hints: {"dataset"..., "task_id": ..., "solution_id": ..., "hints": ...}
# save_jsonl(hints, args.output + ".hints") # save_jsonl(hint_responses, args.output + ".hint_responses")
hints = load_jsonl(args.output + ".hints") save_jsonl(hints, args.output + ".hints")
# hints = load_jsonl(args.output + ".hints")
hints_dict = defaultdict(dict) hints_dict = defaultdict(dict)
for item in hints: for item in hints:
...@@ -87,6 +91,7 @@ if __name__ == "__main__": ...@@ -87,6 +91,7 @@ if __name__ == "__main__":
worker = partial(generate_worker, model_path=args.model, sampling_params=sampling_params) worker = partial(generate_worker, model_path=args.model, sampling_params=sampling_params)
reason_responses = model_map(worker, reason_prompts, args.tp) reason_responses = model_map(worker, reason_prompts, args.tp)
pprint.pp(reason_responses[0])
save_jsonl(reason_responses, args.output + ".reason") save_jsonl(reason_responses, args.output + ".reason")
# Step3 Verify reasoning results # Step3 Verify reasoning results
...@@ -116,6 +121,7 @@ if __name__ == "__main__": ...@@ -116,6 +121,7 @@ if __name__ == "__main__":
worker = partial(generate_worker, model_path=args.model, sampling_params=sampling_params) worker = partial(generate_worker, model_path=args.model, sampling_params=sampling_params)
verify_responses = model_map(worker, reason_responses, args.tp) verify_responses = model_map(worker, reason_responses, args.tp)
pprint.pp(verify_responses[0])
print("verify response size: {}".format(len(verify_responses))) print("verify response size: {}".format(len(verify_responses)))
# postprocess verify_response. # postprocess verify_response.
...@@ -150,8 +156,7 @@ if __name__ == "__main__": ...@@ -150,8 +156,7 @@ if __name__ == "__main__":
print("Corrects (execution consistent) size: {}".format(len(corrects))) print("Corrects (execution consistent) size: {}".format(len(corrects)))
print("Incorrects (execution consistent) size: {}".format(len(incorrects))) print("Incorrects (execution consistent) size: {}".format(len(incorrects)))
# Step4 Remove hints and Reformat to a SFT dataset # Step4 Reformat to a SFT dataset
# extract reasoning sets
sft = [] sft = []
for item in verify_passed: for item in verify_passed:
...@@ -171,4 +176,26 @@ if __name__ == "__main__": ...@@ -171,4 +176,26 @@ if __name__ == "__main__":
sft.append(line) sft.append(line)
print("Size of sft dataset: {}".format(len(sft))) print("Size of sft dataset: {}".format(len(sft)))
pprint.pp(sft[0])
save_jsonl(sft, args.output) save_jsonl(sft, args.output)
# Step5 keep 1 rationale for 1 solution
task_solution_map = defaultdict(lambda: defaultdict(list))
for entry in sft:
task_id = entry["task_id"]
solution_id = entry["solution_id"]
task_solution_map[task_id][solution_id.split("_")[0]].append(entry)
# Step 2: Keep only one reasoning for each solution
processed_dataset = []
for task_id, solution_map in task_solution_map.items():
for solution, reasoning_list in solution_map.items():
if len(reasoning_list) > 1:
selected_index = random.choice(range(1, len(reasoning_list)))
processed_dataset.append(reasoning_list[selected_index])
else:
processed_dataset.append(reasoning_list[0])
save_jsonl(processed_dataset, args.output.split('.')[0] + "-filtered.jsonl")
...@@ -84,7 +84,7 @@ def evaluate_code_samples(code_samples, apps): ...@@ -84,7 +84,7 @@ def evaluate_code_samples(code_samples, apps):
cpu_num = multiprocessing.cpu_count() // 2 cpu_num = multiprocessing.cpu_count() // 2
chunksize = max(len(code_samples) // (cpu_num * 10), 1) chunksize = max(len(code_samples) // (cpu_num * 10), 1)
results = process_map( results = process_map(
test_generation, args, max_workers=cpu_num, chunksize=chunksize test_generation, args, max_workers=cpu_num, chunksize=1
) )
return results return results
...@@ -100,7 +100,7 @@ def evaluate(code_samples, apps): ...@@ -100,7 +100,7 @@ def evaluate(code_samples, apps):
The 'loop_num' parameter controls the number of times the function will be retried until the test framework obtains a consistent result. The 'loop_num' parameter controls the number of times the function will be retried until the test framework obtains a consistent result.
""" """
all_results = [] all_results = []
for _ in range(2): for _ in range(1):
results = evaluate_code_samples(code_samples, apps) results = evaluate_code_samples(code_samples, apps)
all_results.append(results) all_results.append(results)
......
# copy from codeparrot/apps_metric/testing_util.py # copy from codeparrot/apps_metric/testing_util.py
# https://huggingface.co/spaces/codeparrot/apps_metric/blob/main/testing_util.py # https://huggingface.co/spaces/codeparrot/apps_metric/blob/main/testing_util.py
# Log: Replace pyext with importlib # Log: Replace pyext with importlib
import json import json
import sys import sys
...@@ -66,7 +66,7 @@ def run_test(sample, test=None, debug=False): ...@@ -66,7 +66,7 @@ def run_test(sample, test=None, debug=False):
""" """
# Disable functionalities that can make destructive changes to the test. # Disable functionalities that can make destructive changes to the test.
reliability_guard() reliability_guard()
if debug: if debug:
print(f"start = {datetime.now().time()}") print(f"start = {datetime.now().time()}")
...@@ -84,7 +84,7 @@ def run_test(sample, test=None, debug=False): ...@@ -84,7 +84,7 @@ def run_test(sample, test=None, debug=False):
if debug: if debug:
print(f"loaded input_output = {datetime.now().time()}") print(f"loaded input_output = {datetime.now().time()}")
if test is None: if test is None:
return in_outs return in_outs
elif test is not None: elif test is not None:
...@@ -92,7 +92,7 @@ def run_test(sample, test=None, debug=False): ...@@ -92,7 +92,7 @@ def run_test(sample, test=None, debug=False):
sol = "import sys\nimport time\nimport itertools\nfrom itertools import accumulate, product, permutations, combinations\nimport collections\nfrom collections import Counter, OrderedDict, deque, defaultdict, ChainMap\nfrom functools import lru_cache\nimport math\nfrom math import sqrt, sin, cos, tan, ceil, fabs, floor, gcd, exp, log, log2\nimport fractions\nfrom typing import List, Tuple\nimport numpy as np\nimport random\nimport heapq\nfrom heapq import *\n" sol = "import sys\nimport time\nimport itertools\nfrom itertools import accumulate, product, permutations, combinations\nimport collections\nfrom collections import Counter, OrderedDict, deque, defaultdict, ChainMap\nfrom functools import lru_cache\nimport math\nfrom math import sqrt, sin, cos, tan, ceil, fabs, floor, gcd, exp, log, log2\nimport fractions\nfrom typing import List, Tuple\nimport numpy as np\nimport random\nimport heapq\nfrom heapq import *\n"
if debug: if debug:
print(f"loading test code = {datetime.now().time()}") print(f"loading test code = {datetime.now().time()}")
if which_type == CODE_TYPE.call_based: if which_type == CODE_TYPE.call_based:
sol += test sol += test
if debug: if debug:
...@@ -124,7 +124,7 @@ def run_test(sample, test=None, debug=False): ...@@ -124,7 +124,7 @@ def run_test(sample, test=None, debug=False):
else: else:
new_test.append(x + "\n") new_test.append(x + "\n")
tmp_test = new_test tmp_test = new_test
new_test = "" new_test = ""
started = False started = False
for i in tmp_test: for i in tmp_test:
...@@ -133,7 +133,7 @@ def run_test(sample, test=None, debug=False): ...@@ -133,7 +133,7 @@ def run_test(sample, test=None, debug=False):
new_test += "def code():\n" new_test += "def code():\n"
new_test += i new_test += i
started = True started = True
elif started and ((i.startswith("from ")) or (i.startswith("import "))): elif started and ((i.startswith("from ")) or (i.startswith("import "))):
new_test += "\t" + i new_test += "\t" + i
else: else:
new_test += i new_test += i
...@@ -157,7 +157,7 @@ def run_test(sample, test=None, debug=False): ...@@ -157,7 +157,7 @@ def run_test(sample, test=None, debug=False):
signal.alarm(0) signal.alarm(0)
if debug: if debug:
print(f"get method = {datetime.now().time()}") print(f"get method = {datetime.now().time()}")
try: try:
method = getattr(tmp, method_name) # get_attr second arg must be str method = getattr(tmp, method_name) # get_attr second arg must be str
except: except:
...@@ -196,7 +196,7 @@ def run_test(sample, test=None, debug=False): ...@@ -196,7 +196,7 @@ def run_test(sample, test=None, debug=False):
# ground truth sequences are not tuples # ground truth sequences are not tuples
if isinstance(output, tuple): if isinstance(output, tuple):
output = list(output) output = list(output)
tmp_result = output == in_outs["outputs"][index] tmp_result = output == in_outs["outputs"][index]
if isinstance(in_outs["outputs"][index], list) and in_outs["outputs"][index]: if isinstance(in_outs["outputs"][index], list) and in_outs["outputs"][index]:
tmp_result = tmp_result or (output == in_outs["outputs"][index][0]) tmp_result = tmp_result or (output == in_outs["outputs"][index][0])
...@@ -278,7 +278,7 @@ def run_test(sample, test=None, debug=False): ...@@ -278,7 +278,7 @@ def run_test(sample, test=None, debug=False):
print(f"Failed check1 exception = {e}") print(f"Failed check1 exception = {e}")
pass pass
if tmp_result == True: if tmp_result == True:
results.append(tmp_result) results.append(tmp_result)
continue continue
...@@ -312,10 +312,10 @@ def run_test(sample, test=None, debug=False): ...@@ -312,10 +312,10 @@ def run_test(sample, test=None, debug=False):
if debug: if debug:
nl = "\n" nl = "\n"
if not isinstance(inputs, list): if not isinstance(inputs, list):
print(f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}") print(f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
else: else:
print(f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}") print(f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
if tmp_result == True: if tmp_result == True:
results.append(tmp_result) results.append(tmp_result)
continue continue
...@@ -350,9 +350,9 @@ def run_test(sample, test=None, debug=False): ...@@ -350,9 +350,9 @@ def run_test(sample, test=None, debug=False):
# try by converting the stuff into split up list # try by converting the stuff into split up list
if isinstance(in_outs["outputs"][index], list): if isinstance(in_outs["outputs"][index], list):
for tmp_index, i in enumerate(in_outs["outputs"][index]): for tmp_index, i in enumerate(in_outs["outputs"][index]):
in_outs["outputs"][index][tmp_index] = set(i.split()) in_outs["outputs"][index][tmp_index] = list(i.split())
else: else:
in_outs["outputs"][index] = set(in_outs["outputs"][index].split()) in_outs["outputs"][index] = list(in_outs["outputs"][index].split())
try: try:
tmp_result = (output == in_outs["outputs"][index]) tmp_result = (output == in_outs["outputs"][index])
...@@ -363,7 +363,7 @@ def run_test(sample, test=None, debug=False): ...@@ -363,7 +363,7 @@ def run_test(sample, test=None, debug=False):
if tmp_result == True: if tmp_result == True:
results.append(tmp_result) results.append(tmp_result)
continue continue
# try by converting the output into a split up list too # try by converting the output into a split up list too
if isinstance(output, list): if isinstance(output, list):
...@@ -371,14 +371,14 @@ def run_test(sample, test=None, debug=False): ...@@ -371,14 +371,14 @@ def run_test(sample, test=None, debug=False):
output[tmp_index] = i.split() output[tmp_index] = i.split()
output = list(filter(len, output)) output = list(filter(len, output))
for tmp_index, i in enumerate(output): for tmp_index, i in enumerate(output):
output[tmp_index] = set(i) output[tmp_index] = list(i)
else: else:
output = output.split() output = output.split()
output = list(filter(len, output)) output = list(filter(len, output))
output = set(output) output = list(output)
try: try:
tmp_result = (set(frozenset(s) for s in output) == set(frozenset(s) for s in in_outs["outputs"][index])) tmp_result = (list(list(s) for s in output) == list(list(s) for s in in_outs["outputs"][index]))
except Exception as e: except Exception as e:
if debug: if debug:
print(f"Failed check5 exception = {e}") print(f"Failed check5 exception = {e}")
...@@ -386,30 +386,30 @@ def run_test(sample, test=None, debug=False): ...@@ -386,30 +386,30 @@ def run_test(sample, test=None, debug=False):
# if they are all numbers, round so that similar numbers are treated as identical # if they are all numbers, round so that similar numbers are treated as identical
try: try:
tmp_result = tmp_result or (set(frozenset(round(float(t),3) for t in s) for s in output) ==\ tmp_result = tmp_result or (list(list(round(float(t),3) for t in s) for s in output) ==\
set(frozenset(round(float(t),3) for t in s) for s in in_outs["outputs"][index])) list(list(round(float(t),3) for t in s) for s in in_outs["outputs"][index]))
except Exception as e: except Exception as e:
if debug: if debug:
print(f"Failed check6 exception = {e}") print(f"Failed check6 exception = {e}")
if tmp_result == True and debug: if tmp_result == True and debug:
print("PASSED") print("PASSED")
results.append(tmp_result) results.append(tmp_result)
if debug: if debug:
nl = "\n" nl = "\n"
if not isinstance(inputs, list): if not isinstance(inputs, list):
print(f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}") print(f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
else: else:
print(f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}") print(f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
return results return results
def custom_compare_(output, ground_truth): def custom_compare_(output, ground_truth):
if isinstance(output, list): if isinstance(output, list):
output_1 = "\n".join(output) output_1 = "\n".join(output)
if stripped_string_compare(output_1, ground_truth): if stripped_string_compare(output_1, ground_truth):
...@@ -451,7 +451,7 @@ def call_method(method, inputs): ...@@ -451,7 +451,7 @@ def call_method(method, inputs):
pass pass
finally: finally:
pass pass
return _inner_call_method(method) return _inner_call_method(method)
......
...@@ -3,55 +3,56 @@ set -xe ...@@ -3,55 +3,56 @@ set -xe
model="/lustre/S/huangdi/open_for_out/models/Qwen2.5-Coder-7B-Instruct/" model="/lustre/S/huangdi/open_for_out/models/Qwen2.5-Coder-7B-Instruct/"
project="/lustre/S/nanziyuan/projects/ccc" project="/lustre/S/nanziyuan/projects/ccc"
modelname="qwen25_coder_inst" modelname="qwen25_coder_inst"
data="${project}/data"
trainset="${project}/data/train/${modelname}-apps-train.jsonl" trainset="${data}/train/${modelname}-apps-train.jsonl"
testset="${project}/data/test/${modelname}-apps-test.jsonl" testset="${data}/test/${modelname}-apps-test.jsonl"
train_selected_pairs="${project}/data/train/${modelname}-apps-train-selected_pairs.jsonl" train_selected_pairs="${data}/train/${modelname}-apps-train-selected_pairs.jsonl"
apps="/lustre/S/nanziyuan/datasets/apps/" apps="/lustre/S/nanziyuan/datasets/apps/"
sft="${project}/data/train/${modelname}-sft.jsonl" sft="${data}/train/${modelname}-sft.jsonl"
ftmodel="${project}/model/qwen25_coder_inst_7b-algolr" ftmodel="${project}/model/qwen25_coder_inst_7b-algolr"
testset="/lustre/S/nanziyuan/projects/ccc/data/test/qwen25_coder_inst-apps-test.jsonl" testset="${data}/test/qwen25_coder_inst-apps-test.jsonl"
evalresults="/lustre/S/nanziyuan/projects/ccc/data/eval/qwen25_code_inst-apps-test-algolr-score.jsonl" evalresults="${data}/eval/qwen25_code_inst-apps-test-algolr-score.jsonl"
# export CUDA_VISIBLE_DEVICES=0,1,2,3 # export CUDA_VISIBLE_DEVICES=0,1,2,3
# python -m codecritic.cli.algolr \ python -m codecritic.cli.algolr \
# --model ${model} \ --model ${model} \
# --dataset ${trainset} \ --dataset ${trainset} \
# --pairinfo ${train_selected_pairs} \ --pairinfo ${train_selected_pairs} \
# --apps ${apps} \ --apps ${apps} \
# --output ${sft} \ --output ${sft} \
# --level beginner \ --level beginner \
# --tp 1 --tp 1
# deepspeed --module \ deepspeed --module \
# openrlhf.cli.train_sft \ openrlhf.cli.train_sft \
# --max_len 4096 \ --max_len 4096 \
# --dataset ${sft} \ --dataset ${sft} \
# --input_key question \ --input_key question \
# --output_key response \ --output_key response \
# --apply_chat_template \ --apply_chat_template \
# --train_batch_size 256 \ --train_batch_size 256 \
# --micro_train_batch_size 2 \ --micro_train_batch_size 2 \
# --max_samples 500000 \ --max_samples 500000 \
# --pretrain ${model} \ --pretrain ${model} \
# --save_path ${ftmodel} \ --save_path ${ftmodel} \
# --save_steps -1 \ --save_steps -1 \
# --logging_steps 1 \ --logging_steps 1 \
# --eval_steps -1 \ --eval_steps -1 \
# --zero_stage 2 \ --zero_stage 2 \
# --max_epochs 1 \ --max_epochs 1 \
# --bf16 \ --bf16 \
# --flash_attn \ --flash_attn \
# --learning_rate 5e-6 \ --learning_rate 5e-6 \
# --load_checkpoint \ --load_checkpoint \
# --gradient_checkpointing \ --gradient_checkpointing \
# --use_tensorboard "${ftmodel}_log" --use_tensorboard "${ftmodel}_log"
python -m codecritic.cli.test_genrm \ python -m codecritic.cli.test_genrm \
......
...@@ -12,7 +12,7 @@ train_selected_pairs="${project}/data/train/${modelname}-apps-train-selected_pai ...@@ -12,7 +12,7 @@ train_selected_pairs="${project}/data/train/${modelname}-apps-train-selected_pai
reward_ds="${project}/data/train/${modelname}-apps-train-reward_dataset.jsonl" reward_ds="${project}/data/train/${modelname}-apps-train-reward_dataset.jsonl"
export CUDA_VISIBLE_DEVICES=0,1,2,3 # export CUDA_VISIBLE_DEVICES=0,1,2,3
## Sampling ## Sampling
## APPS ## APPS
......
...@@ -9,29 +9,29 @@ ftmodel="${project}/model/qwen25_coder_inst_7b-orm" ...@@ -9,29 +9,29 @@ ftmodel="${project}/model/qwen25_coder_inst_7b-orm"
testset="/lustre/S/nanziyuan/projects/ccc/data/test/qwen25_coder_inst-apps-test.jsonl" testset="/lustre/S/nanziyuan/projects/ccc/data/test/qwen25_coder_inst-apps-test.jsonl"
evalresults="/lustre/S/nanziyuan/projects/ccc/data/eval/qwen25_code_inst-apps-test-orm-score.jsonl" evalresults="/lustre/S/nanziyuan/projects/ccc/data/eval/qwen25_code_inst-apps-test-orm-score.jsonl"
# deepspeed --module \ deepspeed --module \
# openrlhf.cli.train_rm \ openrlhf.cli.train_rm \
# --save_path ${ftmodel} \ --save_path ${ftmodel} \
# --save_steps -1 \ --save_steps -1 \
# --logging_steps 1 \ --logging_steps 1 \
# --eval_steps -1 \ --eval_steps -1 \
# --train_batch_size 256 \ --train_batch_size 256 \
# --micro_train_batch_size 1 \ --micro_train_batch_size 1 \
# --pretrain ${model} \ --pretrain ${model} \
# --bf16 \ --bf16 \
# --max_epochs 1 \ --max_epochs 1 \
# --max_len 8192 \ --max_len 8192 \
# --zero_stage 3 \ --zero_stage 3 \
# --learning_rate 9e-6 \ --learning_rate 9e-6 \
# --dataset ${dataset} \ --dataset ${dataset} \
# --apply_chat_template \ --apply_chat_template \
# --prompt_key messages \ --prompt_key messages \
# --chosen_key chosen \ --chosen_key chosen \
# --rejected_key rejected \ --rejected_key rejected \
# --flash_attn \ --flash_attn \
# --load_checkpoint \ --load_checkpoint \
# --gradient_checkpointing \ --gradient_checkpointing \
# --use_tensorboard "${ftmodel}_log" --use_tensorboard "${ftmodel}_log"
start_server() { start_server() {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment