aws_config/meta/credentials/
chain.rs1use aws_credential_types::{
7 provider::{self, error::CredentialsError, future, ProvideCredentials},
8 Credentials,
9};
10use aws_smithy_types::error::display::DisplayErrorContext;
11use std::borrow::Cow;
12use std::fmt::Debug;
13use tracing::Instrument;
14
15pub struct CredentialsProviderChain {
36 providers: Vec<(Cow<'static, str>, Box<dyn ProvideCredentials>)>,
37}
38
39impl Debug for CredentialsProviderChain {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 f.debug_struct("CredentialsProviderChain")
42 .field(
43 "providers",
44 &self
45 .providers
46 .iter()
47 .map(|provider| &provider.0)
48 .collect::<Vec<&Cow<'static, str>>>(),
49 )
50 .finish()
51 }
52}
53
54impl CredentialsProviderChain {
55 pub fn first_try(
57 name: impl Into<Cow<'static, str>>,
58 provider: impl ProvideCredentials + 'static,
59 ) -> Self {
60 CredentialsProviderChain {
61 providers: vec![(name.into(), Box::new(provider))],
62 }
63 }
64
65 pub fn or_else(
67 mut self,
68 name: impl Into<Cow<'static, str>>,
69 provider: impl ProvideCredentials + 'static,
70 ) -> Self {
71 self.providers.push((name.into(), Box::new(provider)));
72 self
73 }
74
75 #[cfg(feature = "rustls")]
77 pub async fn or_default_provider(self) -> Self {
78 self.or_else(
79 "DefaultProviderChain",
80 crate::default_provider::credentials::default_provider().await,
81 )
82 }
83
84 #[cfg(feature = "rustls")]
86 pub async fn default_provider() -> Self {
87 Self::first_try(
88 "DefaultProviderChain",
89 crate::default_provider::credentials::default_provider().await,
90 )
91 }
92
93 async fn credentials(&self) -> provider::Result {
94 for (name, provider) in &self.providers {
95 let span = tracing::debug_span!("load_credentials", provider = %name);
96 match provider.provide_credentials().instrument(span).await {
97 Ok(credentials) => {
98 tracing::debug!(provider = %name, "loaded credentials");
99 return Ok(credentials);
100 }
101 Err(err @ CredentialsError::CredentialsNotLoaded(_)) => {
102 tracing::debug!(provider = %name, context = %DisplayErrorContext(&err), "provider in chain did not provide credentials");
103 }
104 Err(err) => {
105 tracing::warn!(provider = %name, error = %DisplayErrorContext(&err), "provider failed to provide credentials");
106 return Err(err);
107 }
108 }
109 }
110 Err(CredentialsError::not_loaded(
111 "no providers in chain provided credentials",
112 ))
113 }
114}
115
116impl ProvideCredentials for CredentialsProviderChain {
117 fn provide_credentials<'a>(&'a self) -> future::ProvideCredentials<'a>
118 where
119 Self: 'a,
120 {
121 future::ProvideCredentials::new(self.credentials())
122 }
123
124 fn fallback_on_interrupt(&self) -> Option<Credentials> {
125 for (_, provider) in &self.providers {
126 match provider.fallback_on_interrupt() {
127 creds @ Some(_) => return creds,
128 None => {}
129 }
130 }
131 None
132 }
133}
134
135#[cfg(test)]
136mod tests {
137 use std::time::Duration;
138
139 use aws_credential_types::{
140 credential_fn::provide_credentials_fn,
141 provider::{error::CredentialsError, future, ProvideCredentials},
142 Credentials,
143 };
144 use aws_smithy_async::future::timeout::Timeout;
145
146 use crate::meta::credentials::CredentialsProviderChain;
147
148 #[derive(Debug)]
149 struct FallbackCredentials(Credentials);
150
151 impl ProvideCredentials for FallbackCredentials {
152 fn provide_credentials<'a>(&'a self) -> future::ProvideCredentials<'a>
153 where
154 Self: 'a,
155 {
156 future::ProvideCredentials::new(async {
157 tokio::time::sleep(Duration::from_millis(200)).await;
158 Ok(self.0.clone())
159 })
160 }
161
162 fn fallback_on_interrupt(&self) -> Option<Credentials> {
163 Some(self.0.clone())
164 }
165 }
166
167 #[tokio::test]
168 async fn fallback_credentials_should_be_returned_from_provider2_on_timeout_while_provider2_was_providing_credentials(
169 ) {
170 let chain = CredentialsProviderChain::first_try(
171 "provider1",
172 provide_credentials_fn(|| async {
173 tokio::time::sleep(Duration::from_millis(200)).await;
174 Err(CredentialsError::not_loaded(
175 "no providers in chain provided credentials",
176 ))
177 }),
178 )
179 .or_else("provider2", FallbackCredentials(Credentials::for_tests()));
180
181 let expected = chain.provide_credentials().await.unwrap();
183
184 let timeout = Timeout::new(
186 chain.provide_credentials(),
187 tokio::time::sleep(Duration::from_millis(300)),
188 );
189 match timeout.await {
190 Ok(_) => panic!("provide_credentials completed before timeout future"),
191 Err(_err) => match chain.fallback_on_interrupt() {
192 Some(actual) => assert_eq!(actual, expected),
193 None => panic!(
194 "provide_credentials timed out and no credentials returned from fallback_on_interrupt"
195 ),
196 },
197 };
198 }
199
200 #[tokio::test]
201 async fn fallback_credentials_should_be_returned_from_provider2_on_timeout_while_provider1_was_providing_credentials(
202 ) {
203 let chain = CredentialsProviderChain::first_try(
204 "provider1",
205 provide_credentials_fn(|| async {
206 tokio::time::sleep(Duration::from_millis(200)).await;
207 Err(CredentialsError::not_loaded(
208 "no providers in chain provided credentials",
209 ))
210 }),
211 )
212 .or_else("provider2", FallbackCredentials(Credentials::for_tests()));
213
214 let expected = chain.provide_credentials().await.unwrap();
216
217 let timeout = Timeout::new(
219 chain.provide_credentials(),
220 tokio::time::sleep(Duration::from_millis(100)),
221 );
222 match timeout.await {
223 Ok(_) => panic!("provide_credentials completed before timeout future"),
224 Err(_err) => match chain.fallback_on_interrupt() {
225 Some(actual) => assert_eq!(actual, expected),
226 None => panic!(
227 "provide_credentials timed out and no credentials returned from fallback_on_interrupt"
228 ),
229 },
230 };
231 }
232}