1use axum::{
4 extract::{rejection::BytesRejection, FromRequest, Request},
5 response::{IntoResponse, Response},
6};
7use axum_core::__composite_rejection as composite_rejection;
8use axum_core::__define_rejection as define_rejection;
9use bytes::{Bytes, BytesMut};
10use http::StatusCode;
11use prost::Message;
12
13#[derive(Debug, Clone, Copy, Default)]
90#[cfg_attr(docsrs, doc(cfg(feature = "protobuf")))]
91#[must_use]
92pub struct Protobuf<T>(pub T);
93
94impl<T, S> FromRequest<S> for Protobuf<T>
95where
96 T: Message + Default,
97 S: Send + Sync,
98{
99 type Rejection = ProtobufRejection;
100
101 async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
102 let mut bytes = Bytes::from_request(req, state).await?;
103
104 match T::decode(&mut bytes) {
105 Ok(value) => Ok(Protobuf(value)),
106 Err(err) => Err(ProtobufDecodeError::from_err(err).into()),
107 }
108 }
109}
110
111axum_core::__impl_deref!(Protobuf);
112
113impl<T> From<T> for Protobuf<T> {
114 fn from(inner: T) -> Self {
115 Self(inner)
116 }
117}
118
119impl<T> IntoResponse for Protobuf<T>
120where
121 T: Message + Default,
122{
123 fn into_response(self) -> Response {
124 let mut buf = BytesMut::with_capacity(128);
125 match &self.0.encode(&mut buf) {
126 Ok(()) => buf.into_response(),
127 Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response(),
128 }
129 }
130}
131
132define_rejection! {
133 #[status = UNPROCESSABLE_ENTITY]
134 #[body = "Failed to decode the body"]
135 pub struct ProtobufDecodeError(Error);
139}
140
141composite_rejection! {
142 pub enum ProtobufRejection {
147 ProtobufDecodeError,
148 BytesRejection,
149 }
150}
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155 use crate::test_helpers::*;
156 use axum::{routing::post, Router};
157
158 #[tokio::test]
159 async fn decode_body() {
160 #[derive(prost::Message)]
161 struct Input {
162 #[prost(string, tag = "1")]
163 foo: String,
164 }
165
166 let app = Router::new().route(
167 "/",
168 post(|input: Protobuf<Input>| async move { input.foo.to_owned() }),
169 );
170
171 let input = Input {
172 foo: "bar".to_owned(),
173 };
174
175 let client = TestClient::new(app);
176 let res = client.post("/").body(input.encode_to_vec()).await;
177
178 let body = res.text().await;
179
180 assert_eq!(body, "bar");
181 }
182
183 #[tokio::test]
184 async fn prost_decode_error() {
185 #[derive(prost::Message)]
186 struct Input {
187 #[prost(string, tag = "1")]
188 foo: String,
189 }
190
191 #[derive(prost::Message)]
192 struct Expected {
193 #[prost(int32, tag = "1")]
194 test: i32,
195 }
196
197 let app = Router::new().route("/", post(|_: Protobuf<Expected>| async {}));
198
199 let input = Input {
200 foo: "bar".to_owned(),
201 };
202
203 let client = TestClient::new(app);
204 let res = client.post("/").body(input.encode_to_vec()).await;
205
206 assert_eq!(res.status(), StatusCode::UNPROCESSABLE_ENTITY);
207 }
208
209 #[tokio::test]
210 async fn encode_body() {
211 #[derive(prost::Message)]
212 struct Input {
213 #[prost(string, tag = "1")]
214 foo: String,
215 }
216
217 #[derive(prost::Message)]
218 struct Output {
219 #[prost(string, tag = "1")]
220 result: String,
221 }
222
223 #[axum::debug_handler]
224 async fn handler(input: Protobuf<Input>) -> Protobuf<Output> {
225 let output = Output {
226 result: input.foo.to_owned(),
227 };
228
229 Protobuf(output)
230 }
231
232 let app = Router::new().route("/", post(handler));
233
234 let input = Input {
235 foo: "bar".to_owned(),
236 };
237
238 let client = TestClient::new(app);
239 let res = client.post("/").body(input.encode_to_vec()).await;
240
241 assert_eq!(
242 res.headers()["content-type"],
243 mime::APPLICATION_OCTET_STREAM.as_ref()
244 );
245
246 let body = res.bytes().await;
247
248 let output = Output::decode(body).unwrap();
249
250 assert_eq!(output.result, "bar");
251 }
252}