Unverified Commit 1ec5eb50 by Chi Zhang Committed by GitHub

[dataproto] fix: add assertion for uneven chunk (#115)

- forbid uneven chunk for DataProto
parent 5a94e14d
...@@ -108,6 +108,9 @@ def test_chunk_concat(): ...@@ -108,6 +108,9 @@ def test_chunk_concat():
labels = ['a', 'b', 'c', 'd', 'e', 'f'] labels = ['a', 'b', 'c', 'd', 'e', 'f']
data = DataProto.from_dict(tensors={'obs': obs}, non_tensors={'labels': labels}, meta_info={'name': 'abdce'}) data = DataProto.from_dict(tensors={'obs': obs}, non_tensors={'labels': labels}, meta_info={'name': 'abdce'})
with pytest.raises(AssertionError):
data.chunk(5)
data_split = data.chunk(2) data_split = data.chunk(2)
assert len(data_split) == 2 assert len(data_split) == 2
assert torch.all(torch.eq(data_split[0].batch['obs'], torch.tensor([1, 2, 3]))) assert torch.all(torch.eq(data_split[0].batch['obs'], torch.tensor([1, 2, 3])))
...@@ -237,3 +240,23 @@ def test_torch_save_data_proto(): ...@@ -237,3 +240,23 @@ def test_torch_save_data_proto():
import os import os
os.remove('test_data.pt') os.remove('test_data.pt')
def test_len():
obs = torch.tensor([[1, 2], [3, 4], [5, 6]])
labels = np.array(['a', 'b', 'c'], dtype=object)
data = DataProto.from_dict(tensors={'obs': obs}, non_tensors={'labels': labels}, meta_info={'info': 'test_info'})
assert len(data) == 3
data = DataProto(batch=None, non_tensor_batch={'labels': labels}, meta_info={'info': 'test_info'})
assert len(data) == 3
data = DataProto(batch=None, non_tensor_batch={}, meta_info={'info': 'test_info'})
assert len(data) == 0
data = DataProto(batch=None, non_tensor_batch=None, meta_info={'info': 'test_info'})
assert len(data) == 0
...@@ -178,7 +178,13 @@ class DataProto: ...@@ -178,7 +178,13 @@ class DataProto:
self.check_consistency() self.check_consistency()
def __len__(self): def __len__(self):
return self.batch.batch_size[0] if self.batch is not None:
return self.batch.batch_size[0]
elif self.non_tensor_batch is not None and len(self.non_tensor_batch) > 0:
random_key = list(self.non_tensor_batch.keys())[0]
return self.non_tensor_batch[random_key].shape[0]
else:
return 0
def __getitem__(self, item): def __getitem__(self, item):
tensor_data = self.batch[item] tensor_data = self.batch[item]
...@@ -240,7 +246,11 @@ class DataProto: ...@@ -240,7 +246,11 @@ class DataProto:
if self.batch is not None: if self.batch is not None:
assert len(self.batch.batch_size) == 1, 'only support num_batch_dims=1' assert len(self.batch.batch_size) == 1, 'only support num_batch_dims=1'
if len(self.non_tensor_batch) != 0: if self.non_tensor_batch is not None:
for key, val in self.non_tensor_batch.items():
assert isinstance(val, np.ndarray)
if self.batch is not None and len(self.non_tensor_batch) != 0:
# TODO: we can actually lift this restriction if needed # TODO: we can actually lift this restriction if needed
assert len(self.batch.batch_size) == 1, 'only support num_batch_dims=1 when non_tensor_batch is not empty.' assert len(self.batch.batch_size) == 1, 'only support num_batch_dims=1 when non_tensor_batch is not empty.'
...@@ -478,6 +488,9 @@ class DataProto: ...@@ -478,6 +488,9 @@ class DataProto:
Returns: Returns:
List[DataProto]: a list of DataProto after splitting List[DataProto]: a list of DataProto after splitting
""" """
assert len(
self) % chunks == 0, f'only support equal chunk. Got size of DataProto {len(self)} and chunk {chunks}.'
if self.batch is not None: if self.batch is not None:
batch_lst = self.batch.chunk(chunks=chunks, dim=0) batch_lst = self.batch.chunk(chunks=chunks, dim=0)
else: else:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment