Skip to content

Commit 3865e94

Browse files
authored
run_and_extract_custom: remove use of explicit tokio_retry without utility (#460)
we had 1 case of using "raw" tokio_retry rather than the retry utility. This was due to using a special custom parsing logic for chunks, rather than built in json functionality. This PR adds a run_and_extract custom that let's a user specify the function to parse the response body.
1 parent df1145f commit 3865e94

File tree

4 files changed

+131
-101
lines changed

4 files changed

+131
-101
lines changed

cas_client/src/download_utils.rs

Lines changed: 71 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,17 @@ use futures::TryStreamExt;
1313
use http::header::RANGE;
1414
use http::StatusCode;
1515
use merklehash::MerkleHash;
16+
use reqwest::Response;
1617
use reqwest_middleware::ClientWithMiddleware;
17-
use tokio_retry::strategy::ExponentialBackoff;
1818
use tracing::{debug, error, info, trace, warn};
1919
use url::Url;
2020
use utils::singleflight::Group;
2121

2222
use crate::error::{CasClientError, Result};
23-
use crate::http_client::{Api, BASE_RETRY_DELAY_MS, BASE_RETRY_MAX_DURATION_MS, NUM_RETRIES};
23+
use crate::http_client::Api;
2424
use crate::output_provider::OutputProvider;
2525
use crate::remote_client::{get_reconstruction_with_endpoint_and_client, PREFIX_DEFAULT};
26+
use crate::retry_wrapper::{RetryWrapper, RetryableReqwestError};
2627

2728
utils::configurable_constants! {
2829
// Env (HF_XET_NUM_RANGE_IN_SEGMENT_BASE) base value for the approx number of ranges in the initial
@@ -479,28 +480,6 @@ pub(crate) async fn get_one_fetch_term_data(
479480
Ok(term_download_output)
480481
}
481482

482-
struct ChunkRangeDeserializeFromBytesStreamRetryCondition;
483-
484-
impl tokio_retry::Condition<CasClientError> for ChunkRangeDeserializeFromBytesStreamRetryCondition {
485-
fn should_retry(&mut self, err: &CasClientError) -> bool {
486-
// we only care about retrying some error yielded by trying to deserialize the stream
487-
let CasClientError::CasObjectError(CasObjectError::InternalIOError(cas_object_io_err)) = err else {
488-
return false;
489-
};
490-
let Some(inner) = cas_object_io_err.get_ref() else {
491-
return false;
492-
};
493-
let Some(inner_reqwest_err) = inner.downcast_ref::<reqwest::Error>() else {
494-
return false;
495-
};
496-
// errors that indicate reading the body failed
497-
inner_reqwest_err.is_body()
498-
|| inner_reqwest_err.is_decode()
499-
|| inner_reqwest_err.is_timeout()
500-
|| inner_reqwest_err.is_request()
501-
}
502-
}
503-
504483
/// use the provided http_client to make requests to S3/blob store using the url and url_range
505484
/// parts of a CASReconstructionFetchInfo. The url_range part is used directly in a http Range header
506485
/// value (see fn `range_header`).
@@ -511,61 +490,83 @@ async fn download_fetch_term_data(
511490
) -> Result<DownloadRangeResult> {
512491
trace!("{hash},{},{}", fetch_term.range.start, fetch_term.range.end);
513492

493+
let api_tag = "s3::get_range";
514494
let url = Url::parse(fetch_term.url.as_str())?;
515495

516-
tokio_retry::RetryIf::spawn(
517-
ExponentialBackoff::from_millis(BASE_RETRY_DELAY_MS)
518-
.max_delay(Duration::from_millis(BASE_RETRY_MAX_DURATION_MS))
519-
.take(NUM_RETRIES as usize),
520-
|| async {
521-
let response = match http_client
522-
.get(url.clone())
523-
.header(RANGE, fetch_term.url_range.range_header())
524-
.with_extension(Api("s3::get_range"))
525-
.send()
526-
.await
527-
.map_err(CasClientError::from)
528-
.log_error("error downloading range")?
529-
.error_for_status()
530-
{
531-
Ok(response) => response,
532-
Err(e) => return match e.status() {
533-
Some(StatusCode::FORBIDDEN) => {
534-
info!("error code {} for hash {hash}, will re-fetch reconstruction", StatusCode::FORBIDDEN,);
535-
Ok(DownloadRangeResult::Forbidden)
536-
},
537-
_ => Err(e.into()),
538-
}
539-
.log_error("error code"),
540-
};
496+
// helper to convert a CasObjectError to RetryableReqwestError
497+
// only retryable if the error originates from an error from the byte stream from reqwest
498+
let parse_map_err = |err: CasObjectError| {
499+
let CasObjectError::InternalIOError(cas_object_io_err) = &err else {
500+
return RetryableReqwestError::FatalError(CasClientError::CasObjectError(err));
501+
};
502+
let Some(inner) = cas_object_io_err.get_ref() else {
503+
return RetryableReqwestError::FatalError(CasClientError::CasObjectError(err));
504+
};
505+
// attempt to cast into the reqwest error wrapped by std::io::Error::other
506+
let Some(inner_reqwest_err) = inner.downcast_ref::<reqwest::Error>() else {
507+
return RetryableReqwestError::FatalError(CasClientError::CasObjectError(err));
508+
};
509+
// errors that indicate reading the body failed
510+
if inner_reqwest_err.is_body()
511+
|| inner_reqwest_err.is_decode()
512+
|| inner_reqwest_err.is_timeout()
513+
|| inner_reqwest_err.is_request()
514+
{
515+
RetryableReqwestError::RetryableError(CasClientError::CasObjectError(err))
516+
} else {
517+
RetryableReqwestError::FatalError(CasClientError::CasObjectError(err))
518+
}
519+
};
541520

542-
if let Some(content_length) = response.content_length() {
543-
let expected_len = fetch_term.url_range.length();
544-
if content_length != expected_len {
545-
error!("got back a smaller byte range ({content_length}) than requested ({expected_len}) from s3");
546-
return Err(CasClientError::InvalidRange);
547-
}
521+
let parse = move |response: Response| async move {
522+
if let Some(content_length) = response.content_length() {
523+
let expected_len = fetch_term.url_range.length();
524+
if content_length != expected_len {
525+
error!("got back a smaller byte range ({content_length}) than requested ({expected_len}) from s3");
526+
return Err(RetryableReqwestError::FatalError(CasClientError::InvalidRange));
548527
}
528+
}
549529

550-
let (data, chunk_byte_indices) = cas_object::deserialize_async::deserialize_chunks_from_stream(
551-
response.bytes_stream().map_err(std::io::Error::other),
552-
)
553-
.await?;
554-
Ok(DownloadRangeResult::Data(TermDownloadOutput {
555-
data,
556-
chunk_byte_indices,
557-
chunk_range: fetch_term.range,
558-
}))
559-
},
560-
ChunkRangeDeserializeFromBytesStreamRetryCondition,
561-
)
562-
.await
530+
let (data, chunk_byte_indices) = cas_object::deserialize_async::deserialize_chunks_from_stream(
531+
response.bytes_stream().map_err(std::io::Error::other),
532+
)
533+
.await
534+
.map_err(parse_map_err)?;
535+
Ok(DownloadRangeResult::Data(TermDownloadOutput {
536+
data,
537+
chunk_byte_indices,
538+
chunk_range: fetch_term.range,
539+
}))
540+
};
541+
542+
let result = RetryWrapper::new(api_tag)
543+
.run_and_extract_custom(
544+
move || {
545+
http_client
546+
.get(url.clone())
547+
.header(RANGE, fetch_term.url_range.range_header())
548+
.with_extension(Api(api_tag))
549+
.send()
550+
},
551+
parse,
552+
)
553+
.await;
554+
// in case the error was a 403 Forbidden status code, we raise it up as the special DownloadRangeResult::Forbidden
555+
// variant so the fetch info is refetched
556+
if result
557+
.as_ref()
558+
.is_err_and(|e| e.status().is_some_and(|status| status == StatusCode::FORBIDDEN))
559+
{
560+
return Ok(DownloadRangeResult::Forbidden);
561+
}
562+
result
563563
}
564564

565565
#[cfg(test)]
566566
mod tests {
567567
use anyhow::Result;
568568
use cas_types::{HttpRange, QueryReconstructionResponse};
569+
use http::header::RANGE;
569570
use httpmock::prelude::*;
570571
use tokio::task::JoinSet;
571572
use tokio::time::sleep;
@@ -769,8 +770,8 @@ mod tests {
769770
// download task will not return if keep hitting 403
770771
handle.abort();
771772

772-
assert!(mock_fi.hits() >= 2);
773-
assert!(mock_data.hits() >= 2);
773+
assert!(mock_fi.hits() >= 2, "assertion failed: mock_fi.hits() {} >= 2", mock_fi.hits());
774+
assert!(mock_data.hits() >= 2, "assertion failed: mock_data.hits() {} >= 2", mock_data.hits());
774775

775776
Ok(())
776777
}

cas_client/src/error.rs

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use std::fmt::Debug;
22
use std::num::TryFromIntError;
33

44
use anyhow::anyhow;
5+
use http::StatusCode;
56
use merklehash::MerkleHash;
67
use thiserror::Error;
78
use tokio::sync::mpsc::error::SendError;
@@ -47,12 +48,6 @@ pub enum CasClientError {
4748
#[error("Parse Error: {0}")]
4849
ParseError(#[from] url::ParseError),
4950

50-
#[error("Server Error: {0}")]
51-
ServerConnectionError(String),
52-
53-
#[error("Client Connection Error: {0}")]
54-
ClientConnectionError(String),
55-
5651
#[error("ReqwestMiddleware Error: {0}")]
5752
ReqwestMiddlewareError(#[from] reqwest_middleware::Error),
5853

@@ -87,6 +82,16 @@ impl CasClientError {
8782
pub fn internal<T: Debug>(value: T) -> Self {
8883
CasClientError::InternalError(anyhow!("{value:?}"))
8984
}
85+
86+
// if this error originates from a received http error code returns Some() with that code
87+
// otherwise None
88+
pub fn status(&self) -> Option<StatusCode> {
89+
match self {
90+
CasClientError::ReqwestMiddlewareError(e) => e.status(),
91+
CasClientError::ReqwestError(e, _) => e.status(),
92+
_ => None,
93+
}
94+
}
9095
}
9196

9297
// Define our own result type here (this seems to be the standard).

cas_client/src/remote_client.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -243,16 +243,16 @@ impl RemoteClient {
243243
let client = self.authenticated_http_client.clone();
244244
let api_tag = "cas::query_dedup";
245245

246-
let response = RetryWrapper::new(api_tag)
246+
let result = RetryWrapper::new(api_tag)
247247
.with_429_no_retry()
248248
.log_errors_as_info()
249249
.run(move || client.get(url.clone()).with_extension(Api(api_tag)).send())
250250
.await;
251251

252-
if matches!(response, Err(CasClientError::ServerConnectionError(_))) {
252+
if result.as_ref().is_err_and(|e| e.status().is_some()) {
253253
return Ok(None);
254254
}
255-
Ok(Some(response?))
255+
Ok(Some(result?))
256256
}
257257
}
258258

0 commit comments

Comments
 (0)