Unverified Commit 3165d988 by Lumeng Wu Committed by GitHub

fix: (1) skipped last step (2) redundant validation and logging (#409)

This PR solves these 2 following problems.

1. Last step skipped

`self.global_steps += 1` before if `self.global_steps >=
self.total_training_steps` makes the last step skipped.

We start from step 1, and we expect `self.total_training_steps` in
total.


https://github.com/volcengine/verl/blob/82b38e25c72e1b6de7d7d2092af6e1ed5dd2a400/verl/trainer/ppo/ray_trainer.py#L999-L1001

   When `self.global_steps == self.total_training_steps-1`:

   * we have only executed `self.total_training_steps-1` steps

   * `self.global_steps` is updated to `self.total_training_steps`
* `self.global_steps >= self.total_training_steps` is satisfied, and the
training ends.

   Therefore, we should put `self.global_steps += 1` at last

2. redundant validation and logging

If `self.total_training_steps % self.config.trainer.test_freq == 0` :

   * `self._validate()` will be executed twice 

1.
https://github.com/volcengine/verl/blob/82b38e25c72e1b6de7d7d2092af6e1ed5dd2a400/verl/trainer/ppo/ray_trainer.py#L984

2.
https://github.com/volcengine/verl/blob/82b38e25c72e1b6de7d7d2092af6e1ed5dd2a400/verl/trainer/ppo/ray_trainer.py#L1005

   * logging will also be executed twice

1.
https://github.com/volcengine/verl/blob/82b38e25c72e1b6de7d7d2092af6e1ed5dd2a400/verl/trainer/ppo/ray_trainer.py#L985
and
https://github.com/volcengine/verl/blob/82b38e25c72e1b6de7d7d2092af6e1ed5dd2a400/verl/trainer/ppo/ray_trainer.py#L997
2.
https://github.com/volcengine/verl/blob/82b38e25c72e1b6de7d7d2092af6e1ed5dd2a400/verl/trainer/ppo/ray_trainer.py#L1007
parent 0cc2bdad
...@@ -53,6 +53,7 @@ def fit(self): ...@@ -53,6 +53,7 @@ def fit(self):
# we start from step 1 # we start from step 1
self.global_steps += 1 self.global_steps += 1
last_val_metrics = None
for epoch in range(self.config.trainer.total_epochs): for epoch in range(self.config.trainer.total_epochs):
for batch_dict in self.train_dataloader: for batch_dict in self.train_dataloader:
...@@ -63,6 +64,7 @@ def fit(self): ...@@ -63,6 +64,7 @@ def fit(self):
# pop those keys for generation # pop those keys for generation
gen_batch = batch.pop(batch_keys=['input_ids', 'attention_mask', 'position_ids']) gen_batch = batch.pop(batch_keys=['input_ids', 'attention_mask', 'position_ids'])
is_last_step = self.global_steps >= self.total_training_steps
with _timer('step', timing_raw): with _timer('step', timing_raw):
# generate a batch # generate a batch
...@@ -168,13 +170,15 @@ def fit(self): ...@@ -168,13 +170,15 @@ def fit(self):
# validate # validate
if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and \ if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and \
self.global_steps % self.config.trainer.test_freq == 0: (is_last_step or self.global_steps % self.config.trainer.test_freq == 0):
with _timer('testing', timing_raw): with _timer('testing', timing_raw):
val_metrics: dict = self._validate() val_metrics: dict = self._validate()
if is_last_step:
last_val_metrics = val_metrics
metrics.update(val_metrics) metrics.update(val_metrics)
if self.config.trainer.save_freq > 0 and \ if self.config.trainer.save_freq > 0 and (is_last_step or \
self.global_steps % self.config.trainer.save_freq == 0: self.global_steps % self.config.trainer.save_freq == 0):
with _timer('save_checkpoint', timing_raw): with _timer('save_checkpoint', timing_raw):
self._save_checkpoint() self._save_checkpoint()
...@@ -185,13 +189,8 @@ def fit(self): ...@@ -185,13 +189,8 @@ def fit(self):
# TODO: make a canonical logger that supports various backend # TODO: make a canonical logger that supports various backend
logger.log(data=metrics, step=self.global_steps) logger.log(data=metrics, step=self.global_steps)
self.global_steps += 1
if self.global_steps >= self.total_training_steps: if self.global_steps >= self.total_training_steps:
pprint(f'Final validation metrics: {last_val_metrics}')
# perform validation after training
if self.val_reward_fn is not None:
val_metrics = self._validate()
pprint(f'Final validation metrics: {val_metrics}')
logger.log(data=val_metrics, step=self.global_steps)
return return
self.global_steps += 1
#!/bin/bash #!/bin/bash
pip3 install --upgrade yapf pip3 install --upgrade yapf
python3 -m yapf -ir -vv --style ./.style.yapf verl tests single_controller examples python3 -m yapf -ir -vv --style ./.style.yapf verl tests examples
\ No newline at end of file
...@@ -468,11 +468,11 @@ class FSDPSFTTrainer(object): ...@@ -468,11 +468,11 @@ class FSDPSFTTrainer(object):
for data in tqdm(self.train_dataloader, for data in tqdm(self.train_dataloader,
total=self.steps_per_epoch, total=self.steps_per_epoch,
desc=f"Epoch {epoch+1}/{self.config.trainer.total_epochs}"): desc=f"Epoch {epoch+1}/{self.config.trainer.total_epochs}"):
global_step += 1
data = TensorDict(data, batch_size=self.config.data.train_batch_size).cuda() data = TensorDict(data, batch_size=self.config.data.train_batch_size).cuda()
metric = self.training_step(data) metric = self.training_step(data)
if rank == 0: if rank == 0:
tracking.log(data=metric, step=global_step) tracking.log(data=metric, step=global_step)
global_step += 1
# for early exit validation # for early exit validation
if global_step >= self.total_training_steps: if global_step >= self.total_training_steps:
......
...@@ -878,6 +878,7 @@ class RayPPOTrainer(object): ...@@ -878,6 +878,7 @@ class RayPPOTrainer(object):
# we start from step 1 # we start from step 1
self.global_steps += 1 self.global_steps += 1
last_val_metrics = None
for epoch in range(self.config.trainer.total_epochs): for epoch in range(self.config.trainer.total_epochs):
for batch_dict in self.train_dataloader: for batch_dict in self.train_dataloader:
...@@ -898,6 +899,8 @@ class RayPPOTrainer(object): ...@@ -898,6 +899,8 @@ class RayPPOTrainer(object):
non_tensor_batch_keys=['raw_prompt_ids'], non_tensor_batch_keys=['raw_prompt_ids'],
) )
is_last_step = self.global_steps >= self.total_training_steps
with _timer('step', timing_raw): with _timer('step', timing_raw):
# generate a batch # generate a batch
with _timer('gen', timing_raw): with _timer('gen', timing_raw):
...@@ -996,13 +999,15 @@ class RayPPOTrainer(object): ...@@ -996,13 +999,15 @@ class RayPPOTrainer(object):
# validate # validate
if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and \ if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and \
self.global_steps % self.config.trainer.test_freq == 0: (is_last_step or self.global_steps % self.config.trainer.test_freq == 0):
with _timer('testing', timing_raw): with _timer('testing', timing_raw):
val_metrics: dict = self._validate() val_metrics: dict = self._validate()
if is_last_step:
last_val_metrics = val_metrics
metrics.update(val_metrics) metrics.update(val_metrics)
if self.config.trainer.save_freq > 0 and \ if self.config.trainer.save_freq > 0 and ( is_last_step or \
self.global_steps % self.config.trainer.save_freq == 0: self.global_steps % self.config.trainer.save_freq == 0):
with _timer('save_checkpoint', timing_raw): with _timer('save_checkpoint', timing_raw):
self._save_checkpoint() self._save_checkpoint()
...@@ -1018,17 +1023,8 @@ class RayPPOTrainer(object): ...@@ -1018,17 +1023,8 @@ class RayPPOTrainer(object):
# TODO: make a canonical logger that supports various backend # TODO: make a canonical logger that supports various backend
logger.log(data=metrics, step=self.global_steps) logger.log(data=metrics, step=self.global_steps)
self.global_steps += 1 if is_last_step:
pprint(f'Final validation metrics: {last_val_metrics}')
if self.global_steps >= self.total_training_steps:
# perform validation after training
if self.val_reward_fn is not None:
val_metrics = self._validate()
pprint(f'Final validation metrics: {val_metrics}')
logger.log(data=val_metrics, step=self.global_steps)
if self.config.trainer.save_freq > 0 and \
(self.global_steps - 1) % self.config.trainer.save_freq != 0:
with _timer('save_checkpoint', timing_raw):
self._save_checkpoint()
return return
self.global_steps += 1
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment