aws_config/sts/
assume_role.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6//! Assume credentials for a role through the AWS Security Token Service (STS).
7
8use aws_credential_types::provider::{
9    self, error::CredentialsError, future, ProvideCredentials, SharedCredentialsProvider,
10};
11use aws_sdk_sts::operation::assume_role::builders::AssumeRoleFluentBuilder;
12use aws_sdk_sts::operation::assume_role::AssumeRoleError;
13use aws_sdk_sts::types::PolicyDescriptorType;
14use aws_sdk_sts::Client as StsClient;
15use aws_smithy_runtime::client::identity::IdentityCache;
16use aws_smithy_runtime_api::client::result::SdkError;
17use aws_smithy_types::error::display::DisplayErrorContext;
18use aws_types::region::Region;
19use aws_types::SdkConfig;
20use std::time::Duration;
21use tracing::Instrument;
22
23/// Credentials provider that uses credentials provided by another provider to assume a role
24/// through the AWS Security Token Service (STS).
25///
26/// When asked to provide credentials, this provider will first invoke the inner credentials
27/// provider to get AWS credentials for STS. Then, it will call STS to get assumed credentials for
28/// the desired role.
29///
30/// # Examples
31/// Create an AssumeRoleProvider explicitly set to us-east-2 that utilizes the default credentials chain.
32/// ```no_run
33/// use aws_config::sts::AssumeRoleProvider;
34/// use aws_types::region::Region;
35/// # async fn docs() {
36/// let provider = AssumeRoleProvider::builder("arn:aws:iam::123456789012:role/demo")
37///   .region(Region::from_static("us-east-2"))
38///   .session_name("testAR")
39///   .build().await;
40/// }
41/// ```
42///
43/// Create an AssumeRoleProvider from an explicitly configured base configuration.
44/// ```no_run
45/// use aws_config::sts::AssumeRoleProvider;
46/// use aws_types::region::Region;
47/// # async fn docs() {
48/// let conf = aws_config::from_env().use_fips(true).load().await;
49/// let provider = AssumeRoleProvider::builder("arn:aws:iam::123456789012:role/demo")
50///   .configure(&conf)
51///   .session_name("testAR")
52///   .build().await;
53/// }
54/// ```
55///
56/// Create an AssumeroleProvider that sources credentials from a provider credential provider:
57/// ```no_run
58/// use aws_config::sts::AssumeRoleProvider;
59/// use aws_types::region::Region;
60/// use aws_config::environment::EnvironmentVariableCredentialsProvider;
61/// # async fn docs() {
62/// let provider = AssumeRoleProvider::builder("arn:aws:iam::123456789012:role/demo")
63///   .session_name("test-assume-role-session")
64///   // only consider environment variables, explicitly.
65///   .build_from_provider(EnvironmentVariableCredentialsProvider::new()).await;
66/// }
67/// ```
68///
69#[derive(Debug)]
70pub struct AssumeRoleProvider {
71    inner: Inner,
72}
73
74#[derive(Debug)]
75struct Inner {
76    fluent_builder: AssumeRoleFluentBuilder,
77}
78
79impl AssumeRoleProvider {
80    /// Build a new role-assuming provider for the given role.
81    ///
82    /// The `role` argument should take the form an Amazon Resource Name (ARN) like
83    ///
84    /// ```text
85    /// arn:aws:iam::123456789012:role/example
86    /// ```
87    pub fn builder(role: impl Into<String>) -> AssumeRoleProviderBuilder {
88        AssumeRoleProviderBuilder::new(role.into())
89    }
90}
91
92/// A builder for [`AssumeRoleProvider`].
93///
94/// Construct one through [`AssumeRoleProvider::builder`].
95#[derive(Debug)]
96pub struct AssumeRoleProviderBuilder {
97    role_arn: String,
98    external_id: Option<String>,
99    session_name: Option<String>,
100    session_length: Option<Duration>,
101    policy: Option<String>,
102    policy_arns: Option<Vec<PolicyDescriptorType>>,
103    region_override: Option<Region>,
104    sdk_config: Option<SdkConfig>,
105}
106
107impl AssumeRoleProviderBuilder {
108    /// Start a new assume role builder for the given role.
109    ///
110    /// The `role` argument should take the form an Amazon Resource Name (ARN) like
111    ///
112    /// ```text
113    /// arn:aws:iam::123456789012:role/example
114    /// ```
115    pub fn new(role: impl Into<String>) -> Self {
116        Self {
117            role_arn: role.into(),
118            external_id: None,
119            session_name: None,
120            session_length: None,
121            policy: None,
122            policy_arns: None,
123            sdk_config: None,
124            region_override: None,
125        }
126    }
127
128    /// Set a unique identifier that might be required when you assume a role in another account.
129    ///
130    /// If the administrator of the account to which the role belongs provided you with an external
131    /// ID, then provide that value in this parameter. The value can be any string, such as a
132    /// passphrase or account number.
133    pub fn external_id(mut self, id: impl Into<String>) -> Self {
134        self.external_id = Some(id.into());
135        self
136    }
137
138    /// Set an identifier for the assumed role session.
139    ///
140    /// Use the role session name to uniquely identify a session when the same role is assumed by
141    /// different principals or for different reasons. In cross-account scenarios, the role session
142    /// name is visible to, and can be logged by the account that owns the role. The role session
143    /// name is also used in the ARN of the assumed role principal.
144    pub fn session_name(mut self, name: impl Into<String>) -> Self {
145        self.session_name = Some(name.into());
146        self
147    }
148
149    /// Set an IAM policy in JSON format that you want to use as an inline session policy.
150    ///
151    /// This parameter is optional
152    /// For more information, see
153    /// [policy](aws_sdk_sts::operation::assume_role::builders::AssumeRoleInputBuilder::policy_arns)
154    pub fn policy(mut self, policy: impl Into<String>) -> Self {
155        self.policy = Some(policy.into());
156        self
157    }
158
159    /// Set the Amazon Resource Names (ARNs) of the IAM managed policies that you want to use as managed session policies.
160    ///
161    /// This parameter is optional.
162    /// For more information, see
163    /// [policy_arns](aws_sdk_sts::operation::assume_role::builders::AssumeRoleInputBuilder::policy_arns)
164    pub fn policy_arns(mut self, policy_arns: Vec<String>) -> Self {
165        self.policy_arns = Some(
166            policy_arns
167                .into_iter()
168                .map(|arn| PolicyDescriptorType::builder().arn(arn).build())
169                .collect::<Vec<_>>(),
170        );
171        self
172    }
173
174    /// Set the expiration time of the role session.
175    ///
176    /// When unset, this value defaults to 1 hour.
177    ///
178    /// The value specified can range from 900 seconds (15 minutes) up to the maximum session duration
179    /// set for the role. The maximum session duration setting can have a value from 1 hour to 12 hours.
180    /// If you specify a value higher than this setting or the administrator setting (whichever is lower),
181    /// **you will be unable to assume the role**. For example, if you specify a session duration of 12 hours,
182    /// but your administrator set the maximum session duration to 6 hours, you cannot assume the role.
183    ///
184    /// For more information, see
185    /// [duration_seconds](aws_sdk_sts::operation::assume_role::builders::AssumeRoleInputBuilder::duration_seconds)
186    pub fn session_length(mut self, length: Duration) -> Self {
187        self.session_length = Some(length);
188        self
189    }
190
191    /// Set the region to assume the role in.
192    ///
193    /// This dictates which STS endpoint the AssumeRole action is invoked on. This will override
194    /// a region set from `.configure(...)`
195    pub fn region(mut self, region: Region) -> Self {
196        self.region_override = Some(region);
197        self
198    }
199
200    /// Sets the configuration used for this provider
201    ///
202    /// This enables overriding the connection used to communicate with STS in addition to other internal
203    /// fields like the time source and sleep implementation used for caching.
204    ///
205    /// If this field is not provided, configuration from [`aws_config::load_from_env().await`] is used.
206    ///
207    /// # Examples
208    /// ```rust
209    /// # async fn docs() {
210    /// use aws_types::region::Region;
211    /// use aws_config::sts::AssumeRoleProvider;
212    /// let config = aws_config::from_env().region(Region::from_static("us-west-2")).load().await;
213    /// let assume_role_provider = AssumeRoleProvider::builder("arn:aws:iam::123456789012:role/example")
214    ///   .configure(&config)
215    ///   .build();
216    /// }
217    pub fn configure(mut self, conf: &SdkConfig) -> Self {
218        self.sdk_config = Some(conf.clone());
219        self
220    }
221
222    /// Build a credentials provider for this role.
223    ///
224    /// Base credentials will be used from the [`SdkConfig`] set via [`Self::configure`] or loaded
225    /// from [`aws_config::from_env`](crate::from_env) if `configure` was never called.
226    pub async fn build(self) -> AssumeRoleProvider {
227        let mut conf = match self.sdk_config {
228            Some(conf) => conf,
229            None => crate::load_defaults(crate::BehaviorVersion::latest()).await,
230        };
231        // ignore a identity cache set from SdkConfig
232        conf = conf
233            .into_builder()
234            .identity_cache(IdentityCache::no_cache())
235            .build();
236
237        // set a region override if one exists
238        if let Some(region) = self.region_override {
239            conf = conf.into_builder().region(region).build()
240        }
241
242        let config = aws_sdk_sts::config::Builder::from(&conf);
243
244        let time_source = conf.time_source().expect("A time source must be provided.");
245
246        let session_name = self.session_name.unwrap_or_else(|| {
247            super::util::default_session_name("assume-role-provider", time_source.now())
248        });
249
250        let sts_client = StsClient::from_conf(config.build());
251        let fluent_builder = sts_client
252            .assume_role()
253            .set_role_arn(Some(self.role_arn))
254            .set_external_id(self.external_id)
255            .set_role_session_name(Some(session_name))
256            .set_policy(self.policy)
257            .set_policy_arns(self.policy_arns)
258            .set_duration_seconds(self.session_length.map(|dur| dur.as_secs() as i32));
259
260        AssumeRoleProvider {
261            inner: Inner { fluent_builder },
262        }
263    }
264
265    /// Build a credentials provider for this role authorized by the given `provider`.
266    pub async fn build_from_provider(
267        mut self,
268        provider: impl ProvideCredentials + 'static,
269    ) -> AssumeRoleProvider {
270        let conf = match self.sdk_config {
271            Some(conf) => conf,
272            None => crate::load_defaults(crate::BehaviorVersion::latest()).await,
273        };
274        let conf = conf
275            .into_builder()
276            .credentials_provider(SharedCredentialsProvider::new(provider))
277            .build();
278        self.sdk_config = Some(conf);
279        self.build().await
280    }
281}
282
283impl Inner {
284    async fn credentials(&self) -> provider::Result {
285        tracing::debug!("retrieving assumed credentials");
286
287        let assumed = self.fluent_builder.clone().send().in_current_span().await;
288        match assumed {
289            Ok(assumed) => {
290                tracing::debug!(
291                    access_key_id = ?assumed.credentials.as_ref().map(|c| &c.access_key_id),
292                    "obtained assumed credentials"
293                );
294                super::util::into_credentials(assumed.credentials, "AssumeRoleProvider")
295            }
296            Err(SdkError::ServiceError(ref context))
297                if matches!(
298                    context.err(),
299                    AssumeRoleError::RegionDisabledException(_)
300                        | AssumeRoleError::MalformedPolicyDocumentException(_)
301                ) =>
302            {
303                Err(CredentialsError::invalid_configuration(
304                    assumed.err().unwrap(),
305                ))
306            }
307            Err(SdkError::ServiceError(ref context)) => {
308                tracing::warn!(error = %DisplayErrorContext(context.err()), "STS refused to grant assume role");
309                Err(CredentialsError::provider_error(assumed.err().unwrap()))
310            }
311            Err(err) => Err(CredentialsError::provider_error(err)),
312        }
313    }
314}
315
316impl ProvideCredentials for AssumeRoleProvider {
317    fn provide_credentials<'a>(&'a self) -> future::ProvideCredentials<'a>
318    where
319        Self: 'a,
320    {
321        future::ProvideCredentials::new(
322            self.inner
323                .credentials()
324                .instrument(tracing::debug_span!("assume_role")),
325        )
326    }
327}
328
329#[cfg(test)]
330mod test {
331    use crate::sts::AssumeRoleProvider;
332    use aws_credential_types::credential_fn::provide_credentials_fn;
333    use aws_credential_types::provider::{ProvideCredentials, SharedCredentialsProvider};
334    use aws_credential_types::Credentials;
335    use aws_smithy_async::rt::sleep::{SharedAsyncSleep, TokioSleep};
336    use aws_smithy_async::test_util::instant_time_and_sleep;
337    use aws_smithy_async::time::StaticTimeSource;
338    use aws_smithy_http_client::test_util::{capture_request, ReplayEvent, StaticReplayClient};
339    use aws_smithy_runtime::test_util::capture_test_logs::capture_test_logs;
340    use aws_smithy_runtime_api::client::behavior_version::BehaviorVersion;
341    use aws_smithy_types::body::SdkBody;
342    use aws_types::os_shim_internal::Env;
343    use aws_types::region::Region;
344    use aws_types::SdkConfig;
345    use http::header::AUTHORIZATION;
346    use std::time::{Duration, UNIX_EPOCH};
347
348    #[tokio::test]
349    async fn configures_session_length() {
350        let (http_client, request) = capture_request(None);
351        let sdk_config = SdkConfig::builder()
352            .sleep_impl(SharedAsyncSleep::new(TokioSleep::new()))
353            .time_source(StaticTimeSource::new(
354                UNIX_EPOCH + Duration::from_secs(1234567890 - 120),
355            ))
356            .http_client(http_client)
357            .region(Region::from_static("this-will-be-overridden"))
358            .behavior_version(crate::BehaviorVersion::latest())
359            .build();
360        let provider = AssumeRoleProvider::builder("myrole")
361            .configure(&sdk_config)
362            .region(Region::new("us-east-1"))
363            .session_length(Duration::from_secs(1234567))
364            .build_from_provider(provide_credentials_fn(|| async {
365                Ok(Credentials::for_tests())
366            }))
367            .await;
368        let _ = dbg!(provider.provide_credentials().await);
369        let req = request.expect_request();
370        let str_body = std::str::from_utf8(req.body().bytes().unwrap()).unwrap();
371        assert!(str_body.contains("1234567"), "{}", str_body);
372        assert_eq!(req.uri(), "https://sts.us-east-1.amazonaws.com/");
373    }
374
375    #[tokio::test]
376    async fn loads_region_from_sdk_config() {
377        let (http_client, request) = capture_request(None);
378        let sdk_config = SdkConfig::builder()
379            .behavior_version(crate::BehaviorVersion::latest())
380            .sleep_impl(SharedAsyncSleep::new(TokioSleep::new()))
381            .time_source(StaticTimeSource::new(
382                UNIX_EPOCH + Duration::from_secs(1234567890 - 120),
383            ))
384            .http_client(http_client)
385            .credentials_provider(SharedCredentialsProvider::new(provide_credentials_fn(
386                || async {
387                    panic!("don't call me — will be overridden");
388                },
389            )))
390            .region(Region::from_static("us-west-2"))
391            .build();
392        let provider = AssumeRoleProvider::builder("myrole")
393            .configure(&sdk_config)
394            .session_length(Duration::from_secs(1234567))
395            .build_from_provider(provide_credentials_fn(|| async {
396                Ok(Credentials::for_tests())
397            }))
398            .await;
399        let _ = dbg!(provider.provide_credentials().await);
400        let req = request.expect_request();
401        assert_eq!(req.uri(), "https://sts.us-west-2.amazonaws.com/");
402    }
403
404    /// Test that `build()` where no provider is passed still works
405    #[tokio::test]
406    async fn build_method_from_sdk_config() {
407        let _guard = capture_test_logs();
408        let (http_client, request) = capture_request(Some(
409            http::Response::builder()
410                .status(404)
411                .body(SdkBody::from(""))
412                .unwrap(),
413        ));
414        let conf = crate::defaults(BehaviorVersion::latest())
415            .env(Env::from_slice(&[
416                ("AWS_ACCESS_KEY_ID", "123-key"),
417                ("AWS_SECRET_ACCESS_KEY", "456"),
418                ("AWS_REGION", "us-west-17"),
419            ]))
420            .use_dual_stack(true)
421            .use_fips(true)
422            .time_source(StaticTimeSource::from_secs(1234567890))
423            .http_client(http_client)
424            .load()
425            .await;
426        let provider = AssumeRoleProvider::builder("role")
427            .configure(&conf)
428            .build()
429            .await;
430        let _ = dbg!(provider.provide_credentials().await);
431        let req = request.expect_request();
432        let auth_header = req.headers().get(AUTHORIZATION).unwrap().to_string();
433        let expect = "Credential=123-key/20090213/us-west-17/sts/aws4_request";
434        assert!(
435            auth_header.contains(expect),
436            "Expected header to contain {expect} but it was {auth_header}"
437        );
438        // ensure that FIPS & DualStack are also respected
439        assert_eq!("https://sts-fips.us-west-17.api.aws/", req.uri())
440    }
441
442    #[tokio::test]
443    async fn provider_does_not_cache_credentials_by_default() {
444        let http_client = StaticReplayClient::new(vec![
445            ReplayEvent::new(http::Request::new(SdkBody::from("request body")),
446            http::Response::builder().status(200).body(SdkBody::from(
447                "<AssumeRoleResponse xmlns=\"https://sts.amazonaws.com/doc/2011-06-15/\">\n  <AssumeRoleResult>\n    <AssumedRoleUser>\n      <AssumedRoleId>AROAR42TAWARILN3MNKUT:assume-role-from-profile-1632246085998</AssumedRoleId>\n      <Arn>arn:aws:sts::130633740322:assumed-role/assume-provider-test/assume-role-from-profile-1632246085998</Arn>\n    </AssumedRoleUser>\n    <Credentials>\n      <AccessKeyId>ASIARCORRECT</AccessKeyId>\n      <SecretAccessKey>secretkeycorrect</SecretAccessKey>\n      <SessionToken>tokencorrect</SessionToken>\n      <Expiration>2009-02-13T23:31:30Z</Expiration>\n    </Credentials>\n  </AssumeRoleResult>\n  <ResponseMetadata>\n    <RequestId>d9d47248-fd55-4686-ad7c-0fb7cd1cddd7</RequestId>\n  </ResponseMetadata>\n</AssumeRoleResponse>\n"
448            )).unwrap()),
449            ReplayEvent::new(http::Request::new(SdkBody::from("request body")),
450            http::Response::builder().status(200).body(SdkBody::from(
451                "<AssumeRoleResponse xmlns=\"https://sts.amazonaws.com/doc/2011-06-15/\">\n  <AssumeRoleResult>\n    <AssumedRoleUser>\n      <AssumedRoleId>AROAR42TAWARILN3MNKUT:assume-role-from-profile-1632246085998</AssumedRoleId>\n      <Arn>arn:aws:sts::130633740322:assumed-role/assume-provider-test/assume-role-from-profile-1632246085998</Arn>\n    </AssumedRoleUser>\n    <Credentials>\n      <AccessKeyId>ASIARCORRECT</AccessKeyId>\n      <SecretAccessKey>TESTSECRET</SecretAccessKey>\n      <SessionToken>tokencorrect</SessionToken>\n      <Expiration>2009-02-13T23:33:30Z</Expiration>\n    </Credentials>\n  </AssumeRoleResult>\n  <ResponseMetadata>\n    <RequestId>c2e971c2-702d-4124-9b1f-1670febbea18</RequestId>\n  </ResponseMetadata>\n</AssumeRoleResponse>\n"
452            )).unwrap()),
453        ]);
454
455        let (testing_time_source, sleep) = instant_time_and_sleep(
456            UNIX_EPOCH + Duration::from_secs(1234567890 - 120), // 1234567890 since UNIX_EPOCH is 2009-02-13T23:31:30Z
457        );
458
459        let sdk_config = SdkConfig::builder()
460            .sleep_impl(SharedAsyncSleep::new(sleep))
461            .time_source(testing_time_source.clone())
462            .http_client(http_client)
463            .behavior_version(crate::BehaviorVersion::latest())
464            .build();
465        let credentials_list = std::sync::Arc::new(std::sync::Mutex::new(vec![
466            Credentials::new(
467                "test",
468                "test",
469                None,
470                Some(UNIX_EPOCH + Duration::from_secs(1234567890 + 1)),
471                "test",
472            ),
473            Credentials::new(
474                "test",
475                "test",
476                None,
477                Some(UNIX_EPOCH + Duration::from_secs(1234567890 + 120)),
478                "test",
479            ),
480        ]));
481        let credentials_list_cloned = credentials_list.clone();
482        let provider = AssumeRoleProvider::builder("myrole")
483            .configure(&sdk_config)
484            .region(Region::new("us-east-1"))
485            .build_from_provider(provide_credentials_fn(move || {
486                let list = credentials_list.clone();
487                async move {
488                    let next = list.lock().unwrap().remove(0);
489                    Ok(next)
490                }
491            }))
492            .await;
493
494        let creds_first = provider
495            .provide_credentials()
496            .await
497            .expect("should return valid credentials");
498
499        // After time has been advanced by 120 seconds, the first credentials _could_ still be valid
500        // if `LazyCredentialsCache` were used, but the provider uses `NoCredentialsCache` by default
501        // so the first credentials will not be used.
502        testing_time_source.advance(Duration::from_secs(120));
503
504        let creds_second = provider
505            .provide_credentials()
506            .await
507            .expect("should return the second credentials");
508        assert_ne!(creds_first, creds_second);
509        assert!(credentials_list_cloned.lock().unwrap().is_empty());
510    }
511}