Commit 708b9eee by ZhangXiaoyun

cancel rollout

parent d4c74075
...@@ -98,8 +98,10 @@ def _qwen_math_gold_infer_fn(input_str: str, model, tokenizer, device, acc): ...@@ -98,8 +98,10 @@ def _qwen_math_gold_infer_fn(input_str: str, model, tokenizer, device, acc):
# "stop_token_ids": [151643, 151643], # "stop_token_ids": [151643, 151643],
} }
sampling_params = SamplingParams(**inference_params) sampling_params = SamplingParams(**inference_params)
step_scores = []
# 从redis中获取所有已经计算过的step_score
step_scores = []
for i in range(len(inputs)): for i in range(len(inputs)):
step_score = get_shared_value(inputs[i]) step_score = get_shared_value(inputs[i])
if step_score is None: if step_score is None:
...@@ -112,6 +114,12 @@ def _qwen_math_gold_infer_fn(input_str: str, model, tokenizer, device, acc): ...@@ -112,6 +114,12 @@ def _qwen_math_gold_infer_fn(input_str: str, model, tokenizer, device, acc):
inputs = inputs[len(step_scores):] inputs = inputs[len(step_scores):]
# 如果前序是0,那就不用再rollout了
if len(step_scores) != 0 and step_scores[-1] == 0:
step_scores.extend([0] * (len(inputs) - len(step_scores)))
return step_scores
global lock global lock
with lock: with lock:
outputs = model.generate(inputs, sampling_params) outputs = model.generate(inputs, sampling_params)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment