aws_config/sts/
assume_role.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6//! Assume credentials for a role through the AWS Security Token Service (STS).
7
8use aws_credential_types::provider::{
9    self, error::CredentialsError, future, ProvideCredentials, SharedCredentialsProvider,
10};
11use aws_sdk_sts::operation::assume_role::builders::AssumeRoleFluentBuilder;
12use aws_sdk_sts::operation::assume_role::AssumeRoleError;
13use aws_sdk_sts::types::PolicyDescriptorType;
14use aws_sdk_sts::Client as StsClient;
15use aws_smithy_runtime::client::identity::IdentityCache;
16use aws_smithy_runtime_api::client::result::SdkError;
17use aws_smithy_types::error::display::DisplayErrorContext;
18use aws_types::region::Region;
19use aws_types::SdkConfig;
20use std::time::Duration;
21use tracing::Instrument;
22
23/// Credentials provider that uses credentials provided by another provider to assume a role
24/// through the AWS Security Token Service (STS).
25///
26/// When asked to provide credentials, this provider will first invoke the inner credentials
27/// provider to get AWS credentials for STS. Then, it will call STS to get assumed credentials for
28/// the desired role.
29///
30/// # Examples
31/// Create an AssumeRoleProvider explicitly set to us-east-2 that utilizes the default credentials chain.
32/// ```no_run
33/// use aws_config::sts::AssumeRoleProvider;
34/// use aws_types::region::Region;
35/// # async fn docs() {
36/// let provider = AssumeRoleProvider::builder("arn:aws:iam::123456789012:role/demo")
37///   .region(Region::from_static("us-east-2"))
38///   .session_name("testAR")
39///   .build().await;
40/// }
41/// ```
42///
43/// Create an AssumeRoleProvider from an explicitly configured base configuration.
44/// ```no_run
45/// use aws_config::sts::AssumeRoleProvider;
46/// use aws_types::region::Region;
47/// # async fn docs() {
48/// let conf = aws_config::from_env().use_fips(true).load().await;
49/// let provider = AssumeRoleProvider::builder("arn:aws:iam::123456789012:role/demo")
50///   .configure(&conf)
51///   .session_name("testAR")
52///   .build().await;
53/// }
54/// ```
55///
56/// Create an AssumeroleProvider that sources credentials from a provider credential provider:
57/// ```no_run
58/// use aws_config::sts::AssumeRoleProvider;
59/// use aws_types::region::Region;
60/// use aws_config::environment::EnvironmentVariableCredentialsProvider;
61/// # async fn docs() {
62/// let provider = AssumeRoleProvider::builder("arn:aws:iam::123456789012:role/demo")
63///   .session_name("test-assume-role-session")
64///   // only consider environment variables, explicitly.
65///   .build_from_provider(EnvironmentVariableCredentialsProvider::new()).await;
66/// }
67/// ```
68///
69#[derive(Debug)]
70pub struct AssumeRoleProvider {
71    inner: Inner,
72}
73
74#[derive(Debug)]
75struct Inner {
76    fluent_builder: AssumeRoleFluentBuilder,
77}
78
79impl AssumeRoleProvider {
80    /// Build a new role-assuming provider for the given role.
81    ///
82    /// The `role` argument should take the form an Amazon Resource Name (ARN) like
83    ///
84    /// ```text
85    /// arn:aws:iam::123456789012:role/example
86    /// ```
87    pub fn builder(role: impl Into<String>) -> AssumeRoleProviderBuilder {
88        AssumeRoleProviderBuilder::new(role.into())
89    }
90}
91
92/// A builder for [`AssumeRoleProvider`].
93///
94/// Construct one through [`AssumeRoleProvider::builder`].
95#[derive(Debug)]
96pub struct AssumeRoleProviderBuilder {
97    role_arn: String,
98    external_id: Option<String>,
99    session_name: Option<String>,
100    session_length: Option<Duration>,
101    policy: Option<String>,
102    policy_arns: Option<Vec<PolicyDescriptorType>>,
103    region_override: Option<Region>,
104    sdk_config: Option<SdkConfig>,
105}
106
107impl AssumeRoleProviderBuilder {
108    /// Start a new assume role builder for the given role.
109    ///
110    /// The `role` argument should take the form an Amazon Resource Name (ARN) like
111    ///
112    /// ```text
113    /// arn:aws:iam::123456789012:role/example
114    /// ```
115    pub fn new(role: impl Into<String>) -> Self {
116        Self {
117            role_arn: role.into(),
118            external_id: None,
119            session_name: None,
120            session_length: None,
121            policy: None,
122            policy_arns: None,
123            sdk_config: None,
124            region_override: None,
125        }
126    }
127
128    /// Set a unique identifier that might be required when you assume a role in another account.
129    ///
130    /// If the administrator of the account to which the role belongs provided you with an external
131    /// ID, then provide that value in this parameter. The value can be any string, such as a
132    /// passphrase or account number.
133    pub fn external_id(mut self, id: impl Into<String>) -> Self {
134        self.external_id = Some(id.into());
135        self
136    }
137
138    /// Set an identifier for the assumed role session.
139    ///
140    /// Use the role session name to uniquely identify a session when the same role is assumed by
141    /// different principals or for different reasons. In cross-account scenarios, the role session
142    /// name is visible to, and can be logged by the account that owns the role. The role session
143    /// name is also used in the ARN of the assumed role principal.
144    pub fn session_name(mut self, name: impl Into<String>) -> Self {
145        self.session_name = Some(name.into());
146        self
147    }
148
149    /// Set an IAM policy in JSON format that you want to use as an inline session policy.
150    ///
151    /// This parameter is optional
152    /// For more information, see
153    /// [policy](aws_sdk_sts::operation::assume_role::builders::AssumeRoleInputBuilder::policy_arns)
154    pub fn policy(mut self, policy: impl Into<String>) -> Self {
155        self.policy = Some(policy.into());
156        self
157    }
158
159    /// Set the Amazon Resource Names (ARNs) of the IAM managed policies that you want to use as managed session policies.
160    ///
161    /// This parameter is optional.
162    /// For more information, see
163    /// [policy_arns](aws_sdk_sts::operation::assume_role::builders::AssumeRoleInputBuilder::policy_arns)
164    pub fn policy_arns(mut self, policy_arns: Vec<String>) -> Self {
165        self.policy_arns = Some(
166            policy_arns
167                .into_iter()
168                .map(|arn| PolicyDescriptorType::builder().arn(arn).build())
169                .collect::<Vec<_>>(),
170        );
171        self
172    }
173
174    /// Set the expiration time of the role session.
175    ///
176    /// When unset, this value defaults to 1 hour.
177    ///
178    /// The value specified can range from 900 seconds (15 minutes) up to the maximum session duration
179    /// set for the role. The maximum session duration setting can have a value from 1 hour to 12 hours.
180    /// If you specify a value higher than this setting or the administrator setting (whichever is lower),
181    /// **you will be unable to assume the role**. For example, if you specify a session duration of 12 hours,
182    /// but your administrator set the maximum session duration to 6 hours, you cannot assume the role.
183    ///
184    /// For more information, see
185    /// [duration_seconds](aws_sdk_sts::operation::assume_role::builders::AssumeRoleInputBuilder::duration_seconds)
186    pub fn session_length(mut self, length: Duration) -> Self {
187        self.session_length = Some(length);
188        self
189    }
190
191    /// Set the region to assume the role in.
192    ///
193    /// This dictates which STS endpoint the AssumeRole action is invoked on. This will override
194    /// a region set from `.configure(...)`
195    pub fn region(mut self, region: Region) -> Self {
196        self.region_override = Some(region);
197        self
198    }
199
200    /// Sets the configuration used for this provider
201    ///
202    /// This enables overriding the connection used to communicate with STS in addition to other internal
203    /// fields like the time source and sleep implementation used for caching.
204    ///
205    /// If this field is not provided, configuration from [`aws_config::load_from_env().await`] is used.
206    ///
207    /// # Examples
208    /// ```rust
209    /// # async fn docs() {
210    /// use aws_types::region::Region;
211    /// use aws_config::sts::AssumeRoleProvider;
212    /// let config = aws_config::from_env().region(Region::from_static("us-west-2")).load().await;
213    /// let assume_role_provider = AssumeRoleProvider::builder("arn:aws:iam::123456789012:role/example")
214    ///   .configure(&config)
215    ///   .build();
216    /// }
217    pub fn configure(mut self, conf: &SdkConfig) -> Self {
218        self.sdk_config = Some(conf.clone());
219        self
220    }
221
222    /// Build a credentials provider for this role.
223    ///
224    /// Base credentials will be used from the [`SdkConfig`] set via [`Self::configure`] or loaded
225    /// from [`aws_config::from_env`](crate::from_env) if `configure` was never called.
226    pub async fn build(self) -> AssumeRoleProvider {
227        let mut conf = match self.sdk_config {
228            Some(conf) => conf,
229            None => crate::load_defaults(crate::BehaviorVersion::latest()).await,
230        };
231        // ignore a identity cache set from SdkConfig
232        conf = conf
233            .into_builder()
234            .identity_cache(IdentityCache::no_cache())
235            .build();
236
237        // set a region override if one exists
238        if let Some(region) = self.region_override {
239            conf = conf.into_builder().region(region).build()
240        }
241
242        let config = aws_sdk_sts::config::Builder::from(&conf);
243
244        let time_source = conf.time_source().expect("A time source must be provided.");
245
246        let session_name = self.session_name.unwrap_or_else(|| {
247            super::util::default_session_name("assume-role-provider", time_source.now())
248        });
249
250        let sts_client = StsClient::from_conf(config.build());
251        let fluent_builder = sts_client
252            .assume_role()
253            .set_role_arn(Some(self.role_arn))
254            .set_external_id(self.external_id)
255            .set_role_session_name(Some(session_name))
256            .set_policy(self.policy)
257            .set_policy_arns(self.policy_arns)
258            .set_duration_seconds(self.session_length.map(|dur| dur.as_secs() as i32));
259
260        AssumeRoleProvider {
261            inner: Inner { fluent_builder },
262        }
263    }
264
265    /// Build a credentials provider for this role authorized by the given `provider`.
266    pub async fn build_from_provider(
267        mut self,
268        provider: impl ProvideCredentials + 'static,
269    ) -> AssumeRoleProvider {
270        let conf = match self.sdk_config {
271            Some(conf) => conf,
272            None => crate::load_defaults(crate::BehaviorVersion::latest()).await,
273        };
274        let conf = conf
275            .into_builder()
276            .credentials_provider(SharedCredentialsProvider::new(provider))
277            .build();
278        self.sdk_config = Some(conf);
279        self.build().await
280    }
281}
282
283impl Inner {
284    async fn credentials(&self) -> provider::Result {
285        tracing::debug!("retrieving assumed credentials");
286
287        let assumed = self.fluent_builder.clone().send().in_current_span().await;
288        match assumed {
289            Ok(assumed) => {
290                tracing::debug!(
291                    access_key_id = ?assumed.credentials.as_ref().map(|c| &c.access_key_id),
292                    "obtained assumed credentials"
293                );
294                super::util::into_credentials(assumed.credentials, "AssumeRoleProvider")
295            }
296            Err(SdkError::ServiceError(ref context))
297                if matches!(
298                    context.err(),
299                    AssumeRoleError::RegionDisabledException(_)
300                        | AssumeRoleError::MalformedPolicyDocumentException(_)
301                ) =>
302            {
303                Err(CredentialsError::invalid_configuration(
304                    assumed.err().unwrap(),
305                ))
306            }
307            Err(SdkError::ServiceError(ref context)) => {
308                tracing::warn!(error = %DisplayErrorContext(context.err()), "STS refused to grant assume role");
309                Err(CredentialsError::provider_error(assumed.err().unwrap()))
310            }
311            Err(err) => Err(CredentialsError::provider_error(err)),
312        }
313    }
314}
315
316impl ProvideCredentials for AssumeRoleProvider {
317    fn provide_credentials<'a>(&'a self) -> future::ProvideCredentials<'a>
318    where
319        Self: 'a,
320    {
321        future::ProvideCredentials::new(
322            self.inner
323                .credentials()
324                .instrument(tracing::debug_span!("assume_role")),
325        )
326    }
327}
328
329#[cfg(test)]
330mod test {
331    use crate::sts::AssumeRoleProvider;
332    use aws_credential_types::credential_fn::provide_credentials_fn;
333    use aws_credential_types::provider::{ProvideCredentials, SharedCredentialsProvider};
334    use aws_credential_types::Credentials;
335    use aws_smithy_async::rt::sleep::{SharedAsyncSleep, TokioSleep};
336    use aws_smithy_async::test_util::instant_time_and_sleep;
337    use aws_smithy_async::time::StaticTimeSource;
338    use aws_smithy_runtime::client::http::test_util::{
339        capture_request, ReplayEvent, StaticReplayClient,
340    };
341    use aws_smithy_runtime::test_util::capture_test_logs::capture_test_logs;
342    use aws_smithy_runtime_api::client::behavior_version::BehaviorVersion;
343    use aws_smithy_types::body::SdkBody;
344    use aws_types::os_shim_internal::Env;
345    use aws_types::region::Region;
346    use aws_types::SdkConfig;
347    use http::header::AUTHORIZATION;
348    use std::time::{Duration, UNIX_EPOCH};
349
350    #[tokio::test]
351    async fn configures_session_length() {
352        let (http_client, request) = capture_request(None);
353        let sdk_config = SdkConfig::builder()
354            .sleep_impl(SharedAsyncSleep::new(TokioSleep::new()))
355            .time_source(StaticTimeSource::new(
356                UNIX_EPOCH + Duration::from_secs(1234567890 - 120),
357            ))
358            .http_client(http_client)
359            .region(Region::from_static("this-will-be-overridden"))
360            .behavior_version(crate::BehaviorVersion::latest())
361            .build();
362        let provider = AssumeRoleProvider::builder("myrole")
363            .configure(&sdk_config)
364            .region(Region::new("us-east-1"))
365            .session_length(Duration::from_secs(1234567))
366            .build_from_provider(provide_credentials_fn(|| async {
367                Ok(Credentials::for_tests())
368            }))
369            .await;
370        let _ = dbg!(provider.provide_credentials().await);
371        let req = request.expect_request();
372        let str_body = std::str::from_utf8(req.body().bytes().unwrap()).unwrap();
373        assert!(str_body.contains("1234567"), "{}", str_body);
374        assert_eq!(req.uri(), "https://sts.us-east-1.amazonaws.com/");
375    }
376
377    #[tokio::test]
378    async fn loads_region_from_sdk_config() {
379        let (http_client, request) = capture_request(None);
380        let sdk_config = SdkConfig::builder()
381            .behavior_version(crate::BehaviorVersion::latest())
382            .sleep_impl(SharedAsyncSleep::new(TokioSleep::new()))
383            .time_source(StaticTimeSource::new(
384                UNIX_EPOCH + Duration::from_secs(1234567890 - 120),
385            ))
386            .http_client(http_client)
387            .credentials_provider(SharedCredentialsProvider::new(provide_credentials_fn(
388                || async {
389                    panic!("don't call me — will be overridden");
390                },
391            )))
392            .region(Region::from_static("us-west-2"))
393            .build();
394        let provider = AssumeRoleProvider::builder("myrole")
395            .configure(&sdk_config)
396            .session_length(Duration::from_secs(1234567))
397            .build_from_provider(provide_credentials_fn(|| async {
398                Ok(Credentials::for_tests())
399            }))
400            .await;
401        let _ = dbg!(provider.provide_credentials().await);
402        let req = request.expect_request();
403        assert_eq!(req.uri(), "https://sts.us-west-2.amazonaws.com/");
404    }
405
406    /// Test that `build()` where no provider is passed still works
407    #[tokio::test]
408    async fn build_method_from_sdk_config() {
409        let _guard = capture_test_logs();
410        let (http_client, request) = capture_request(Some(
411            http::Response::builder()
412                .status(404)
413                .body(SdkBody::from(""))
414                .unwrap(),
415        ));
416        let conf = crate::defaults(BehaviorVersion::latest())
417            .env(Env::from_slice(&[
418                ("AWS_ACCESS_KEY_ID", "123-key"),
419                ("AWS_SECRET_ACCESS_KEY", "456"),
420                ("AWS_REGION", "us-west-17"),
421            ]))
422            .use_dual_stack(true)
423            .use_fips(true)
424            .time_source(StaticTimeSource::from_secs(1234567890))
425            .http_client(http_client)
426            .load()
427            .await;
428        let provider = AssumeRoleProvider::builder("role")
429            .configure(&conf)
430            .build()
431            .await;
432        let _ = dbg!(provider.provide_credentials().await);
433        let req = request.expect_request();
434        let auth_header = req.headers().get(AUTHORIZATION).unwrap().to_string();
435        let expect = "Credential=123-key/20090213/us-west-17/sts/aws4_request";
436        assert!(
437            auth_header.contains(expect),
438            "Expected header to contain {expect} but it was {auth_header}"
439        );
440        // ensure that FIPS & DualStack are also respected
441        assert_eq!("https://sts-fips.us-west-17.api.aws/", req.uri())
442    }
443
444    #[tokio::test]
445    async fn provider_does_not_cache_credentials_by_default() {
446        let http_client = StaticReplayClient::new(vec![
447            ReplayEvent::new(http::Request::new(SdkBody::from("request body")),
448            http::Response::builder().status(200).body(SdkBody::from(
449                "<AssumeRoleResponse xmlns=\"https://sts.amazonaws.com/doc/2011-06-15/\">\n  <AssumeRoleResult>\n    <AssumedRoleUser>\n      <AssumedRoleId>AROAR42TAWARILN3MNKUT:assume-role-from-profile-1632246085998</AssumedRoleId>\n      <Arn>arn:aws:sts::130633740322:assumed-role/assume-provider-test/assume-role-from-profile-1632246085998</Arn>\n    </AssumedRoleUser>\n    <Credentials>\n      <AccessKeyId>ASIARCORRECT</AccessKeyId>\n      <SecretAccessKey>secretkeycorrect</SecretAccessKey>\n      <SessionToken>tokencorrect</SessionToken>\n      <Expiration>2009-02-13T23:31:30Z</Expiration>\n    </Credentials>\n  </AssumeRoleResult>\n  <ResponseMetadata>\n    <RequestId>d9d47248-fd55-4686-ad7c-0fb7cd1cddd7</RequestId>\n  </ResponseMetadata>\n</AssumeRoleResponse>\n"
450            )).unwrap()),
451            ReplayEvent::new(http::Request::new(SdkBody::from("request body")),
452            http::Response::builder().status(200).body(SdkBody::from(
453                "<AssumeRoleResponse xmlns=\"https://sts.amazonaws.com/doc/2011-06-15/\">\n  <AssumeRoleResult>\n    <AssumedRoleUser>\n      <AssumedRoleId>AROAR42TAWARILN3MNKUT:assume-role-from-profile-1632246085998</AssumedRoleId>\n      <Arn>arn:aws:sts::130633740322:assumed-role/assume-provider-test/assume-role-from-profile-1632246085998</Arn>\n    </AssumedRoleUser>\n    <Credentials>\n      <AccessKeyId>ASIARCORRECT</AccessKeyId>\n      <SecretAccessKey>TESTSECRET</SecretAccessKey>\n      <SessionToken>tokencorrect</SessionToken>\n      <Expiration>2009-02-13T23:33:30Z</Expiration>\n    </Credentials>\n  </AssumeRoleResult>\n  <ResponseMetadata>\n    <RequestId>c2e971c2-702d-4124-9b1f-1670febbea18</RequestId>\n  </ResponseMetadata>\n</AssumeRoleResponse>\n"
454            )).unwrap()),
455        ]);
456
457        let (testing_time_source, sleep) = instant_time_and_sleep(
458            UNIX_EPOCH + Duration::from_secs(1234567890 - 120), // 1234567890 since UNIX_EPOCH is 2009-02-13T23:31:30Z
459        );
460
461        let sdk_config = SdkConfig::builder()
462            .sleep_impl(SharedAsyncSleep::new(sleep))
463            .time_source(testing_time_source.clone())
464            .http_client(http_client)
465            .behavior_version(crate::BehaviorVersion::latest())
466            .build();
467        let credentials_list = std::sync::Arc::new(std::sync::Mutex::new(vec![
468            Credentials::new(
469                "test",
470                "test",
471                None,
472                Some(UNIX_EPOCH + Duration::from_secs(1234567890 + 1)),
473                "test",
474            ),
475            Credentials::new(
476                "test",
477                "test",
478                None,
479                Some(UNIX_EPOCH + Duration::from_secs(1234567890 + 120)),
480                "test",
481            ),
482        ]));
483        let credentials_list_cloned = credentials_list.clone();
484        let provider = AssumeRoleProvider::builder("myrole")
485            .configure(&sdk_config)
486            .region(Region::new("us-east-1"))
487            .build_from_provider(provide_credentials_fn(move || {
488                let list = credentials_list.clone();
489                async move {
490                    let next = list.lock().unwrap().remove(0);
491                    Ok(next)
492                }
493            }))
494            .await;
495
496        let creds_first = provider
497            .provide_credentials()
498            .await
499            .expect("should return valid credentials");
500
501        // After time has been advanced by 120 seconds, the first credentials _could_ still be valid
502        // if `LazyCredentialsCache` were used, but the provider uses `NoCredentialsCache` by default
503        // so the first credentials will not be used.
504        testing_time_source.advance(Duration::from_secs(120));
505
506        let creds_second = provider
507            .provide_credentials()
508            .await
509            .expect("should return the second credentials");
510        assert_ne!(creds_first, creds_second);
511        assert!(credentials_list_cloned.lock().unwrap().is_empty());
512    }
513}