Commit ec19c45c by Yaoyu Zhu

fix some bugs in reward computation and huge logging, but some problems still exist

parent 037861b1
......@@ -106,7 +106,7 @@ ssh -L 8266:localhost:49931 zhuyaoyu@r8l40-a02 -N
发现是卡在了`verl/workers/reward_manager/prime.py``single_compute_score`里面,应该是pickle的东西太大了(最大的在`29M`左右),这个是里面有一些编程竞赛题IO有几万几十万条数据撑起来的。
**解决方案***完全没搞懂原理*)在`single_compute_score`前面加个pickle作为缓冲就不报错了,连带下面的are_equal_under_sympy超时也不报错了,非常离谱
**解决方案**之前是在`single_compute_score`前面加上pickle+print作为缓冲,后来发现一个更优雅一点的方法:重载`ProcessPoolExecutor`以及时终止超时的任务,这样好像效果更好点,体现为准确率比之前高
#### are_equal_under_sympy 超时
......@@ -116,9 +116,8 @@ ssh -L 8266:localhost:49931 zhuyaoyu@r8l40-a02 -N
跑的时候遇到了以`"[ASSESS]\n\n`开头的一大坨log输出。**定位某段log是哪段代码输出的方法如下:**
1. 首先重定向sandbox的out和err到不同的文件,确认这一坨东西是verl里面输出的。
2. 重载logging和print函数,使日志文件里面加上文件名和行号等信息。
3. 最后定位出来是`verl/workers/reward_manager/prime.py`的打印`num_examine`里面的`print(sequences_str)`这句话。
1. 重载logging和print函数,使日志文件里面加上文件名和行号等信息。可以在SandboxFusion里面也重载一遍。
2. 最后定位出来是**`verl/workers/reward_manager/prime.py`的打印`num_examine`里面的`print(sequences_str)`**这句话打印了一坨输出。此外,`verl/utils/reward_score/sandbox_fusion/utils.py`里面那句`logger.error(f"Case {case_index}: input: {str(stdin_data)}")`在IO很大的时候也会打印大量输出。把这两个改掉就好。
重载部分的代码如下:
......@@ -192,6 +191,8 @@ builtins.print = traced_print
#### 最后贴一个(应该是)正常的初始测试集准确率
**警告:这个结果可能被sympy超时影响导致acc偏低!!!**
```bash
Initial validation metrics: {'val-aux/numina_synthetic_math/reward/mean@1':
0.19642857142857142, 'val-core/numina_synthetic_math/reward/mean@2': 0.5,
......
......@@ -14,9 +14,14 @@
import logging
import sys
import structlog
import os
os.environ["VERL_LIGHT_INIT"] = "1"
import verl
root_logger = logging.getLogger()
root_logger.setLevel(logging.WARNING) # 设置全局日志级别
from sandbox.configs.run_config import RunConfig
config = RunConfig.get_instance_sync()
......
......@@ -19,40 +19,129 @@ import pkg_resources
from packaging.version import parse as parse_version
from pkg_resources import DistributionNotFound
from .protocol import DataProto
from .utils.device import is_npu_available
from .utils.logging_utils import set_basic_config
if os.environ.get("VERL_LIGHT_INIT", "0") != "1":
from .protocol import DataProto
from .utils.device import is_npu_available
from .utils.logging_utils import set_basic_config
version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__)))
version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__)))
with open(os.path.join(version_folder, "version/version")) as f:
__version__ = f.read().strip()
with open(os.path.join(version_folder, "version/version")) as f:
__version__ = f.read().strip()
set_basic_config(level=logging.WARNING)
set_basic_config(level=logging.WARNING)
__all__ = ["DataProto", "__version__"]
__all__ = ["DataProto", "__version__"]
if os.getenv("VERL_USE_MODELSCOPE", "False").lower() == "true":
import importlib
if os.getenv("VERL_USE_MODELSCOPE", "False").lower() == "true":
import importlib
if importlib.util.find_spec("modelscope") is None:
raise ImportError("You are using the modelscope hub, please install modelscope by `pip install modelscope -U`")
# Patch hub to download models from modelscope to speed up.
from modelscope.utils.hf_util import patch_hub
if importlib.util.find_spec("modelscope") is None:
raise ImportError("You are using the modelscope hub, please install modelscope by `pip install modelscope -U`")
# Patch hub to download models from modelscope to speed up.
from modelscope.utils.hf_util import patch_hub
patch_hub()
patch_hub()
if is_npu_available:
package_name = "transformers"
required_version_spec = "4.51.0"
try:
installed_version = pkg_resources.get_distribution(package_name).version
installed = parse_version(installed_version)
required = parse_version(required_version_spec)
if is_npu_available:
package_name = "transformers"
required_version_spec = "4.51.0"
try:
installed_version = pkg_resources.get_distribution(package_name).version
installed = parse_version(installed_version)
required = parse_version(required_version_spec)
if not installed >= required:
raise ValueError(f"{package_name} version >= {required_version_spec} is required on ASCEND NPU, current version is {installed}.")
except DistributionNotFound as e:
raise ImportError(f"package {package_name} is not installed, please run pip install {package_name}=={required_version_spec}") from e
if not installed >= required:
raise ValueError(f"{package_name} version >= {required_version_spec} is required on ASCEND NPU, current version is {installed}.")
except DistributionNotFound as e:
raise ImportError(f"package {package_name} is not installed, please run pip install {package_name}=={required_version_spec}") from e
# put this at program entry (ideally before you create ProcessPoolExecutor)
import pickle, sys, os, inspect, logging, pprint, builtins, traceback
def remove_useless_path(path):
def setup_enhanced_logging():
# 1. 增强日志记录工厂,添加文件和行号信息
original_factory = logging.getLogRecordFactory()
def record_factory(*args, **kwargs):
record = original_factory(*args, **kwargs)
# 深度回溯找到实际业务代码位置(跳过框架调用)
frame = inspect.currentframe()
depth = 0
while frame and depth < 15: # 增加深度,确保跳过框架代码
frame = frame.f_back
depth += 1
if frame and "/site-packages/" not in (frame.f_code.co_filename or "") and "/lib/python" not in (frame.f_code.co_filename or ""):
# 找到第一个不在 site-packages 中的调用(即业务代码)
# record.filename = os.path.basename(frame.f_code.co_filename)
record.filename = frame.f_code.co_filename
record.lineno = frame.f_lineno
return record
# 如果没找到,使用默认位置
return record
logging.setLogRecordFactory(record_factory)
# 2. 配置日志格式,包含文件和行号
formatter = logging.Formatter(
'%(asctime)s [%(levelname)s] [LOGGING @ %(filename)s:%(lineno)d] - %(message)s'
)
# 3. 应用到所有现有日志处理器
root_logger = logging.getLogger()
root_logger.setLevel(logging.INFO) # 设置全局日志级别
# 替换所有处理器的格式
for handler in root_logger.handlers[:]:
# 跳过 Ray 的日志处理器(如果能识别的话)
if not isinstance(handler, logging.StreamHandler):
root_logger.removeHandler(handler)
continue
handler.setFormatter(formatter)
# 如果没有处理器,添加一个
if not any(isinstance(h, logging.StreamHandler) for h in root_logger.handlers):
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(formatter)
root_logger.addHandler(handler)
# 4. 禁用 Ray 的颜色和前缀(可选)
os.environ["RAY_DISABLE_COLOR"] = "1"
os.environ["RAY_LOG_PREFIX"] = "0"
# 自动执行配置
setup_enhanced_logging()
# 可选:同时替换 print 函数(覆盖某些库直接调用 print 的情况)
import builtins
original_print = builtins.print
def traced_print(*args, **kwargs):
# 获取调用者信息
frame = inspect.currentframe().f_back
depth = 0
while frame and depth < 15:
# filename = os.path.basename(frame.f_code.co_filename)
filename = frame.f_code.co_filename
lineno = frame.f_lineno
if depth == 0:
original_print(f"[PRINT @ {filename}:{lineno} DEPTH={depth}]", *args, **kwargs)
break
frame = frame.f_back
depth += 1
builtins.print = traced_print
import faulthandler, signal
# 打开 faulthandler
faulthandler.enable(file=sys.stderr, all_threads=True)
# 收到 SIGUSR1 时,dump 全部线程栈到 stderr
faulthandler.register(signal.SIGUSR1, file=sys.stderr, all_threads=True)
\ No newline at end of file
......@@ -21,97 +21,6 @@ import ray
from verl.trainer.ppo.ray_trainer import RayPPOTrainer
from verl.trainer.ppo.reward import load_reward_manager
# put this at program entry (ideally before you create ProcessPoolExecutor)
import pickle, sys, os, inspect, logging, pprint, builtins, traceback
import verl.utils.reward_score.prime_math as pm
# logging.basicConfig(
# level=os.getenv("VERL_LOGGING_LEVEL", "INFO"),
# format="%(asctime)s %(levelname)s %(processName)s: %(message)s",
# datefmt="%F %T",
# handlers=[logging.StreamHandler(sys.stdout)],
# force=True,
# )
def setup_enhanced_logging():
# 1. 增强日志记录工厂,添加文件和行号信息
original_factory = logging.getLogRecordFactory()
def record_factory(*args, **kwargs):
record = original_factory(*args, **kwargs)
# 深度回溯找到实际业务代码位置(跳过框架调用)
frame = inspect.currentframe()
depth = 0
while frame and depth < 10: # 增加深度,确保跳过框架代码
frame = frame.f_back
depth += 1
if frame and "site-packages" not in (frame.f_code.co_filename or ""):
# 找到第一个不在 site-packages 中的调用(即业务代码)
record.filename = os.path.basename(frame.f_code.co_filename)
record.lineno = frame.f_lineno
return record
# 如果没找到,使用默认位置
return record
logging.setLogRecordFactory(record_factory)
# 2. 配置日志格式,包含文件和行号
formatter = logging.Formatter(
'%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d - %(message)s'
)
# 3. 应用到所有现有日志处理器
root_logger = logging.getLogger()
root_logger.setLevel(logging.INFO) # 设置全局日志级别
# 替换所有处理器的格式
for handler in root_logger.handlers[:]:
# 跳过 Ray 的日志处理器(如果能识别的话)
if not isinstance(handler, logging.StreamHandler):
root_logger.removeHandler(handler)
continue
handler.setFormatter(formatter)
# 如果没有处理器,添加一个
if not any(isinstance(h, logging.StreamHandler) for h in root_logger.handlers):
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(formatter)
root_logger.addHandler(handler)
# 4. 禁用 Ray 的颜色和前缀(可选)
os.environ["RAY_DISABLE_COLOR"] = "1"
os.environ["RAY_LOG_PREFIX"] = "0"
# 自动执行配置
setup_enhanced_logging()
# 可选:同时替换 print 函数(覆盖某些库直接调用 print 的情况)
import builtins
original_print = builtins.print
def traced_print(*args, **kwargs):
# 获取调用者信息
frame = inspect.currentframe().f_back
depth = 0
while frame and depth < 15:
filename = os.path.basename(frame.f_code.co_filename)
lineno = frame.f_lineno
if depth == 0:
original_print(f"[PRINT @ {filename}:{lineno} DEPTH={depth}]", *args, **kwargs)
break
frame = frame.f_back
depth += 1
builtins.print = traced_print
import faulthandler, signal
# 打开 faulthandler
faulthandler.enable(file=sys.stderr, all_threads=True)
# 收到 SIGUSR1 时,dump 全部线程栈到 stderr
faulthandler.register(signal.SIGUSR1, file=sys.stderr, all_threads=True)
@hydra.main(config_path="config", config_name="ppo_trainer", version_base=None)
def main(config):
......
......@@ -273,7 +273,8 @@ if __name__ == '__main__':
# Log code and input only on error for brevity
generation_to_log = generation[:200] + "..." if len(generation) > 200 else generation
logger.error(f"Case {case_index}: code: {generation_to_log}")
logger.error(f"Case {case_index}: input: {str(stdin_data)}")
# logger.error(f"Case {case_index}: input: {str(stdin_data)}")
logger.error(f"Case {case_index}: input: {str(stdin_data)[:200]}")
elif api_response:
# --- Add debug logging ---
logger.debug(f"Case {case_index}: API Response: {api_response}")
......
......@@ -29,49 +29,9 @@ import threading, pickle
from multiprocessing.managers import BaseProxy
import tempfile, os
# class TrackableProcessPoolExecutor(ProcessPoolExecutor):
# def __init__(self, *args, **kwargs):
# super().__init__(*args, **kwargs)
# # 跟踪:Future -> 对应的工作进程ID
# self.future_to_pid = {}
# # 线程锁,避免并发修改
# self.lock = threading.Lock()
# def submit(self, fn, *args, **kwargs):
# # 提交任务前获取当前工作进程(通过_worker_ids获取)
# # 注意:这是基于CPython实现的内部属性,可能随版本变化
# with self.lock:
# # 提交任务并获取Future
# future = super().submit(fn, *args, **kwargs)
# # 获取当前可用的工作进程ID列表(executor._processes是内部字典:{pid: process})
# if hasattr(self, '_processes') and self._processes:
# # 简单映射:假设最新的进程处理新任务(实际是轮询,这里简化)
# pid = next(reversed(self._processes.keys()))
# self.future_to_pid[future] = pid
# return future
# def terminate_future_worker(self, future):
# """增强版:使用psutil确保进程被彻底终止"""
# with self.lock:
# pid = self.future_to_pid.get(future)
# if pid:
# # 尝试使用psutil终止进程及其所有子进程
# try:
# process = psutil.Process(pid)
# # 终止所有子进程
# for child in process.children(recursive=True):
# child.terminate()
# # 终止主进程
# process.terminate()
# # 等待进程结束
# process.wait(timeout=1.0)
# print(f"[Debug] Terminated worker process {pid} and its children")
# except (psutil.NoSuchProcess, psutil.TimeoutExpired) as e:
# print(f"[Warning] Failed to terminate process {pid}: {e}")
# finally:
# # 清理映射
# self.future_to_pid.pop(future, None)
from multiprocessing import Manager
import queue
import psutil
# def save_large_object_to_file(obj):
......@@ -151,13 +111,133 @@ import tempfile, os
# print(f"[Warning] Failed to delete temp file: {e}")
def _pid_capturing_wrapper(fn, args, kwargs, future_id, pid_queue):
"""工作进程:获取PID并发送到队列"""
try:
pid = psutil.Process().pid
pid_queue.put((future_id, pid))
return fn(*args, **kwargs)
except Exception:
raise
class TrackableProcessPoolExecutor(ProcessPoolExecutor):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.manager = Manager()
self.pid_queue = self.manager.Queue() # 进程间队列
self.stop_event = threading.Event()
self.async_future_to_pid = {} # async_future -> pid
self.concurrent_to_async = {} # concur_future -> async_future
self.lock = threading.Lock()
self._wrapped_by_asyncio = False
self._start_listener_thread()
def _start_listener_thread(self):
"""监听线程:从队列获取PID并更新映射表"""
def listener():
while not self.stop_event.is_set():
try:
future_id, pid = self.pid_queue.get(block=True, timeout=1)
with self.lock:
# 匹配对应的async_future
for async_fut in list(self.async_future_to_pid.keys()):
if id(async_fut) == future_id:
self.async_future_to_pid[async_fut] = pid
break
self.pid_queue.task_done()
except queue.Empty:
continue # 忽略超时
except Exception:
continue
# 线程退出
self.listener_thread = threading.Thread(target=listener)
self.listener_thread.start()
def _set_wrapped_by_asyncio(self):
self._wrapped_by_asyncio = True
def _submit_async(self, fn, args, kwargs):
"""提交异步任务:关联future_id与async_future"""
loop = asyncio.get_event_loop()
async_future = loop.create_future()
future_id = id(async_future) # 用async_future的id作为标识
# 提交任务到进程池
concur_future = super().submit(
_pid_capturing_wrapper,
fn, args, kwargs,
future_id=future_id,
pid_queue=self.pid_queue
)
# 记录映射关系
with self.lock:
self.concurrent_to_async[concur_future] = async_future
self.async_future_to_pid[async_future] = None # 初始化PID占位
# 同步Future状态的回调
def callback(f):
if async_future.done():
return
try:
result = f.result()
if not async_future.done():
loop.call_soon_threadsafe(async_future.set_result, result)
except Exception:
pass # 忽略状态错误
concur_future.add_done_callback(callback)
return async_future
def submit(self, fn, *args, **kwargs):
if self._wrapped_by_asyncio:
return self._submit_async(fn, args, kwargs)
else:
return super().submit(fn, *args, **kwargs)
def terminate_future_worker(self, future):
"""终止与future关联的工作进程"""
with self.lock:
pid = self.async_future_to_pid.get(future)
if not pid:
return False
# 执行终止操作
try:
process = psutil.Process(pid)
# 终止子进程和主进程
for child in process.children(recursive=True):
child.terminate()
process.terminate()
process.wait(timeout=0.5)
# 取消关联的Future
with self.lock:
concur_future = next((cf for cf, af in self.concurrent_to_async.items() if af == future), None)
if concur_future:
concur_future.cancel()
if not future.done():
future.cancel()
return True
except Exception:
return False
def shutdown(self, wait=True):
self.stop_event.set()
super().shutdown(wait=wait)
self.listener_thread.join(timeout=0.5)
self.manager.shutdown()
# 注意:sandbox场景下,这边里面任务就是提交个API,所以按道理的话合batch肯定比下面这种single的快
async def single_compute_score(evaluation_func, completion, reference, task, task_extra_info, executor, timeout=30.0):
loop = asyncio.get_running_loop()
# sometimes (completion, reference, task, task_extra_info) is large, especially the `reference` containing large-size IO
payload_size = len(pickle.dumps((task, completion, reference, task_extra_info)))
# print(f"[single_compute_score] eval={evaluation_func} pickle={payload_size} bytes", end="\t")
try:
# Ensure process_completion is called properly
future = loop.run_in_executor(executor, partial(evaluation_func, task, completion, reference, task_extra_info))
......@@ -176,8 +256,10 @@ async def parallel_compute_score_async(evaluation_func, completions, references,
if extra_info is None:
extra_info = [None] * len(tasks)
scores = []
print('In parallel_compute_score_async!!!')
# with ProcessPoolExecutor(max_workers=num_processes) as executor:
with TrackableProcessPoolExecutor(max_workers=num_processes) as executor:
executor._set_wrapped_by_asyncio()
# to prevent very occasional starvation caused by some anomalous programs ( like infinite loop ), the exceptions in async programs will instantly halt the evaluation, and all summoned processes will be killed.
try:
# Create tasks for all rows
......@@ -186,23 +268,10 @@ async def parallel_compute_score_async(evaluation_func, completions, references,
except Exception as e:
print(f"[Exception] async gather failed: {e}")
raise
finally:
terminated_count = 0
for pid, proc in executor._processes.items():
try:
p = psutil.Process(pid)
p.terminate()
try:
p.wait(timeout=5)
except psutil.TimeoutExpired:
p.kill()
terminated_count += 1
except Exception:
pass
print(f"[Shutdown] {terminated_count} subprocess(es) terminated.")
# Process results
for result, completion, reference, task in zip(results, completions, references, tasks):
print('reward score is', result)
if isinstance(result, Exception) or result is None:
# Handle failed or timed-out tasks
scores.append(0.0)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment