aws_config/meta/credentials/
chain.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use aws_credential_types::{
7    provider::{self, error::CredentialsError, future, ProvideCredentials},
8    Credentials,
9};
10use aws_smithy_types::error::display::DisplayErrorContext;
11use std::borrow::Cow;
12use std::fmt::Debug;
13use tracing::Instrument;
14
15/// Credentials provider that checks a series of inner providers
16///
17/// Each provider will be evaluated in order:
18/// * If a provider returns valid [`Credentials`] they will be returned immediately.
19///   No other credential providers will be used.
20/// * Otherwise, if a provider returns [`CredentialsError::CredentialsNotLoaded`], the next provider will be checked.
21/// * Finally, if a provider returns any other error condition, an error will be returned immediately.
22///
23/// # Examples
24///
25/// ```no_run
26/// # fn example() {
27/// use aws_config::meta::credentials::CredentialsProviderChain;
28/// use aws_config::environment::credentials::EnvironmentVariableCredentialsProvider;
29/// use aws_config::profile::ProfileFileCredentialsProvider;
30///
31/// let provider = CredentialsProviderChain::first_try("Environment", EnvironmentVariableCredentialsProvider::new())
32///     .or_else("Profile", ProfileFileCredentialsProvider::builder().build());
33/// # }
34/// ```
35pub struct CredentialsProviderChain {
36    providers: Vec<(Cow<'static, str>, Box<dyn ProvideCredentials>)>,
37}
38
39impl Debug for CredentialsProviderChain {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        f.debug_struct("CredentialsProviderChain")
42            .field(
43                "providers",
44                &self
45                    .providers
46                    .iter()
47                    .map(|provider| &provider.0)
48                    .collect::<Vec<&Cow<'static, str>>>(),
49            )
50            .finish()
51    }
52}
53
54impl CredentialsProviderChain {
55    /// Create a `CredentialsProviderChain` that begins by evaluating this provider
56    pub fn first_try(
57        name: impl Into<Cow<'static, str>>,
58        provider: impl ProvideCredentials + 'static,
59    ) -> Self {
60        CredentialsProviderChain {
61            providers: vec![(name.into(), Box::new(provider))],
62        }
63    }
64
65    /// Add a fallback provider to the credentials provider chain
66    pub fn or_else(
67        mut self,
68        name: impl Into<Cow<'static, str>>,
69        provider: impl ProvideCredentials + 'static,
70    ) -> Self {
71        self.providers.push((name.into(), Box::new(provider)));
72        self
73    }
74
75    /// Add a fallback to the default provider chain
76    #[cfg(feature = "rustls")]
77    pub async fn or_default_provider(self) -> Self {
78        self.or_else(
79            "DefaultProviderChain",
80            crate::default_provider::credentials::default_provider().await,
81        )
82    }
83
84    /// Creates a credential provider chain that starts with the default provider
85    #[cfg(feature = "rustls")]
86    pub async fn default_provider() -> Self {
87        Self::first_try(
88            "DefaultProviderChain",
89            crate::default_provider::credentials::default_provider().await,
90        )
91    }
92
93    async fn credentials(&self) -> provider::Result {
94        for (name, provider) in &self.providers {
95            let span = tracing::debug_span!("load_credentials", provider = %name);
96            match provider.provide_credentials().instrument(span).await {
97                Ok(credentials) => {
98                    tracing::debug!(provider = %name, "loaded credentials");
99                    return Ok(credentials);
100                }
101                Err(err @ CredentialsError::CredentialsNotLoaded(_)) => {
102                    tracing::debug!(provider = %name, context = %DisplayErrorContext(&err), "provider in chain did not provide credentials");
103                }
104                Err(err) => {
105                    tracing::warn!(provider = %name, error = %DisplayErrorContext(&err), "provider failed to provide credentials");
106                    return Err(err);
107                }
108            }
109        }
110        Err(CredentialsError::not_loaded(
111            "no providers in chain provided credentials",
112        ))
113    }
114}
115
116impl ProvideCredentials for CredentialsProviderChain {
117    fn provide_credentials<'a>(&'a self) -> future::ProvideCredentials<'a>
118    where
119        Self: 'a,
120    {
121        future::ProvideCredentials::new(self.credentials())
122    }
123
124    fn fallback_on_interrupt(&self) -> Option<Credentials> {
125        for (_, provider) in &self.providers {
126            match provider.fallback_on_interrupt() {
127                creds @ Some(_) => return creds,
128                None => {}
129            }
130        }
131        None
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    use std::time::Duration;
138
139    use aws_credential_types::{
140        credential_fn::provide_credentials_fn,
141        provider::{error::CredentialsError, future, ProvideCredentials},
142        Credentials,
143    };
144    use aws_smithy_async::future::timeout::Timeout;
145
146    use crate::meta::credentials::CredentialsProviderChain;
147
148    #[derive(Debug)]
149    struct FallbackCredentials(Credentials);
150
151    impl ProvideCredentials for FallbackCredentials {
152        fn provide_credentials<'a>(&'a self) -> future::ProvideCredentials<'a>
153        where
154            Self: 'a,
155        {
156            future::ProvideCredentials::new(async {
157                tokio::time::sleep(Duration::from_millis(200)).await;
158                Ok(self.0.clone())
159            })
160        }
161
162        fn fallback_on_interrupt(&self) -> Option<Credentials> {
163            Some(self.0.clone())
164        }
165    }
166
167    #[tokio::test]
168    async fn fallback_credentials_should_be_returned_from_provider2_on_timeout_while_provider2_was_providing_credentials(
169    ) {
170        let chain = CredentialsProviderChain::first_try(
171            "provider1",
172            provide_credentials_fn(|| async {
173                tokio::time::sleep(Duration::from_millis(200)).await;
174                Err(CredentialsError::not_loaded(
175                    "no providers in chain provided credentials",
176                ))
177            }),
178        )
179        .or_else("provider2", FallbackCredentials(Credentials::for_tests()));
180
181        // Let the first call to `provide_credentials` succeed.
182        let expected = chain.provide_credentials().await.unwrap();
183
184        // Let the second call fail with an external timeout.
185        let timeout = Timeout::new(
186            chain.provide_credentials(),
187            tokio::time::sleep(Duration::from_millis(300)),
188        );
189        match timeout.await {
190            Ok(_) => panic!("provide_credentials completed before timeout future"),
191            Err(_err) => match chain.fallback_on_interrupt() {
192                Some(actual) => assert_eq!(actual, expected),
193                None => panic!(
194                    "provide_credentials timed out and no credentials returned from fallback_on_interrupt"
195                ),
196            },
197        };
198    }
199
200    #[tokio::test]
201    async fn fallback_credentials_should_be_returned_from_provider2_on_timeout_while_provider1_was_providing_credentials(
202    ) {
203        let chain = CredentialsProviderChain::first_try(
204            "provider1",
205            provide_credentials_fn(|| async {
206                tokio::time::sleep(Duration::from_millis(200)).await;
207                Err(CredentialsError::not_loaded(
208                    "no providers in chain provided credentials",
209                ))
210            }),
211        )
212        .or_else("provider2", FallbackCredentials(Credentials::for_tests()));
213
214        // Let the first call to `provide_credentials` succeed.
215        let expected = chain.provide_credentials().await.unwrap();
216
217        // Let the second call fail with an external timeout.
218        let timeout = Timeout::new(
219            chain.provide_credentials(),
220            tokio::time::sleep(Duration::from_millis(100)),
221        );
222        match timeout.await {
223            Ok(_) => panic!("provide_credentials completed before timeout future"),
224            Err(_err) => match chain.fallback_on_interrupt() {
225                Some(actual) => assert_eq!(actual, expected),
226                None => panic!(
227                    "provide_credentials timed out and no credentials returned from fallback_on_interrupt"
228                ),
229            },
230        };
231    }
232}