Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
V
verl
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ZhangXiaoyun
verl
Commits
1ec5eb50
Unverified
Commit
1ec5eb50
authored
Jan 18, 2025
by
Chi Zhang
Committed by
GitHub
Jan 18, 2025
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[dataproto] fix: add assertion for uneven chunk (#115)
- forbid uneven chunk for DataProto
parent
5a94e14d
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
38 additions
and
2 deletions
+38
-2
tests/utility/test_tensor_dict_utilities.py
+23
-0
verl/protocol.py
+15
-2
No files found.
tests/utility/test_tensor_dict_utilities.py
View file @
1ec5eb50
...
...
@@ -108,6 +108,9 @@ def test_chunk_concat():
labels
=
[
'a'
,
'b'
,
'c'
,
'd'
,
'e'
,
'f'
]
data
=
DataProto
.
from_dict
(
tensors
=
{
'obs'
:
obs
},
non_tensors
=
{
'labels'
:
labels
},
meta_info
=
{
'name'
:
'abdce'
})
with
pytest
.
raises
(
AssertionError
):
data
.
chunk
(
5
)
data_split
=
data
.
chunk
(
2
)
assert
len
(
data_split
)
==
2
assert
torch
.
all
(
torch
.
eq
(
data_split
[
0
]
.
batch
[
'obs'
],
torch
.
tensor
([
1
,
2
,
3
])))
...
...
@@ -237,3 +240,23 @@ def test_torch_save_data_proto():
import
os
os
.
remove
(
'test_data.pt'
)
def
test_len
():
obs
=
torch
.
tensor
([[
1
,
2
],
[
3
,
4
],
[
5
,
6
]])
labels
=
np
.
array
([
'a'
,
'b'
,
'c'
],
dtype
=
object
)
data
=
DataProto
.
from_dict
(
tensors
=
{
'obs'
:
obs
},
non_tensors
=
{
'labels'
:
labels
},
meta_info
=
{
'info'
:
'test_info'
})
assert
len
(
data
)
==
3
data
=
DataProto
(
batch
=
None
,
non_tensor_batch
=
{
'labels'
:
labels
},
meta_info
=
{
'info'
:
'test_info'
})
assert
len
(
data
)
==
3
data
=
DataProto
(
batch
=
None
,
non_tensor_batch
=
{},
meta_info
=
{
'info'
:
'test_info'
})
assert
len
(
data
)
==
0
data
=
DataProto
(
batch
=
None
,
non_tensor_batch
=
None
,
meta_info
=
{
'info'
:
'test_info'
})
assert
len
(
data
)
==
0
verl/protocol.py
View file @
1ec5eb50
...
...
@@ -178,7 +178,13 @@ class DataProto:
self
.
check_consistency
()
def
__len__
(
self
):
return
self
.
batch
.
batch_size
[
0
]
if
self
.
batch
is
not
None
:
return
self
.
batch
.
batch_size
[
0
]
elif
self
.
non_tensor_batch
is
not
None
and
len
(
self
.
non_tensor_batch
)
>
0
:
random_key
=
list
(
self
.
non_tensor_batch
.
keys
())[
0
]
return
self
.
non_tensor_batch
[
random_key
]
.
shape
[
0
]
else
:
return
0
def
__getitem__
(
self
,
item
):
tensor_data
=
self
.
batch
[
item
]
...
...
@@ -240,7 +246,11 @@ class DataProto:
if
self
.
batch
is
not
None
:
assert
len
(
self
.
batch
.
batch_size
)
==
1
,
'only support num_batch_dims=1'
if
len
(
self
.
non_tensor_batch
)
!=
0
:
if
self
.
non_tensor_batch
is
not
None
:
for
key
,
val
in
self
.
non_tensor_batch
.
items
():
assert
isinstance
(
val
,
np
.
ndarray
)
if
self
.
batch
is
not
None
and
len
(
self
.
non_tensor_batch
)
!=
0
:
# TODO: we can actually lift this restriction if needed
assert
len
(
self
.
batch
.
batch_size
)
==
1
,
'only support num_batch_dims=1 when non_tensor_batch is not empty.'
...
...
@@ -478,6 +488,9 @@ class DataProto:
Returns:
List[DataProto]: a list of DataProto after splitting
"""
assert
len
(
self
)
%
chunks
==
0
,
f
'only support equal chunk. Got size of DataProto {len(self)} and chunk {chunks}.'
if
self
.
batch
is
not
None
:
batch_lst
=
self
.
batch
.
chunk
(
chunks
=
chunks
,
dim
=
0
)
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment