rama_http/layer/classify/
status_in_range_is_error.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
use super::{ClassifiedResponse, ClassifyResponse, NeverClassifyEos, SharedClassifier};
use rama_http_types::StatusCode;
use std::{fmt, ops::RangeInclusive};

/// Response classifier that considers responses with a status code within some range to be
/// failures.
#[derive(Debug, Clone)]
pub struct StatusInRangeAsFailures {
    range: RangeInclusive<u16>,
}

impl StatusInRangeAsFailures {
    /// Creates a new `StatusInRangeAsFailures`.
    ///
    /// # Panics
    ///
    /// Panics if the start or end of `range` aren't valid status codes as determined by
    /// [`StatusCode::from_u16`].
    ///
    /// [`StatusCode::from_u16`]: https://docs.rs/http/latest/http/status/struct.StatusCode.html#method.from_u16
    pub fn new(range: RangeInclusive<u16>) -> Self {
        assert!(
            StatusCode::from_u16(*range.start()).is_ok(),
            "range start isn't a valid status code"
        );
        assert!(
            StatusCode::from_u16(*range.end()).is_ok(),
            "range end isn't a valid status code"
        );

        Self { range }
    }

    /// Creates a new `StatusInRangeAsFailures` that classifies client and server responses as
    /// failures.
    ///
    /// This is a convenience for `StatusInRangeAsFailures::new(400..=599)`.
    pub fn new_for_client_and_server_errors() -> Self {
        Self::new(400..=599)
    }

    /// Convert this `StatusInRangeAsFailures` into a [`MakeClassifier`].
    ///
    /// [`MakeClassifier`]: super::MakeClassifier
    pub fn into_make_classifier(self) -> SharedClassifier<Self> {
        SharedClassifier::new(self)
    }
}

impl ClassifyResponse for StatusInRangeAsFailures {
    type FailureClass = StatusInRangeFailureClass;
    type ClassifyEos = NeverClassifyEos<Self::FailureClass>;

    fn classify_response<B>(
        self,
        res: &http::Response<B>,
    ) -> ClassifiedResponse<Self::FailureClass, Self::ClassifyEos> {
        if self.range.contains(&res.status().as_u16()) {
            let class = StatusInRangeFailureClass::StatusCode(res.status());
            ClassifiedResponse::Ready(Err(class))
        } else {
            ClassifiedResponse::Ready(Ok(()))
        }
    }

    fn classify_error<E>(self, error: &E) -> Self::FailureClass
    where
        E: std::fmt::Display,
    {
        StatusInRangeFailureClass::Error(error.to_string())
    }
}

/// The failure class for [`StatusInRangeAsFailures`].
#[derive(Debug)]
pub enum StatusInRangeFailureClass {
    /// A response was classified as a failure with the corresponding status.
    StatusCode(StatusCode),
    /// A response was classified as an error with the corresponding error description.
    Error(String),
}

impl fmt::Display for StatusInRangeFailureClass {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        match self {
            Self::StatusCode(code) => write!(f, "Status code: {}", code),
            Self::Error(error) => write!(f, "Error: {}", error),
        }
    }
}

#[cfg(test)]
mod tests {
    #[allow(unused_imports)]
    use super::*;
    use rama_http_types::Response;

    #[test]
    fn basic() {
        let classifier = StatusInRangeAsFailures::new(400..=599);

        assert!(matches!(
            dbg!(classifier
                .clone()
                .classify_response(&response_with_status(200))),
            ClassifiedResponse::Ready(Ok(())),
        ));

        assert!(matches!(
            dbg!(classifier
                .clone()
                .classify_response(&response_with_status(400))),
            ClassifiedResponse::Ready(Err(StatusInRangeFailureClass::StatusCode(
                StatusCode::BAD_REQUEST
            ))),
        ));

        assert!(matches!(
            dbg!(classifier.classify_response(&response_with_status(500))),
            ClassifiedResponse::Ready(Err(StatusInRangeFailureClass::StatusCode(
                StatusCode::INTERNAL_SERVER_ERROR
            ))),
        ));
    }

    fn response_with_status(status: u16) -> Response<()> {
        Response::builder().status(status).body(()).unwrap()
    }
}