Unverified Commit 44a65f95 by Jiawei Liu Committed by GitHub

fix: slicing returns DataProto not DataProtoItem (#718)

parent db1d3251
...@@ -197,7 +197,8 @@ class DataProto: ...@@ -197,7 +197,8 @@ class DataProto:
def __getitem__(self, item): def __getitem__(self, item):
tensor_data = self.batch[item] tensor_data = self.batch[item]
non_tensor_data = {key: val[item] for key, val in self.non_tensor_batch.items()} non_tensor_data = {key: val[item] for key, val in self.non_tensor_batch.items()}
return DataProtoItem(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info) return_type = DataProto if isinstance(item, slice) else DataProtoItem
return return_type(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info)
def __getstate__(self): def __getstate__(self):
import io import io
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment