@@ -13,16 +13,17 @@ use futures::TryStreamExt;
13
13
use http:: header:: RANGE ;
14
14
use http:: StatusCode ;
15
15
use merklehash:: MerkleHash ;
16
+ use reqwest:: Response ;
16
17
use reqwest_middleware:: ClientWithMiddleware ;
17
- use tokio_retry:: strategy:: ExponentialBackoff ;
18
18
use tracing:: { debug, error, info, trace, warn} ;
19
19
use url:: Url ;
20
20
use utils:: singleflight:: Group ;
21
21
22
22
use crate :: error:: { CasClientError , Result } ;
23
- use crate :: http_client:: { Api , BASE_RETRY_DELAY_MS , BASE_RETRY_MAX_DURATION_MS , NUM_RETRIES } ;
23
+ use crate :: http_client:: Api ;
24
24
use crate :: output_provider:: OutputProvider ;
25
25
use crate :: remote_client:: { get_reconstruction_with_endpoint_and_client, PREFIX_DEFAULT } ;
26
+ use crate :: retry_wrapper:: { RetryWrapper , RetryableReqwestError } ;
26
27
27
28
utils:: configurable_constants! {
28
29
// Env (HF_XET_NUM_RANGE_IN_SEGMENT_BASE) base value for the approx number of ranges in the initial
@@ -479,28 +480,6 @@ pub(crate) async fn get_one_fetch_term_data(
479
480
Ok ( term_download_output)
480
481
}
481
482
482
- struct ChunkRangeDeserializeFromBytesStreamRetryCondition ;
483
-
484
- impl tokio_retry:: Condition < CasClientError > for ChunkRangeDeserializeFromBytesStreamRetryCondition {
485
- fn should_retry ( & mut self , err : & CasClientError ) -> bool {
486
- // we only care about retrying some error yielded by trying to deserialize the stream
487
- let CasClientError :: CasObjectError ( CasObjectError :: InternalIOError ( cas_object_io_err) ) = err else {
488
- return false ;
489
- } ;
490
- let Some ( inner) = cas_object_io_err. get_ref ( ) else {
491
- return false ;
492
- } ;
493
- let Some ( inner_reqwest_err) = inner. downcast_ref :: < reqwest:: Error > ( ) else {
494
- return false ;
495
- } ;
496
- // errors that indicate reading the body failed
497
- inner_reqwest_err. is_body ( )
498
- || inner_reqwest_err. is_decode ( )
499
- || inner_reqwest_err. is_timeout ( )
500
- || inner_reqwest_err. is_request ( )
501
- }
502
- }
503
-
504
483
/// use the provided http_client to make requests to S3/blob store using the url and url_range
505
484
/// parts of a CASReconstructionFetchInfo. The url_range part is used directly in a http Range header
506
485
/// value (see fn `range_header`).
@@ -511,61 +490,83 @@ async fn download_fetch_term_data(
511
490
) -> Result < DownloadRangeResult > {
512
491
trace ! ( "{hash},{},{}" , fetch_term. range. start, fetch_term. range. end) ;
513
492
493
+ let api_tag = "s3::get_range" ;
514
494
let url = Url :: parse ( fetch_term. url . as_str ( ) ) ?;
515
495
516
- tokio_retry:: RetryIf :: spawn (
517
- ExponentialBackoff :: from_millis ( BASE_RETRY_DELAY_MS )
518
- . max_delay ( Duration :: from_millis ( BASE_RETRY_MAX_DURATION_MS ) )
519
- . take ( NUM_RETRIES as usize ) ,
520
- || async {
521
- let response = match http_client
522
- . get ( url. clone ( ) )
523
- . header ( RANGE , fetch_term. url_range . range_header ( ) )
524
- . with_extension ( Api ( "s3::get_range" ) )
525
- . send ( )
526
- . await
527
- . map_err ( CasClientError :: from)
528
- . log_error ( "error downloading range" ) ?
529
- . error_for_status ( )
530
- {
531
- Ok ( response) => response,
532
- Err ( e) => return match e. status ( ) {
533
- Some ( StatusCode :: FORBIDDEN ) => {
534
- info ! ( "error code {} for hash {hash}, will re-fetch reconstruction" , StatusCode :: FORBIDDEN , ) ;
535
- Ok ( DownloadRangeResult :: Forbidden )
536
- } ,
537
- _ => Err ( e. into ( ) ) ,
538
- }
539
- . log_error ( "error code" ) ,
540
- } ;
496
+ // helper to convert a CasObjectError to RetryableReqwestError
497
+ // only retryable if the error originates from an error from the byte stream from reqwest
498
+ let parse_map_err = |err : CasObjectError | {
499
+ let CasObjectError :: InternalIOError ( cas_object_io_err) = & err else {
500
+ return RetryableReqwestError :: FatalError ( CasClientError :: CasObjectError ( err) ) ;
501
+ } ;
502
+ let Some ( inner) = cas_object_io_err. get_ref ( ) else {
503
+ return RetryableReqwestError :: FatalError ( CasClientError :: CasObjectError ( err) ) ;
504
+ } ;
505
+ // attempt to cast into the reqwest error wrapped by std::io::Error::other
506
+ let Some ( inner_reqwest_err) = inner. downcast_ref :: < reqwest:: Error > ( ) else {
507
+ return RetryableReqwestError :: FatalError ( CasClientError :: CasObjectError ( err) ) ;
508
+ } ;
509
+ // errors that indicate reading the body failed
510
+ if inner_reqwest_err. is_body ( )
511
+ || inner_reqwest_err. is_decode ( )
512
+ || inner_reqwest_err. is_timeout ( )
513
+ || inner_reqwest_err. is_request ( )
514
+ {
515
+ RetryableReqwestError :: RetryableError ( CasClientError :: CasObjectError ( err) )
516
+ } else {
517
+ RetryableReqwestError :: FatalError ( CasClientError :: CasObjectError ( err) )
518
+ }
519
+ } ;
541
520
542
- if let Some ( content_length ) = response . content_length ( ) {
543
- let expected_len = fetch_term . url_range . length ( ) ;
544
- if content_length != expected_len {
545
- error ! ( "got back a smaller byte range ({content_length}) than requested ({ expected_len}) from s3" ) ;
546
- return Err ( CasClientError :: InvalidRange ) ;
547
- }
521
+ let parse = move | response : Response | async move {
522
+ if let Some ( content_length ) = response . content_length ( ) {
523
+ let expected_len = fetch_term . url_range . length ( ) ;
524
+ if content_length != expected_len {
525
+ error ! ( "got back a smaller byte range ({content_length}) than requested ({expected_len}) from s3" ) ;
526
+ return Err ( RetryableReqwestError :: FatalError ( CasClientError :: InvalidRange ) ) ;
548
527
}
528
+ }
549
529
550
- let ( data, chunk_byte_indices) = cas_object:: deserialize_async:: deserialize_chunks_from_stream (
551
- response. bytes_stream ( ) . map_err ( std:: io:: Error :: other) ,
552
- )
553
- . await ?;
554
- Ok ( DownloadRangeResult :: Data ( TermDownloadOutput {
555
- data,
556
- chunk_byte_indices,
557
- chunk_range : fetch_term. range ,
558
- } ) )
559
- } ,
560
- ChunkRangeDeserializeFromBytesStreamRetryCondition ,
561
- )
562
- . await
530
+ let ( data, chunk_byte_indices) = cas_object:: deserialize_async:: deserialize_chunks_from_stream (
531
+ response. bytes_stream ( ) . map_err ( std:: io:: Error :: other) ,
532
+ )
533
+ . await
534
+ . map_err ( parse_map_err) ?;
535
+ Ok ( DownloadRangeResult :: Data ( TermDownloadOutput {
536
+ data,
537
+ chunk_byte_indices,
538
+ chunk_range : fetch_term. range ,
539
+ } ) )
540
+ } ;
541
+
542
+ let result = RetryWrapper :: new ( api_tag)
543
+ . run_and_extract_custom (
544
+ move || {
545
+ http_client
546
+ . get ( url. clone ( ) )
547
+ . header ( RANGE , fetch_term. url_range . range_header ( ) )
548
+ . with_extension ( Api ( api_tag) )
549
+ . send ( )
550
+ } ,
551
+ parse,
552
+ )
553
+ . await ;
554
+ // in case the error was a 403 Forbidden status code, we raise it up as the special DownloadRangeResult::Forbidden
555
+ // variant so the fetch info is refetched
556
+ if result
557
+ . as_ref ( )
558
+ . is_err_and ( |e| e. status ( ) . is_some_and ( |status| status == StatusCode :: FORBIDDEN ) )
559
+ {
560
+ return Ok ( DownloadRangeResult :: Forbidden ) ;
561
+ }
562
+ result
563
563
}
564
564
565
565
#[ cfg( test) ]
566
566
mod tests {
567
567
use anyhow:: Result ;
568
568
use cas_types:: { HttpRange , QueryReconstructionResponse } ;
569
+ use http:: header:: RANGE ;
569
570
use httpmock:: prelude:: * ;
570
571
use tokio:: task:: JoinSet ;
571
572
use tokio:: time:: sleep;
@@ -769,8 +770,8 @@ mod tests {
769
770
// download task will not return if keep hitting 403
770
771
handle. abort ( ) ;
771
772
772
- assert ! ( mock_fi. hits( ) >= 2 ) ;
773
- assert ! ( mock_data. hits( ) >= 2 ) ;
773
+ assert ! ( mock_fi. hits( ) >= 2 , "assertion failed: mock_fi.hits() {} >= 2" , mock_fi . hits ( ) ) ;
774
+ assert ! ( mock_data. hits( ) >= 2 , "assertion failed: mock_data.hits() {} >= 2" , mock_data . hits ( ) ) ;
774
775
775
776
Ok ( ( ) )
776
777
}
0 commit comments