1use aws_credential_types::provider::{
9 self, error::CredentialsError, future, ProvideCredentials, SharedCredentialsProvider,
10};
11use aws_sdk_sts::operation::assume_role::builders::AssumeRoleFluentBuilder;
12use aws_sdk_sts::operation::assume_role::AssumeRoleError;
13use aws_sdk_sts::types::PolicyDescriptorType;
14use aws_sdk_sts::Client as StsClient;
15use aws_smithy_runtime::client::identity::IdentityCache;
16use aws_smithy_runtime_api::client::result::SdkError;
17use aws_smithy_types::error::display::DisplayErrorContext;
18use aws_types::region::Region;
19use aws_types::SdkConfig;
20use std::time::Duration;
21use tracing::Instrument;
22
23#[derive(Debug)]
70pub struct AssumeRoleProvider {
71 inner: Inner,
72}
73
74#[derive(Debug)]
75struct Inner {
76 fluent_builder: AssumeRoleFluentBuilder,
77}
78
79impl AssumeRoleProvider {
80 pub fn builder(role: impl Into<String>) -> AssumeRoleProviderBuilder {
88 AssumeRoleProviderBuilder::new(role.into())
89 }
90}
91
92#[derive(Debug)]
96pub struct AssumeRoleProviderBuilder {
97 role_arn: String,
98 external_id: Option<String>,
99 session_name: Option<String>,
100 session_length: Option<Duration>,
101 policy: Option<String>,
102 policy_arns: Option<Vec<PolicyDescriptorType>>,
103 region_override: Option<Region>,
104 sdk_config: Option<SdkConfig>,
105}
106
107impl AssumeRoleProviderBuilder {
108 pub fn new(role: impl Into<String>) -> Self {
116 Self {
117 role_arn: role.into(),
118 external_id: None,
119 session_name: None,
120 session_length: None,
121 policy: None,
122 policy_arns: None,
123 sdk_config: None,
124 region_override: None,
125 }
126 }
127
128 pub fn external_id(mut self, id: impl Into<String>) -> Self {
134 self.external_id = Some(id.into());
135 self
136 }
137
138 pub fn session_name(mut self, name: impl Into<String>) -> Self {
145 self.session_name = Some(name.into());
146 self
147 }
148
149 pub fn policy(mut self, policy: impl Into<String>) -> Self {
155 self.policy = Some(policy.into());
156 self
157 }
158
159 pub fn policy_arns(mut self, policy_arns: Vec<String>) -> Self {
165 self.policy_arns = Some(
166 policy_arns
167 .into_iter()
168 .map(|arn| PolicyDescriptorType::builder().arn(arn).build())
169 .collect::<Vec<_>>(),
170 );
171 self
172 }
173
174 pub fn session_length(mut self, length: Duration) -> Self {
187 self.session_length = Some(length);
188 self
189 }
190
191 pub fn region(mut self, region: Region) -> Self {
196 self.region_override = Some(region);
197 self
198 }
199
200 pub fn configure(mut self, conf: &SdkConfig) -> Self {
218 self.sdk_config = Some(conf.clone());
219 self
220 }
221
222 pub async fn build(self) -> AssumeRoleProvider {
227 let mut conf = match self.sdk_config {
228 Some(conf) => conf,
229 None => crate::load_defaults(crate::BehaviorVersion::latest()).await,
230 };
231 conf = conf
233 .into_builder()
234 .identity_cache(IdentityCache::no_cache())
235 .build();
236
237 if let Some(region) = self.region_override {
239 conf = conf.into_builder().region(region).build()
240 }
241
242 let config = aws_sdk_sts::config::Builder::from(&conf);
243
244 let time_source = conf.time_source().expect("A time source must be provided.");
245
246 let session_name = self.session_name.unwrap_or_else(|| {
247 super::util::default_session_name("assume-role-provider", time_source.now())
248 });
249
250 let sts_client = StsClient::from_conf(config.build());
251 let fluent_builder = sts_client
252 .assume_role()
253 .set_role_arn(Some(self.role_arn))
254 .set_external_id(self.external_id)
255 .set_role_session_name(Some(session_name))
256 .set_policy(self.policy)
257 .set_policy_arns(self.policy_arns)
258 .set_duration_seconds(self.session_length.map(|dur| dur.as_secs() as i32));
259
260 AssumeRoleProvider {
261 inner: Inner { fluent_builder },
262 }
263 }
264
265 pub async fn build_from_provider(
267 mut self,
268 provider: impl ProvideCredentials + 'static,
269 ) -> AssumeRoleProvider {
270 let conf = match self.sdk_config {
271 Some(conf) => conf,
272 None => crate::load_defaults(crate::BehaviorVersion::latest()).await,
273 };
274 let conf = conf
275 .into_builder()
276 .credentials_provider(SharedCredentialsProvider::new(provider))
277 .build();
278 self.sdk_config = Some(conf);
279 self.build().await
280 }
281}
282
283impl Inner {
284 async fn credentials(&self) -> provider::Result {
285 tracing::debug!("retrieving assumed credentials");
286
287 let assumed = self.fluent_builder.clone().send().in_current_span().await;
288 match assumed {
289 Ok(assumed) => {
290 tracing::debug!(
291 access_key_id = ?assumed.credentials.as_ref().map(|c| &c.access_key_id),
292 "obtained assumed credentials"
293 );
294 super::util::into_credentials(assumed.credentials, "AssumeRoleProvider")
295 }
296 Err(SdkError::ServiceError(ref context))
297 if matches!(
298 context.err(),
299 AssumeRoleError::RegionDisabledException(_)
300 | AssumeRoleError::MalformedPolicyDocumentException(_)
301 ) =>
302 {
303 Err(CredentialsError::invalid_configuration(
304 assumed.err().unwrap(),
305 ))
306 }
307 Err(SdkError::ServiceError(ref context)) => {
308 tracing::warn!(error = %DisplayErrorContext(context.err()), "STS refused to grant assume role");
309 Err(CredentialsError::provider_error(assumed.err().unwrap()))
310 }
311 Err(err) => Err(CredentialsError::provider_error(err)),
312 }
313 }
314}
315
316impl ProvideCredentials for AssumeRoleProvider {
317 fn provide_credentials<'a>(&'a self) -> future::ProvideCredentials<'a>
318 where
319 Self: 'a,
320 {
321 future::ProvideCredentials::new(
322 self.inner
323 .credentials()
324 .instrument(tracing::debug_span!("assume_role")),
325 )
326 }
327}
328
329#[cfg(test)]
330mod test {
331 use crate::sts::AssumeRoleProvider;
332 use aws_credential_types::credential_fn::provide_credentials_fn;
333 use aws_credential_types::provider::{ProvideCredentials, SharedCredentialsProvider};
334 use aws_credential_types::Credentials;
335 use aws_smithy_async::rt::sleep::{SharedAsyncSleep, TokioSleep};
336 use aws_smithy_async::test_util::instant_time_and_sleep;
337 use aws_smithy_async::time::StaticTimeSource;
338 use aws_smithy_http_client::test_util::{capture_request, ReplayEvent, StaticReplayClient};
339 use aws_smithy_runtime::test_util::capture_test_logs::capture_test_logs;
340 use aws_smithy_runtime_api::client::behavior_version::BehaviorVersion;
341 use aws_smithy_types::body::SdkBody;
342 use aws_types::os_shim_internal::Env;
343 use aws_types::region::Region;
344 use aws_types::SdkConfig;
345 use http::header::AUTHORIZATION;
346 use std::time::{Duration, UNIX_EPOCH};
347
348 #[tokio::test]
349 async fn configures_session_length() {
350 let (http_client, request) = capture_request(None);
351 let sdk_config = SdkConfig::builder()
352 .sleep_impl(SharedAsyncSleep::new(TokioSleep::new()))
353 .time_source(StaticTimeSource::new(
354 UNIX_EPOCH + Duration::from_secs(1234567890 - 120),
355 ))
356 .http_client(http_client)
357 .region(Region::from_static("this-will-be-overridden"))
358 .behavior_version(crate::BehaviorVersion::latest())
359 .build();
360 let provider = AssumeRoleProvider::builder("myrole")
361 .configure(&sdk_config)
362 .region(Region::new("us-east-1"))
363 .session_length(Duration::from_secs(1234567))
364 .build_from_provider(provide_credentials_fn(|| async {
365 Ok(Credentials::for_tests())
366 }))
367 .await;
368 let _ = dbg!(provider.provide_credentials().await);
369 let req = request.expect_request();
370 let str_body = std::str::from_utf8(req.body().bytes().unwrap()).unwrap();
371 assert!(str_body.contains("1234567"), "{}", str_body);
372 assert_eq!(req.uri(), "https://sts.us-east-1.amazonaws.com/");
373 }
374
375 #[tokio::test]
376 async fn loads_region_from_sdk_config() {
377 let (http_client, request) = capture_request(None);
378 let sdk_config = SdkConfig::builder()
379 .behavior_version(crate::BehaviorVersion::latest())
380 .sleep_impl(SharedAsyncSleep::new(TokioSleep::new()))
381 .time_source(StaticTimeSource::new(
382 UNIX_EPOCH + Duration::from_secs(1234567890 - 120),
383 ))
384 .http_client(http_client)
385 .credentials_provider(SharedCredentialsProvider::new(provide_credentials_fn(
386 || async {
387 panic!("don't call me — will be overridden");
388 },
389 )))
390 .region(Region::from_static("us-west-2"))
391 .build();
392 let provider = AssumeRoleProvider::builder("myrole")
393 .configure(&sdk_config)
394 .session_length(Duration::from_secs(1234567))
395 .build_from_provider(provide_credentials_fn(|| async {
396 Ok(Credentials::for_tests())
397 }))
398 .await;
399 let _ = dbg!(provider.provide_credentials().await);
400 let req = request.expect_request();
401 assert_eq!(req.uri(), "https://sts.us-west-2.amazonaws.com/");
402 }
403
404 #[tokio::test]
406 async fn build_method_from_sdk_config() {
407 let _guard = capture_test_logs();
408 let (http_client, request) = capture_request(Some(
409 http::Response::builder()
410 .status(404)
411 .body(SdkBody::from(""))
412 .unwrap(),
413 ));
414 let conf = crate::defaults(BehaviorVersion::latest())
415 .env(Env::from_slice(&[
416 ("AWS_ACCESS_KEY_ID", "123-key"),
417 ("AWS_SECRET_ACCESS_KEY", "456"),
418 ("AWS_REGION", "us-west-17"),
419 ]))
420 .use_dual_stack(true)
421 .use_fips(true)
422 .time_source(StaticTimeSource::from_secs(1234567890))
423 .http_client(http_client)
424 .load()
425 .await;
426 let provider = AssumeRoleProvider::builder("role")
427 .configure(&conf)
428 .build()
429 .await;
430 let _ = dbg!(provider.provide_credentials().await);
431 let req = request.expect_request();
432 let auth_header = req.headers().get(AUTHORIZATION).unwrap().to_string();
433 let expect = "Credential=123-key/20090213/us-west-17/sts/aws4_request";
434 assert!(
435 auth_header.contains(expect),
436 "Expected header to contain {expect} but it was {auth_header}"
437 );
438 assert_eq!("https://sts-fips.us-west-17.api.aws/", req.uri())
440 }
441
442 #[tokio::test]
443 async fn provider_does_not_cache_credentials_by_default() {
444 let http_client = StaticReplayClient::new(vec![
445 ReplayEvent::new(http::Request::new(SdkBody::from("request body")),
446 http::Response::builder().status(200).body(SdkBody::from(
447 "<AssumeRoleResponse xmlns=\"https://sts.amazonaws.com/doc/2011-06-15/\">\n <AssumeRoleResult>\n <AssumedRoleUser>\n <AssumedRoleId>AROAR42TAWARILN3MNKUT:assume-role-from-profile-1632246085998</AssumedRoleId>\n <Arn>arn:aws:sts::130633740322:assumed-role/assume-provider-test/assume-role-from-profile-1632246085998</Arn>\n </AssumedRoleUser>\n <Credentials>\n <AccessKeyId>ASIARCORRECT</AccessKeyId>\n <SecretAccessKey>secretkeycorrect</SecretAccessKey>\n <SessionToken>tokencorrect</SessionToken>\n <Expiration>2009-02-13T23:31:30Z</Expiration>\n </Credentials>\n </AssumeRoleResult>\n <ResponseMetadata>\n <RequestId>d9d47248-fd55-4686-ad7c-0fb7cd1cddd7</RequestId>\n </ResponseMetadata>\n</AssumeRoleResponse>\n"
448 )).unwrap()),
449 ReplayEvent::new(http::Request::new(SdkBody::from("request body")),
450 http::Response::builder().status(200).body(SdkBody::from(
451 "<AssumeRoleResponse xmlns=\"https://sts.amazonaws.com/doc/2011-06-15/\">\n <AssumeRoleResult>\n <AssumedRoleUser>\n <AssumedRoleId>AROAR42TAWARILN3MNKUT:assume-role-from-profile-1632246085998</AssumedRoleId>\n <Arn>arn:aws:sts::130633740322:assumed-role/assume-provider-test/assume-role-from-profile-1632246085998</Arn>\n </AssumedRoleUser>\n <Credentials>\n <AccessKeyId>ASIARCORRECT</AccessKeyId>\n <SecretAccessKey>TESTSECRET</SecretAccessKey>\n <SessionToken>tokencorrect</SessionToken>\n <Expiration>2009-02-13T23:33:30Z</Expiration>\n </Credentials>\n </AssumeRoleResult>\n <ResponseMetadata>\n <RequestId>c2e971c2-702d-4124-9b1f-1670febbea18</RequestId>\n </ResponseMetadata>\n</AssumeRoleResponse>\n"
452 )).unwrap()),
453 ]);
454
455 let (testing_time_source, sleep) = instant_time_and_sleep(
456 UNIX_EPOCH + Duration::from_secs(1234567890 - 120), );
458
459 let sdk_config = SdkConfig::builder()
460 .sleep_impl(SharedAsyncSleep::new(sleep))
461 .time_source(testing_time_source.clone())
462 .http_client(http_client)
463 .behavior_version(crate::BehaviorVersion::latest())
464 .build();
465 let credentials_list = std::sync::Arc::new(std::sync::Mutex::new(vec![
466 Credentials::new(
467 "test",
468 "test",
469 None,
470 Some(UNIX_EPOCH + Duration::from_secs(1234567890 + 1)),
471 "test",
472 ),
473 Credentials::new(
474 "test",
475 "test",
476 None,
477 Some(UNIX_EPOCH + Duration::from_secs(1234567890 + 120)),
478 "test",
479 ),
480 ]));
481 let credentials_list_cloned = credentials_list.clone();
482 let provider = AssumeRoleProvider::builder("myrole")
483 .configure(&sdk_config)
484 .region(Region::new("us-east-1"))
485 .build_from_provider(provide_credentials_fn(move || {
486 let list = credentials_list.clone();
487 async move {
488 let next = list.lock().unwrap().remove(0);
489 Ok(next)
490 }
491 }))
492 .await;
493
494 let creds_first = provider
495 .provide_credentials()
496 .await
497 .expect("should return valid credentials");
498
499 testing_time_source.advance(Duration::from_secs(120));
503
504 let creds_second = provider
505 .provide_credentials()
506 .await
507 .expect("should return the second credentials");
508 assert_ne!(creds_first, creds_second);
509 assert!(credentials_list_cloned.lock().unwrap().is_empty());
510 }
511}