Unverified Commit ab525bce by Guangming Sheng Committed by GitHub

[misc]fix: pad dataproto when pad size is larger than len(dataproto) (#150)

- As titled
- Solved: #149 

Waiting for testing from @chujiezheng

---------

Co-authored-by: Chi Zhang <zhangchi.usc1992@bytedance.com>
parent 9fca71d2
......@@ -206,6 +206,20 @@ def test_dataproto_pad_unpad():
assert (unpadd_data.non_tensor_batch['labels'] == labels).all()
assert unpadd_data.meta_info == {'info': 'test_info'}
padded_data, pad_size = pad_dataproto_to_divisor(data, size_divisor=7)
assert pad_size == 4
expected_obs = torch.tensor([[1, 2], [3, 4], [5, 6], [1, 2], [3, 4], [5, 6], [1, 2]])
expected_labels = ['a', 'b', 'c', 'a', 'b', 'c', 'a']
assert torch.all(torch.eq(padded_data.batch['obs'], expected_obs))
assert (padded_data.non_tensor_batch['labels'] == expected_labels).all()
assert padded_data.meta_info == {'info': 'test_info'}
unpadd_data = unpad_dataproto(padded_data, pad_size=pad_size)
assert torch.all(torch.eq(unpadd_data.batch['obs'], obs))
assert (unpadd_data.non_tensor_batch['labels'] == labels).all()
assert unpadd_data.meta_info == {'info': 'test_info'}
def test_dataproto_fold_unfold():
from verl.protocol import fold_batch_dim, unfold_batch_dim, DataProto
......
......@@ -51,7 +51,13 @@ def pad_dataproto_to_divisor(data: 'DataProto', size_divisor: int):
assert isinstance(data, DataProto), 'data must be a DataProto'
if len(data) % size_divisor != 0:
pad_size = size_divisor - len(data) % size_divisor
data_padded = DataProto.concat([data, data[:pad_size]])
padding_protos = []
remaining_pad = pad_size
while remaining_pad > 0:
take_size = min(remaining_pad, len(data))
padding_protos.append(data[:take_size])
remaining_pad -= take_size
data_padded = DataProto.concat([data] + padding_protos)
else:
pad_size = 0
data_padded = data
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment