Commit c3b216bd by 苏舞仙

filter flags

parent 014a39f4
......@@ -313,6 +313,10 @@ class RayPPOTrainer(object):
self._validate_config()
self._create_dataloader()
# effective dapo
self.filter_flags = {}
def _validate_config(self):
config = self.config
# number of GPUs total
......@@ -811,6 +815,11 @@ class RayPPOTrainer(object):
with open(local_latest_checkpointed_iteration, 'w') as f:
f.write(str(self.global_steps))
# save filter flags
local_latest_filter_flags_path = os.path.join(local_global_step_folder, 'latest_filter_flags.json')
with open(local_latest_filter_flags_path, 'w', encoding='utf-8') as f:
json.dump(self.filter_flags, f)
def _load_checkpoint(self):
if self.config.trainer.resume_mode == 'disable':
return 0
......@@ -864,6 +873,15 @@ class RayPPOTrainer(object):
else:
print(f"Warning: No dataloader state found at {dataloader_local_path}, will start from scratch")
# load filter flags
local_latest_filter_flags_path = os.path.join(global_step_folder, 'latest_filter_flags.json')
try:
with open(local_latest_filter_flags_path, 'r', encoding='utf-8') as f:
self.filter_flags = json.load(f)
except Exception as e:
print(f'Failed to load filter flags from {local_latest_filter_flags_path}.')
self.filter_flags = {}
def _balance_batch(self, batch: DataProto, metrics, logging_prefix='global_seqlen'):
"""Reorder the data on single controller such that each dp rank gets similar total tokens"""
attention_mask = batch.batch['attention_mask']
......@@ -925,6 +943,26 @@ class RayPPOTrainer(object):
metrics = {}
new_batch: DataProto = DataProto.from_single_dict(batch_dict)
if 'problem_id' in new_batch.non_tensor_batch.keys():
filter_key = 'problem_id'
elif 'question' in new_batch.batch.keys():
filter_key = 'question'
else:
filter_key = 'raw_prompt_ids'
select_idx = []
for id, item in enumerate(new_batch.non_tensor_batch[filter_key]):
if item in self.filter_flags and self.filter_flags[item] == 1:
continue
select_idx.append(id)
if len(select_idx) == 0:
# 换下一批样本继续生成吧
continue
new_batch = new_batch.select_idxs(select_idx)
num_gen_batches += 1
# pop those keys for generation
if 'multi_modal_inputs' in new_batch.non_tensor_batch.keys():
......@@ -1037,6 +1075,18 @@ class RayPPOTrainer(object):
for idx, traj_from_prompt_uid in enumerate(new_batch.non_tensor_batch['uid']):
if traj_from_prompt_uid in kept_prompt_uids:
kept_traj_idxs.append(idx)
# filter uids
filter_prompt_uids = [uid for uid, std in prompt_uid2metric_std.items() if std == 0 and prompt_uid2metric_vals[uid] != 0]
filter_traj_idxs = []
for idx, traj_from_prompt_uid in enumerate(new_batch.non_tensor_batch['uid']):
if traj_from_prompt_uid in filter_prompt_uids:
filter_traj_idxs.append(idx)
for idx in filter_traj_idxs:
print(new_batch.non_tensor_batch[filter_key][idx])
self.filter_flags[new_batch.non_tensor_batch[filter_key][idx]] = 1
new_batch = new_batch[kept_traj_idxs]
if batch is None:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment