rama_http/layer/compression/
predicate.rs1use crate::dep::http_body::Body;
10use rama_http_types::{HeaderMap, StatusCode, Version, dep::http::Extensions, header};
11use std::{fmt, sync::Arc};
12
13pub trait Predicate: Clone {
15 fn should_compress<B>(&self, response: &rama_http_types::Response<B>) -> bool
17 where
18 B: Body;
19
20 fn and<Other>(self, other: Other) -> And<Self, Other>
24 where
25 Self: Sized,
26 Other: Predicate,
27 {
28 And {
29 lhs: self,
30 rhs: other,
31 }
32 }
33}
34
35impl<F> Predicate for F
36where
37 F: Fn(StatusCode, Version, &HeaderMap, &Extensions) -> bool + Clone,
38{
39 fn should_compress<B>(&self, response: &rama_http_types::Response<B>) -> bool
40 where
41 B: Body,
42 {
43 let status = response.status();
44 let version = response.version();
45 let headers = response.headers();
46 let extensions = response.extensions();
47 self(status, version, headers, extensions)
48 }
49}
50
51impl<T> Predicate for Option<T>
52where
53 T: Predicate,
54{
55 fn should_compress<B>(&self, response: &rama_http_types::Response<B>) -> bool
56 where
57 B: Body,
58 {
59 self.as_ref()
60 .map(|inner| inner.should_compress(response))
61 .unwrap_or(true)
62 }
63}
64
65#[derive(Debug, Clone, Default, Copy)]
69pub struct And<Lhs, Rhs> {
70 lhs: Lhs,
71 rhs: Rhs,
72}
73
74impl<Lhs, Rhs> Predicate for And<Lhs, Rhs>
75where
76 Lhs: Predicate,
77 Rhs: Predicate,
78{
79 fn should_compress<B>(&self, response: &rama_http_types::Response<B>) -> bool
80 where
81 B: Body,
82 {
83 self.lhs.should_compress(response) && self.rhs.should_compress(response)
84 }
85}
86
87#[derive(Debug, Clone)]
117pub struct DefaultPredicate(
118 And<And<And<SizeAbove, NotForContentType>, NotForContentType>, NotForContentType>,
119);
120
121impl DefaultPredicate {
122 pub fn new() -> Self {
124 let inner = SizeAbove::new(SizeAbove::DEFAULT_MIN_SIZE)
125 .and(NotForContentType::GRPC)
126 .and(NotForContentType::IMAGES)
127 .and(NotForContentType::SSE);
128 Self(inner)
129 }
130}
131
132impl Default for DefaultPredicate {
133 fn default() -> Self {
134 Self::new()
135 }
136}
137
138impl Predicate for DefaultPredicate {
139 fn should_compress<B>(&self, response: &rama_http_types::Response<B>) -> bool
140 where
141 B: Body,
142 {
143 self.0.should_compress(response)
144 }
145}
146
147#[derive(Clone, Copy, Debug)]
149pub struct SizeAbove(u16);
150
151impl SizeAbove {
152 pub(crate) const DEFAULT_MIN_SIZE: u16 = 32;
153
154 pub const fn new(min_size_bytes: u16) -> Self {
160 Self(min_size_bytes)
161 }
162}
163
164impl Default for SizeAbove {
165 fn default() -> Self {
166 Self(Self::DEFAULT_MIN_SIZE)
167 }
168}
169
170impl Predicate for SizeAbove {
171 fn should_compress<B>(&self, response: &rama_http_types::Response<B>) -> bool
172 where
173 B: Body,
174 {
175 let content_size = response.body().size_hint().exact().or_else(|| {
176 response
177 .headers()
178 .get(header::CONTENT_LENGTH)
179 .and_then(|h| h.to_str().ok())
180 .and_then(|val| val.parse().ok())
181 });
182
183 match content_size {
184 Some(size) => size >= (self.0 as u64),
185 _ => true,
186 }
187 }
188}
189
190#[derive(Clone, Debug)]
192pub struct NotForContentType {
193 content_type: Str,
194 exception: Option<Str>,
195}
196
197impl NotForContentType {
198 pub const GRPC: Self = Self::const_new("application/grpc");
200
201 pub const IMAGES: Self = Self {
203 content_type: Str::Static("image/"),
204 exception: Some(Str::Static("image/svg+xml")),
205 };
206
207 pub const SSE: Self = Self::const_new("text/event-stream");
209
210 pub fn new(content_type: &str) -> Self {
212 Self {
213 content_type: Str::Shared(content_type.into()),
214 exception: None,
215 }
216 }
217
218 pub const fn const_new(content_type: &'static str) -> Self {
220 Self {
221 content_type: Str::Static(content_type),
222 exception: None,
223 }
224 }
225}
226
227impl Predicate for NotForContentType {
228 fn should_compress<B>(&self, response: &rama_http_types::Response<B>) -> bool
229 where
230 B: Body,
231 {
232 if let Some(except) = &self.exception {
233 if content_type(response) == except.as_str() {
234 return true;
235 }
236 }
237
238 !content_type(response).starts_with(self.content_type.as_str())
239 }
240}
241
242#[derive(Clone)]
243enum Str {
244 Static(&'static str),
245 Shared(Arc<str>),
246}
247
248impl Str {
249 fn as_str(&self) -> &str {
250 match self {
251 Str::Static(s) => s,
252 Str::Shared(s) => s,
253 }
254 }
255}
256
257impl fmt::Debug for Str {
258 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
259 match self {
260 Self::Static(inner) => inner.fmt(f),
261 Self::Shared(inner) => inner.fmt(f),
262 }
263 }
264}
265
266fn content_type<B>(response: &rama_http_types::Response<B>) -> &str {
267 response
268 .headers()
269 .get(header::CONTENT_TYPE)
270 .and_then(|h| h.to_str().ok())
271 .unwrap_or_default()
272}