Unverified Commit 44a65f95 by Jiawei Liu Committed by GitHub

fix: slicing returns DataProto not DataProtoItem (#718)

parent db1d3251
...@@ -197,7 +197,8 @@ class DataProto: ...@@ -197,7 +197,8 @@ class DataProto:
def __getitem__(self, item): def __getitem__(self, item):
tensor_data = self.batch[item] tensor_data = self.batch[item]
non_tensor_data = {key: val[item] for key, val in self.non_tensor_batch.items()} non_tensor_data = {key: val[item] for key, val in self.non_tensor_batch.items()}
return DataProtoItem(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info) return_type = DataProto if isinstance(item, slice) else DataProtoItem
return return_type(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info)
def __getstate__(self): def __getstate__(self):
import io import io
...@@ -603,7 +604,7 @@ import ray ...@@ -603,7 +604,7 @@ import ray
class DataProtoFuture: class DataProtoFuture:
""" """
DataProtoFuture aims to eliminate actual data fetching on driver. By doing so, the driver doesn't have to wait DataProtoFuture aims to eliminate actual data fetching on driver. By doing so, the driver doesn't have to wait
for data so that asynchronous execution becomes possible. for data so that asynchronous execution becomes possible.
DataProtoFuture contains a list of futures from another WorkerGroup of size world_size. DataProtoFuture contains a list of futures from another WorkerGroup of size world_size.
- collect_fn is a Callable that reduces the list of futures to a DataProto - collect_fn is a Callable that reduces the list of futures to a DataProto
- dispatch_fn is a Callable that partitions the DataProto into a list of DataProto of size world_size and then select - dispatch_fn is a Callable that partitions the DataProto into a list of DataProto of size world_size and then select
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment