1use std::{fmt, ops, sync::Arc};
4
5use actix_utils::future::{ok, ready, Ready};
6use serde::de::DeserializeOwned;
7
8use crate::{dev::Payload, error::QueryPayloadError, Error, FromRequest, HttpRequest};
9
10#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
61pub struct Query<T>(pub T);
62
63impl<T> Query<T> {
64 pub fn into_inner(self) -> T {
66 self.0
67 }
68}
69
70impl<T: DeserializeOwned> Query<T> {
71 pub fn from_query(query_str: &str) -> Result<Self, QueryPayloadError> {
82 serde_urlencoded::from_str::<T>(query_str)
83 .map(Self)
84 .map_err(QueryPayloadError::Deserialize)
85 }
86}
87
88impl<T> ops::Deref for Query<T> {
89 type Target = T;
90
91 fn deref(&self) -> &T {
92 &self.0
93 }
94}
95
96impl<T> ops::DerefMut for Query<T> {
97 fn deref_mut(&mut self) -> &mut T {
98 &mut self.0
99 }
100}
101
102impl<T: fmt::Display> fmt::Display for Query<T> {
103 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
104 self.0.fmt(f)
105 }
106}
107
108impl<T: DeserializeOwned> FromRequest for Query<T> {
110 type Error = Error;
111 type Future = Ready<Result<Self, Error>>;
112
113 #[inline]
114 fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future {
115 let error_handler = req
116 .app_data::<QueryConfig>()
117 .and_then(|c| c.err_handler.clone());
118
119 serde_urlencoded::from_str::<T>(req.query_string())
120 .map(|val| ok(Query(val)))
121 .unwrap_or_else(move |err| {
122 let err = QueryPayloadError::Deserialize(err);
123
124 log::debug!(
125 "Failed during Query extractor deserialization. \
126 Request path: {:?}",
127 req.path()
128 );
129
130 let err = if let Some(error_handler) = error_handler {
131 (error_handler)(err, req)
132 } else {
133 err.into()
134 };
135
136 ready(Err(err))
137 })
138 }
139}
140
141#[derive(Clone, Default)]
171pub struct QueryConfig {
172 #[allow(clippy::type_complexity)]
173 err_handler: Option<Arc<dyn Fn(QueryPayloadError, &HttpRequest) -> Error + Send + Sync>>,
174}
175
176impl QueryConfig {
177 pub fn error_handler<F>(mut self, f: F) -> Self
179 where
180 F: Fn(QueryPayloadError, &HttpRequest) -> Error + Send + Sync + 'static,
181 {
182 self.err_handler = Some(Arc::new(f));
183 self
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use actix_http::StatusCode;
190 use derive_more::Display;
191 use serde::Deserialize;
192
193 use super::*;
194 use crate::{error::InternalError, test::TestRequest, HttpResponse};
195
196 #[derive(Deserialize, Debug, Display)]
197 struct Id {
198 id: String,
199 }
200
201 #[actix_rt::test]
202 async fn test_service_request_extract() {
203 let req = TestRequest::with_uri("/name/user1/").to_srv_request();
204 assert!(Query::<Id>::from_query(req.query_string()).is_err());
205
206 let req = TestRequest::with_uri("/name/user1/?id=test").to_srv_request();
207 let mut s = Query::<Id>::from_query(req.query_string()).unwrap();
208
209 assert_eq!(s.id, "test");
210 assert_eq!(
211 format!("{}, {:?}", s, s),
212 "test, Query(Id { id: \"test\" })"
213 );
214
215 s.id = "test1".to_string();
216 let s = s.into_inner();
217 assert_eq!(s.id, "test1");
218 }
219
220 #[actix_rt::test]
221 async fn test_request_extract() {
222 let req = TestRequest::with_uri("/name/user1/").to_srv_request();
223 let (req, mut pl) = req.into_parts();
224 assert!(Query::<Id>::from_request(&req, &mut pl).await.is_err());
225
226 let req = TestRequest::with_uri("/name/user1/?id=test").to_srv_request();
227 let (req, mut pl) = req.into_parts();
228
229 let mut s = Query::<Id>::from_request(&req, &mut pl).await.unwrap();
230 assert_eq!(s.id, "test");
231 assert_eq!(
232 format!("{}, {:?}", s, s),
233 "test, Query(Id { id: \"test\" })"
234 );
235
236 s.id = "test1".to_string();
237 let s = s.into_inner();
238 assert_eq!(s.id, "test1");
239 }
240
241 #[actix_rt::test]
242 #[should_panic]
243 async fn test_tuple_panic() {
244 let req = TestRequest::with_uri("/?one=1&two=2").to_srv_request();
245 let (req, mut pl) = req.into_parts();
246
247 Query::<(u32, u32)>::from_request(&req, &mut pl)
248 .await
249 .unwrap();
250 }
251
252 #[actix_rt::test]
253 async fn test_custom_error_responder() {
254 let req = TestRequest::with_uri("/name/user1/")
255 .app_data(QueryConfig::default().error_handler(|e, _| {
256 let resp = HttpResponse::UnprocessableEntity().finish();
257 InternalError::from_response(e, resp).into()
258 }))
259 .to_srv_request();
260
261 let (req, mut pl) = req.into_parts();
262 let query = Query::<Id>::from_request(&req, &mut pl).await;
263
264 assert!(query.is_err());
265 assert_eq!(
266 query
267 .unwrap_err()
268 .as_response_error()
269 .error_response()
270 .status(),
271 StatusCode::UNPROCESSABLE_ENTITY
272 );
273 }
274}