axum_extra/
protobuf.rs

1//! Protocol Buffer extractor and response.
2
3use axum::{
4    extract::{rejection::BytesRejection, FromRequest, Request},
5    response::{IntoResponse, Response},
6};
7use axum_core::__composite_rejection as composite_rejection;
8use axum_core::__define_rejection as define_rejection;
9use bytes::{Bytes, BytesMut};
10use http::StatusCode;
11use prost::Message;
12
13/// A Protocol Buffer message extractor and response.
14///
15/// This can be used both as an extractor and as a response.
16///
17/// # As extractor
18///
19/// When used as an extractor, it can decode request bodies into some type that
20/// implements [`prost::Message`]. The request will be rejected (and a [`ProtobufRejection`] will
21/// be returned) if:
22///
23/// - The body couldn't be decoded into the target Protocol Buffer message type.
24/// - Buffering the request body fails.
25///
26/// See [`ProtobufRejection`] for more details.
27///
28/// The extractor does not expect a `Content-Type` header to be present in the request.
29///
30/// # Extractor example
31///
32/// ```rust,no_run
33/// use axum::{routing::post, Router};
34/// use axum_extra::protobuf::Protobuf;
35///
36/// #[derive(prost::Message)]
37/// struct CreateUser {
38///     #[prost(string, tag="1")]
39///     email: String,
40///     #[prost(string, tag="2")]
41///     password: String,
42/// }
43///
44/// async fn create_user(Protobuf(payload): Protobuf<CreateUser>) {
45///     // payload is `CreateUser`
46/// }
47///
48/// let app = Router::new().route("/users", post(create_user));
49/// # let _: Router = app;
50/// ```
51///
52/// # As response
53///
54/// When used as a response, it can encode any type that implements [`prost::Message`] to
55/// a newly allocated buffer.
56///
57/// If no `Content-Type` header is set, the `Content-Type: application/octet-stream` header
58/// will be used automatically.
59///
60/// # Response example
61///
62/// ```
63/// use axum::{
64///     extract::Path,
65///     routing::get,
66///     Router,
67/// };
68/// use axum_extra::protobuf::Protobuf;
69///
70/// #[derive(prost::Message)]
71/// struct User {
72///     #[prost(string, tag="1")]
73///     username: String,
74/// }
75///
76/// async fn get_user(Path(user_id) : Path<String>) -> Protobuf<User> {
77///     let user = find_user(user_id).await;
78///     Protobuf(user)
79/// }
80///
81/// async fn find_user(user_id: String) -> User {
82///     // ...
83///     # unimplemented!()
84/// }
85///
86/// let app = Router::new().route("/users/{id}", get(get_user));
87/// # let _: Router = app;
88/// ```
89#[derive(Debug, Clone, Copy, Default)]
90#[cfg_attr(docsrs, doc(cfg(feature = "protobuf")))]
91#[must_use]
92pub struct Protobuf<T>(pub T);
93
94impl<T, S> FromRequest<S> for Protobuf<T>
95where
96    T: Message + Default,
97    S: Send + Sync,
98{
99    type Rejection = ProtobufRejection;
100
101    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
102        let mut bytes = Bytes::from_request(req, state).await?;
103
104        match T::decode(&mut bytes) {
105            Ok(value) => Ok(Protobuf(value)),
106            Err(err) => Err(ProtobufDecodeError::from_err(err).into()),
107        }
108    }
109}
110
111axum_core::__impl_deref!(Protobuf);
112
113impl<T> From<T> for Protobuf<T> {
114    fn from(inner: T) -> Self {
115        Self(inner)
116    }
117}
118
119impl<T> IntoResponse for Protobuf<T>
120where
121    T: Message + Default,
122{
123    fn into_response(self) -> Response {
124        let mut buf = BytesMut::with_capacity(128);
125        match &self.0.encode(&mut buf) {
126            Ok(()) => buf.into_response(),
127            Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response(),
128        }
129    }
130}
131
132define_rejection! {
133    #[status = UNPROCESSABLE_ENTITY]
134    #[body = "Failed to decode the body"]
135    /// Rejection type for [`Protobuf`].
136    ///
137    /// This rejection is used if the request body couldn't be decoded into the target type.
138    pub struct ProtobufDecodeError(Error);
139}
140
141composite_rejection! {
142    /// Rejection used for [`Protobuf`].
143    ///
144    /// Contains one variant for each way the [`Protobuf`] extractor
145    /// can fail.
146    pub enum ProtobufRejection {
147        ProtobufDecodeError,
148        BytesRejection,
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155    use crate::test_helpers::*;
156    use axum::{routing::post, Router};
157
158    #[tokio::test]
159    async fn decode_body() {
160        #[derive(prost::Message)]
161        struct Input {
162            #[prost(string, tag = "1")]
163            foo: String,
164        }
165
166        let app = Router::new().route(
167            "/",
168            post(|input: Protobuf<Input>| async move { input.foo.to_owned() }),
169        );
170
171        let input = Input {
172            foo: "bar".to_owned(),
173        };
174
175        let client = TestClient::new(app);
176        let res = client.post("/").body(input.encode_to_vec()).await;
177
178        let body = res.text().await;
179
180        assert_eq!(body, "bar");
181    }
182
183    #[tokio::test]
184    async fn prost_decode_error() {
185        #[derive(prost::Message)]
186        struct Input {
187            #[prost(string, tag = "1")]
188            foo: String,
189        }
190
191        #[derive(prost::Message)]
192        struct Expected {
193            #[prost(int32, tag = "1")]
194            test: i32,
195        }
196
197        let app = Router::new().route("/", post(|_: Protobuf<Expected>| async {}));
198
199        let input = Input {
200            foo: "bar".to_owned(),
201        };
202
203        let client = TestClient::new(app);
204        let res = client.post("/").body(input.encode_to_vec()).await;
205
206        assert_eq!(res.status(), StatusCode::UNPROCESSABLE_ENTITY);
207    }
208
209    #[tokio::test]
210    async fn encode_body() {
211        #[derive(prost::Message)]
212        struct Input {
213            #[prost(string, tag = "1")]
214            foo: String,
215        }
216
217        #[derive(prost::Message)]
218        struct Output {
219            #[prost(string, tag = "1")]
220            result: String,
221        }
222
223        #[axum::debug_handler]
224        async fn handler(input: Protobuf<Input>) -> Protobuf<Output> {
225            let output = Output {
226                result: input.foo.to_owned(),
227            };
228
229            Protobuf(output)
230        }
231
232        let app = Router::new().route("/", post(handler));
233
234        let input = Input {
235            foo: "bar".to_owned(),
236        };
237
238        let client = TestClient::new(app);
239        let res = client.post("/").body(input.encode_to_vec()).await;
240
241        assert_eq!(
242            res.headers()["content-type"],
243            mime::APPLICATION_OCTET_STREAM.as_ref()
244        );
245
246        let body = res.bytes().await;
247
248        let output = Output::decode(body).unwrap();
249
250        assert_eq!(output.result, "bar");
251    }
252}