Unverified Commit 1ec5eb50 by Chi Zhang Committed by GitHub

[dataproto] fix: add assertion for uneven chunk (#115)

- forbid uneven chunk for DataProto
parent 5a94e14d
......@@ -108,6 +108,9 @@ def test_chunk_concat():
labels = ['a', 'b', 'c', 'd', 'e', 'f']
data = DataProto.from_dict(tensors={'obs': obs}, non_tensors={'labels': labels}, meta_info={'name': 'abdce'})
with pytest.raises(AssertionError):
data.chunk(5)
data_split = data.chunk(2)
assert len(data_split) == 2
assert torch.all(torch.eq(data_split[0].batch['obs'], torch.tensor([1, 2, 3])))
......@@ -237,3 +240,23 @@ def test_torch_save_data_proto():
import os
os.remove('test_data.pt')
def test_len():
obs = torch.tensor([[1, 2], [3, 4], [5, 6]])
labels = np.array(['a', 'b', 'c'], dtype=object)
data = DataProto.from_dict(tensors={'obs': obs}, non_tensors={'labels': labels}, meta_info={'info': 'test_info'})
assert len(data) == 3
data = DataProto(batch=None, non_tensor_batch={'labels': labels}, meta_info={'info': 'test_info'})
assert len(data) == 3
data = DataProto(batch=None, non_tensor_batch={}, meta_info={'info': 'test_info'})
assert len(data) == 0
data = DataProto(batch=None, non_tensor_batch=None, meta_info={'info': 'test_info'})
assert len(data) == 0
......@@ -178,7 +178,13 @@ class DataProto:
self.check_consistency()
def __len__(self):
return self.batch.batch_size[0]
if self.batch is not None:
return self.batch.batch_size[0]
elif self.non_tensor_batch is not None and len(self.non_tensor_batch) > 0:
random_key = list(self.non_tensor_batch.keys())[0]
return self.non_tensor_batch[random_key].shape[0]
else:
return 0
def __getitem__(self, item):
tensor_data = self.batch[item]
......@@ -240,7 +246,11 @@ class DataProto:
if self.batch is not None:
assert len(self.batch.batch_size) == 1, 'only support num_batch_dims=1'
if len(self.non_tensor_batch) != 0:
if self.non_tensor_batch is not None:
for key, val in self.non_tensor_batch.items():
assert isinstance(val, np.ndarray)
if self.batch is not None and len(self.non_tensor_batch) != 0:
# TODO: we can actually lift this restriction if needed
assert len(self.batch.batch_size) == 1, 'only support num_batch_dims=1 when non_tensor_batch is not empty.'
......@@ -478,6 +488,9 @@ class DataProto:
Returns:
List[DataProto]: a list of DataProto after splitting
"""
assert len(
self) % chunks == 0, f'only support equal chunk. Got size of DataProto {len(self)} and chunk {chunks}.'
if self.batch is not None:
batch_lst = self.batch.chunk(chunks=chunks, dim=0)
else:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment