Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
codecritic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Ziyuan Nan
codecritic
Commits
f3dd6691
Commit
f3dd6691
authored
Jan 05, 2025
by
nanziyuan
Browse files
Options
Browse Files
Download
Plain Diff
merge conflicts
parents
9af763c6
79b13ce2
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
179 additions
and
104 deletions
+179
-104
README.md
+1
-1
codecritic/cli/reformat.py
+59
-0
codecritic/cli/select_preference_pairs.py
+117
-0
codecritic/cli/test_genrm.py
+2
-3
codecritic/dataset/edit_distance.py
+0
-100
No files found.
README.md
View file @
f3dd6691
...
...
@@ -4,7 +4,7 @@
```
pip install scikit-learn
pip install
pip install
nltk
```
## Evaluation
...
...
codecritic/cli/reformat.py
View file @
f3dd6691
import
argparse
from
itertools
import
product
,
chain
from
codecritic.utils.json
import
load_jsonl
,
save_jsonl
def
mk_preference_pair
(
ds
,
pair
):
task_id
=
pair
[
"task_id"
]
chosen
=
ds
[
task_id
][
pair
[
"chosen"
]]
rejected
=
ds
[
task_id
][
pair
[
"rejected"
]]
return
{
"messages"
:
chosen
[
"messages"
][:
1
],
"chosen"
:
chosen
[
"messages"
][
1
:],
"rejected"
:
rejected
[
"messages"
][
1
:],
"meta_pairinfo"
:
pair
}
def
mk_sft
(
ds
,
pair
):
dataset_name
=
pair
[
"dataset"
]
task_id
=
pair
[
"task_id"
]
chosen
=
ds
[
task_id
][
pair
[
"chosen"
]]
rejected
=
ds
[
task_id
][
pair
[
"rejected"
]]
# TODO add judgement response
return
[
{
"question"
:
chosen
[
"messages"
][:
1
],
"response"
:
chosen
[
"messages"
][
1
:],
"dataset"
:
dataset_name
,
"task_id"
:
task_id
,
"solution_id"
:
chosen
},
{
"question"
:
rejected
[
"messages"
][:
1
],
"response"
:
rejected
[
"messages"
][
1
:],
"dataset"
:
dataset_name
,
"task_id"
:
task_id
,
"solution_id"
:
rejected
}
]
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--dataset"
,
type
=
str
,
help
=
"path/to/dataset"
)
parser
.
add_argument
(
"--pairs"
,
type
=
str
,
help
=
"path/to/selected_pairs"
)
parser
.
add_argument
(
"--output"
,
type
=
str
,
help
=
"path/to/output"
)
args
=
parser
.
parse_args
()
dataset
=
load_jsonl
(
args
.
dataset
)
selected_pairs
=
load_jsonl
(
args
.
pairs
)
if
args
.
format
==
"sft"
:
sft_ds
=
list
(
chain
.
from_iterable
([
mk_sft
(
dataset
,
pair
)
for
pair
in
selected_pairs
]))
save_jsonl
(
sft_ds
,
args
.
output
)
elif
args
.
format
==
"reward"
:
reward_ds
=
[
mk_preference_pair
(
dataset
,
pair
)
for
pair
in
selected_pairs
]
save_jsonl
(
reward_ds
,
args
.
output
)
else
:
raise
NotImplementedError
(
f
"Unknown format: {args.format}"
)
codecritic/cli/select_preference_pairs.py
0 → 100644
View file @
f3dd6691
import
argparse
from
collections
import
defaultdict
from
itertools
import
product
,
chain
from
pathlib
import
Path
import
os
import
re
from
tqdm.contrib.concurrent
import
process_map
from
rapidfuzz
import
fuzz
from
codecritic.utils.json
import
load_jsonl
,
save_jsonl
def
group_and_filter
(
dataset
):
grouped
=
defaultdict
(
list
)
for
sample
in
dataset
:
grouped
[
sample
[
"task_id"
]]
.
append
(
sample
)
# filter groups passed/failed all testcase
for
task_id
,
group
in
grouped
.
items
():
passes
=
{
x
[
"pass"
]
for
x
in
group
}
if
len
(
passes
)
==
2
:
yield
group
# Precompile regular expressions
SINGLE_LINE_COMMENT_REGEX
=
re
.
compile
(
r'#.*'
)
MULTILINE_DOUBLE_QUOTE_REGEX
=
re
.
compile
(
r'^\s*""".*?"""\s*$'
,
flags
=
re
.
DOTALL
|
re
.
MULTILINE
)
MULTILINE_SINGLE_QUOTE_REGEX
=
re
.
compile
(
r"^\s*'''.*?'''\s*$"
,
flags
=
re
.
DOTALL
|
re
.
MULTILINE
)
def
preprocess_code
(
code
):
# Remove single-line comments
code
=
SINGLE_LINE_COMMENT_REGEX
.
sub
(
''
,
code
)
# Remove standalone docstrings (triple-quoted strings that are not part of an expression)
code
=
MULTILINE_DOUBLE_QUOTE_REGEX
.
sub
(
''
,
code
)
code
=
MULTILINE_SINGLE_QUOTE_REGEX
.
sub
(
''
,
code
)
# Remove blank lines
code
=
"
\n
"
.
join
([
line
for
line
in
code
.
splitlines
()
if
line
.
strip
()])
return
code
def
compute_pair_similarity
(
group
):
correct_code_set
,
incorrect_code_set
=
set
(),
{
''
}
correct_samples
,
incorrect_samples
=
[],
[]
assert
len
(
group
)
>
0
dataset_name
=
group
[
0
][
"dataset"
]
task_id
=
group
[
0
][
"task_id"
]
for
sample
in
group
:
code
=
preprocess_code
(
sample
[
"code"
])
item
=
{
"solution_id"
:
sample
[
"solution_id"
],
"code"
:
code
}
if
sample
[
"pass"
]
and
(
code
not
in
correct_code_set
):
correct_samples
.
append
(
item
)
elif
(
not
sample
[
"pass"
])
and
(
code
not
in
incorrect_code_set
):
incorrect_samples
.
append
(
item
)
results
=
[]
for
correct
,
incorrect
in
product
(
correct_samples
,
incorrect_samples
):
score
=
fuzz
.
ratio
(
correct
[
"code"
],
incorrect
[
"code"
])
results
.
append
({
"dataset"
:
dataset_name
,
"task_id"
:
task_id
,
"chosen"
:
correct
[
"solution_id"
],
"rejected"
:
incorrect
[
"solution_id"
],
"similarity"
:
score
})
return
results
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--dataset"
,
type
=
str
,
help
=
"path/to/dataset"
)
parser
.
add_argument
(
"--output"
,
type
=
str
,
help
=
"path/to/output"
)
args
=
parser
.
parse_args
()
dataset
=
load_jsonl
(
args
.
dataset
)
cache_path
=
Path
(
args
.
dataset
+
".pairinfo"
)
if
not
cache_path
.
exists
():
results
=
process_map
(
compute_pair_similarity
,
group_and_filter
(
dataset
),
max_workers
=
os
.
cpu_count
(),
chunksize
=
1
,
)
pairinfo
=
list
(
chain
.
from_iterable
(
results
))
save_jsonl
(
pairinfo
,
cache_path
)
else
:
pairinfo
=
load_jsonl
(
cache_path
)
print
(
f
"load cached similarity information from {cache_path}"
)
# select pairs
ds
=
defaultdict
(
dict
)
for
item
in
dataset
:
ds
[
item
[
"task_id"
]][
item
[
"solution_id"
]]
=
item
sorted_pairinfo
=
sorted
(
pairinfo
,
key
=
lambda
x
:
x
[
"similarity"
])
task_groups
=
defaultdict
(
list
)
for
item
in
pairinfo
:
task_groups
[
item
[
"task_id"
]]
.
append
(
item
)
# Step 2: Select the 4 pairs with the smallest score for each task
selected_pairs
=
[]
for
task
,
items
in
task_groups
.
items
():
# Sort items for this task by score and select the top 4
sorted_items
=
sorted
(
items
,
key
=
lambda
x
:
x
[
"similarity"
])[:
4
]
selected_pairs
.
extend
(
sorted_items
)
save_jsonl
(
selected_pairs
,
args
.
output
)
codecritic/cli/test_genrm.py
View file @
f3dd6691
...
...
@@ -5,8 +5,7 @@ import os
from
transformers
import
AutoTokenizer
from
vllm
import
SamplingParams
from
codecritic.dataset.genrm_prompt
import
JUDGE_MESSAGE
,
JUDGE_TOEKNS
from
codecritic.dataset.legacy_genrm_prompt
import
COV_MESSAGE
from
codecritic.dataset.genrm_prompt
import
THINK_MESSAGE
,
JUDGE_MESSAGE
,
JUDGE_TOEKNS
from
codecritic.utils.inference
import
generate_worker
,
score_worker
from
codecritic.utils.parallel
import
model_map
from
codecritic.utils.json
import
load_jsonl
,
save_jsonl
...
...
@@ -36,7 +35,7 @@ if __name__ == "__main__":
if
args
.
reasoning
:
for
item
in
dataset
:
item
[
"messages"
]
.
append
(
COV
_MESSAGE
)
item
[
"messages"
]
.
append
(
THINK
_MESSAGE
)
sampling_params
=
SamplingParams
(
n
=
1
,
...
...
codecritic/dataset/edit_distance.py
deleted
100644 → 0
View file @
9af763c6
from
codecritic.utils.json
import
load_jsonl
from
codecritic.dataset.code
import
extract_code
,
code_template
from
nltk.metrics.distance
import
edit_distance
from
collections
import
defaultdict
from
itertools
import
product
,
chain
import
multiprocessing
from
tqdm.contrib.concurrent
import
process_map
def
mk_preference_pair
(
instruction
,
chosen_code
,
rejected_code
):
return
{
"messages"
:
[
{
"role"
:
"user"
,
"content"
:
instruction
},
],
"chosen"
:
{
"role"
:
"assistant"
,
"content"
:
code_template
.
format
(
chosen_code
)},
"rejected"
:
{
"role"
:
"assistant"
,
"content"
:
code_template
.
format
(
rejected_code
),
},
}
def
mk_problem_groups
(
train_dataset_path
,
n
):
train_dataset
=
load_jsonl
(
train_dataset_path
)
assert
len
(
train_dataset
)
%
n
==
0
problems
=
[]
for
i
in
range
(
len
(
train_dataset
)
//
n
):
problem
=
train_dataset
[
i
*
n
:
(
i
+
1
)
*
n
]
problem_id
=
problem
[
0
][
"problem_id"
]
eval_results
=
[
d
[
"eval_result"
]
for
d
in
problem
]
# filter all passed/failed problems
if
True
in
eval_results
and
False
in
eval_results
:
instruction
=
problem
[
0
][
"messages"
][
0
][
"content"
]
correct_codes
,
incorrect_codes
=
[],
[]
for
d
in
problem
:
assert
d
[
"problem_id"
]
==
problem_id
,
"dataset is not sorted"
code
=
extract_code
(
d
[
"messages"
][
1
][
"content"
])
if
d
[
"eval_result"
]:
correct_codes
.
append
(
code
)
else
:
incorrect_codes
.
append
(
code
)
problems
.
append
(
dict
(
problem_id
=
problem_id
,
instruction
=
instruction
,
correct_codes
=
correct_codes
,
incorrect_codes
=
incorrect_codes
,
)
)
return
problems
def
calculate_edit_distances_for_problem
(
problem
):
local_pairs
=
[]
for
pair
in
product
(
problem
[
"correct_codes"
],
problem
[
"incorrect_codes"
]):
# transform incorrect code to correct code
distance
=
edit_distance
(
pair
[
1
],
pair
[
0
])
local_pairs
.
append
(
(
distance
,
problem
[
"problem_id"
],
problem
[
"instruction"
],
pair
)
)
return
local_pairs
def
calculate_edit_distances
(
problems
):
cpu_num
=
multiprocessing
.
cpu_count
()
results
=
process_map
(
calculate_edit_distances_for_problem
,
problems
,
max_workers
=
cpu_num
,
chunksize
=
32
,
)
return
list
(
chain
.
from_iterable
(
results
))
def
mk_edit_distance_dataset
(
all_pairs
,
k
,
n
,
is_max
=
True
):
"""
Top-k pairs with the maximum/minimum edit distance.
Each problem can contribute no more than n pairs.
Each code snippet can be used only once.
"""
all_pairs
.
sort
(
reverse
=
is_max
,
key
=
lambda
x
:
x
[
0
])
code_usages
=
defaultdict
(
set
)
problem_contributions
=
defaultdict
(
int
)
preference_pairs
,
pairs_metadata
=
[],
[]
for
distance
,
problem_id
,
instr
,
pair
in
all_pairs
:
if
len
(
preference_pairs
)
>=
k
:
break
is_code_used
=
(
pair
[
0
]
in
code_usages
[
problem_id
])
or
(
pair
[
1
]
in
code_usages
[
problem_id
]
)
if
not
is_code_used
and
problem_contributions
[
problem_id
]
<
n
:
code_usages
[
problem_id
]
.
update
(
pair
)
problem_contributions
[
problem_id
]
+=
1
preference_pairs
.
append
(
mk_preference_pair
(
instr
,
pair
[
0
],
pair
[
1
]))
pairs_metadata
.
append
(
dict
(
problem_id
=
problem_id
,
edit_distance
=
distance
))
return
preference_pairs
,
pairs_metadata
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment