reqwest_middleware/
client.rs

1use http::Extensions;
2use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
3use reqwest::{Body, Client, IntoUrl, Method, Request, Response};
4use serde::Serialize;
5use std::convert::TryFrom;
6use std::fmt::{self, Display};
7use std::sync::Arc;
8
9#[cfg(feature = "multipart")]
10use reqwest::multipart;
11
12use crate::error::Result;
13use crate::middleware::{Middleware, Next};
14use crate::RequestInitialiser;
15
16/// A `ClientBuilder` is used to build a [`ClientWithMiddleware`].
17///
18/// [`ClientWithMiddleware`]: crate::ClientWithMiddleware
19pub struct ClientBuilder {
20    client: Client,
21    middleware_stack: Vec<Arc<dyn Middleware>>,
22    initialiser_stack: Vec<Arc<dyn RequestInitialiser>>,
23}
24
25impl ClientBuilder {
26    pub fn new(client: Client) -> Self {
27        ClientBuilder {
28            client,
29            middleware_stack: Vec::new(),
30            initialiser_stack: Vec::new(),
31        }
32    }
33
34    /// This method allows creating a ClientBuilder
35    /// from an existing ClientWithMiddleware instance
36    pub fn from_client(client_with_middleware: ClientWithMiddleware) -> Self {
37        Self {
38            client: client_with_middleware.inner,
39            middleware_stack: client_with_middleware.middleware_stack.into_vec(),
40            initialiser_stack: client_with_middleware.initialiser_stack.into_vec(),
41        }
42    }
43
44    /// Convenience method to attach middleware.
45    ///
46    /// If you need to keep a reference to the middleware after attaching, use [`with_arc`].
47    ///
48    /// [`with_arc`]: Self::with_arc
49    pub fn with<M>(self, middleware: M) -> Self
50    where
51        M: Middleware,
52    {
53        self.with_arc(Arc::new(middleware))
54    }
55
56    /// Add middleware to the chain. [`with`] is more ergonomic if you don't need the `Arc`.
57    ///
58    /// [`with`]: Self::with
59    pub fn with_arc(mut self, middleware: Arc<dyn Middleware>) -> Self {
60        self.middleware_stack.push(middleware);
61        self
62    }
63
64    /// Convenience method to attach a request initialiser.
65    ///
66    /// If you need to keep a reference to the initialiser after attaching, use [`with_arc_init`].
67    ///
68    /// [`with_arc_init`]: Self::with_arc_init
69    pub fn with_init<I>(self, initialiser: I) -> Self
70    where
71        I: RequestInitialiser,
72    {
73        self.with_arc_init(Arc::new(initialiser))
74    }
75
76    /// Add a request initialiser to the chain. [`with_init`] is more ergonomic if you don't need the `Arc`.
77    ///
78    /// [`with_init`]: Self::with_init
79    pub fn with_arc_init(mut self, initialiser: Arc<dyn RequestInitialiser>) -> Self {
80        self.initialiser_stack.push(initialiser);
81        self
82    }
83
84    /// Returns a `ClientWithMiddleware` using this builder configuration.
85    pub fn build(self) -> ClientWithMiddleware {
86        ClientWithMiddleware {
87            inner: self.client,
88            middleware_stack: self.middleware_stack.into_boxed_slice(),
89            initialiser_stack: self.initialiser_stack.into_boxed_slice(),
90        }
91    }
92}
93
94/// `ClientWithMiddleware` is a wrapper around [`reqwest::Client`] which runs middleware on every
95/// request.
96#[derive(Clone, Default)]
97pub struct ClientWithMiddleware {
98    inner: reqwest::Client,
99    middleware_stack: Box<[Arc<dyn Middleware>]>,
100    initialiser_stack: Box<[Arc<dyn RequestInitialiser>]>,
101}
102
103impl ClientWithMiddleware {
104    /// See [`ClientBuilder`] for a more ergonomic way to build `ClientWithMiddleware` instances.
105    pub fn new<T>(client: Client, middleware_stack: T) -> Self
106    where
107        T: Into<Box<[Arc<dyn Middleware>]>>,
108    {
109        ClientWithMiddleware {
110            inner: client,
111            middleware_stack: middleware_stack.into(),
112            // TODO(conradludgate) - allow downstream code to control this manually if desired
113            initialiser_stack: Box::new([]),
114        }
115    }
116
117    /// Convenience method to make a `GET` request to a URL.
118    ///
119    /// # Errors
120    ///
121    /// This method fails whenever the supplied `Url` cannot be parsed.
122    pub fn get<U: IntoUrl>(&self, url: U) -> RequestBuilder {
123        self.request(Method::GET, url)
124    }
125
126    /// Convenience method to make a `POST` request to a URL.
127    ///
128    /// # Errors
129    ///
130    /// This method fails whenever the supplied `Url` cannot be parsed.
131    pub fn post<U: IntoUrl>(&self, url: U) -> RequestBuilder {
132        self.request(Method::POST, url)
133    }
134
135    /// Convenience method to make a `PUT` request to a URL.
136    ///
137    /// # Errors
138    ///
139    /// This method fails whenever the supplied `Url` cannot be parsed.
140    pub fn put<U: IntoUrl>(&self, url: U) -> RequestBuilder {
141        self.request(Method::PUT, url)
142    }
143
144    /// Convenience method to make a `PATCH` request to a URL.
145    ///
146    /// # Errors
147    ///
148    /// This method fails whenever the supplied `Url` cannot be parsed.
149    pub fn patch<U: IntoUrl>(&self, url: U) -> RequestBuilder {
150        self.request(Method::PATCH, url)
151    }
152
153    /// Convenience method to make a `DELETE` request to a URL.
154    ///
155    /// # Errors
156    ///
157    /// This method fails whenever the supplied `Url` cannot be parsed.
158    pub fn delete<U: IntoUrl>(&self, url: U) -> RequestBuilder {
159        self.request(Method::DELETE, url)
160    }
161
162    /// Convenience method to make a `HEAD` request to a URL.
163    ///
164    /// # Errors
165    ///
166    /// This method fails whenever the supplied `Url` cannot be parsed.
167    pub fn head<U: IntoUrl>(&self, url: U) -> RequestBuilder {
168        self.request(Method::HEAD, url)
169    }
170
171    /// Start building a `Request` with the `Method` and `Url`.
172    ///
173    /// Returns a `RequestBuilder`, which will allow setting headers and
174    /// the request body before sending.
175    ///
176    /// # Errors
177    ///
178    /// This method fails whenever the supplied `Url` cannot be parsed.
179    pub fn request<U: IntoUrl>(&self, method: Method, url: U) -> RequestBuilder {
180        let req = RequestBuilder {
181            inner: self.inner.request(method, url),
182            extensions: Extensions::new(),
183            middleware_stack: self.middleware_stack.clone(),
184            initialiser_stack: self.initialiser_stack.clone(),
185        };
186        self.initialiser_stack
187            .iter()
188            .fold(req, |req, i| i.init(req))
189    }
190
191    /// Executes a `Request`.
192    ///
193    /// A `Request` can be built manually with `Request::new()` or obtained
194    /// from a RequestBuilder with `RequestBuilder::build()`.
195    ///
196    /// You should prefer to use the `RequestBuilder` and
197    /// `RequestBuilder::send()`.
198    ///
199    /// # Errors
200    ///
201    /// This method fails if there was an error while sending request,
202    /// redirect loop was detected or redirect limit was exhausted.
203    pub async fn execute(&self, req: Request) -> Result<Response> {
204        let mut ext = Extensions::new();
205        self.execute_with_extensions(req, &mut ext).await
206    }
207
208    /// Executes a `Request` with initial [`Extensions`].
209    ///
210    /// A `Request` can be built manually with `Request::new()` or obtained
211    /// from a RequestBuilder with `RequestBuilder::build()`.
212    ///
213    /// You should prefer to use the `RequestBuilder` and
214    /// `RequestBuilder::send()`.
215    ///
216    /// # Errors
217    ///
218    /// This method fails if there was an error while sending request,
219    /// redirect loop was detected or redirect limit was exhausted.
220    pub async fn execute_with_extensions(
221        &self,
222        req: Request,
223        ext: &mut Extensions,
224    ) -> Result<Response> {
225        let next = Next::new(&self.inner, &self.middleware_stack);
226        next.run(req, ext).await
227    }
228}
229
230/// Create a `ClientWithMiddleware` without any middleware.
231impl From<Client> for ClientWithMiddleware {
232    fn from(client: Client) -> Self {
233        ClientWithMiddleware {
234            inner: client,
235            middleware_stack: Box::new([]),
236            initialiser_stack: Box::new([]),
237        }
238    }
239}
240
241impl fmt::Debug for ClientWithMiddleware {
242    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
243        // skipping middleware_stack field for now
244        f.debug_struct("ClientWithMiddleware")
245            .field("inner", &self.inner)
246            .finish_non_exhaustive()
247    }
248}
249
250#[cfg(not(target_arch = "wasm32"))]
251mod service {
252    use std::{
253        future::Future,
254        pin::Pin,
255        task::{Context, Poll},
256    };
257
258    use crate::Result;
259    use http::Extensions;
260    use reqwest::{Request, Response};
261
262    use crate::{middleware::BoxFuture, ClientWithMiddleware, Next};
263
264    // this is meant to be semi-private, same as reqwest's pending
265    pub struct Pending {
266        inner: BoxFuture<'static, Result<Response>>,
267    }
268
269    impl Unpin for Pending {}
270
271    impl Future for Pending {
272        type Output = Result<Response>;
273
274        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
275            self.inner.as_mut().poll(cx)
276        }
277    }
278
279    impl tower_service::Service<Request> for ClientWithMiddleware {
280        type Response = Response;
281        type Error = crate::Error;
282        type Future = Pending;
283
284        fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
285            self.inner.poll_ready(cx).map_err(crate::Error::Reqwest)
286        }
287
288        fn call(&mut self, req: Request) -> Self::Future {
289            let inner = self.inner.clone();
290            let middlewares = self.middleware_stack.clone();
291            let mut extensions = Extensions::new();
292            Pending {
293                inner: Box::pin(async move {
294                    let next = Next::new(&inner, &middlewares);
295                    next.run(req, &mut extensions).await
296                }),
297            }
298        }
299    }
300
301    impl tower_service::Service<Request> for &'_ ClientWithMiddleware {
302        type Response = Response;
303        type Error = crate::Error;
304        type Future = Pending;
305
306        fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
307            (&self.inner).poll_ready(cx).map_err(crate::Error::Reqwest)
308        }
309
310        fn call(&mut self, req: Request) -> Self::Future {
311            let inner = self.inner.clone();
312            let middlewares = self.middleware_stack.clone();
313            let mut extensions = Extensions::new();
314            Pending {
315                inner: Box::pin(async move {
316                    let next = Next::new(&inner, &middlewares);
317                    next.run(req, &mut extensions).await
318                }),
319            }
320        }
321    }
322}
323
324/// This is a wrapper around [`reqwest::RequestBuilder`] exposing the same API.
325#[must_use = "RequestBuilder does nothing until you 'send' it"]
326pub struct RequestBuilder {
327    inner: reqwest::RequestBuilder,
328    middleware_stack: Box<[Arc<dyn Middleware>]>,
329    initialiser_stack: Box<[Arc<dyn RequestInitialiser>]>,
330    extensions: Extensions,
331}
332
333impl RequestBuilder {
334    /// Assemble a builder starting from an existing `Client` and a `Request`.
335    pub fn from_parts(client: ClientWithMiddleware, request: Request) -> RequestBuilder {
336        let inner = reqwest::RequestBuilder::from_parts(client.inner, request);
337        RequestBuilder {
338            inner,
339            middleware_stack: client.middleware_stack,
340            initialiser_stack: client.initialiser_stack,
341            extensions: Extensions::new(),
342        }
343    }
344
345    /// Add a `Header` to this Request.
346    pub fn header<K, V>(self, key: K, value: V) -> Self
347    where
348        HeaderName: TryFrom<K>,
349        <HeaderName as TryFrom<K>>::Error: Into<http::Error>,
350        HeaderValue: TryFrom<V>,
351        <HeaderValue as TryFrom<V>>::Error: Into<http::Error>,
352    {
353        RequestBuilder {
354            inner: self.inner.header(key, value),
355            ..self
356        }
357    }
358
359    /// Add a set of Headers to the existing ones on this Request.
360    ///
361    /// The headers will be merged in to any already set.
362    pub fn headers(self, headers: HeaderMap) -> Self {
363        RequestBuilder {
364            inner: self.inner.headers(headers),
365            ..self
366        }
367    }
368
369    #[cfg(not(target_arch = "wasm32"))]
370    pub fn version(self, version: reqwest::Version) -> Self {
371        RequestBuilder {
372            inner: self.inner.version(version),
373            ..self
374        }
375    }
376
377    /// Enable HTTP basic authentication.
378    ///
379    /// ```rust
380    /// # use anyhow::Error;
381    ///
382    /// # async fn run() -> Result<(), Error> {
383    /// let client = reqwest_middleware::ClientWithMiddleware::from(reqwest::Client::new());
384    /// let resp = client.delete("http://httpbin.org/delete")
385    ///     .basic_auth("admin", Some("good password"))
386    ///     .send()
387    ///     .await?;
388    /// # Ok(())
389    /// # }
390    /// ```
391    pub fn basic_auth<U, P>(self, username: U, password: Option<P>) -> Self
392    where
393        U: Display,
394        P: Display,
395    {
396        RequestBuilder {
397            inner: self.inner.basic_auth(username, password),
398            ..self
399        }
400    }
401
402    /// Enable HTTP bearer authentication.
403    pub fn bearer_auth<T>(self, token: T) -> Self
404    where
405        T: Display,
406    {
407        RequestBuilder {
408            inner: self.inner.bearer_auth(token),
409            ..self
410        }
411    }
412
413    /// Set the request body.
414    pub fn body<T: Into<Body>>(self, body: T) -> Self {
415        RequestBuilder {
416            inner: self.inner.body(body),
417            ..self
418        }
419    }
420
421    /// Enables a request timeout.
422    ///
423    /// The timeout is applied from when the request starts connecting until the
424    /// response body has finished. It affects only this request and overrides
425    /// the timeout configured using `ClientBuilder::timeout()`.
426    #[cfg(not(target_arch = "wasm32"))]
427    pub fn timeout(self, timeout: std::time::Duration) -> Self {
428        RequestBuilder {
429            inner: self.inner.timeout(timeout),
430            ..self
431        }
432    }
433
434    #[cfg(feature = "multipart")]
435    #[cfg_attr(docsrs, doc(cfg(feature = "multipart")))]
436    pub fn multipart(self, multipart: multipart::Form) -> Self {
437        RequestBuilder {
438            inner: self.inner.multipart(multipart),
439            ..self
440        }
441    }
442
443    /// Modify the query string of the URL.
444    ///
445    /// Modifies the URL of this request, adding the parameters provided.
446    /// This method appends and does not overwrite. This means that it can
447    /// be called multiple times and that existing query parameters are not
448    /// overwritten if the same key is used. The key will simply show up
449    /// twice in the query string.
450    /// Calling `.query(&[("foo", "a"), ("foo", "b")])` gives `"foo=a&foo=b"`.
451    ///
452    /// # Note
453    /// This method does not support serializing a single key-value
454    /// pair. Instead of using `.query(("key", "val"))`, use a sequence, such
455    /// as `.query(&[("key", "val")])`. It's also possible to serialize structs
456    /// and maps into a key-value pair.
457    ///
458    /// # Errors
459    /// This method will fail if the object you provide cannot be serialized
460    /// into a query string.
461    pub fn query<T: Serialize + ?Sized>(self, query: &T) -> Self {
462        RequestBuilder {
463            inner: self.inner.query(query),
464            ..self
465        }
466    }
467
468    /// Send a form body.
469    ///
470    /// Sets the body to the url encoded serialization of the passed value,
471    /// and also sets the `Content-Type: application/x-www-form-urlencoded`
472    /// header.
473    ///
474    /// ```rust
475    /// # use anyhow::Error;
476    /// # use std::collections::HashMap;
477    /// #
478    /// # async fn run() -> Result<(), Error> {
479    /// let mut params = HashMap::new();
480    /// params.insert("lang", "rust");
481    ///
482    /// let client = reqwest_middleware::ClientWithMiddleware::from(reqwest::Client::new());
483    /// let res = client.post("http://httpbin.org")
484    ///     .form(&params)
485    ///     .send()
486    ///     .await?;
487    /// # Ok(())
488    /// # }
489    /// ```
490    ///
491    /// # Errors
492    ///
493    /// This method fails if the passed value cannot be serialized into
494    /// url encoded format
495    pub fn form<T: Serialize + ?Sized>(self, form: &T) -> Self {
496        RequestBuilder {
497            inner: self.inner.form(form),
498            ..self
499        }
500    }
501
502    /// Send a JSON body.
503    ///
504    /// # Optional
505    ///
506    /// This requires the optional `json` feature enabled.
507    ///
508    /// # Errors
509    ///
510    /// Serialization can fail if `T`'s implementation of `Serialize` decides to
511    /// fail, or if `T` contains a map with non-string keys.
512    #[cfg(feature = "json")]
513    #[cfg_attr(docsrs, doc(cfg(feature = "json")))]
514    pub fn json<T: Serialize + ?Sized>(self, json: &T) -> Self {
515        RequestBuilder {
516            inner: self.inner.json(json),
517            ..self
518        }
519    }
520
521    /// Disable CORS on fetching the request.
522    ///
523    /// # WASM
524    ///
525    /// This option is only effective with WebAssembly target.
526    ///
527    /// The [request mode][mdn] will be set to 'no-cors'.
528    ///
529    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/API/Request/mode
530    pub fn fetch_mode_no_cors(self) -> Self {
531        RequestBuilder {
532            inner: self.inner.fetch_mode_no_cors(),
533            ..self
534        }
535    }
536
537    /// Build a `Request`, which can be inspected, modified and executed with
538    /// `ClientWithMiddleware::execute()`.
539    pub fn build(self) -> reqwest::Result<Request> {
540        self.inner.build()
541    }
542
543    /// Build a `Request`, which can be inspected, modified and executed with
544    /// `ClientWithMiddleware::execute()`.
545    ///
546    /// This is similar to [`RequestBuilder::build()`], but also returns the
547    /// embedded `Client`.
548    pub fn build_split(self) -> (ClientWithMiddleware, reqwest::Result<Request>) {
549        let Self {
550            inner,
551            middleware_stack,
552            initialiser_stack,
553            ..
554        } = self;
555        let (inner, req) = inner.build_split();
556        let client = ClientWithMiddleware {
557            inner,
558            middleware_stack,
559            initialiser_stack,
560        };
561        (client, req)
562    }
563
564    /// Inserts the extension into this request builder
565    pub fn with_extension<T: Send + Sync + Clone + 'static>(mut self, extension: T) -> Self {
566        self.extensions.insert(extension);
567        self
568    }
569
570    /// Returns a mutable reference to the internal set of extensions for this request
571    pub fn extensions(&mut self) -> &mut Extensions {
572        &mut self.extensions
573    }
574
575    /// Constructs the Request and sends it to the target URL, returning a
576    /// future Response.
577    ///
578    /// # Errors
579    ///
580    /// This method fails if there was an error while sending request,
581    /// redirect loop was detected or redirect limit was exhausted.
582    ///
583    /// # Example
584    ///
585    /// ```no_run
586    /// # use anyhow::Error;
587    /// #
588    /// # async fn run() -> Result<(), Error> {
589    /// let response = reqwest_middleware::ClientWithMiddleware::from(reqwest::Client::new())
590    ///     .get("https://hyper.rs")
591    ///     .send()
592    ///     .await?;
593    /// # Ok(())
594    /// # }
595    /// ```
596    pub async fn send(mut self) -> Result<Response> {
597        let mut extensions = std::mem::take(self.extensions());
598        let (client, req) = self.build_split();
599        client.execute_with_extensions(req?, &mut extensions).await
600    }
601
602    /// Attempt to clone the RequestBuilder.
603    ///
604    /// `None` is returned if the RequestBuilder can not be cloned,
605    /// i.e. if the request body is a stream.
606    ///
607    /// # Examples
608    ///
609    /// ```
610    /// # use reqwest::Error;
611    /// #
612    /// # fn run() -> Result<(), Error> {
613    /// let client = reqwest_middleware::ClientWithMiddleware::from(reqwest::Client::new());
614    /// let builder = client.post("http://httpbin.org/post")
615    ///     .body("from a &str!");
616    /// let clone = builder.try_clone();
617    /// assert!(clone.is_some());
618    /// # Ok(())
619    /// # }
620    /// ```
621    pub fn try_clone(&self) -> Option<Self> {
622        self.inner.try_clone().map(|inner| RequestBuilder {
623            inner,
624            middleware_stack: self.middleware_stack.clone(),
625            initialiser_stack: self.initialiser_stack.clone(),
626            extensions: self.extensions.clone(),
627        })
628    }
629}
630
631impl fmt::Debug for RequestBuilder {
632    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
633        // skipping middleware_stack field for now
634        f.debug_struct("RequestBuilder")
635            .field("inner", &self.inner)
636            .finish_non_exhaustive()
637    }
638}