Unverified Commit 3165d988 by Lumeng Wu Committed by GitHub

fix: (1) skipped last step (2) redundant validation and logging (#409)

This PR solves these 2 following problems.

1. Last step skipped

`self.global_steps += 1` before if `self.global_steps >=
self.total_training_steps` makes the last step skipped.

We start from step 1, and we expect `self.total_training_steps` in
total.


https://github.com/volcengine/verl/blob/82b38e25c72e1b6de7d7d2092af6e1ed5dd2a400/verl/trainer/ppo/ray_trainer.py#L999-L1001

   When `self.global_steps == self.total_training_steps-1`:

   * we have only executed `self.total_training_steps-1` steps

   * `self.global_steps` is updated to `self.total_training_steps`
* `self.global_steps >= self.total_training_steps` is satisfied, and the
training ends.

   Therefore, we should put `self.global_steps += 1` at last

2. redundant validation and logging

If `self.total_training_steps % self.config.trainer.test_freq == 0` :

   * `self._validate()` will be executed twice 

1.
https://github.com/volcengine/verl/blob/82b38e25c72e1b6de7d7d2092af6e1ed5dd2a400/verl/trainer/ppo/ray_trainer.py#L984

2.
https://github.com/volcengine/verl/blob/82b38e25c72e1b6de7d7d2092af6e1ed5dd2a400/verl/trainer/ppo/ray_trainer.py#L1005

   * logging will also be executed twice

1.
https://github.com/volcengine/verl/blob/82b38e25c72e1b6de7d7d2092af6e1ed5dd2a400/verl/trainer/ppo/ray_trainer.py#L985
and
https://github.com/volcengine/verl/blob/82b38e25c72e1b6de7d7d2092af6e1ed5dd2a400/verl/trainer/ppo/ray_trainer.py#L997
2.
https://github.com/volcengine/verl/blob/82b38e25c72e1b6de7d7d2092af6e1ed5dd2a400/verl/trainer/ppo/ray_trainer.py#L1007
parent 0cc2bdad
......@@ -53,6 +53,7 @@ def fit(self):
# we start from step 1
self.global_steps += 1
last_val_metrics = None
for epoch in range(self.config.trainer.total_epochs):
for batch_dict in self.train_dataloader:
......@@ -63,6 +64,7 @@ def fit(self):
# pop those keys for generation
gen_batch = batch.pop(batch_keys=['input_ids', 'attention_mask', 'position_ids'])
is_last_step = self.global_steps >= self.total_training_steps
with _timer('step', timing_raw):
# generate a batch
......@@ -168,13 +170,15 @@ def fit(self):
# validate
if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and \
self.global_steps % self.config.trainer.test_freq == 0:
(is_last_step or self.global_steps % self.config.trainer.test_freq == 0):
with _timer('testing', timing_raw):
val_metrics: dict = self._validate()
if is_last_step:
last_val_metrics = val_metrics
metrics.update(val_metrics)
if self.config.trainer.save_freq > 0 and \
self.global_steps % self.config.trainer.save_freq == 0:
if self.config.trainer.save_freq > 0 and (is_last_step or \
self.global_steps % self.config.trainer.save_freq == 0):
with _timer('save_checkpoint', timing_raw):
self._save_checkpoint()
......@@ -185,13 +189,8 @@ def fit(self):
# TODO: make a canonical logger that supports various backend
logger.log(data=metrics, step=self.global_steps)
self.global_steps += 1
if self.global_steps >= self.total_training_steps:
# perform validation after training
if self.val_reward_fn is not None:
val_metrics = self._validate()
pprint(f'Final validation metrics: {val_metrics}')
logger.log(data=val_metrics, step=self.global_steps)
pprint(f'Final validation metrics: {last_val_metrics}')
return
self.global_steps += 1
#!/bin/bash
pip3 install --upgrade yapf
python3 -m yapf -ir -vv --style ./.style.yapf verl tests single_controller examples
python3 -m yapf -ir -vv --style ./.style.yapf verl tests examples
\ No newline at end of file
......@@ -468,11 +468,11 @@ class FSDPSFTTrainer(object):
for data in tqdm(self.train_dataloader,
total=self.steps_per_epoch,
desc=f"Epoch {epoch+1}/{self.config.trainer.total_epochs}"):
global_step += 1
data = TensorDict(data, batch_size=self.config.data.train_batch_size).cuda()
metric = self.training_step(data)
if rank == 0:
tracking.log(data=metric, step=global_step)
global_step += 1
# for early exit validation
if global_step >= self.total_training_steps:
......
......@@ -878,6 +878,7 @@ class RayPPOTrainer(object):
# we start from step 1
self.global_steps += 1
last_val_metrics = None
for epoch in range(self.config.trainer.total_epochs):
for batch_dict in self.train_dataloader:
......@@ -898,6 +899,8 @@ class RayPPOTrainer(object):
non_tensor_batch_keys=['raw_prompt_ids'],
)
is_last_step = self.global_steps >= self.total_training_steps
with _timer('step', timing_raw):
# generate a batch
with _timer('gen', timing_raw):
......@@ -996,13 +999,15 @@ class RayPPOTrainer(object):
# validate
if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and \
self.global_steps % self.config.trainer.test_freq == 0:
(is_last_step or self.global_steps % self.config.trainer.test_freq == 0):
with _timer('testing', timing_raw):
val_metrics: dict = self._validate()
if is_last_step:
last_val_metrics = val_metrics
metrics.update(val_metrics)
if self.config.trainer.save_freq > 0 and \
self.global_steps % self.config.trainer.save_freq == 0:
if self.config.trainer.save_freq > 0 and ( is_last_step or \
self.global_steps % self.config.trainer.save_freq == 0):
with _timer('save_checkpoint', timing_raw):
self._save_checkpoint()
......@@ -1018,17 +1023,8 @@ class RayPPOTrainer(object):
# TODO: make a canonical logger that supports various backend
logger.log(data=metrics, step=self.global_steps)
self.global_steps += 1
if self.global_steps >= self.total_training_steps:
# perform validation after training
if self.val_reward_fn is not None:
val_metrics = self._validate()
pprint(f'Final validation metrics: {val_metrics}')
logger.log(data=val_metrics, step=self.global_steps)
if self.config.trainer.save_freq > 0 and \
(self.global_steps - 1) % self.config.trainer.save_freq != 0:
with _timer('save_checkpoint', timing_raw):
self._save_checkpoint()
if is_last_step:
pprint(f'Final validation metrics: {last_val_metrics}')
return
self.global_steps += 1
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment