1use std::{
4 future::Future,
5 marker::PhantomData,
6 pin::Pin,
7 task::{Context, Poll},
8};
9
10use actix_http::encoding::Encoder;
11use actix_service::{Service, Transform};
12use actix_utils::future::{ok, Either, Ready};
13use futures_core::ready;
14use mime::Mime;
15use once_cell::sync::Lazy;
16use pin_project_lite::pin_project;
17
18use crate::{
19 body::{EitherBody, MessageBody},
20 http::{
21 header::{self, AcceptEncoding, ContentEncoding, Encoding, HeaderValue},
22 StatusCode,
23 },
24 service::{ServiceRequest, ServiceResponse},
25 Error, HttpMessage, HttpResponse,
26};
27
28#[derive(Debug, Clone, Default)]
76#[non_exhaustive]
77pub struct Compress;
78
79impl<S, B> Transform<S, ServiceRequest> for Compress
80where
81 B: MessageBody,
82 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
83{
84 type Response = ServiceResponse<EitherBody<Encoder<B>>>;
85 type Error = Error;
86 type Transform = CompressMiddleware<S>;
87 type InitError = ();
88 type Future = Ready<Result<Self::Transform, Self::InitError>>;
89
90 fn new_transform(&self, service: S) -> Self::Future {
91 ok(CompressMiddleware { service })
92 }
93}
94
95pub struct CompressMiddleware<S> {
96 service: S,
97}
98
99impl<S, B> Service<ServiceRequest> for CompressMiddleware<S>
100where
101 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
102 B: MessageBody,
103{
104 type Response = ServiceResponse<EitherBody<Encoder<B>>>;
105 type Error = Error;
106 #[allow(clippy::type_complexity)]
107 type Future = Either<CompressResponse<S, B>, Ready<Result<Self::Response, Self::Error>>>;
108
109 actix_service::forward_ready!(service);
110
111 #[allow(clippy::borrow_interior_mutable_const)]
112 fn call(&self, req: ServiceRequest) -> Self::Future {
113 let accept_encoding = req.get_header::<AcceptEncoding>();
115
116 let accept_encoding = match accept_encoding {
117 None => {
119 return Either::left(CompressResponse {
120 encoding: Encoding::identity(),
121 fut: self.service.call(req),
122 _phantom: PhantomData,
123 })
124 }
125
126 Some(accept_encoding) => accept_encoding,
128 };
129
130 match accept_encoding.negotiate(SUPPORTED_ENCODINGS.iter()) {
131 None => {
132 let mut res = HttpResponse::with_body(
133 StatusCode::NOT_ACCEPTABLE,
134 SUPPORTED_ENCODINGS_STRING.as_str(),
135 );
136
137 res.headers_mut()
138 .insert(header::VARY, HeaderValue::from_static("Accept-Encoding"));
139
140 Either::right(ok(req
141 .into_response(res)
142 .map_into_boxed_body()
143 .map_into_right_body()))
144 }
145
146 Some(encoding) => Either::left(CompressResponse {
147 fut: self.service.call(req),
148 encoding,
149 _phantom: PhantomData,
150 }),
151 }
152 }
153}
154
155pin_project! {
156 pub struct CompressResponse<S, B>
157 where
158 S: Service<ServiceRequest>,
159 {
160 #[pin]
161 fut: S::Future,
162 encoding: Encoding,
163 _phantom: PhantomData<B>,
164 }
165}
166
167impl<S, B> Future for CompressResponse<S, B>
168where
169 B: MessageBody,
170 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
171{
172 type Output = Result<ServiceResponse<EitherBody<Encoder<B>>>, Error>;
173
174 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
175 let this = self.as_mut().project();
176
177 match ready!(this.fut.poll(cx)) {
178 Ok(resp) => {
179 let enc = match this.encoding {
180 Encoding::Known(enc) => *enc,
181 Encoding::Unknown(enc) => {
182 unimplemented!("encoding '{enc}' should not be here");
183 }
184 };
185
186 Poll::Ready(Ok(resp.map_body(move |head, body| {
187 let content_type = head.headers.get(header::CONTENT_TYPE);
188
189 fn default_compress_predicate(content_type: Option<&HeaderValue>) -> bool {
190 match content_type {
191 None => true,
192 Some(hdr) => {
193 match hdr.to_str().ok().and_then(|hdr| hdr.parse::<Mime>().ok()) {
194 Some(mime) if mime.type_().as_str() == "image" => false,
195 Some(mime) if mime.type_().as_str() == "video" => false,
196 _ => true,
197 }
198 }
199 }
200 }
201
202 let enc = if default_compress_predicate(content_type) {
203 enc
204 } else {
205 ContentEncoding::Identity
206 };
207
208 EitherBody::left(Encoder::response(enc, head, body))
209 })))
210 }
211
212 Err(err) => Poll::Ready(Err(err)),
213 }
214 }
215}
216
217static SUPPORTED_ENCODINGS_STRING: Lazy<String> = Lazy::new(|| {
218 #[allow(unused_mut)] let mut encoding: Vec<&str> = vec![];
220
221 #[cfg(feature = "compress-brotli")]
222 {
223 encoding.push("br");
224 }
225
226 #[cfg(feature = "compress-gzip")]
227 {
228 encoding.push("gzip");
229 encoding.push("deflate");
230 }
231
232 #[cfg(feature = "compress-zstd")]
233 {
234 encoding.push("zstd");
235 }
236
237 assert!(
238 !encoding.is_empty(),
239 "encoding can not be empty unless __compress feature has been explicitly enabled by itself"
240 );
241
242 encoding.join(", ")
243});
244
245static SUPPORTED_ENCODINGS: &[Encoding] = &[
246 Encoding::identity(),
247 #[cfg(feature = "compress-brotli")]
248 {
249 Encoding::brotli()
250 },
251 #[cfg(feature = "compress-gzip")]
252 {
253 Encoding::gzip()
254 },
255 #[cfg(feature = "compress-gzip")]
256 {
257 Encoding::deflate()
258 },
259 #[cfg(feature = "compress-zstd")]
260 {
261 Encoding::zstd()
262 },
263];
264
265#[cfg(feature = "compress-gzip")]
267#[cfg(test)]
268mod tests {
269 use std::collections::HashSet;
270
271 use static_assertions::assert_impl_all;
272
273 use super::*;
274 use crate::{http::header::ContentType, middleware::DefaultHeaders, test, web, App};
275
276 const HTML_DATA_PART: &str = "<html><h1>hello world</h1></html";
277 const HTML_DATA: &str = const_str::repeat!(HTML_DATA_PART, 100);
278
279 const TEXT_DATA_PART: &str = "hello world ";
280 const TEXT_DATA: &str = const_str::repeat!(TEXT_DATA_PART, 100);
281
282 assert_impl_all!(Compress: Send, Sync);
283
284 pub fn gzip_decode(bytes: impl AsRef<[u8]>) -> Vec<u8> {
285 use std::io::Read as _;
286 let mut decoder = flate2::read::GzDecoder::new(bytes.as_ref());
287 let mut buf = Vec::new();
288 decoder.read_to_end(&mut buf).unwrap();
289 buf
290 }
291
292 #[track_caller]
293 fn assert_successful_res_with_content_type<B>(res: &ServiceResponse<B>, ct: &str) {
294 assert!(res.status().is_success());
295 assert!(
296 res.headers()
297 .get(header::CONTENT_TYPE)
298 .expect("content-type header should be present")
299 .to_str()
300 .expect("content-type header should be utf-8")
301 .contains(ct),
302 "response's content-type did not match {}",
303 ct
304 );
305 }
306
307 #[track_caller]
308 fn assert_successful_gzip_res_with_content_type<B>(res: &ServiceResponse<B>, ct: &str) {
309 assert_successful_res_with_content_type(res, ct);
310 assert_eq!(
311 res.headers()
312 .get(header::CONTENT_ENCODING)
313 .expect("response should be gzip compressed"),
314 "gzip",
315 );
316 }
317
318 #[track_caller]
319 fn assert_successful_identity_res_with_content_type<B>(res: &ServiceResponse<B>, ct: &str) {
320 assert_successful_res_with_content_type(res, ct);
321 assert!(
322 res.headers().get(header::CONTENT_ENCODING).is_none(),
323 "response should not be compressed",
324 );
325 }
326
327 #[actix_rt::test]
328 async fn prevents_double_compressing() {
329 let app = test::init_service({
330 App::new()
331 .wrap(Compress::default())
332 .route(
333 "/single",
334 web::get().to(move || HttpResponse::Ok().body(TEXT_DATA)),
335 )
336 .service(
337 web::resource("/double")
338 .wrap(Compress::default())
339 .wrap(DefaultHeaders::new().add(("x-double", "true")))
340 .route(web::get().to(move || HttpResponse::Ok().body(TEXT_DATA))),
341 )
342 })
343 .await;
344
345 let req = test::TestRequest::default()
346 .uri("/single")
347 .insert_header((header::ACCEPT_ENCODING, "gzip"))
348 .to_request();
349 let res = test::call_service(&app, req).await;
350 assert_eq!(res.status(), StatusCode::OK);
351 assert_eq!(res.headers().get("x-double"), None);
352 assert_eq!(res.headers().get(header::CONTENT_ENCODING).unwrap(), "gzip");
353 let bytes = test::read_body(res).await;
354 assert_eq!(gzip_decode(bytes), TEXT_DATA.as_bytes());
355
356 let req = test::TestRequest::default()
357 .uri("/double")
358 .insert_header((header::ACCEPT_ENCODING, "gzip"))
359 .to_request();
360 let res = test::call_service(&app, req).await;
361 assert_eq!(res.status(), StatusCode::OK);
362 assert_eq!(res.headers().get("x-double").unwrap(), "true");
363 assert_eq!(res.headers().get(header::CONTENT_ENCODING).unwrap(), "gzip");
364 let bytes = test::read_body(res).await;
365 assert_eq!(gzip_decode(bytes), TEXT_DATA.as_bytes());
366 }
367
368 #[actix_rt::test]
369 async fn retains_previously_set_vary_header() {
370 let app = test::init_service({
371 App::new()
372 .wrap(Compress::default())
373 .default_service(web::to(move || {
374 HttpResponse::Ok()
375 .insert_header((header::VARY, "x-test"))
376 .body(TEXT_DATA)
377 }))
378 })
379 .await;
380
381 let req = test::TestRequest::default()
382 .insert_header((header::ACCEPT_ENCODING, "gzip"))
383 .to_request();
384 let res = test::call_service(&app, req).await;
385 assert_eq!(res.status(), StatusCode::OK);
386 #[allow(clippy::mutable_key_type)]
387 let vary_headers = res.headers().get_all(header::VARY).collect::<HashSet<_>>();
388 assert!(vary_headers.contains(&HeaderValue::from_static("x-test")));
389 assert!(vary_headers.contains(&HeaderValue::from_static("accept-encoding")));
390 }
391
392 fn configure_predicate_test(cfg: &mut web::ServiceConfig) {
393 cfg.route(
394 "/html",
395 web::get().to(|| {
396 HttpResponse::Ok()
397 .content_type(ContentType::html())
398 .body(HTML_DATA)
399 }),
400 )
401 .route(
402 "/image",
403 web::get().to(|| {
404 HttpResponse::Ok()
405 .content_type(ContentType::jpeg())
406 .body(TEXT_DATA)
407 }),
408 );
409 }
410
411 #[actix_rt::test]
412 async fn prevents_compression_jpeg() {
413 let app = test::init_service(
414 App::new()
415 .wrap(Compress::default())
416 .configure(configure_predicate_test),
417 )
418 .await;
419
420 let req =
421 test::TestRequest::with_uri("/html").insert_header((header::ACCEPT_ENCODING, "gzip"));
422 let res = test::call_service(&app, req.to_request()).await;
423 assert_successful_gzip_res_with_content_type(&res, "text/html");
424 assert_ne!(test::read_body(res).await, HTML_DATA.as_bytes());
425
426 let req =
427 test::TestRequest::with_uri("/image").insert_header((header::ACCEPT_ENCODING, "gzip"));
428 let res = test::call_service(&app, req.to_request()).await;
429 assert_successful_identity_res_with_content_type(&res, "image/jpeg");
430 assert_eq!(test::read_body(res).await, TEXT_DATA.as_bytes());
431 }
432
433 #[actix_rt::test]
434 async fn prevents_compression_empty() {
435 let app = test::init_service({
436 App::new()
437 .wrap(Compress::default())
438 .default_service(web::to(move || HttpResponse::Ok().finish()))
439 })
440 .await;
441
442 let req = test::TestRequest::default()
443 .insert_header((header::ACCEPT_ENCODING, "gzip"))
444 .to_request();
445 let res = test::call_service(&app, req).await;
446 assert_eq!(res.status(), StatusCode::OK);
447 assert!(!res.headers().contains_key(header::CONTENT_ENCODING));
448 assert!(test::read_body(res).await.is_empty());
449 }
450}
451
452#[cfg(feature = "compress-brotli")]
453#[cfg(test)]
454mod tests_brotli {
455 use super::*;
456 use crate::{test, web, App};
457
458 #[actix_rt::test]
459 async fn prevents_compression_empty() {
460 let app = test::init_service({
461 App::new()
462 .wrap(Compress::default())
463 .default_service(web::to(move || HttpResponse::Ok().finish()))
464 })
465 .await;
466
467 let req = test::TestRequest::default()
468 .insert_header((header::ACCEPT_ENCODING, "br"))
469 .to_request();
470 let res = test::call_service(&app, req).await;
471 assert_eq!(res.status(), StatusCode::OK);
472 assert!(!res.headers().contains_key(header::CONTENT_ENCODING));
473 assert!(test::read_body(res).await.is_empty());
474 }
475}