Commit 22606212 by nzy

step3, 4: train & test critic model

parent c401feaf
......@@ -46,4 +46,24 @@ train_yaml_path = ""
test_yaml_path = ""
minimal_test_score_path = ""
eval_result_path = ""
deepspeed_cfg_path = ""
\ No newline at end of file
deepspeed_cfg_path = ""
[critic]
model_path = ""
dataset_name = ""
dataset_path = ""
dataset_info_path = ""
meta_data_path = ""
[critic.train]
train_yaml_path = ""
deepspeed_cfg_path = ""
[critic.test]
reason_result_path = ""
score_result_path = ""
[critic.test.sampling_params]
n = 1
temperature = 0.0
max_new_tokens = 512
\ No newline at end of file
from utils_vllm import vllm_inference
from utils_vllm import vllm_chatcomplete
from utils import read_config
cfg = read_config()
vllm_inference(
vllm_chatcomplete(
cfg["model"],
cfg["sample"]["sample_prompt_path"],
cfg["sample"]["sample_result_path"],
......
from utils import read_config
train_yaml = """\
### model
model_name_or_path: {model_path}
### method
stage: sft
do_train: true
finetuning_type: full
deepspeed: {deepspeed_config_path}
### dataset
dataset: {dataset_name}
template: deepseekcoder
cutoff_len: 4096
max_samples: 10000
overwrite_cache: true
preprocessing_num_workers: 16
### output
output_dir: {critic_model_path}
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 2
learning_rate: 1.0e-5
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500
"""
def mk_llamafactory_sft_yaml(cfg):
with open(cfg["critic"]["train"]["train_yaml_path"], "w") as f:
train_str = train_yaml.format(
model_path=cfg["model"],
deepspeed_config_path=cfg["critic"]["train"]["deepspeed_cfg_path"],
dataset_name=cfg["critic"]["train"]["dataset_name"],
critic_model_path=cfg["critic"]["model_path"],
)
f.write(train_str)
if __name__ == "__main__":
cfg = read_config()
mk_llamafactory_sft_yaml(cfg)
\ No newline at end of file
......@@ -48,7 +48,7 @@ template: {model_template}
stage: rm
"""
def mk_llamafactory_config_yaml(cfg):
def mk_llamafactory_orm_yaml(cfg):
orm_dataset = cfg["orm_dataset"]
orm_cfg = cfg["orm"][orm_dataset]
data_cfg = cfg["preference_dataset"][orm_dataset]
......@@ -73,4 +73,4 @@ def mk_llamafactory_config_yaml(cfg):
if __name__ == "__main__":
cfg = read_config(["orm_dataset"])
mk_llamafactory_config_yaml(cfg)
\ No newline at end of file
mk_llamafactory_orm_yaml(cfg)
\ No newline at end of file
from utils_vllm import vllm_chatcomplete, vllm_score
from utils import read_config
from transformers import AutoTokenizer
cfg = read_config()
vllm_chatcomplete(
cfg["critic"]["model_path"],
cfg["dataset"]["minimal_test_path"],
cfg["critic"]["test"]["reason_result_path"],
cfg["critic"]["test"]["sampling_params"]
)
tokenizer = AutoTokenizer.from_pretrained(cfg["model"])
score_tokens = tokenizer.encode("Yes")
assert len(score_tokens) == 1
score_token = score_tokens[0]
vllm_score(
cfg["critic"]["model_path"],
cfg["critic"]["test"]["reson_result_path"],
cfg["critic"]["test"]["score_result_path"],
score_token
)
\ No newline at end of file
......@@ -8,7 +8,7 @@ from functools import partial
from utils import load_jsonl, save_jsonl
def worker(cuda_device, prompts, model_path, sampling_params):
def generate_worker(cuda_device, prompts, model_path, sampling_params):
os.environ["CUDA_VISIBLE_DEVICES"] = cuda_device
llm = LLM(model=model_path, seed=42, max_model_len=8 * 1024, swap_space=16)
......@@ -41,7 +41,41 @@ def worker(cuda_device, prompts, model_path, sampling_params):
return result
def vllm_inference(model_path, prompt_path, output_path, sampling_params):
def score_worker(cuda_device, prompts, model_path, score_token):
os.environ["CUDA_VISIBLE_DEVICES"] = cuda_device
llm = LLM(model=model_path, seed=42, max_model_len=8 * 1024, swap_space=16)
tokenizer = llm.get_tokenizer()
stop_tokens = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>")]
print(f"SUCCESS: load llm {model_path} on cuda {cuda_device}")
vllm_sampling_params = SamplingParams(
n=1,
temperature=0,
max_tokens=1,
logprobs=1000
)
text_prompts = [tokenizer.apply_chat_template(item["messages"], tokenize=False, add_generation_prompt=True) for item in prompts]
outputs = llm.generate(text_prompts, sampling_params=vllm_sampling_params, use_tqdm=False)
result = []
for item, output in zip(prompts, outputs):
for response in output.outputs:
# response.logprobs: list[dict[int, Logprob]] https://github.com/vllm-project/vllm/blob/main/vllm/sequence.py
sample_logprobs = response.logprobs
logprob = sample_logprobs[0].get(score_token)
newitem = item.copy()
if logprob:
newitem["score"] = logprob.logprob
else:
newitem["score"] = 0
result.append(newitem)
return result
def vllm_chatcomplete(model_path, prompt_path, output_path, sampling_params):
prompts = load_jsonl(prompt_path)
# Respect the slurm's gpu allocation
......@@ -54,7 +88,7 @@ def vllm_inference(model_path, prompt_path, output_path, sampling_params):
sub_prompts[i % gpu_num].append(prompt)
args = list(zip(cuda_devices, sub_prompts))
worker_llm = partial(worker, model_path=model_path, sampling_params=sampling_params)
worker_llm = partial(generate_worker, model_path=model_path, sampling_params=sampling_params)
with multiprocessing.Pool(gpu_num) as pool:
nested_results = pool.starmap(worker_llm, args)
......@@ -62,3 +96,26 @@ def vllm_inference(model_path, prompt_path, output_path, sampling_params):
results = list(chain(*nested_results))
print(f"size of dataset: {len(results)}")
save_jsonl(results, output_path)
def vllm_score(model_path, prompt_path, output_path, score_token):
prompts = load_jsonl(prompt_path)
# Respect the slurm's gpu allocation
cuda_devices = os.environ["CUDA_VISIBLE_DEVICES"].split(',')
gpu_num = len(cuda_devices)
# split data
sub_prompts = [[] for _ in range(gpu_num)]
for i, prompt in enumerate(prompts):
sub_prompts[i % gpu_num].append(prompt)
args = list(zip(cuda_devices, sub_prompts))
worker_llm = partial(score_worker, model_path=model_path, score_token=score_token)
with multiprocessing.Pool(gpu_num) as pool:
nested_results = pool.starmap(worker_llm, args)
results = list(chain(*nested_results))
print(f"size of dataset: {len(results)}")
save_jsonl(results, output_path)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment