1use crate::{headers::HeaderExt, Method, Request, Response, Uri};
2use rama_core::{
3 error::{BoxError, ErrorExt, OpaqueError},
4 Context, Service,
5};
6use std::future::Future;
7
8pub trait HttpClientExt<State>:
12 private::HttpClientExtSealed<State> + Sized + Send + Sync + 'static
13{
14 type ExecuteResponse;
16 type ExecuteError;
18
19 fn get(&self, url: impl IntoUrl) -> RequestBuilder<Self, State, Self::ExecuteResponse>;
27
28 fn post(&self, url: impl IntoUrl) -> RequestBuilder<Self, State, Self::ExecuteResponse>;
36
37 fn put(&self, url: impl IntoUrl) -> RequestBuilder<Self, State, Self::ExecuteResponse>;
45
46 fn patch(&self, url: impl IntoUrl) -> RequestBuilder<Self, State, Self::ExecuteResponse>;
54
55 fn delete(&self, url: impl IntoUrl) -> RequestBuilder<Self, State, Self::ExecuteResponse>;
63
64 fn head(&self, url: impl IntoUrl) -> RequestBuilder<Self, State, Self::ExecuteResponse>;
70
71 fn request(
84 &self,
85 method: Method,
86 url: impl IntoUrl,
87 ) -> RequestBuilder<Self, State, Self::ExecuteResponse>;
88
89 fn execute(
95 &self,
96 ctx: Context<State>,
97 request: Request,
98 ) -> impl Future<Output = Result<Self::ExecuteResponse, Self::ExecuteError>>;
99}
100
101impl<State, S, Body> HttpClientExt<State> for S
102where
103 S: Service<State, Request, Response = Response<Body>, Error: Into<BoxError>>,
104{
105 type ExecuteResponse = Response<Body>;
106 type ExecuteError = S::Error;
107
108 fn get(&self, url: impl IntoUrl) -> RequestBuilder<Self, State, Self::ExecuteResponse> {
109 self.request(Method::GET, url)
110 }
111
112 fn post(&self, url: impl IntoUrl) -> RequestBuilder<Self, State, Self::ExecuteResponse> {
113 self.request(Method::POST, url)
114 }
115
116 fn put(&self, url: impl IntoUrl) -> RequestBuilder<Self, State, Self::ExecuteResponse> {
117 self.request(Method::PUT, url)
118 }
119
120 fn patch(&self, url: impl IntoUrl) -> RequestBuilder<Self, State, Self::ExecuteResponse> {
121 self.request(Method::PATCH, url)
122 }
123
124 fn delete(&self, url: impl IntoUrl) -> RequestBuilder<Self, State, Self::ExecuteResponse> {
125 self.request(Method::DELETE, url)
126 }
127
128 fn head(&self, url: impl IntoUrl) -> RequestBuilder<Self, State, Self::ExecuteResponse> {
129 self.request(Method::HEAD, url)
130 }
131
132 fn request(
133 &self,
134 method: Method,
135 url: impl IntoUrl,
136 ) -> RequestBuilder<Self, State, Self::ExecuteResponse> {
137 let uri = match url.into_url() {
138 Ok(uri) => uri,
139 Err(err) => {
140 return RequestBuilder {
141 http_client_service: self,
142 state: RequestBuilderState::Error(err),
143 _phantom: std::marker::PhantomData,
144 }
145 }
146 };
147
148 let builder = crate::dep::http::request::Builder::new()
149 .method(method)
150 .uri(uri);
151
152 RequestBuilder {
153 http_client_service: self,
154 state: RequestBuilderState::PreBody(builder),
155 _phantom: std::marker::PhantomData,
156 }
157 }
158
159 fn execute(
160 &self,
161 ctx: Context<State>,
162 request: Request,
163 ) -> impl Future<Output = Result<Self::ExecuteResponse, Self::ExecuteError>> {
164 Service::serve(self, ctx, request)
165 }
166}
167
168pub trait IntoUrl: private::IntoUrlSealed {}
174
175impl IntoUrl for Uri {}
176impl IntoUrl for &str {}
177impl IntoUrl for String {}
178impl IntoUrl for &String {}
179
180pub trait IntoHeaderName: private::IntoHeaderNameSealed {}
186
187impl IntoHeaderName for crate::HeaderName {}
188impl IntoHeaderName for Option<crate::HeaderName> {}
189impl IntoHeaderName for &str {}
190impl IntoHeaderName for String {}
191impl IntoHeaderName for &String {}
192impl IntoHeaderName for &[u8] {}
193
194pub trait IntoHeaderValue: private::IntoHeaderValueSealed {}
200
201impl IntoHeaderValue for crate::HeaderValue {}
202impl IntoHeaderValue for &str {}
203impl IntoHeaderValue for String {}
204impl IntoHeaderValue for &String {}
205impl IntoHeaderValue for &[u8] {}
206
207mod private {
208 use rama_http_types::HeaderName;
209 use rama_net::Protocol;
210
211 use super::*;
212
213 pub trait IntoUrlSealed {
214 fn into_url(self) -> Result<Uri, OpaqueError>;
215 }
216
217 impl IntoUrlSealed for Uri {
218 fn into_url(self) -> Result<Uri, OpaqueError> {
219 let protocol: Option<Protocol> = self.scheme().map(Into::into);
220 match protocol {
221 Some(protocol) => {
222 if protocol.is_http() {
223 Ok(self)
224 } else {
225 Err(OpaqueError::from_display(format!(
226 "Unsupported protocol: {protocol}"
227 )))
228 }
229 }
230 None => Err(OpaqueError::from_display("Missing scheme in URI")),
231 }
232 }
233 }
234
235 impl IntoUrlSealed for &str {
236 fn into_url(self) -> Result<Uri, OpaqueError> {
237 match self.parse::<Uri>() {
238 Ok(uri) => uri.into_url(),
239 Err(_) => Err(OpaqueError::from_display(format!("Invalid URL: {}", self))),
240 }
241 }
242 }
243
244 impl IntoUrlSealed for String {
245 fn into_url(self) -> Result<Uri, OpaqueError> {
246 self.as_str().into_url()
247 }
248 }
249
250 impl IntoUrlSealed for &String {
251 fn into_url(self) -> Result<Uri, OpaqueError> {
252 self.as_str().into_url()
253 }
254 }
255
256 pub trait IntoHeaderNameSealed {
257 fn into_header_name(self) -> Result<crate::HeaderName, OpaqueError>;
258 }
259
260 impl IntoHeaderNameSealed for HeaderName {
261 fn into_header_name(self) -> Result<crate::HeaderName, OpaqueError> {
262 Ok(self)
263 }
264 }
265
266 impl IntoHeaderNameSealed for Option<HeaderName> {
267 fn into_header_name(self) -> Result<crate::HeaderName, OpaqueError> {
268 match self {
269 Some(name) => Ok(name),
270 None => Err(OpaqueError::from_display("Header name is required")),
271 }
272 }
273 }
274
275 impl IntoHeaderNameSealed for &str {
276 fn into_header_name(self) -> Result<crate::HeaderName, OpaqueError> {
277 let name = self
278 .parse::<crate::HeaderName>()
279 .map_err(OpaqueError::from_std)?;
280 Ok(name)
281 }
282 }
283
284 impl IntoHeaderNameSealed for String {
285 fn into_header_name(self) -> Result<crate::HeaderName, OpaqueError> {
286 self.as_str().into_header_name()
287 }
288 }
289
290 impl IntoHeaderNameSealed for &String {
291 fn into_header_name(self) -> Result<crate::HeaderName, OpaqueError> {
292 self.as_str().into_header_name()
293 }
294 }
295
296 impl IntoHeaderNameSealed for &[u8] {
297 fn into_header_name(self) -> Result<crate::HeaderName, OpaqueError> {
298 let name = crate::HeaderName::from_bytes(self).map_err(OpaqueError::from_std)?;
299 Ok(name)
300 }
301 }
302
303 pub trait IntoHeaderValueSealed {
304 fn into_header_value(self) -> Result<crate::HeaderValue, OpaqueError>;
305 }
306
307 impl IntoHeaderValueSealed for crate::HeaderValue {
308 fn into_header_value(self) -> Result<crate::HeaderValue, OpaqueError> {
309 Ok(self)
310 }
311 }
312
313 impl IntoHeaderValueSealed for &str {
314 fn into_header_value(self) -> Result<crate::HeaderValue, OpaqueError> {
315 let value = self
316 .parse::<crate::HeaderValue>()
317 .map_err(OpaqueError::from_std)?;
318 Ok(value)
319 }
320 }
321
322 impl IntoHeaderValueSealed for String {
323 fn into_header_value(self) -> Result<crate::HeaderValue, OpaqueError> {
324 self.as_str().into_header_value()
325 }
326 }
327
328 impl IntoHeaderValueSealed for &String {
329 fn into_header_value(self) -> Result<crate::HeaderValue, OpaqueError> {
330 self.as_str().into_header_value()
331 }
332 }
333
334 impl IntoHeaderValueSealed for &[u8] {
335 fn into_header_value(self) -> Result<crate::HeaderValue, OpaqueError> {
336 let value = crate::HeaderValue::from_bytes(self).map_err(OpaqueError::from_std)?;
337 Ok(value)
338 }
339 }
340
341 pub trait HttpClientExtSealed<State> {}
342
343 impl<State, S, Body> HttpClientExtSealed<State> for S where
344 S: Service<State, Request, Response = Response<Body>, Error: Into<BoxError>>
345 {
346 }
347}
348
349pub struct RequestBuilder<'a, S, State, Response> {
353 http_client_service: &'a S,
354 state: RequestBuilderState,
355 _phantom: std::marker::PhantomData<fn(State, Response) -> ()>,
356}
357
358impl<S, State, Response> std::fmt::Debug for RequestBuilder<'_, S, State, Response>
359where
360 S: std::fmt::Debug,
361{
362 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
363 f.debug_struct("RequestBuilder")
364 .field("http_client_service", &self.http_client_service)
365 .field("state", &self.state)
366 .finish()
367 }
368}
369
370#[derive(Debug)]
371enum RequestBuilderState {
372 PreBody(crate::dep::http::request::Builder),
373 PostBody(crate::Request),
374 Error(OpaqueError),
375}
376
377impl<'a, S, State, Body> RequestBuilder<'a, S, State, Response<Body>>
378where
379 S: Service<State, Request, Response = Response<Body>, Error: Into<BoxError>>,
380{
381 pub fn header<K, V>(mut self, key: K, value: V) -> Self
383 where
384 K: IntoHeaderName,
385 V: IntoHeaderValue,
386 {
387 match self.state {
388 RequestBuilderState::PreBody(builder) => {
389 let key = match key.into_header_name() {
390 Ok(key) => key,
391 Err(err) => {
392 self.state = RequestBuilderState::Error(err);
393 return self;
394 }
395 };
396 let value = match value.into_header_value() {
397 Ok(value) => value,
398 Err(err) => {
399 self.state = RequestBuilderState::Error(err);
400 return self;
401 }
402 };
403 self.state = RequestBuilderState::PreBody(builder.header(key, value));
404 self
405 }
406 RequestBuilderState::PostBody(mut request) => {
407 let key = match key.into_header_name() {
408 Ok(key) => key,
409 Err(err) => {
410 self.state = RequestBuilderState::Error(err);
411 return self;
412 }
413 };
414 let value = match value.into_header_value() {
415 Ok(value) => value,
416 Err(err) => {
417 self.state = RequestBuilderState::Error(err);
418 return self;
419 }
420 };
421 request.headers_mut().append(key, value);
422 self.state = RequestBuilderState::PostBody(request);
423 self
424 }
425 RequestBuilderState::Error(err) => {
426 self.state = RequestBuilderState::Error(err);
427 self
428 }
429 }
430 }
431
432 pub fn typed_header<H>(self, header: H) -> Self
436 where
437 H: crate::headers::Header,
438 {
439 self.header(H::name().clone(), header.encode_to_value())
440 }
441
442 pub fn headers(mut self, headers: crate::HeaderMap) -> Self {
446 for (key, value) in headers.into_iter() {
447 self = self.header(key, value);
448 }
449 self
450 }
451
452 pub fn basic_auth<U, P>(self, username: U, password: P) -> Self
454 where
455 U: AsRef<str>,
456 P: AsRef<str>,
457 {
458 let header = crate::headers::Authorization::basic(username.as_ref(), password.as_ref());
459 self.typed_header(header)
460 }
461
462 pub fn bearer_auth<T>(mut self, token: T) -> Self
464 where
465 T: AsRef<str>,
466 {
467 let header = match crate::headers::Authorization::bearer(token.as_ref()) {
468 Ok(header) => header,
469 Err(err) => {
470 self.state = match self.state {
471 RequestBuilderState::Error(original_err) => {
472 RequestBuilderState::Error(original_err)
473 }
474 _ => RequestBuilderState::Error(OpaqueError::from_std(err)),
475 };
476 return self;
477 }
478 };
479
480 self.typed_header(header)
481 }
482
483 pub fn body<T>(mut self, body: T) -> Self
487 where
488 T: TryInto<crate::Body, Error: Into<BoxError>>,
489 {
490 self.state = match self.state {
491 RequestBuilderState::PreBody(builder) => match body.try_into() {
492 Ok(body) => match builder.body(body) {
493 Ok(req) => RequestBuilderState::PostBody(req),
494 Err(err) => RequestBuilderState::Error(OpaqueError::from_std(err)),
495 },
496 Err(err) => RequestBuilderState::Error(OpaqueError::from_boxed(err.into())),
497 },
498 RequestBuilderState::PostBody(mut req) => match body.try_into() {
499 Ok(body) => {
500 *req.body_mut() = body;
501 RequestBuilderState::PostBody(req)
502 }
503 Err(err) => RequestBuilderState::Error(OpaqueError::from_boxed(err.into())),
504 },
505 RequestBuilderState::Error(err) => RequestBuilderState::Error(err),
506 };
507 self
508 }
509
510 pub fn form<T: serde::Serialize + ?Sized>(mut self, form: &T) -> Self {
514 self.state = match self.state {
515 RequestBuilderState::PreBody(mut builder) => match serde_html_form::to_string(form) {
516 Ok(body) => {
517 let builder = match builder.headers_mut() {
518 Some(headers) => {
519 if !headers.contains_key(crate::header::CONTENT_TYPE) {
520 headers.insert(
521 crate::header::CONTENT_TYPE,
522 crate::HeaderValue::from_static(
523 "application/x-www-form-urlencoded",
524 ),
525 );
526 }
527 builder
528 }
529 None => builder.header(
530 crate::header::CONTENT_TYPE,
531 crate::HeaderValue::from_static("application/x-www-form-urlencoded"),
532 ),
533 };
534 match builder.body(body.into()) {
535 Ok(req) => RequestBuilderState::PostBody(req),
536 Err(err) => RequestBuilderState::Error(OpaqueError::from_std(err)),
537 }
538 }
539 Err(err) => RequestBuilderState::Error(OpaqueError::from_std(err)),
540 },
541 RequestBuilderState::PostBody(mut req) => match serde_html_form::to_string(form) {
542 Ok(body) => {
543 if !req.headers().contains_key(crate::header::CONTENT_TYPE) {
544 req.headers_mut().insert(
545 crate::header::CONTENT_TYPE,
546 crate::HeaderValue::from_static("application/x-www-form-urlencoded"),
547 );
548 }
549 *req.body_mut() = body.into();
550 RequestBuilderState::PostBody(req)
551 }
552 Err(err) => RequestBuilderState::Error(OpaqueError::from_std(err)),
553 },
554 RequestBuilderState::Error(err) => RequestBuilderState::Error(err),
555 };
556 self
557 }
558
559 pub fn json<T: serde::Serialize + ?Sized>(mut self, json: &T) -> Self {
563 self.state = match self.state {
564 RequestBuilderState::PreBody(mut builder) => match serde_json::to_vec(json) {
565 Ok(body) => {
566 let builder = match builder.headers_mut() {
567 Some(headers) => {
568 if !headers.contains_key(crate::header::CONTENT_TYPE) {
569 headers.insert(
570 crate::header::CONTENT_TYPE,
571 crate::HeaderValue::from_static("application/json"),
572 );
573 }
574 builder
575 }
576 None => builder.header(
577 crate::header::CONTENT_TYPE,
578 crate::HeaderValue::from_static("application/json"),
579 ),
580 };
581 match builder.body(body.into()) {
582 Ok(req) => RequestBuilderState::PostBody(req),
583 Err(err) => RequestBuilderState::Error(OpaqueError::from_std(err)),
584 }
585 }
586 Err(err) => RequestBuilderState::Error(OpaqueError::from_std(err)),
587 },
588 RequestBuilderState::PostBody(mut req) => match serde_json::to_vec(json) {
589 Ok(body) => {
590 if !req.headers().contains_key(crate::header::CONTENT_TYPE) {
591 req.headers_mut().insert(
592 crate::header::CONTENT_TYPE,
593 crate::HeaderValue::from_static("application/json"),
594 );
595 }
596 *req.body_mut() = body.into();
597 RequestBuilderState::PostBody(req)
598 }
599 Err(err) => RequestBuilderState::Error(OpaqueError::from_std(err)),
600 },
601 RequestBuilderState::Error(err) => RequestBuilderState::Error(err),
602 };
603 self
604 }
605
606 pub fn version(mut self, version: crate::Version) -> Self {
610 match self.state {
611 RequestBuilderState::PreBody(builder) => {
612 self.state = RequestBuilderState::PreBody(builder.version(version));
613 self
614 }
615 RequestBuilderState::PostBody(mut request) => {
616 *request.version_mut() = version;
617 self.state = RequestBuilderState::PostBody(request);
618 self
619 }
620 RequestBuilderState::Error(err) => {
621 self.state = RequestBuilderState::Error(err);
622 self
623 }
624 }
625 }
626
627 pub async fn send(self, ctx: Context<State>) -> Result<Response<Body>, OpaqueError> {
633 let request = match self.state {
634 RequestBuilderState::PreBody(builder) => builder
635 .body(crate::Body::empty())
636 .map_err(OpaqueError::from_std)?,
637 RequestBuilderState::PostBody(request) => request,
638 RequestBuilderState::Error(err) => return Err(err),
639 };
640
641 let uri = request.uri().clone();
642 match self.http_client_service.serve(ctx, request).await {
643 Ok(response) => Ok(response),
644 Err(err) => Err(OpaqueError::from_boxed(err.into()).context(uri.to_string())),
645 }
646 }
647}
648
649#[cfg(test)]
650mod test {
651 use rama_http_types::StatusCode;
652
653 use super::*;
654 use crate::{
655 layer::{
656 required_header::AddRequiredRequestHeadersLayer,
657 retry::{ManagedPolicy, RetryLayer},
658 trace::TraceLayer,
659 },
660 IntoResponse,
661 };
662 use rama_core::{
663 layer::{Layer, MapResultLayer},
664 service::{service_fn, BoxService},
665 };
666 use rama_utils::backoff::ExponentialBackoff;
667 use std::convert::Infallible;
668
669 async fn fake_client_fn<S, Body>(
670 _ctx: Context<S>,
671 request: Request<Body>,
672 ) -> Result<Response, Infallible>
673 where
674 S: Clone + Send + Sync + 'static,
675 Body: crate::dep::http_body::Body<Data: Send + 'static, Error: Send + 'static>
676 + Send
677 + 'static,
678 {
679 let ua = request.headers().get(crate::header::USER_AGENT).unwrap();
680 assert_eq!(
681 ua.to_str().unwrap(),
682 format!("{}/{}", rama_utils::info::NAME, rama_utils::info::VERSION)
683 );
684
685 Ok(StatusCode::OK.into_response())
686 }
687
688 fn map_internal_client_error<E, Body>(
689 result: Result<Response<Body>, E>,
690 ) -> Result<Response, rama_core::error::BoxError>
691 where
692 E: Into<rama_core::error::BoxError>,
693 Body: crate::dep::http_body::Body<Data = bytes::Bytes, Error: Into<BoxError>>
694 + Send
695 + Sync
696 + 'static,
697 {
698 match result {
699 Ok(response) => Ok(response.map(crate::Body::new)),
700 Err(err) => Err(err.into()),
701 }
702 }
703
704 type OpaqueError = rama_core::error::BoxError;
705 type HttpClient<S> = BoxService<S, Request, Response, OpaqueError>;
706
707 fn client<S: Clone + Send + Sync + 'static>() -> HttpClient<S> {
708 let builder = (
709 MapResultLayer::new(map_internal_client_error),
710 TraceLayer::new_for_http(),
711 );
712
713 #[cfg(feature = "compression")]
714 let builder = (
715 builder,
716 crate::layer::decompression::DecompressionLayer::new(),
717 );
718
719 (
720 builder,
721 RetryLayer::new(ManagedPolicy::default().with_backoff(ExponentialBackoff::default())),
722 AddRequiredRequestHeadersLayer::default(),
723 )
724 .layer(service_fn(fake_client_fn))
725 .boxed()
726 }
727
728 #[tokio::test]
729 async fn test_client_happy_path() {
730 let response = client()
731 .get("http://127.0.0.1:8080")
732 .send(Context::default())
733 .await
734 .unwrap();
735 assert_eq!(response.status(), StatusCode::OK);
736 }
737}