1
1
use crate :: { extract:: rejection:: * , response:: IntoResponseParts } ;
2
+ use axum_core:: extract:: OptionalFromRequestParts ;
2
3
use axum_core:: {
3
4
extract:: FromRequestParts ,
4
5
response:: { IntoResponse , Response , ResponseParts } ,
5
6
} ;
6
- use http:: { request:: Parts , Request } ;
7
+ use http:: { request:: Parts , Extensions , Request } ;
7
8
use std:: {
8
9
convert:: Infallible ,
9
10
task:: { Context , Poll } ,
@@ -43,7 +44,8 @@ use tower_service::Service;
43
44
/// ```
44
45
///
45
46
/// If the extension is missing it will reject the request with a `500 Internal
46
- /// Server Error` response.
47
+ /// Server Error` response. Alternatively, you can use `Option<Extension<T>>` to
48
+ /// make the extension extractor optional.
47
49
///
48
50
/// # As response
49
51
///
@@ -69,6 +71,15 @@ use tower_service::Service;
69
71
#[ must_use]
70
72
pub struct Extension < T > ( pub T ) ;
71
73
74
+ impl < T > Extension < T >
75
+ where
76
+ T : Clone + Send + Sync + ' static ,
77
+ {
78
+ fn from_extensions ( extensions : & Extensions ) -> Option < Self > {
79
+ extensions. get :: < T > ( ) . cloned ( ) . map ( Extension )
80
+ }
81
+ }
82
+
72
83
impl < T , S > FromRequestParts < S > for Extension < T >
73
84
where
74
85
T : Clone + Send + Sync + ' static ,
@@ -77,17 +88,27 @@ where
77
88
type Rejection = ExtensionRejection ;
78
89
79
90
async fn from_request_parts ( req : & mut Parts , _state : & S ) -> Result < Self , Self :: Rejection > {
80
- let value = req
81
- . extensions
82
- . get :: < T > ( )
83
- . ok_or_else ( || {
84
- MissingExtension :: from_err ( format ! (
85
- "Extension of type `{}` was not found. Perhaps you forgot to add it? See `axum::Extension`." ,
86
- std:: any:: type_name:: <T >( )
87
- ) )
88
- } ) . cloned ( ) ?;
89
-
90
- Ok ( Extension ( value) )
91
+ Ok ( Self :: from_extensions ( & req. extensions ) . ok_or_else ( || {
92
+ MissingExtension :: from_err ( format ! (
93
+ "Extension of type `{}` was not found. Perhaps you forgot to add it? See `axum::Extension`." ,
94
+ std:: any:: type_name:: <T >( )
95
+ ) )
96
+ } ) ?)
97
+ }
98
+ }
99
+
100
+ impl < T , S > OptionalFromRequestParts < S > for Extension < T >
101
+ where
102
+ T : Clone + Send + Sync + ' static ,
103
+ S : Send + Sync ,
104
+ {
105
+ type Rejection = Infallible ;
106
+
107
+ async fn from_request_parts (
108
+ req : & mut Parts ,
109
+ _state : & S ,
110
+ ) -> Result < Option < Self > , Self :: Rejection > {
111
+ Ok ( Self :: from_extensions ( & req. extensions ) )
91
112
}
92
113
}
93
114
@@ -161,3 +182,62 @@ where
161
182
self . inner . call ( req)
162
183
}
163
184
}
185
+
186
+ #[ cfg( test) ]
187
+ mod tests {
188
+ use super :: * ;
189
+ use crate :: routing:: get;
190
+ use crate :: test_helpers:: TestClient ;
191
+ use crate :: Router ;
192
+ use http:: StatusCode ;
193
+
194
+ #[ derive( Clone ) ]
195
+ struct Foo ( String ) ;
196
+
197
+ #[ derive( Clone ) ]
198
+ struct Bar ( String ) ;
199
+
200
+ #[ crate :: test]
201
+ async fn extension_extractor ( ) {
202
+ async fn requires_foo ( Extension ( foo) : Extension < Foo > ) -> String {
203
+ foo. 0
204
+ }
205
+
206
+ async fn optional_foo ( extension : Option < Extension < Foo > > ) -> String {
207
+ extension. map ( |foo| foo. 0 . 0 ) . unwrap_or ( "none" . to_owned ( ) )
208
+ }
209
+
210
+ async fn requires_bar ( Extension ( bar) : Extension < Bar > ) -> String {
211
+ bar. 0
212
+ }
213
+
214
+ async fn optional_bar ( extension : Option < Extension < Bar > > ) -> String {
215
+ extension. map ( |bar| bar. 0 . 0 ) . unwrap_or ( "none" . to_owned ( ) )
216
+ }
217
+
218
+ let app = Router :: new ( )
219
+ . route ( "/requires_foo" , get ( requires_foo) )
220
+ . route ( "/optional_foo" , get ( optional_foo) )
221
+ . route ( "/requires_bar" , get ( requires_bar) )
222
+ . route ( "/optional_bar" , get ( optional_bar) )
223
+ . layer ( Extension ( Foo ( "foo" . to_owned ( ) ) ) ) ;
224
+
225
+ let client = TestClient :: new ( app) ;
226
+
227
+ let response = client. get ( "/requires_foo" ) . await ;
228
+ assert_eq ! ( response. status( ) , StatusCode :: OK ) ;
229
+ assert_eq ! ( response. text( ) . await , "foo" ) ;
230
+
231
+ let response = client. get ( "/optional_foo" ) . await ;
232
+ assert_eq ! ( response. status( ) , StatusCode :: OK ) ;
233
+ assert_eq ! ( response. text( ) . await , "foo" ) ;
234
+
235
+ let response = client. get ( "/requires_bar" ) . await ;
236
+ assert_eq ! ( response. status( ) , StatusCode :: INTERNAL_SERVER_ERROR ) ;
237
+ assert_eq ! ( response. text( ) . await , "Missing request extension: Extension of type `axum::extension::tests::Bar` was not found. Perhaps you forgot to add it? See `axum::Extension`." ) ;
238
+
239
+ let response = client. get ( "/optional_bar" ) . await ;
240
+ assert_eq ! ( response. status( ) , StatusCode :: OK ) ;
241
+ assert_eq ! ( response. text( ) . await , "none" ) ;
242
+ }
243
+ }
0 commit comments