2
2
import os
3
3
import signal
4
4
import traceback
5
+ import types
5
6
from abc import ABC , abstractmethod
6
7
from enum import Enum
7
8
from multiprocessing import Process , Queue
@@ -166,7 +167,7 @@ def __init__(
166
167
start_index : int ,
167
168
dataset_name : str ,
168
169
node_rank : int ,
169
- prepare_item : Callable ,
170
+ dataset_optimizer : "DatasetOptimizer" ,
170
171
src_dir : str ,
171
172
remote_src_dir : str ,
172
173
remote_dst_dir : Optional [str ],
@@ -185,7 +186,7 @@ def __init__(
185
186
self .start_index = start_index
186
187
self .dataset_name = dataset_name
187
188
self .node_rank = node_rank
188
- self .prepare_item = prepare_item
189
+ self .prepare_item = dataset_optimizer . prepare_item
189
190
self .src_dir = src_dir
190
191
self .remote_src_dir = remote_src_dir
191
192
self .remote_dst_dir = remote_dst_dir
@@ -207,6 +208,7 @@ def __init__(
207
208
self .uploader : Optional [Process ] = None
208
209
self ._collected_items = 0
209
210
self ._counter = 0
211
+ self ._index_counter = 0
210
212
211
213
def run (self ) -> None :
212
214
try :
@@ -250,14 +252,17 @@ def _loop(self) -> None:
250
252
return
251
253
continue
252
254
253
- item_index = index + self .start_index
254
- item_data = self .prepare_item (self .items [index ]) if self .prepare_item else self .items [index ] # type: ignore
255
- chunk_filepath = self .cache ._add_item (item_index , item_data )
256
-
257
- self ._try_upload (chunk_filepath )
255
+ item_data_or_generator = self .prepare_item (self .items [index ]) if self .prepare_item else self .items [index ] # type: ignore
256
+ if isinstance (item_data_or_generator , types .GeneratorType ):
257
+ for item_data in item_data_or_generator :
258
+ chunk_filepath = self .cache ._add_item (self ._index_counter , item_data )
259
+ self ._try_upload (chunk_filepath )
260
+ self ._index_counter += 1
261
+ else :
262
+ chunk_filepath = self .cache ._add_item (index + self .start_index , item_data_or_generator )
263
+ self ._try_upload (chunk_filepath )
258
264
259
265
self ._counter += 1
260
-
261
266
if self .progress_queue :
262
267
self .progress_queue .put ((self .worker_index , self ._counter ))
263
268
@@ -623,7 +628,7 @@ def _create_thread_workers(self, begins: List[int], workers_user_items: List[Lis
623
628
begins [worker_idx ],
624
629
self .name ,
625
630
_get_node_rank (),
626
- self . prepare_item ,
631
+ self ,
627
632
self .src_dir ,
628
633
self .remote_src_dir ,
629
634
self .remote_dst_dir ,
@@ -632,7 +637,9 @@ def _create_thread_workers(self, begins: List[int], workers_user_items: List[Lis
632
637
self .error_queue ,
633
638
self .num_downloaders ,
634
639
self .delete_cached_files ,
635
- 2 if self .fast_dev_run else self .chunk_size , # In dev run, create chunks with 2 items
640
+ (self .chunk_size if self .chunk_size else 2 )
641
+ if self .fast_dev_run
642
+ else self .chunk_size , # In dev run, create chunks with 2 items
636
643
None if self .fast_dev_run else self .chunk_bytes ,
637
644
self .compression ,
638
645
)
@@ -657,7 +664,7 @@ def _create_process_workers(self, begins: List[int], workers_user_items: List[Li
657
664
begins [worker_idx ],
658
665
self .name ,
659
666
_get_node_rank (),
660
- self . prepare_item ,
667
+ self ,
661
668
self .src_dir ,
662
669
self .remote_src_dir ,
663
670
self .remote_dst_dir ,
@@ -666,7 +673,9 @@ def _create_process_workers(self, begins: List[int], workers_user_items: List[Li
666
673
self .error_queue ,
667
674
self .num_downloaders ,
668
675
self .delete_cached_files ,
669
- 2 if self .fast_dev_run else self .chunk_size , # In dev run, create chunks with 2 items
676
+ (self .chunk_size if self .chunk_size else 2 )
677
+ if self .fast_dev_run
678
+ else self .chunk_size , # In dev run, create chunks with 2 items
670
679
None if self .fast_dev_run else self .chunk_bytes ,
671
680
self .compression ,
672
681
)
0 commit comments