Skip to content

Commit 39b8569

Browse files
authored
parutils makeover remove async_scoped (#454)
removing async_scoped dep and creating new parallel util to replace tokio_par_for_each that relied on async_scoped. Any usage of tokio_par_for_each which is the only fn used out of parutils has been replaced with the new `tokio_run_max_concurrency_fold_result_with_semaphore` TODO: - [x] add more tests - [x] use semaphore acquired from the global semaphore provider where/if relevant.
1 parent 48be7b0 commit 39b8569

File tree

16 files changed

+344
-493
lines changed

16 files changed

+344
-493
lines changed

Cargo.lock

Lines changed: 0 additions & 24 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ members = [
1212
"file_utils",
1313
"mdb_shard",
1414
"merklehash",
15-
"parutils",
1615
"progress_tracking",
1716
"utils",
1817
"xet_threadpool",
@@ -32,7 +31,6 @@ debug = 1
3231

3332
[workspace.dependencies]
3433
anyhow = "1"
35-
async-scoped = { version = "0.7", features = ["use-tokio"] }
3634
async-trait = "0.1"
3735
base64 = "0.22"
3836
bincode = "1.3"

README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ xet-core enables huggingface_hub to utilize xet storage for uploading and downlo
5252
* [hf_xet](./hf_xet): Python integration with Rust code, uses maturin to build hfxet Python package. Main integration with HF Hub Python package.
5353
* [mdb_shard](./mdb_shard): Shard operations, including Shard format, dedupe probing, benchmarks, and utilities.
5454
* [merklehash](./merklehash): MerkleHash type, 256-bit hash, widely used across many crates.
55-
* [parutils](./parutils): Provides parallel execution utilities relying on Tokio (ex. parallel foreach).
5655
* [progress_reporting](./progress_reporting): offers ReportedWriter so progress for Writer operations can be displayed.
5756
* [utils](./utils): general utilities, including singleflight, progress, serialization_utils and threadpool.
5857

cas_client/src/remote_client.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ utils::configurable_constants! {
5454
}
5555

5656
lazy_static! {
57-
static ref DOWNLOAD_CONCURRENCY_LIMITER: GlobalSemaphoreHandle =
57+
static ref DOWNLOAD_CHUNK_RANGE_CONCURRENCY_LIMITER: GlobalSemaphoreHandle =
5858
global_semaphore_handle!(*NUM_CONCURRENT_RANGE_GETS);
5959
}
6060

@@ -329,7 +329,8 @@ impl RemoteClient {
329329
let download_scheduler = DownloadSegmentLengthTuner::from_configurable_constants();
330330
let download_scheduler_clone = download_scheduler.clone();
331331

332-
let download_concurrency_limiter = ThreadPool::current().global_semaphore(*DOWNLOAD_CONCURRENCY_LIMITER);
332+
let download_concurrency_limiter =
333+
ThreadPool::current().global_semaphore(*DOWNLOAD_CHUNK_RANGE_CONCURRENCY_LIMITER);
333334

334335
let queue_dispatcher: JoinHandle<Result<()>> = tokio::spawn(async move {
335336
let mut remaining_total_len = total_len;
@@ -479,7 +480,8 @@ impl RemoteClient {
479480
let term_download_client = self.http_client_with_retry.clone();
480481
let download_scheduler = DownloadSegmentLengthTuner::from_configurable_constants();
481482

482-
let download_concurrency_limiter = ThreadPool::current().global_semaphore(*DOWNLOAD_CONCURRENCY_LIMITER);
483+
let download_concurrency_limiter =
484+
ThreadPool::current().global_semaphore(*DOWNLOAD_CHUNK_RANGE_CONCURRENCY_LIMITER);
483485

484486
let process_result = move |result: TermDownloadResult<u64>,
485487
total_written: &mut u64,

data/Cargo.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ deduplication = { path = "../deduplication" }
2222
error_printer = { path = "../error_printer" }
2323
mdb_shard = { path = "../mdb_shard" }
2424
merklehash = { path = "../merklehash" }
25-
parutils = { path = "../parutils" }
2625
progress_tracking = { path = "../progress_tracking" }
2726
utils = { path = "../utils" }
2827
xet_threadpool = { path = "../xet_threadpool" }

data/src/data_client.rs

Lines changed: 27 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,19 @@ use cas_client::{CacheConfig, FileProvider, OutputProvider, CHUNK_CACHE_SIZE_BYT
1010
use cas_object::CompressionScheme;
1111
use deduplication::DeduplicationMetrics;
1212
use dirs::home_dir;
13-
use parutils::{tokio_par_for_each, ParallelError};
1413
use progress_tracking::item_tracking::ItemProgressUpdater;
1514
use progress_tracking::TrackingProgressUpdater;
1615
use tracing::{info, info_span, instrument, Instrument, Span};
1716
use ulid::Ulid;
1817
use utils::auth::{AuthConfig, TokenRefresher};
1918
use utils::normalized_path_from_user_string;
19+
use xet_threadpool::utils::run_constrained_with_semaphore;
20+
use xet_threadpool::{global_semaphore_handle, GlobalSemaphoreHandle, ThreadPool};
2021

2122
use crate::configurations::*;
22-
use crate::constants::{INGESTION_BLOCK_SIZE, MAX_CONCURRENT_DOWNLOADS, MAX_CONCURRENT_FILE_INGESTION};
23+
use crate::constants::{INGESTION_BLOCK_SIZE, MAX_CONCURRENT_DOWNLOADS};
2324
use crate::errors::DataProcessingError;
25+
use crate::file_upload_session::CONCURRENT_FILE_INGESTION_LIMITER;
2426
use crate::{errors, FileDownloader, FileUploadSession, XetFileInfo};
2527

2628
utils::configurable_constants! {
@@ -125,30 +127,22 @@ pub async fn upload_bytes_async(
125127
let config = default_config(endpoint.unwrap_or(DEFAULT_CAS_ENDPOINT.clone()), None, token_info, token_refresher)?;
126128
Span::current().record("session_id", &config.session_id);
127129

130+
let semaphore = ThreadPool::current().global_semaphore(*CONCURRENT_FILE_INGESTION_LIMITER);
128131
let upload_session = FileUploadSession::new(config, progress_updater).await?;
129-
let blobs_with_spans = add_spans(file_contents, || info_span!("clean_task"));
130-
131-
// clean the bytes
132-
let files = tokio_par_for_each(blobs_with_spans, *MAX_CONCURRENT_FILE_INGESTION, |(blob, span), _| {
133-
async {
134-
let (xf, _metrics) = clean_bytes(upload_session.clone(), blob).await?;
135-
Ok(xf)
136-
}
137-
.instrument(span.unwrap_or_else(|| info_span!("unexpected_span")))
138-
})
139-
.await
140-
.map_err(|e| match e {
141-
ParallelError::JoinError => DataProcessingError::InternalError("Join error".to_string()),
142-
ParallelError::TaskError(e) => e,
143-
})?;
132+
let clean_futures = file_contents.into_iter().map(|blob| {
133+
let upload_session = upload_session.clone();
134+
async move { clean_bytes(upload_session, blob).await.map(|(xf, _metrics)| xf) }
135+
.instrument(info_span!("clean_task"))
136+
});
137+
let files = run_constrained_with_semaphore(clean_futures, semaphore).await?;
144138

145139
// Push the CAS blocks and flush the mdb to disk
146140
let _metrics = upload_session.finalize().await?;
147141

148142
Ok(files)
149143
}
150144

151-
#[instrument(skip_all, name = "data_client::upload_files",
145+
#[instrument(skip_all, name = "data_client::upload_files",
152146
fields(session_id = tracing::field::Empty,
153147
num_files=file_paths.len(),
154148
new_bytes = tracing::field::Empty,
@@ -157,7 +151,7 @@ pub async fn upload_bytes_async(
157151
new_chunks = tracing::field::Empty,
158152
deduped_chunks = tracing::field::Empty,
159153
defrag_prevented_dedup_chunks = tracing::field::Empty
160-
))]
154+
))]
161155
pub async fn upload_async(
162156
file_paths: Vec<String>,
163157
endpoint: Option<String>,
@@ -201,6 +195,11 @@ pub async fn download_async(
201195
token_refresher: Option<Arc<dyn TokenRefresher>>,
202196
progress_updaters: Option<Vec<Arc<dyn TrackingProgressUpdater>>>,
203197
) -> errors::Result<Vec<String>> {
198+
lazy_static! {
199+
static ref CONCURRENT_FILE_DOWNLOAD_LIMITER: GlobalSemaphoreHandle =
200+
global_semaphore_handle!(*MAX_CONCURRENT_DOWNLOADS);
201+
}
202+
204203
if let Some(updaters) = &progress_updaters {
205204
if updaters.len() != file_infos.len() {
206205
return Err(DataProcessingError::ParameterError(
@@ -212,30 +211,19 @@ pub async fn download_async(
212211
default_config(endpoint.unwrap_or(DEFAULT_CAS_ENDPOINT.to_string()), None, token_info, token_refresher)?;
213212
Span::current().record("session_id", &config.session_id);
214213

214+
let processor = Arc::new(FileDownloader::new(config).await?);
215215
let updaters = match progress_updaters {
216216
None => vec![None; file_infos.len()],
217217
Some(updaters) => updaters.into_iter().map(Some).collect(),
218218
};
219-
let file_with_progress = file_infos.into_iter().zip(updaters).collect::<Vec<_>>();
220-
let extended_file_info_list = add_spans(file_with_progress, || info_span!("download_file"));
221-
222-
let processor = &Arc::new(FileDownloader::new(config).await?);
223-
let paths = tokio_par_for_each(
224-
extended_file_info_list,
225-
*MAX_CONCURRENT_DOWNLOADS,
226-
|(((file_info, file_path), updater), span), _| {
227-
async move {
228-
let proc = processor.clone();
229-
smudge_file(&proc, &file_info, &file_path, updater).await
230-
}
231-
.instrument(span.unwrap_or_else(|| info_span!("unexpected_span")))
232-
},
233-
)
234-
.await
235-
.map_err(|e| match e {
236-
ParallelError::JoinError => DataProcessingError::InternalError("Join error".to_string()),
237-
ParallelError::TaskError(e) => e,
238-
})?;
219+
let smudge_file_futures = file_infos.into_iter().zip(updaters).map(|((file_info, file_path), updater)| {
220+
let proc = processor.clone();
221+
async move { smudge_file(&proc, &file_info, &file_path, updater).await }.instrument(info_span!("download_file"))
222+
});
223+
224+
let semaphore = ThreadPool::current().global_semaphore(*CONCURRENT_FILE_DOWNLOAD_LIMITER);
225+
226+
let paths = run_constrained_with_semaphore(smudge_file_futures, semaphore).await?;
239227

240228
Ok(paths)
241229
}
@@ -298,23 +286,12 @@ async fn smudge_file(
298286
Ok(file_path.to_string())
299287
}
300288

301-
/// Adds spans to the indicated list for each element.
302-
///
303-
/// Although a span will be added for each element, we need an Option<Span> since
304-
/// tokio_par_for_each requires the input list be Default, which Span isn't.
305-
pub fn add_spans<I, F: Fn() -> Span>(v: Vec<I>, create_span: F) -> Vec<(I, Option<Span>)> {
306-
let spans: Vec<Option<Span>> = v.iter().map(|_| Some(create_span())).collect();
307-
v.into_iter().zip(spans).collect()
308-
}
309-
310289
#[cfg(test)]
311290
mod tests {
312291
use std::env;
313292

314293
use serial_test::serial;
315294
use tempfile::tempdir;
316-
use tracing::info;
317-
use tracing_test::traced_test;
318295

319296
use super::*;
320297

@@ -402,41 +379,4 @@ mod tests {
402379
"cache dir = {test_cache_dir:?}; does not start with {expected:?}",
403380
);
404381
}
405-
406-
#[tokio::test(flavor = "multi_thread")]
407-
#[traced_test]
408-
async fn test_add_spans() {
409-
let outer_span = info_span!("outer_span");
410-
async {
411-
let v = vec!["a", "b", "c"];
412-
let expected_len = v.len();
413-
let v_plus = add_spans(v, || info_span!("task_span"));
414-
assert_eq!(v_plus.len(), expected_len);
415-
tokio_par_for_each(v_plus, expected_len, |(s, span), i| {
416-
async move {
417-
info!("inside: {s},{i}");
418-
Ok::<(), ()>(())
419-
}
420-
.instrument(span.unwrap())
421-
})
422-
.await
423-
.unwrap();
424-
}
425-
.instrument(outer_span)
426-
.await;
427-
428-
assert!(logs_contain("inside: a,0"));
429-
assert!(logs_contain("inside: b,1"));
430-
assert!(logs_contain("inside: c,2"));
431-
logs_assert(|lines: &[&str]| {
432-
match lines
433-
.iter()
434-
.filter(|line| line.contains("task_span") && line.contains("outer_span"))
435-
.count()
436-
{
437-
3 => Ok(()),
438-
n => Err(format!("Expected 3 lines, got {n}")),
439-
}
440-
});
441-
}
442382
}

data/src/errors.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use thiserror::Error;
88
use tokio::sync::AcquireError;
99
use tracing::error;
1010
use utils::errors::{AuthError, SingleflightError};
11+
use xet_threadpool::utils::ParutilsError;
1112

1213
#[derive(Error, Debug)]
1314
pub enum DataProcessingError {
@@ -91,3 +92,14 @@ impl From<SingleflightError<DataProcessingError>> for DataProcessingError {
9192
}
9293
}
9394
}
95+
96+
impl From<ParutilsError<DataProcessingError>> for DataProcessingError {
97+
fn from(value: ParutilsError<DataProcessingError>) -> Self {
98+
match value {
99+
ParutilsError::Join(e) => DataProcessingError::JoinError(e),
100+
ParutilsError::Acquire(e) => DataProcessingError::PermitAcquisitionError(e),
101+
ParutilsError::Task(e) => e,
102+
e => DataProcessingError::InternalError(e.to_string()),
103+
}
104+
}
105+
}

0 commit comments

Comments
 (0)