@@ -10,17 +10,19 @@ use cas_client::{CacheConfig, FileProvider, OutputProvider, CHUNK_CACHE_SIZE_BYT
10
10
use cas_object:: CompressionScheme ;
11
11
use deduplication:: DeduplicationMetrics ;
12
12
use dirs:: home_dir;
13
- use parutils:: { tokio_par_for_each, ParallelError } ;
14
13
use progress_tracking:: item_tracking:: ItemProgressUpdater ;
15
14
use progress_tracking:: TrackingProgressUpdater ;
16
15
use tracing:: { info, info_span, instrument, Instrument , Span } ;
17
16
use ulid:: Ulid ;
18
17
use utils:: auth:: { AuthConfig , TokenRefresher } ;
19
18
use utils:: normalized_path_from_user_string;
19
+ use xet_threadpool:: utils:: run_constrained_with_semaphore;
20
+ use xet_threadpool:: { global_semaphore_handle, GlobalSemaphoreHandle , ThreadPool } ;
20
21
21
22
use crate :: configurations:: * ;
22
- use crate :: constants:: { INGESTION_BLOCK_SIZE , MAX_CONCURRENT_DOWNLOADS , MAX_CONCURRENT_FILE_INGESTION } ;
23
+ use crate :: constants:: { INGESTION_BLOCK_SIZE , MAX_CONCURRENT_DOWNLOADS } ;
23
24
use crate :: errors:: DataProcessingError ;
25
+ use crate :: file_upload_session:: CONCURRENT_FILE_INGESTION_LIMITER ;
24
26
use crate :: { errors, FileDownloader , FileUploadSession , XetFileInfo } ;
25
27
26
28
utils:: configurable_constants! {
@@ -125,30 +127,22 @@ pub async fn upload_bytes_async(
125
127
let config = default_config ( endpoint. unwrap_or ( DEFAULT_CAS_ENDPOINT . clone ( ) ) , None , token_info, token_refresher) ?;
126
128
Span :: current ( ) . record ( "session_id" , & config. session_id ) ;
127
129
130
+ let semaphore = ThreadPool :: current ( ) . global_semaphore ( * CONCURRENT_FILE_INGESTION_LIMITER ) ;
128
131
let upload_session = FileUploadSession :: new ( config, progress_updater) . await ?;
129
- let blobs_with_spans = add_spans ( file_contents, || info_span ! ( "clean_task" ) ) ;
130
-
131
- // clean the bytes
132
- let files = tokio_par_for_each ( blobs_with_spans, * MAX_CONCURRENT_FILE_INGESTION , |( blob, span) , _| {
133
- async {
134
- let ( xf, _metrics) = clean_bytes ( upload_session. clone ( ) , blob) . await ?;
135
- Ok ( xf)
136
- }
137
- . instrument ( span. unwrap_or_else ( || info_span ! ( "unexpected_span" ) ) )
138
- } )
139
- . await
140
- . map_err ( |e| match e {
141
- ParallelError :: JoinError => DataProcessingError :: InternalError ( "Join error" . to_string ( ) ) ,
142
- ParallelError :: TaskError ( e) => e,
143
- } ) ?;
132
+ let clean_futures = file_contents. into_iter ( ) . map ( |blob| {
133
+ let upload_session = upload_session. clone ( ) ;
134
+ async move { clean_bytes ( upload_session, blob) . await . map ( |( xf, _metrics) | xf) }
135
+ . instrument ( info_span ! ( "clean_task" ) )
136
+ } ) ;
137
+ let files = run_constrained_with_semaphore ( clean_futures, semaphore) . await ?;
144
138
145
139
// Push the CAS blocks and flush the mdb to disk
146
140
let _metrics = upload_session. finalize ( ) . await ?;
147
141
148
142
Ok ( files)
149
143
}
150
144
151
- #[ instrument( skip_all, name = "data_client::upload_files" ,
145
+ #[ instrument( skip_all, name = "data_client::upload_files" ,
152
146
fields( session_id = tracing:: field:: Empty ,
153
147
num_files=file_paths. len( ) ,
154
148
new_bytes = tracing:: field:: Empty ,
@@ -157,7 +151,7 @@ pub async fn upload_bytes_async(
157
151
new_chunks = tracing:: field:: Empty ,
158
152
deduped_chunks = tracing:: field:: Empty ,
159
153
defrag_prevented_dedup_chunks = tracing:: field:: Empty
160
- ) ) ]
154
+ ) ) ]
161
155
pub async fn upload_async (
162
156
file_paths : Vec < String > ,
163
157
endpoint : Option < String > ,
@@ -201,6 +195,11 @@ pub async fn download_async(
201
195
token_refresher : Option < Arc < dyn TokenRefresher > > ,
202
196
progress_updaters : Option < Vec < Arc < dyn TrackingProgressUpdater > > > ,
203
197
) -> errors:: Result < Vec < String > > {
198
+ lazy_static ! {
199
+ static ref CONCURRENT_FILE_DOWNLOAD_LIMITER : GlobalSemaphoreHandle =
200
+ global_semaphore_handle!( * MAX_CONCURRENT_DOWNLOADS ) ;
201
+ }
202
+
204
203
if let Some ( updaters) = & progress_updaters {
205
204
if updaters. len ( ) != file_infos. len ( ) {
206
205
return Err ( DataProcessingError :: ParameterError (
@@ -212,30 +211,19 @@ pub async fn download_async(
212
211
default_config ( endpoint. unwrap_or ( DEFAULT_CAS_ENDPOINT . to_string ( ) ) , None , token_info, token_refresher) ?;
213
212
Span :: current ( ) . record ( "session_id" , & config. session_id ) ;
214
213
214
+ let processor = Arc :: new ( FileDownloader :: new ( config) . await ?) ;
215
215
let updaters = match progress_updaters {
216
216
None => vec ! [ None ; file_infos. len( ) ] ,
217
217
Some ( updaters) => updaters. into_iter ( ) . map ( Some ) . collect ( ) ,
218
218
} ;
219
- let file_with_progress = file_infos. into_iter ( ) . zip ( updaters) . collect :: < Vec < _ > > ( ) ;
220
- let extended_file_info_list = add_spans ( file_with_progress, || info_span ! ( "download_file" ) ) ;
221
-
222
- let processor = & Arc :: new ( FileDownloader :: new ( config) . await ?) ;
223
- let paths = tokio_par_for_each (
224
- extended_file_info_list,
225
- * MAX_CONCURRENT_DOWNLOADS ,
226
- |( ( ( file_info, file_path) , updater) , span) , _| {
227
- async move {
228
- let proc = processor. clone ( ) ;
229
- smudge_file ( & proc, & file_info, & file_path, updater) . await
230
- }
231
- . instrument ( span. unwrap_or_else ( || info_span ! ( "unexpected_span" ) ) )
232
- } ,
233
- )
234
- . await
235
- . map_err ( |e| match e {
236
- ParallelError :: JoinError => DataProcessingError :: InternalError ( "Join error" . to_string ( ) ) ,
237
- ParallelError :: TaskError ( e) => e,
238
- } ) ?;
219
+ let smudge_file_futures = file_infos. into_iter ( ) . zip ( updaters) . map ( |( ( file_info, file_path) , updater) | {
220
+ let proc = processor. clone ( ) ;
221
+ async move { smudge_file ( & proc, & file_info, & file_path, updater) . await } . instrument ( info_span ! ( "download_file" ) )
222
+ } ) ;
223
+
224
+ let semaphore = ThreadPool :: current ( ) . global_semaphore ( * CONCURRENT_FILE_DOWNLOAD_LIMITER ) ;
225
+
226
+ let paths = run_constrained_with_semaphore ( smudge_file_futures, semaphore) . await ?;
239
227
240
228
Ok ( paths)
241
229
}
@@ -298,23 +286,12 @@ async fn smudge_file(
298
286
Ok ( file_path. to_string ( ) )
299
287
}
300
288
301
- /// Adds spans to the indicated list for each element.
302
- ///
303
- /// Although a span will be added for each element, we need an Option<Span> since
304
- /// tokio_par_for_each requires the input list be Default, which Span isn't.
305
- pub fn add_spans < I , F : Fn ( ) -> Span > ( v : Vec < I > , create_span : F ) -> Vec < ( I , Option < Span > ) > {
306
- let spans: Vec < Option < Span > > = v. iter ( ) . map ( |_| Some ( create_span ( ) ) ) . collect ( ) ;
307
- v. into_iter ( ) . zip ( spans) . collect ( )
308
- }
309
-
310
289
#[ cfg( test) ]
311
290
mod tests {
312
291
use std:: env;
313
292
314
293
use serial_test:: serial;
315
294
use tempfile:: tempdir;
316
- use tracing:: info;
317
- use tracing_test:: traced_test;
318
295
319
296
use super :: * ;
320
297
@@ -402,41 +379,4 @@ mod tests {
402
379
"cache dir = {test_cache_dir:?}; does not start with {expected:?}" ,
403
380
) ;
404
381
}
405
-
406
- #[ tokio:: test( flavor = "multi_thread" ) ]
407
- #[ traced_test]
408
- async fn test_add_spans ( ) {
409
- let outer_span = info_span ! ( "outer_span" ) ;
410
- async {
411
- let v = vec ! [ "a" , "b" , "c" ] ;
412
- let expected_len = v. len ( ) ;
413
- let v_plus = add_spans ( v, || info_span ! ( "task_span" ) ) ;
414
- assert_eq ! ( v_plus. len( ) , expected_len) ;
415
- tokio_par_for_each ( v_plus, expected_len, |( s, span) , i| {
416
- async move {
417
- info ! ( "inside: {s},{i}" ) ;
418
- Ok :: < ( ) , ( ) > ( ( ) )
419
- }
420
- . instrument ( span. unwrap ( ) )
421
- } )
422
- . await
423
- . unwrap ( ) ;
424
- }
425
- . instrument ( outer_span)
426
- . await ;
427
-
428
- assert ! ( logs_contain( "inside: a,0" ) ) ;
429
- assert ! ( logs_contain( "inside: b,1" ) ) ;
430
- assert ! ( logs_contain( "inside: c,2" ) ) ;
431
- logs_assert ( |lines : & [ & str ] | {
432
- match lines
433
- . iter ( )
434
- . filter ( |line| line. contains ( "task_span" ) && line. contains ( "outer_span" ) )
435
- . count ( )
436
- {
437
- 3 => Ok ( ( ) ) ,
438
- n => Err ( format ! ( "Expected 3 lines, got {n}" ) ) ,
439
- }
440
- } ) ;
441
- }
442
382
}
0 commit comments