Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
codecritic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Ziyuan Nan
codecritic
Commits
22606212
Commit
22606212
authored
Oct 14, 2024
by
nzy
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
step3, 4: train & test critic model
parent
c401feaf
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
172 additions
and
10 deletions
+172
-10
example_config.toml
+22
-2
step1_sample_code.py
+2
-2
step3_train_critic_model.py
+60
-0
step3_train_outcome_reward_model.py
+3
-3
step4_test_critic_model.py
+24
-0
utils_vllm.py
+61
-3
No files found.
example_config.toml
View file @
22606212
...
...
@@ -46,4 +46,24 @@ train_yaml_path = ""
test_yaml_path
=
""
minimal_test_score_path
=
""
eval_result_path
=
""
deepspeed_cfg_path
=
""
\ No newline at end of file
deepspeed_cfg_path
=
""
[critic]
model_path
=
""
dataset_name
=
""
dataset_path
=
""
dataset_info_path
=
""
meta_data_path
=
""
[critic.train]
train_yaml_path
=
""
deepspeed_cfg_path
=
""
[critic.test]
reason_result_path
=
""
score_result_path
=
""
[critic.test.sampling_params]
n
=
1
temperature
=
0.0
max_new_tokens
=
512
\ No newline at end of file
step1_sample_code.py
View file @
22606212
from
utils_vllm
import
vllm_
inferenc
e
from
utils_vllm
import
vllm_
chatcomplet
e
from
utils
import
read_config
cfg
=
read_config
()
vllm_
inferenc
e
(
vllm_
chatcomplet
e
(
cfg
[
"model"
],
cfg
[
"sample"
][
"sample_prompt_path"
],
cfg
[
"sample"
][
"sample_result_path"
],
...
...
step3_train_critic_model.py
0 → 100644
View file @
22606212
from
utils
import
read_config
train_yaml
=
"""
\
### model
model_name_or_path: {model_path}
### method
stage: sft
do_train: true
finetuning_type: full
deepspeed: {deepspeed_config_path}
### dataset
dataset: {dataset_name}
template: deepseekcoder
cutoff_len: 4096
max_samples: 10000
overwrite_cache: true
preprocessing_num_workers: 16
### output
output_dir: {critic_model_path}
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 2
learning_rate: 1.0e-5
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500
"""
def
mk_llamafactory_sft_yaml
(
cfg
):
with
open
(
cfg
[
"critic"
][
"train"
][
"train_yaml_path"
],
"w"
)
as
f
:
train_str
=
train_yaml
.
format
(
model_path
=
cfg
[
"model"
],
deepspeed_config_path
=
cfg
[
"critic"
][
"train"
][
"deepspeed_cfg_path"
],
dataset_name
=
cfg
[
"critic"
][
"train"
][
"dataset_name"
],
critic_model_path
=
cfg
[
"critic"
][
"model_path"
],
)
f
.
write
(
train_str
)
if
__name__
==
"__main__"
:
cfg
=
read_config
()
mk_llamafactory_sft_yaml
(
cfg
)
\ No newline at end of file
step3_train_outcome_reward_model.py
View file @
22606212
...
...
@@ -48,7 +48,7 @@ template: {model_template}
stage: rm
"""
def
mk_llamafactory_
config
_yaml
(
cfg
):
def
mk_llamafactory_
orm
_yaml
(
cfg
):
orm_dataset
=
cfg
[
"orm_dataset"
]
orm_cfg
=
cfg
[
"orm"
][
orm_dataset
]
data_cfg
=
cfg
[
"preference_dataset"
][
orm_dataset
]
...
...
@@ -73,4 +73,4 @@ def mk_llamafactory_config_yaml(cfg):
if
__name__
==
"__main__"
:
cfg
=
read_config
([
"orm_dataset"
])
mk_llamafactory_config_yaml
(
cfg
)
\ No newline at end of file
mk_llamafactory_orm_yaml
(
cfg
)
\ No newline at end of file
step4_test_critic_model.py
0 → 100644
View file @
22606212
from
utils_vllm
import
vllm_chatcomplete
,
vllm_score
from
utils
import
read_config
from
transformers
import
AutoTokenizer
cfg
=
read_config
()
vllm_chatcomplete
(
cfg
[
"critic"
][
"model_path"
],
cfg
[
"dataset"
][
"minimal_test_path"
],
cfg
[
"critic"
][
"test"
][
"reason_result_path"
],
cfg
[
"critic"
][
"test"
][
"sampling_params"
]
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
cfg
[
"model"
])
score_tokens
=
tokenizer
.
encode
(
"Yes"
)
assert
len
(
score_tokens
)
==
1
score_token
=
score_tokens
[
0
]
vllm_score
(
cfg
[
"critic"
][
"model_path"
],
cfg
[
"critic"
][
"test"
][
"reson_result_path"
],
cfg
[
"critic"
][
"test"
][
"score_result_path"
],
score_token
)
\ No newline at end of file
utils_vllm.py
View file @
22606212
...
...
@@ -8,7 +8,7 @@ from functools import partial
from
utils
import
load_jsonl
,
save_jsonl
def
worker
(
cuda_device
,
prompts
,
model_path
,
sampling_params
):
def
generate_
worker
(
cuda_device
,
prompts
,
model_path
,
sampling_params
):
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
cuda_device
llm
=
LLM
(
model
=
model_path
,
seed
=
42
,
max_model_len
=
8
*
1024
,
swap_space
=
16
)
...
...
@@ -41,7 +41,41 @@ def worker(cuda_device, prompts, model_path, sampling_params):
return
result
def
vllm_inference
(
model_path
,
prompt_path
,
output_path
,
sampling_params
):
def
score_worker
(
cuda_device
,
prompts
,
model_path
,
score_token
):
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
cuda_device
llm
=
LLM
(
model
=
model_path
,
seed
=
42
,
max_model_len
=
8
*
1024
,
swap_space
=
16
)
tokenizer
=
llm
.
get_tokenizer
()
stop_tokens
=
[
tokenizer
.
eos_token_id
,
tokenizer
.
convert_tokens_to_ids
(
"<|eot_id|>"
)]
print
(
f
"SUCCESS: load llm {model_path} on cuda {cuda_device}"
)
vllm_sampling_params
=
SamplingParams
(
n
=
1
,
temperature
=
0
,
max_tokens
=
1
,
logprobs
=
1000
)
text_prompts
=
[
tokenizer
.
apply_chat_template
(
item
[
"messages"
],
tokenize
=
False
,
add_generation_prompt
=
True
)
for
item
in
prompts
]
outputs
=
llm
.
generate
(
text_prompts
,
sampling_params
=
vllm_sampling_params
,
use_tqdm
=
False
)
result
=
[]
for
item
,
output
in
zip
(
prompts
,
outputs
):
for
response
in
output
.
outputs
:
# response.logprobs: list[dict[int, Logprob]] https://github.com/vllm-project/vllm/blob/main/vllm/sequence.py
sample_logprobs
=
response
.
logprobs
logprob
=
sample_logprobs
[
0
]
.
get
(
score_token
)
newitem
=
item
.
copy
()
if
logprob
:
newitem
[
"score"
]
=
logprob
.
logprob
else
:
newitem
[
"score"
]
=
0
result
.
append
(
newitem
)
return
result
def
vllm_chatcomplete
(
model_path
,
prompt_path
,
output_path
,
sampling_params
):
prompts
=
load_jsonl
(
prompt_path
)
# Respect the slurm's gpu allocation
...
...
@@ -54,7 +88,7 @@ def vllm_inference(model_path, prompt_path, output_path, sampling_params):
sub_prompts
[
i
%
gpu_num
]
.
append
(
prompt
)
args
=
list
(
zip
(
cuda_devices
,
sub_prompts
))
worker_llm
=
partial
(
worker
,
model_path
=
model_path
,
sampling_params
=
sampling_params
)
worker_llm
=
partial
(
generate_
worker
,
model_path
=
model_path
,
sampling_params
=
sampling_params
)
with
multiprocessing
.
Pool
(
gpu_num
)
as
pool
:
nested_results
=
pool
.
starmap
(
worker_llm
,
args
)
...
...
@@ -62,3 +96,26 @@ def vllm_inference(model_path, prompt_path, output_path, sampling_params):
results
=
list
(
chain
(
*
nested_results
))
print
(
f
"size of dataset: {len(results)}"
)
save_jsonl
(
results
,
output_path
)
def
vllm_score
(
model_path
,
prompt_path
,
output_path
,
score_token
):
prompts
=
load_jsonl
(
prompt_path
)
# Respect the slurm's gpu allocation
cuda_devices
=
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
.
split
(
','
)
gpu_num
=
len
(
cuda_devices
)
# split data
sub_prompts
=
[[]
for
_
in
range
(
gpu_num
)]
for
i
,
prompt
in
enumerate
(
prompts
):
sub_prompts
[
i
%
gpu_num
]
.
append
(
prompt
)
args
=
list
(
zip
(
cuda_devices
,
sub_prompts
))
worker_llm
=
partial
(
score_worker
,
model_path
=
model_path
,
score_token
=
score_token
)
with
multiprocessing
.
Pool
(
gpu_num
)
as
pool
:
nested_results
=
pool
.
starmap
(
worker_llm
,
args
)
results
=
list
(
chain
(
*
nested_results
))
print
(
f
"size of dataset: {len(results)}"
)
save_jsonl
(
results
,
output_path
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment