Unverified Commit 44a65f95 by Jiawei Liu Committed by GitHub

fix: slicing returns DataProto not DataProtoItem (#718)

parent db1d3251
......@@ -197,7 +197,8 @@ class DataProto:
def __getitem__(self, item):
tensor_data = self.batch[item]
non_tensor_data = {key: val[item] for key, val in self.non_tensor_batch.items()}
return DataProtoItem(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info)
return_type = DataProto if isinstance(item, slice) else DataProtoItem
return return_type(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info)
def __getstate__(self):
import io
......@@ -603,7 +604,7 @@ import ray
class DataProtoFuture:
"""
DataProtoFuture aims to eliminate actual data fetching on driver. By doing so, the driver doesn't have to wait
for data so that asynchronous execution becomes possible.
for data so that asynchronous execution becomes possible.
DataProtoFuture contains a list of futures from another WorkerGroup of size world_size.
- collect_fn is a Callable that reduces the list of futures to a DataProto
- dispatch_fn is a Callable that partitions the DataProto into a list of DataProto of size world_size and then select
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment