aws_smithy_runtime/client/
interceptors.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use aws_smithy_runtime_api::box_error::BoxError;
7use aws_smithy_runtime_api::client::interceptors::context::{
8    BeforeSerializationInterceptorContextRef, BeforeTransmitInterceptorContextMut,
9    FinalizerInterceptorContextMut, FinalizerInterceptorContextRef,
10};
11use aws_smithy_runtime_api::client::interceptors::context::{
12    Error, Input, InterceptorContext, Output,
13};
14use aws_smithy_runtime_api::client::interceptors::{
15    Intercept, InterceptorError, SharedInterceptor,
16};
17use aws_smithy_runtime_api::client::orchestrator::HttpRequest;
18use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
19use aws_smithy_types::body::SdkBody;
20use aws_smithy_types::config_bag::ConfigBag;
21use aws_smithy_types::error::display::DisplayErrorContext;
22use std::error::Error as StdError;
23use std::fmt;
24use std::marker::PhantomData;
25
26macro_rules! interceptor_impl_fn {
27    (mut $interceptor:ident) => {
28        pub(crate) fn $interceptor(
29            self,
30            ctx: &mut InterceptorContext,
31            runtime_components: &RuntimeComponents,
32            cfg: &mut ConfigBag,
33        ) -> Result<(), InterceptorError> {
34            tracing::trace!(concat!(
35                "running `",
36                stringify!($interceptor),
37                "` interceptors"
38            ));
39            let mut result: Result<(), (&str, BoxError)> = Ok(());
40            let mut ctx = ctx.into();
41            for interceptor in self.into_iter() {
42                if let Some(interceptor) = interceptor.if_enabled(cfg) {
43                    if let Err(new_error) =
44                        interceptor.$interceptor(&mut ctx, runtime_components, cfg)
45                    {
46                        if let Err(last_error) = result {
47                            tracing::debug!(
48                                "{}::{}: {}",
49                                last_error.0,
50                                stringify!($interceptor),
51                                DisplayErrorContext(&*last_error.1)
52                            );
53                        }
54                        result = Err((interceptor.name(), new_error));
55                    }
56                }
57            }
58            result.map_err(|(name, err)| InterceptorError::$interceptor(name, err))
59        }
60    };
61    (ref $interceptor:ident) => {
62        pub(crate) fn $interceptor(
63            self,
64            ctx: &InterceptorContext,
65            runtime_components: &RuntimeComponents,
66            cfg: &mut ConfigBag,
67        ) -> Result<(), InterceptorError> {
68            tracing::trace!(concat!(
69                "running `",
70                stringify!($interceptor),
71                "` interceptors"
72            ));
73            let mut result: Result<(), (&str, BoxError)> = Ok(());
74            let ctx = ctx.into();
75            for interceptor in self.into_iter() {
76                if let Some(interceptor) = interceptor.if_enabled(cfg) {
77                    if let Err(new_error) = interceptor.$interceptor(&ctx, runtime_components, cfg)
78                    {
79                        if let Err(last_error) = result {
80                            tracing::debug!(
81                                "{}::{}: {}",
82                                last_error.0,
83                                stringify!($interceptor),
84                                DisplayErrorContext(&*last_error.1)
85                            );
86                        }
87                        result = Err((interceptor.name(), new_error));
88                    }
89                }
90            }
91            result.map_err(|(name, err)| InterceptorError::$interceptor(name, err))
92        }
93    };
94}
95
96#[derive(Debug)]
97pub(crate) struct Interceptors<I> {
98    interceptors: I,
99}
100
101impl<I> Interceptors<I>
102where
103    I: Iterator<Item = SharedInterceptor>,
104{
105    pub(crate) fn new(interceptors: I) -> Self {
106        Self { interceptors }
107    }
108
109    fn into_iter(self) -> impl Iterator<Item = ConditionallyEnabledInterceptor> {
110        self.interceptors.map(ConditionallyEnabledInterceptor)
111    }
112
113    pub(crate) fn read_before_execution(
114        self,
115        operation: bool,
116        ctx: &InterceptorContext<Input, Output, Error>,
117        cfg: &mut ConfigBag,
118    ) -> Result<(), InterceptorError> {
119        tracing::trace!(
120            "running {} `read_before_execution` interceptors",
121            if operation { "operation" } else { "client" }
122        );
123        let mut result: Result<(), (&str, BoxError)> = Ok(());
124        let ctx: BeforeSerializationInterceptorContextRef<'_> = ctx.into();
125        for interceptor in self.into_iter() {
126            if let Some(interceptor) = interceptor.if_enabled(cfg) {
127                if let Err(new_error) = interceptor.read_before_execution(&ctx, cfg) {
128                    if let Err(last_error) = result {
129                        tracing::debug!(
130                            "{}::{}: {}",
131                            last_error.0,
132                            "read_before_execution",
133                            DisplayErrorContext(&*last_error.1)
134                        );
135                    }
136                    result = Err((interceptor.name(), new_error));
137                }
138            }
139        }
140        result.map_err(|(name, err)| InterceptorError::read_before_execution(name, err))
141    }
142
143    interceptor_impl_fn!(mut modify_before_serialization);
144    interceptor_impl_fn!(ref read_before_serialization);
145    interceptor_impl_fn!(ref read_after_serialization);
146    interceptor_impl_fn!(mut modify_before_retry_loop);
147    interceptor_impl_fn!(ref read_before_attempt);
148    interceptor_impl_fn!(mut modify_before_signing);
149    interceptor_impl_fn!(ref read_before_signing);
150    interceptor_impl_fn!(ref read_after_signing);
151    interceptor_impl_fn!(mut modify_before_transmit);
152    interceptor_impl_fn!(ref read_before_transmit);
153    interceptor_impl_fn!(ref read_after_transmit);
154    interceptor_impl_fn!(mut modify_before_deserialization);
155    interceptor_impl_fn!(ref read_before_deserialization);
156    interceptor_impl_fn!(ref read_after_deserialization);
157
158    pub(crate) fn modify_before_attempt_completion(
159        self,
160        ctx: &mut InterceptorContext<Input, Output, Error>,
161        runtime_components: &RuntimeComponents,
162        cfg: &mut ConfigBag,
163    ) -> Result<(), InterceptorError> {
164        tracing::trace!("running `modify_before_attempt_completion` interceptors");
165        let mut result: Result<(), (&str, BoxError)> = Ok(());
166        let mut ctx: FinalizerInterceptorContextMut<'_> = ctx.into();
167        for interceptor in self.into_iter() {
168            if let Some(interceptor) = interceptor.if_enabled(cfg) {
169                if let Err(new_error) =
170                    interceptor.modify_before_attempt_completion(&mut ctx, runtime_components, cfg)
171                {
172                    if let Err(last_error) = result {
173                        tracing::debug!(
174                            "{}::{}: {}",
175                            last_error.0,
176                            "modify_before_attempt_completion",
177                            DisplayErrorContext(&*last_error.1)
178                        );
179                    }
180                    result = Err((interceptor.name(), new_error));
181                }
182            }
183        }
184        result.map_err(|(name, err)| InterceptorError::modify_before_attempt_completion(name, err))
185    }
186
187    pub(crate) fn read_after_attempt(
188        self,
189        ctx: &InterceptorContext<Input, Output, Error>,
190        runtime_components: &RuntimeComponents,
191        cfg: &mut ConfigBag,
192    ) -> Result<(), InterceptorError> {
193        tracing::trace!("running `read_after_attempt` interceptors");
194        let mut result: Result<(), (&str, BoxError)> = Ok(());
195        let ctx: FinalizerInterceptorContextRef<'_> = ctx.into();
196        for interceptor in self.into_iter() {
197            if let Some(interceptor) = interceptor.if_enabled(cfg) {
198                if let Err(new_error) =
199                    interceptor.read_after_attempt(&ctx, runtime_components, cfg)
200                {
201                    if let Err(last_error) = result {
202                        tracing::debug!(
203                            "{}::{}: {}",
204                            last_error.0,
205                            "read_after_attempt",
206                            DisplayErrorContext(&*last_error.1)
207                        );
208                    }
209                    result = Err((interceptor.name(), new_error));
210                }
211            }
212        }
213        result.map_err(|(name, err)| InterceptorError::read_after_attempt(name, err))
214    }
215
216    pub(crate) fn modify_before_completion(
217        self,
218        ctx: &mut InterceptorContext<Input, Output, Error>,
219        runtime_components: &RuntimeComponents,
220        cfg: &mut ConfigBag,
221    ) -> Result<(), InterceptorError> {
222        tracing::trace!("running `modify_before_completion` interceptors");
223        let mut result: Result<(), (&str, BoxError)> = Ok(());
224        let mut ctx: FinalizerInterceptorContextMut<'_> = ctx.into();
225        for interceptor in self.into_iter() {
226            if let Some(interceptor) = interceptor.if_enabled(cfg) {
227                if let Err(new_error) =
228                    interceptor.modify_before_completion(&mut ctx, runtime_components, cfg)
229                {
230                    if let Err(last_error) = result {
231                        tracing::debug!(
232                            "{}::{}: {}",
233                            last_error.0,
234                            "modify_before_completion",
235                            DisplayErrorContext(&*last_error.1)
236                        );
237                    }
238                    result = Err((interceptor.name(), new_error));
239                }
240            }
241        }
242        result.map_err(|(name, err)| InterceptorError::modify_before_completion(name, err))
243    }
244
245    pub(crate) fn read_after_execution(
246        self,
247        ctx: &InterceptorContext<Input, Output, Error>,
248        runtime_components: &RuntimeComponents,
249        cfg: &mut ConfigBag,
250    ) -> Result<(), InterceptorError> {
251        tracing::trace!("running `read_after_execution` interceptors");
252        let mut result: Result<(), (&str, BoxError)> = Ok(());
253        let ctx: FinalizerInterceptorContextRef<'_> = ctx.into();
254        for interceptor in self.into_iter() {
255            if let Some(interceptor) = interceptor.if_enabled(cfg) {
256                if let Err(new_error) =
257                    interceptor.read_after_execution(&ctx, runtime_components, cfg)
258                {
259                    if let Err(last_error) = result {
260                        tracing::debug!(
261                            "{}::{}: {}",
262                            last_error.0,
263                            "read_after_execution",
264                            DisplayErrorContext(&*last_error.1)
265                        );
266                    }
267                    result = Err((interceptor.name(), new_error));
268                }
269            }
270        }
271        result.map_err(|(name, err)| InterceptorError::read_after_execution(name, err))
272    }
273}
274
275/// A interceptor wrapper to conditionally enable the interceptor based on
276/// [`DisableInterceptor`](aws_smithy_runtime_api::client::interceptors::DisableInterceptor)
277struct ConditionallyEnabledInterceptor(SharedInterceptor);
278impl ConditionallyEnabledInterceptor {
279    fn if_enabled(&self, cfg: &ConfigBag) -> Option<&dyn Intercept> {
280        if self.0.enabled(cfg) {
281            Some(&self.0)
282        } else {
283            None
284        }
285    }
286}
287
288/// Interceptor that maps the request with a given function.
289pub struct MapRequestInterceptor<F, E> {
290    f: F,
291    _phantom: PhantomData<E>,
292}
293
294impl<F, E> fmt::Debug for MapRequestInterceptor<F, E> {
295    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
296        write!(f, "MapRequestInterceptor")
297    }
298}
299
300impl<F, E> MapRequestInterceptor<F, E> {
301    /// Creates a new `MapRequestInterceptor`.
302    pub fn new(f: F) -> Self {
303        Self {
304            f,
305            _phantom: PhantomData,
306        }
307    }
308}
309
310impl<F, E> Intercept for MapRequestInterceptor<F, E>
311where
312    F: Fn(HttpRequest) -> Result<HttpRequest, E> + Send + Sync + 'static,
313    E: StdError + Send + Sync + 'static,
314{
315    fn name(&self) -> &'static str {
316        "MapRequestInterceptor"
317    }
318
319    fn modify_before_signing(
320        &self,
321        context: &mut BeforeTransmitInterceptorContextMut<'_>,
322        _runtime_components: &RuntimeComponents,
323        _cfg: &mut ConfigBag,
324    ) -> Result<(), BoxError> {
325        let mut request = HttpRequest::new(SdkBody::taken());
326        std::mem::swap(&mut request, context.request_mut());
327        let mut mapped = (self.f)(request)?;
328        std::mem::swap(&mut mapped, context.request_mut());
329
330        Ok(())
331    }
332}
333
334/// Interceptor that mutates the request with a given function.
335pub struct MutateRequestInterceptor<F> {
336    f: F,
337}
338
339impl<F> fmt::Debug for MutateRequestInterceptor<F> {
340    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
341        write!(f, "MutateRequestInterceptor")
342    }
343}
344
345impl<F> MutateRequestInterceptor<F> {
346    /// Creates a new `MutateRequestInterceptor`.
347    pub fn new(f: F) -> Self {
348        Self { f }
349    }
350}
351
352impl<F> Intercept for MutateRequestInterceptor<F>
353where
354    F: Fn(&mut HttpRequest) + Send + Sync + 'static,
355{
356    fn name(&self) -> &'static str {
357        "MutateRequestInterceptor"
358    }
359
360    fn modify_before_signing(
361        &self,
362        context: &mut BeforeTransmitInterceptorContextMut<'_>,
363        _runtime_components: &RuntimeComponents,
364        _cfg: &mut ConfigBag,
365    ) -> Result<(), BoxError> {
366        let request = context.request_mut();
367        (self.f)(request);
368
369        Ok(())
370    }
371}
372
373#[cfg(all(test, feature = "test-util"))]
374mod tests {
375    use super::*;
376    use aws_smithy_runtime_api::box_error::BoxError;
377    use aws_smithy_runtime_api::client::interceptors::context::{
378        BeforeTransmitInterceptorContextRef, Input, InterceptorContext,
379    };
380    use aws_smithy_runtime_api::client::interceptors::{
381        disable_interceptor, Intercept, SharedInterceptor,
382    };
383    use aws_smithy_runtime_api::client::runtime_components::{
384        RuntimeComponents, RuntimeComponentsBuilder,
385    };
386    use aws_smithy_types::config_bag::ConfigBag;
387
388    #[derive(Debug)]
389    struct TestInterceptor;
390    impl Intercept for TestInterceptor {
391        fn name(&self) -> &'static str {
392            "TestInterceptor"
393        }
394    }
395
396    #[test]
397    fn test_disable_interceptors() {
398        #[derive(Debug)]
399        struct PanicInterceptor;
400        impl Intercept for PanicInterceptor {
401            fn name(&self) -> &'static str {
402                "PanicInterceptor"
403            }
404
405            fn read_before_transmit(
406                &self,
407                _context: &BeforeTransmitInterceptorContextRef<'_>,
408                _rc: &RuntimeComponents,
409                _cfg: &mut ConfigBag,
410            ) -> Result<(), BoxError> {
411                Err("boom".into())
412            }
413        }
414        let rc = RuntimeComponentsBuilder::for_tests()
415            .with_interceptor(SharedInterceptor::new(PanicInterceptor))
416            .with_interceptor(SharedInterceptor::new(TestInterceptor))
417            .build()
418            .unwrap();
419
420        let mut cfg = ConfigBag::base();
421        let interceptors = Interceptors::new(rc.interceptors());
422        assert_eq!(
423            interceptors
424                .into_iter()
425                .filter(|i| i.if_enabled(&cfg).is_some())
426                .count(),
427            2
428        );
429
430        Interceptors::new(rc.interceptors())
431            .read_before_transmit(
432                &InterceptorContext::new(Input::doesnt_matter()),
433                &rc,
434                &mut cfg,
435            )
436            .expect_err("interceptor returns error");
437        cfg.interceptor_state()
438            .store_put(disable_interceptor::<PanicInterceptor>("test"));
439        assert_eq!(
440            Interceptors::new(rc.interceptors())
441                .into_iter()
442                .filter(|i| i.if_enabled(&cfg).is_some())
443                .count(),
444            1
445        );
446        // shouldn't error because interceptors won't run
447        Interceptors::new(rc.interceptors())
448            .read_before_transmit(
449                &InterceptorContext::new(Input::doesnt_matter()),
450                &rc,
451                &mut cfg,
452            )
453            .expect("interceptor is now disabled");
454    }
455}