1use aws_credential_types::provider::{
9 self, error::CredentialsError, future, ProvideCredentials, SharedCredentialsProvider,
10};
11use aws_sdk_sts::operation::assume_role::builders::AssumeRoleFluentBuilder;
12use aws_sdk_sts::operation::assume_role::AssumeRoleError;
13use aws_sdk_sts::types::PolicyDescriptorType;
14use aws_sdk_sts::Client as StsClient;
15use aws_smithy_runtime::client::identity::IdentityCache;
16use aws_smithy_runtime_api::client::result::SdkError;
17use aws_smithy_types::error::display::DisplayErrorContext;
18use aws_types::region::Region;
19use aws_types::SdkConfig;
20use std::time::Duration;
21use tracing::Instrument;
22
23#[derive(Debug)]
70pub struct AssumeRoleProvider {
71 inner: Inner,
72}
73
74#[derive(Debug)]
75struct Inner {
76 fluent_builder: AssumeRoleFluentBuilder,
77}
78
79impl AssumeRoleProvider {
80 pub fn builder(role: impl Into<String>) -> AssumeRoleProviderBuilder {
88 AssumeRoleProviderBuilder::new(role.into())
89 }
90}
91
92#[derive(Debug)]
96pub struct AssumeRoleProviderBuilder {
97 role_arn: String,
98 external_id: Option<String>,
99 session_name: Option<String>,
100 session_length: Option<Duration>,
101 policy: Option<String>,
102 policy_arns: Option<Vec<PolicyDescriptorType>>,
103 region_override: Option<Region>,
104 sdk_config: Option<SdkConfig>,
105}
106
107impl AssumeRoleProviderBuilder {
108 pub fn new(role: impl Into<String>) -> Self {
116 Self {
117 role_arn: role.into(),
118 external_id: None,
119 session_name: None,
120 session_length: None,
121 policy: None,
122 policy_arns: None,
123 sdk_config: None,
124 region_override: None,
125 }
126 }
127
128 pub fn external_id(mut self, id: impl Into<String>) -> Self {
134 self.external_id = Some(id.into());
135 self
136 }
137
138 pub fn session_name(mut self, name: impl Into<String>) -> Self {
145 self.session_name = Some(name.into());
146 self
147 }
148
149 pub fn policy(mut self, policy: impl Into<String>) -> Self {
155 self.policy = Some(policy.into());
156 self
157 }
158
159 pub fn policy_arns(mut self, policy_arns: Vec<String>) -> Self {
165 self.policy_arns = Some(
166 policy_arns
167 .into_iter()
168 .map(|arn| PolicyDescriptorType::builder().arn(arn).build())
169 .collect::<Vec<_>>(),
170 );
171 self
172 }
173
174 pub fn session_length(mut self, length: Duration) -> Self {
187 self.session_length = Some(length);
188 self
189 }
190
191 pub fn region(mut self, region: Region) -> Self {
196 self.region_override = Some(region);
197 self
198 }
199
200 pub fn configure(mut self, conf: &SdkConfig) -> Self {
218 self.sdk_config = Some(conf.clone());
219 self
220 }
221
222 pub async fn build(self) -> AssumeRoleProvider {
227 let mut conf = match self.sdk_config {
228 Some(conf) => conf,
229 None => crate::load_defaults(crate::BehaviorVersion::latest()).await,
230 };
231 conf = conf
233 .into_builder()
234 .identity_cache(IdentityCache::no_cache())
235 .build();
236
237 if let Some(region) = self.region_override {
239 conf = conf.into_builder().region(region).build()
240 }
241
242 let config = aws_sdk_sts::config::Builder::from(&conf);
243
244 let time_source = conf.time_source().expect("A time source must be provided.");
245
246 let session_name = self.session_name.unwrap_or_else(|| {
247 super::util::default_session_name("assume-role-provider", time_source.now())
248 });
249
250 let sts_client = StsClient::from_conf(config.build());
251 let fluent_builder = sts_client
252 .assume_role()
253 .set_role_arn(Some(self.role_arn))
254 .set_external_id(self.external_id)
255 .set_role_session_name(Some(session_name))
256 .set_policy(self.policy)
257 .set_policy_arns(self.policy_arns)
258 .set_duration_seconds(self.session_length.map(|dur| dur.as_secs() as i32));
259
260 AssumeRoleProvider {
261 inner: Inner { fluent_builder },
262 }
263 }
264
265 pub async fn build_from_provider(
267 mut self,
268 provider: impl ProvideCredentials + 'static,
269 ) -> AssumeRoleProvider {
270 let conf = match self.sdk_config {
271 Some(conf) => conf,
272 None => crate::load_defaults(crate::BehaviorVersion::latest()).await,
273 };
274 let conf = conf
275 .into_builder()
276 .credentials_provider(SharedCredentialsProvider::new(provider))
277 .build();
278 self.sdk_config = Some(conf);
279 self.build().await
280 }
281}
282
283impl Inner {
284 async fn credentials(&self) -> provider::Result {
285 tracing::debug!("retrieving assumed credentials");
286
287 let assumed = self.fluent_builder.clone().send().in_current_span().await;
288 match assumed {
289 Ok(assumed) => {
290 tracing::debug!(
291 access_key_id = ?assumed.credentials.as_ref().map(|c| &c.access_key_id),
292 "obtained assumed credentials"
293 );
294 super::util::into_credentials(assumed.credentials, "AssumeRoleProvider")
295 }
296 Err(SdkError::ServiceError(ref context))
297 if matches!(
298 context.err(),
299 AssumeRoleError::RegionDisabledException(_)
300 | AssumeRoleError::MalformedPolicyDocumentException(_)
301 ) =>
302 {
303 Err(CredentialsError::invalid_configuration(
304 assumed.err().unwrap(),
305 ))
306 }
307 Err(SdkError::ServiceError(ref context)) => {
308 tracing::warn!(error = %DisplayErrorContext(context.err()), "STS refused to grant assume role");
309 Err(CredentialsError::provider_error(assumed.err().unwrap()))
310 }
311 Err(err) => Err(CredentialsError::provider_error(err)),
312 }
313 }
314}
315
316impl ProvideCredentials for AssumeRoleProvider {
317 fn provide_credentials<'a>(&'a self) -> future::ProvideCredentials<'a>
318 where
319 Self: 'a,
320 {
321 future::ProvideCredentials::new(
322 self.inner
323 .credentials()
324 .instrument(tracing::debug_span!("assume_role")),
325 )
326 }
327}
328
329#[cfg(test)]
330mod test {
331 use crate::sts::AssumeRoleProvider;
332 use aws_credential_types::credential_fn::provide_credentials_fn;
333 use aws_credential_types::provider::{ProvideCredentials, SharedCredentialsProvider};
334 use aws_credential_types::Credentials;
335 use aws_smithy_async::rt::sleep::{SharedAsyncSleep, TokioSleep};
336 use aws_smithy_async::test_util::instant_time_and_sleep;
337 use aws_smithy_async::time::StaticTimeSource;
338 use aws_smithy_runtime::client::http::test_util::{
339 capture_request, ReplayEvent, StaticReplayClient,
340 };
341 use aws_smithy_runtime::test_util::capture_test_logs::capture_test_logs;
342 use aws_smithy_runtime_api::client::behavior_version::BehaviorVersion;
343 use aws_smithy_types::body::SdkBody;
344 use aws_types::os_shim_internal::Env;
345 use aws_types::region::Region;
346 use aws_types::SdkConfig;
347 use http::header::AUTHORIZATION;
348 use std::time::{Duration, UNIX_EPOCH};
349
350 #[tokio::test]
351 async fn configures_session_length() {
352 let (http_client, request) = capture_request(None);
353 let sdk_config = SdkConfig::builder()
354 .sleep_impl(SharedAsyncSleep::new(TokioSleep::new()))
355 .time_source(StaticTimeSource::new(
356 UNIX_EPOCH + Duration::from_secs(1234567890 - 120),
357 ))
358 .http_client(http_client)
359 .region(Region::from_static("this-will-be-overridden"))
360 .behavior_version(crate::BehaviorVersion::latest())
361 .build();
362 let provider = AssumeRoleProvider::builder("myrole")
363 .configure(&sdk_config)
364 .region(Region::new("us-east-1"))
365 .session_length(Duration::from_secs(1234567))
366 .build_from_provider(provide_credentials_fn(|| async {
367 Ok(Credentials::for_tests())
368 }))
369 .await;
370 let _ = dbg!(provider.provide_credentials().await);
371 let req = request.expect_request();
372 let str_body = std::str::from_utf8(req.body().bytes().unwrap()).unwrap();
373 assert!(str_body.contains("1234567"), "{}", str_body);
374 assert_eq!(req.uri(), "https://sts.us-east-1.amazonaws.com/");
375 }
376
377 #[tokio::test]
378 async fn loads_region_from_sdk_config() {
379 let (http_client, request) = capture_request(None);
380 let sdk_config = SdkConfig::builder()
381 .behavior_version(crate::BehaviorVersion::latest())
382 .sleep_impl(SharedAsyncSleep::new(TokioSleep::new()))
383 .time_source(StaticTimeSource::new(
384 UNIX_EPOCH + Duration::from_secs(1234567890 - 120),
385 ))
386 .http_client(http_client)
387 .credentials_provider(SharedCredentialsProvider::new(provide_credentials_fn(
388 || async {
389 panic!("don't call me — will be overridden");
390 },
391 )))
392 .region(Region::from_static("us-west-2"))
393 .build();
394 let provider = AssumeRoleProvider::builder("myrole")
395 .configure(&sdk_config)
396 .session_length(Duration::from_secs(1234567))
397 .build_from_provider(provide_credentials_fn(|| async {
398 Ok(Credentials::for_tests())
399 }))
400 .await;
401 let _ = dbg!(provider.provide_credentials().await);
402 let req = request.expect_request();
403 assert_eq!(req.uri(), "https://sts.us-west-2.amazonaws.com/");
404 }
405
406 #[tokio::test]
408 async fn build_method_from_sdk_config() {
409 let _guard = capture_test_logs();
410 let (http_client, request) = capture_request(Some(
411 http::Response::builder()
412 .status(404)
413 .body(SdkBody::from(""))
414 .unwrap(),
415 ));
416 let conf = crate::defaults(BehaviorVersion::latest())
417 .env(Env::from_slice(&[
418 ("AWS_ACCESS_KEY_ID", "123-key"),
419 ("AWS_SECRET_ACCESS_KEY", "456"),
420 ("AWS_REGION", "us-west-17"),
421 ]))
422 .use_dual_stack(true)
423 .use_fips(true)
424 .time_source(StaticTimeSource::from_secs(1234567890))
425 .http_client(http_client)
426 .load()
427 .await;
428 let provider = AssumeRoleProvider::builder("role")
429 .configure(&conf)
430 .build()
431 .await;
432 let _ = dbg!(provider.provide_credentials().await);
433 let req = request.expect_request();
434 let auth_header = req.headers().get(AUTHORIZATION).unwrap().to_string();
435 let expect = "Credential=123-key/20090213/us-west-17/sts/aws4_request";
436 assert!(
437 auth_header.contains(expect),
438 "Expected header to contain {expect} but it was {auth_header}"
439 );
440 assert_eq!("https://sts-fips.us-west-17.api.aws/", req.uri())
442 }
443
444 #[tokio::test]
445 async fn provider_does_not_cache_credentials_by_default() {
446 let http_client = StaticReplayClient::new(vec![
447 ReplayEvent::new(http::Request::new(SdkBody::from("request body")),
448 http::Response::builder().status(200).body(SdkBody::from(
449 "<AssumeRoleResponse xmlns=\"https://sts.amazonaws.com/doc/2011-06-15/\">\n <AssumeRoleResult>\n <AssumedRoleUser>\n <AssumedRoleId>AROAR42TAWARILN3MNKUT:assume-role-from-profile-1632246085998</AssumedRoleId>\n <Arn>arn:aws:sts::130633740322:assumed-role/assume-provider-test/assume-role-from-profile-1632246085998</Arn>\n </AssumedRoleUser>\n <Credentials>\n <AccessKeyId>ASIARCORRECT</AccessKeyId>\n <SecretAccessKey>secretkeycorrect</SecretAccessKey>\n <SessionToken>tokencorrect</SessionToken>\n <Expiration>2009-02-13T23:31:30Z</Expiration>\n </Credentials>\n </AssumeRoleResult>\n <ResponseMetadata>\n <RequestId>d9d47248-fd55-4686-ad7c-0fb7cd1cddd7</RequestId>\n </ResponseMetadata>\n</AssumeRoleResponse>\n"
450 )).unwrap()),
451 ReplayEvent::new(http::Request::new(SdkBody::from("request body")),
452 http::Response::builder().status(200).body(SdkBody::from(
453 "<AssumeRoleResponse xmlns=\"https://sts.amazonaws.com/doc/2011-06-15/\">\n <AssumeRoleResult>\n <AssumedRoleUser>\n <AssumedRoleId>AROAR42TAWARILN3MNKUT:assume-role-from-profile-1632246085998</AssumedRoleId>\n <Arn>arn:aws:sts::130633740322:assumed-role/assume-provider-test/assume-role-from-profile-1632246085998</Arn>\n </AssumedRoleUser>\n <Credentials>\n <AccessKeyId>ASIARCORRECT</AccessKeyId>\n <SecretAccessKey>TESTSECRET</SecretAccessKey>\n <SessionToken>tokencorrect</SessionToken>\n <Expiration>2009-02-13T23:33:30Z</Expiration>\n </Credentials>\n </AssumeRoleResult>\n <ResponseMetadata>\n <RequestId>c2e971c2-702d-4124-9b1f-1670febbea18</RequestId>\n </ResponseMetadata>\n</AssumeRoleResponse>\n"
454 )).unwrap()),
455 ]);
456
457 let (testing_time_source, sleep) = instant_time_and_sleep(
458 UNIX_EPOCH + Duration::from_secs(1234567890 - 120), );
460
461 let sdk_config = SdkConfig::builder()
462 .sleep_impl(SharedAsyncSleep::new(sleep))
463 .time_source(testing_time_source.clone())
464 .http_client(http_client)
465 .behavior_version(crate::BehaviorVersion::latest())
466 .build();
467 let credentials_list = std::sync::Arc::new(std::sync::Mutex::new(vec![
468 Credentials::new(
469 "test",
470 "test",
471 None,
472 Some(UNIX_EPOCH + Duration::from_secs(1234567890 + 1)),
473 "test",
474 ),
475 Credentials::new(
476 "test",
477 "test",
478 None,
479 Some(UNIX_EPOCH + Duration::from_secs(1234567890 + 120)),
480 "test",
481 ),
482 ]));
483 let credentials_list_cloned = credentials_list.clone();
484 let provider = AssumeRoleProvider::builder("myrole")
485 .configure(&sdk_config)
486 .region(Region::new("us-east-1"))
487 .build_from_provider(provide_credentials_fn(move || {
488 let list = credentials_list.clone();
489 async move {
490 let next = list.lock().unwrap().remove(0);
491 Ok(next)
492 }
493 }))
494 .await;
495
496 let creds_first = provider
497 .provide_credentials()
498 .await
499 .expect("should return valid credentials");
500
501 testing_time_source.advance(Duration::from_secs(120));
505
506 let creds_second = provider
507 .provide_credentials()
508 .await
509 .expect("should return the second credentials");
510 assert_ne!(creds_first, creds_second);
511 assert!(credentials_list_cloned.lock().unwrap().is_empty());
512 }
513}