Commit c3b216bd by 苏舞仙

filter flags

parent 014a39f4
...@@ -313,6 +313,10 @@ class RayPPOTrainer(object): ...@@ -313,6 +313,10 @@ class RayPPOTrainer(object):
self._validate_config() self._validate_config()
self._create_dataloader() self._create_dataloader()
# effective dapo
self.filter_flags = {}
def _validate_config(self): def _validate_config(self):
config = self.config config = self.config
# number of GPUs total # number of GPUs total
...@@ -811,6 +815,11 @@ class RayPPOTrainer(object): ...@@ -811,6 +815,11 @@ class RayPPOTrainer(object):
with open(local_latest_checkpointed_iteration, 'w') as f: with open(local_latest_checkpointed_iteration, 'w') as f:
f.write(str(self.global_steps)) f.write(str(self.global_steps))
# save filter flags
local_latest_filter_flags_path = os.path.join(local_global_step_folder, 'latest_filter_flags.json')
with open(local_latest_filter_flags_path, 'w', encoding='utf-8') as f:
json.dump(self.filter_flags, f)
def _load_checkpoint(self): def _load_checkpoint(self):
if self.config.trainer.resume_mode == 'disable': if self.config.trainer.resume_mode == 'disable':
return 0 return 0
...@@ -864,6 +873,15 @@ class RayPPOTrainer(object): ...@@ -864,6 +873,15 @@ class RayPPOTrainer(object):
else: else:
print(f"Warning: No dataloader state found at {dataloader_local_path}, will start from scratch") print(f"Warning: No dataloader state found at {dataloader_local_path}, will start from scratch")
# load filter flags
local_latest_filter_flags_path = os.path.join(global_step_folder, 'latest_filter_flags.json')
try:
with open(local_latest_filter_flags_path, 'r', encoding='utf-8') as f:
self.filter_flags = json.load(f)
except Exception as e:
print(f'Failed to load filter flags from {local_latest_filter_flags_path}.')
self.filter_flags = {}
def _balance_batch(self, batch: DataProto, metrics, logging_prefix='global_seqlen'): def _balance_batch(self, batch: DataProto, metrics, logging_prefix='global_seqlen'):
"""Reorder the data on single controller such that each dp rank gets similar total tokens""" """Reorder the data on single controller such that each dp rank gets similar total tokens"""
attention_mask = batch.batch['attention_mask'] attention_mask = batch.batch['attention_mask']
...@@ -925,6 +943,26 @@ class RayPPOTrainer(object): ...@@ -925,6 +943,26 @@ class RayPPOTrainer(object):
metrics = {} metrics = {}
new_batch: DataProto = DataProto.from_single_dict(batch_dict) new_batch: DataProto = DataProto.from_single_dict(batch_dict)
if 'problem_id' in new_batch.non_tensor_batch.keys():
filter_key = 'problem_id'
elif 'question' in new_batch.batch.keys():
filter_key = 'question'
else:
filter_key = 'raw_prompt_ids'
select_idx = []
for id, item in enumerate(new_batch.non_tensor_batch[filter_key]):
if item in self.filter_flags and self.filter_flags[item] == 1:
continue
select_idx.append(id)
if len(select_idx) == 0:
# 换下一批样本继续生成吧
continue
new_batch = new_batch.select_idxs(select_idx)
num_gen_batches += 1 num_gen_batches += 1
# pop those keys for generation # pop those keys for generation
if 'multi_modal_inputs' in new_batch.non_tensor_batch.keys(): if 'multi_modal_inputs' in new_batch.non_tensor_batch.keys():
...@@ -1038,6 +1076,18 @@ class RayPPOTrainer(object): ...@@ -1038,6 +1076,18 @@ class RayPPOTrainer(object):
if traj_from_prompt_uid in kept_prompt_uids: if traj_from_prompt_uid in kept_prompt_uids:
kept_traj_idxs.append(idx) kept_traj_idxs.append(idx)
# filter uids
filter_prompt_uids = [uid for uid, std in prompt_uid2metric_std.items() if std == 0 and prompt_uid2metric_vals[uid] != 0]
filter_traj_idxs = []
for idx, traj_from_prompt_uid in enumerate(new_batch.non_tensor_batch['uid']):
if traj_from_prompt_uid in filter_prompt_uids:
filter_traj_idxs.append(idx)
for idx in filter_traj_idxs:
print(new_batch.non_tensor_batch[filter_key][idx])
self.filter_flags[new_batch.non_tensor_batch[filter_key][idx]] = 1
new_batch = new_batch[kept_traj_idxs] new_batch = new_batch[kept_traj_idxs]
if batch is None: if batch is None:
batch = new_batch batch = new_batch
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment