Commit 68a8256f by ZhangXiaoyun

fix bug

parent fa5864b0
......@@ -80,9 +80,9 @@ ulimit -u 4125556
cd /nfs_global/S/zhangxiaoyun/prm/openr
export PYTHONPATH=$(pwd)
bash reason/llm_service/create_service_qwen2.5_math_vllm_gold_prm_speed.sh --acc 1.0 --policy_model_name Qwen2.5-Math-7B-Instruct
bash reason/llm_service/create_service_qwen2.5_math_vllm_gold_prm_speed.sh --acc 1.0 --policy_model_name Qwen2.5-Math-1.5B-Instruct
sleep 30s
bash scripts/eval/beam_search.sh --acc 1.0 --policy_model_name Qwen2.5-Math-7B-Instruct
bash scripts/eval/beam_search.sh --acc 1.0 --policy_model_name Qwen2.5-Math-1.5B-Instruct
#- End
echo "Job end at $(date "+%Y-%m-%d %H:%M:%S")"
......@@ -83,7 +83,7 @@ def _qwen_math_gold_infer_fn(input_str: str, model, tokenizer, device, acc):
inputs = [steps[0]]
for i in range(1, len(steps)):
if not steps[i].isspace():
inputs.append(inputs[1] + steps[i])
inputs.append(inputs[-1] + steps[i])
steps_num = len(inputs)
# print("-----------------------------")
# print("steps:", steps)
......@@ -114,14 +114,14 @@ def _qwen_math_gold_infer_fn(input_str: str, model, tokenizer, device, acc):
step_score = 1 - step_score
step_scores.append(step_score)
inputs = inputs[len(step_scores):]
# 如果前序是0,那就不用再rollout了
if len(step_scores) != 0 and step_scores[-1] == 0:
step_scores.extend([0] * (len(inputs) - len(step_scores)))
assert steps_num == len(step_scores), f"{steps_num} != {len(step_scores)}"
return step_scores
inputs = inputs[len(step_scores):]
global lock
with lock:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment