aws_config/
ecs.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6//! Ecs Credentials Provider
7//!
8//! This credential provider is frequently used with an AWS-provided credentials service (e.g.
9//! [IAM Roles for tasks](https://docs.aws.amazon.com/AmazonECS/latest/developerguide/task-iam-roles.html)).
10//! However, it's possible to use environment variables to configure this provider to use your own
11//! credentials sources.
12//!
13//! This provider is part of the [default credentials chain](crate::default_provider::credentials).
14//!
15//! ## Configuration
16//! **First**: It will check the value of `$AWS_CONTAINER_CREDENTIALS_RELATIVE_URI`. It will use this
17//! to construct a URI rooted at `http://169.254.170.2`. For example, if the value of the environment
18//! variable was `/credentials`, the SDK would look for credentials at `http://169.254.170.2/credentials`.
19//!
20//! **Next**: It will check the value of `$AWS_CONTAINER_CREDENTIALS_FULL_URI`. This specifies the full
21//! URL to load credentials. The URL MUST satisfy one of the following three properties:
22//! 1. The URL begins with `https`
23//! 2. The URL refers to an allowed IP address. If a URL contains a domain name instead of an IP address,
24//!    a DNS lookup will be performed. ALL resolved IP addresses MUST refer to an allowed IP address, or
25//!    the credentials provider will return `CredentialsError::InvalidConfiguration`. Valid IP addresses are:
26//!     a) Loopback interfaces
27//!     b) The [ECS Task Metadata V2](https://docs.aws.amazon.com/AmazonECS/latest/developerguide/task-metadata-endpoint-v2.html)
28//!        address ie 169.254.170.2.
29//!     c) [EKS Pod Identity](https://docs.aws.amazon.com/eks/latest/userguide/pod-identities.html) addresses
30//!        ie 169.254.170.23 or fd00:ec2::23
31//!
32//! **Next**: It will check the value of `$AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE`. If this is set,
33//! the filename specified will be read, and the value passed in the `Authorization` header. If the file
34//! cannot be read, an error is returned.
35//!
36//! **Finally**: It will check the value of `$AWS_CONTAINER_AUTHORIZATION_TOKEN`. If this is set, the
37//! value will be passed in the `Authorization` header.
38//!
39//! ## Credentials Format
40//! Credentials MUST be returned in a JSON format:
41//! ```json
42//! {
43//!    "AccessKeyId" : "MUA...",
44//!    "SecretAccessKey" : "/7PC5om....",
45//!    "Token" : "AQoDY....=",
46//!    "Expiration" : "2016-02-25T06:03:31Z"
47//!  }
48//! ```
49//!
50//! Credentials errors MAY be returned with a `code` and `message` field:
51//! ```json
52//! {
53//!   "code": "ErrorCode",
54//!   "message": "Helpful error message."
55//! }
56//! ```
57
58use crate::http_credential_provider::HttpCredentialProvider;
59use crate::provider_config::ProviderConfig;
60use aws_credential_types::provider::{self, error::CredentialsError, future, ProvideCredentials};
61use aws_smithy_runtime::client::endpoint::apply_endpoint;
62use aws_smithy_runtime_api::client::dns::{ResolveDns, ResolveDnsError, SharedDnsResolver};
63use aws_smithy_runtime_api::client::http::HttpConnectorSettings;
64use aws_smithy_runtime_api::shared::IntoShared;
65use aws_smithy_types::error::display::DisplayErrorContext;
66use aws_types::os_shim_internal::{Env, Fs};
67use http::header::InvalidHeaderValue;
68use http::uri::{InvalidUri, PathAndQuery, Scheme};
69use http::{HeaderValue, Uri};
70use std::error::Error;
71use std::fmt::{Display, Formatter};
72use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
73use std::time::Duration;
74use tokio::sync::OnceCell;
75
76const DEFAULT_READ_TIMEOUT: Duration = Duration::from_secs(5);
77const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(2);
78
79// URL from https://docs.aws.amazon.com/AmazonECS/latest/developerguide/task-metadata-endpoint-v2.html
80const BASE_HOST: &str = "http://169.254.170.2";
81const ENV_RELATIVE_URI: &str = "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI";
82const ENV_FULL_URI: &str = "AWS_CONTAINER_CREDENTIALS_FULL_URI";
83const ENV_AUTHORIZATION_TOKEN: &str = "AWS_CONTAINER_AUTHORIZATION_TOKEN";
84const ENV_AUTHORIZATION_TOKEN_FILE: &str = "AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE";
85
86/// Credential provider for ECS and generalized HTTP credentials
87///
88/// See the [module](crate::ecs) documentation for more details.
89///
90/// This credential provider is part of the default chain.
91#[derive(Debug)]
92pub struct EcsCredentialsProvider {
93    inner: OnceCell<Provider>,
94    env: Env,
95    fs: Fs,
96    builder: Builder,
97}
98
99impl EcsCredentialsProvider {
100    /// Builder for [`EcsCredentialsProvider`]
101    pub fn builder() -> Builder {
102        Builder::default()
103    }
104
105    /// Load credentials from this credentials provider
106    pub async fn credentials(&self) -> provider::Result {
107        let env_token_file = self.env.get(ENV_AUTHORIZATION_TOKEN_FILE).ok();
108        let env_token = self.env.get(ENV_AUTHORIZATION_TOKEN).ok();
109        let auth = if let Some(auth_token_file) = env_token_file {
110            let auth = self
111                .fs
112                .read_to_end(auth_token_file)
113                .await
114                .map_err(CredentialsError::provider_error)?;
115            Some(HeaderValue::from_bytes(auth.as_slice()).map_err(|err| {
116                let auth_token = String::from_utf8_lossy(auth.as_slice()).to_string();
117                tracing::warn!(token = %auth_token, "invalid auth token");
118                CredentialsError::invalid_configuration(EcsConfigurationError::InvalidAuthToken {
119                    err,
120                    value: auth_token,
121                })
122            })?)
123        } else if let Some(auth_token) = env_token {
124            Some(HeaderValue::from_str(&auth_token).map_err(|err| {
125                tracing::warn!(token = %auth_token, "invalid auth token");
126                CredentialsError::invalid_configuration(EcsConfigurationError::InvalidAuthToken {
127                    err,
128                    value: auth_token,
129                })
130            })?)
131        } else {
132            None
133        };
134        match self.provider().await {
135            Provider::NotConfigured => {
136                Err(CredentialsError::not_loaded("ECS provider not configured"))
137            }
138            Provider::InvalidConfiguration(err) => {
139                Err(CredentialsError::invalid_configuration(format!("{}", err)))
140            }
141            Provider::Configured(provider) => provider.credentials(auth).await,
142        }
143    }
144
145    async fn provider(&self) -> &Provider {
146        self.inner
147            .get_or_init(|| Provider::make(self.builder.clone()))
148            .await
149    }
150}
151
152impl ProvideCredentials for EcsCredentialsProvider {
153    fn provide_credentials<'a>(&'a self) -> future::ProvideCredentials<'a>
154    where
155        Self: 'a,
156    {
157        future::ProvideCredentials::new(self.credentials())
158    }
159}
160
161/// Inner Provider that can record failed configuration state
162#[derive(Debug)]
163#[allow(clippy::large_enum_variant)]
164enum Provider {
165    Configured(HttpCredentialProvider),
166    NotConfigured,
167    InvalidConfiguration(EcsConfigurationError),
168}
169
170impl Provider {
171    async fn uri(env: Env, dns: Option<SharedDnsResolver>) -> Result<Uri, EcsConfigurationError> {
172        let relative_uri = env.get(ENV_RELATIVE_URI).ok();
173        let full_uri = env.get(ENV_FULL_URI).ok();
174        if let Some(relative_uri) = relative_uri {
175            Self::build_full_uri(relative_uri)
176        } else if let Some(full_uri) = full_uri {
177            let dns = dns.or_else(default_dns);
178            validate_full_uri(&full_uri, dns)
179                .await
180                .map_err(|err| EcsConfigurationError::InvalidFullUri { err, uri: full_uri })
181        } else {
182            Err(EcsConfigurationError::NotConfigured)
183        }
184    }
185
186    async fn make(builder: Builder) -> Self {
187        let provider_config = builder.provider_config.unwrap_or_default();
188        let env = provider_config.env();
189        let uri = match Self::uri(env, builder.dns).await {
190            Ok(uri) => uri,
191            Err(EcsConfigurationError::NotConfigured) => return Provider::NotConfigured,
192            Err(err) => return Provider::InvalidConfiguration(err),
193        };
194        let path_and_query = match uri.path_and_query() {
195            Some(path_and_query) => path_and_query.to_string(),
196            None => uri.path().to_string(),
197        };
198        let endpoint = {
199            let mut parts = uri.into_parts();
200            parts.path_and_query = Some(PathAndQuery::from_static("/"));
201            Uri::from_parts(parts)
202        }
203        .expect("parts will be valid")
204        .to_string();
205
206        let http_provider = HttpCredentialProvider::builder()
207            .configure(&provider_config)
208            .http_connector_settings(
209                HttpConnectorSettings::builder()
210                    .connect_timeout(DEFAULT_CONNECT_TIMEOUT)
211                    .read_timeout(DEFAULT_READ_TIMEOUT)
212                    .build(),
213            )
214            .build("EcsContainer", &endpoint, path_and_query);
215        Provider::Configured(http_provider)
216    }
217
218    fn build_full_uri(relative_uri: String) -> Result<Uri, EcsConfigurationError> {
219        let mut relative_uri = match relative_uri.parse::<Uri>() {
220            Ok(uri) => uri,
221            Err(invalid_uri) => {
222                tracing::warn!(uri = %DisplayErrorContext(&invalid_uri), "invalid URI loaded from environment");
223                return Err(EcsConfigurationError::InvalidRelativeUri {
224                    err: invalid_uri,
225                    uri: relative_uri,
226                });
227            }
228        };
229        let endpoint = Uri::from_static(BASE_HOST);
230        apply_endpoint(&mut relative_uri, &endpoint, None)
231            .expect("appending relative URLs to the ECS endpoint should always succeed");
232        Ok(relative_uri)
233    }
234}
235
236#[derive(Debug)]
237enum EcsConfigurationError {
238    InvalidRelativeUri {
239        err: InvalidUri,
240        uri: String,
241    },
242    InvalidFullUri {
243        err: InvalidFullUriError,
244        uri: String,
245    },
246    InvalidAuthToken {
247        err: InvalidHeaderValue,
248        value: String,
249    },
250    NotConfigured,
251}
252
253impl Display for EcsConfigurationError {
254    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
255        match self {
256            EcsConfigurationError::InvalidRelativeUri { err, uri } => write!(
257                f,
258                "invalid relative URI for ECS provider ({}): {}",
259                err, uri
260            ),
261            EcsConfigurationError::InvalidFullUri { err, uri } => {
262                write!(f, "invalid full URI for ECS provider ({}): {}", err, uri)
263            }
264            EcsConfigurationError::NotConfigured => write!(
265                f,
266                "No environment variables were set to configure ECS provider"
267            ),
268            EcsConfigurationError::InvalidAuthToken { err, value } => write!(
269                f,
270                "`{}` could not be used as a header value for the auth token. {}",
271                value, err
272            ),
273        }
274    }
275}
276
277impl Error for EcsConfigurationError {
278    fn source(&self) -> Option<&(dyn Error + 'static)> {
279        match &self {
280            EcsConfigurationError::InvalidRelativeUri { err, .. } => Some(err),
281            EcsConfigurationError::InvalidFullUri { err, .. } => Some(err),
282            EcsConfigurationError::InvalidAuthToken { err, .. } => Some(err),
283            EcsConfigurationError::NotConfigured => None,
284        }
285    }
286}
287
288/// Builder for [`EcsCredentialsProvider`]
289#[derive(Default, Debug, Clone)]
290pub struct Builder {
291    provider_config: Option<ProviderConfig>,
292    dns: Option<SharedDnsResolver>,
293    connect_timeout: Option<Duration>,
294    read_timeout: Option<Duration>,
295}
296
297impl Builder {
298    /// Override the configuration used for this provider
299    pub fn configure(mut self, provider_config: &ProviderConfig) -> Self {
300        self.provider_config = Some(provider_config.clone());
301        self
302    }
303
304    /// Override the DNS resolver used to validate URIs
305    ///
306    /// URIs must refer to valid IP addresses as defined in the module documentation. The [`ResolveDns`]
307    /// implementation is used to retrieve IP addresses for a given domain.
308    pub fn dns(mut self, dns: impl ResolveDns + 'static) -> Self {
309        self.dns = Some(dns.into_shared());
310        self
311    }
312
313    /// Override the connect timeout for the HTTP client
314    ///
315    /// This value defaults to 2 seconds
316    pub fn connect_timeout(mut self, timeout: Duration) -> Self {
317        self.connect_timeout = Some(timeout);
318        self
319    }
320
321    /// Override the read timeout for the HTTP client
322    ///
323    /// This value defaults to 5 seconds
324    pub fn read_timeout(mut self, timeout: Duration) -> Self {
325        self.read_timeout = Some(timeout);
326        self
327    }
328
329    /// Create an [`EcsCredentialsProvider`] from this builder
330    pub fn build(self) -> EcsCredentialsProvider {
331        let env = self
332            .provider_config
333            .as_ref()
334            .map(|config| config.env())
335            .unwrap_or_default();
336        let fs = self
337            .provider_config
338            .as_ref()
339            .map(|config| config.fs())
340            .unwrap_or_default();
341        EcsCredentialsProvider {
342            inner: OnceCell::new(),
343            env,
344            fs,
345            builder: self,
346        }
347    }
348}
349
350#[derive(Debug)]
351enum InvalidFullUriErrorKind {
352    /// The provided URI could not be parsed as a URI
353    #[non_exhaustive]
354    InvalidUri(InvalidUri),
355
356    /// No Dns resolver was provided
357    #[non_exhaustive]
358    NoDnsResolver,
359
360    /// The URI did not specify a host
361    #[non_exhaustive]
362    MissingHost,
363
364    /// The URI did not refer to an allowed IP address
365    #[non_exhaustive]
366    DisallowedIP,
367
368    /// DNS lookup failed when attempting to resolve the host to an IP Address for validation.
369    DnsLookupFailed(ResolveDnsError),
370}
371
372/// Invalid Full URI
373///
374/// When the full URI setting is used, the URI must either be HTTPS, point to a loopback interface,
375/// or point to known ECS/EKS container IPs.
376#[derive(Debug)]
377pub struct InvalidFullUriError {
378    kind: InvalidFullUriErrorKind,
379}
380
381impl Display for InvalidFullUriError {
382    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
383        use InvalidFullUriErrorKind::*;
384        match self.kind {
385            InvalidUri(_) => write!(f, "URI was invalid"),
386            MissingHost => write!(f, "URI did not specify a host"),
387            DisallowedIP => {
388                write!(f, "URI did not refer to an allowed IP address")
389            }
390            DnsLookupFailed(_) => {
391                write!(
392                    f,
393                    "failed to perform DNS lookup while validating URI"
394                )
395            }
396            NoDnsResolver => write!(f, "no DNS resolver was provided. Enable `rt-tokio` or provide a `dns` resolver to the builder.")
397        }
398    }
399}
400
401impl Error for InvalidFullUriError {
402    fn source(&self) -> Option<&(dyn Error + 'static)> {
403        use InvalidFullUriErrorKind::*;
404        match &self.kind {
405            InvalidUri(err) => Some(err),
406            DnsLookupFailed(err) => Some(err as _),
407            _ => None,
408        }
409    }
410}
411
412impl From<InvalidFullUriErrorKind> for InvalidFullUriError {
413    fn from(kind: InvalidFullUriErrorKind) -> Self {
414        Self { kind }
415    }
416}
417
418/// Validate that `uri` is valid to be used as a full provider URI
419/// Either:
420/// 1. The URL is uses `https`
421/// 2. The URL refers to an allowed IP. If a URL contains a domain name instead of an IP address,
422///    a DNS lookup will be performed. ALL resolved IP addresses MUST refer to an allowed IP, or
423///    the credentials provider will return `CredentialsError::InvalidConfiguration`. Allowed IPs
424///    are the loopback interfaces, and the known ECS/EKS container IPs.
425async fn validate_full_uri(
426    uri: &str,
427    dns: Option<SharedDnsResolver>,
428) -> Result<Uri, InvalidFullUriError> {
429    let uri = uri
430        .parse::<Uri>()
431        .map_err(InvalidFullUriErrorKind::InvalidUri)?;
432    if uri.scheme() == Some(&Scheme::HTTPS) {
433        return Ok(uri);
434    }
435    // For HTTP URIs, we need to validate that it points to a valid IP
436    let host = uri.host().ok_or(InvalidFullUriErrorKind::MissingHost)?;
437    let maybe_ip = if host.starts_with('[') && host.ends_with(']') {
438        host[1..host.len() - 1].parse::<IpAddr>()
439    } else {
440        host.parse::<IpAddr>()
441    };
442    let is_allowed = match maybe_ip {
443        Ok(addr) => is_full_uri_ip_allowed(&addr),
444        Err(_domain_name) => {
445            let dns = dns.ok_or(InvalidFullUriErrorKind::NoDnsResolver)?;
446            dns.resolve_dns(host)
447                .await
448                .map_err(|err| InvalidFullUriErrorKind::DnsLookupFailed(ResolveDnsError::new(err)))?
449                .iter()
450                    .all(|addr| {
451                        if !is_full_uri_ip_allowed(addr) {
452                            tracing::warn!(
453                                addr = ?addr,
454                                "HTTP credential provider cannot be used: Address does not resolve to an allowed IP."
455                            )
456                        };
457                        is_full_uri_ip_allowed(addr)
458                    })
459        }
460    };
461    match is_allowed {
462        true => Ok(uri),
463        false => Err(InvalidFullUriErrorKind::DisallowedIP.into()),
464    }
465}
466
467// "169.254.170.2"
468const ECS_CONTAINER_IPV4: IpAddr = IpAddr::V4(Ipv4Addr::new(169, 254, 170, 2));
469
470// "169.254.170.23"
471const EKS_CONTAINER_IPV4: IpAddr = IpAddr::V4(Ipv4Addr::new(169, 254, 170, 23));
472
473// "fd00:ec2::23"
474const EKS_CONTAINER_IPV6: IpAddr = IpAddr::V6(Ipv6Addr::new(0xFD00, 0x0EC2, 0, 0, 0, 0, 0, 0x23));
475fn is_full_uri_ip_allowed(ip: &IpAddr) -> bool {
476    ip.is_loopback()
477        || ip.eq(&ECS_CONTAINER_IPV4)
478        || ip.eq(&EKS_CONTAINER_IPV4)
479        || ip.eq(&EKS_CONTAINER_IPV6)
480}
481
482/// Default DNS resolver impl
483///
484/// DNS resolution is required to validate that provided URIs point to a valid IP address
485#[cfg(any(not(feature = "rt-tokio"), target_family = "wasm"))]
486fn default_dns() -> Option<SharedDnsResolver> {
487    None
488}
489#[cfg(all(feature = "rt-tokio", not(target_family = "wasm")))]
490fn default_dns() -> Option<SharedDnsResolver> {
491    use aws_smithy_runtime::client::dns::TokioDnsResolver;
492    Some(TokioDnsResolver::new().into_shared())
493}
494
495#[cfg(test)]
496mod test {
497    use super::*;
498    use crate::provider_config::ProviderConfig;
499    use crate::test_case::{no_traffic_client, GenericTestResult};
500    use aws_credential_types::provider::ProvideCredentials;
501    use aws_credential_types::Credentials;
502    use aws_smithy_async::future::never::Never;
503    use aws_smithy_async::rt::sleep::TokioSleep;
504    use aws_smithy_runtime::client::http::test_util::{ReplayEvent, StaticReplayClient};
505    use aws_smithy_runtime_api::client::dns::DnsFuture;
506    use aws_smithy_runtime_api::client::http::HttpClient;
507    use aws_smithy_runtime_api::shared::IntoShared;
508    use aws_smithy_types::body::SdkBody;
509    use aws_types::os_shim_internal::Env;
510    use futures_util::FutureExt;
511    use http::header::AUTHORIZATION;
512    use http::Uri;
513    use serde::Deserialize;
514    use std::collections::HashMap;
515    use std::error::Error;
516    use std::ffi::OsString;
517    use std::net::IpAddr;
518    use std::time::{Duration, UNIX_EPOCH};
519    use tracing_test::traced_test;
520
521    fn provider(
522        env: Env,
523        fs: Fs,
524        http_client: impl HttpClient + 'static,
525    ) -> EcsCredentialsProvider {
526        let provider_config = ProviderConfig::empty()
527            .with_env(env)
528            .with_fs(fs)
529            .with_http_client(http_client)
530            .with_sleep_impl(TokioSleep::new());
531        Builder::default().configure(&provider_config).build()
532    }
533
534    #[derive(Deserialize)]
535    struct EcsUriTest {
536        env: HashMap<String, String>,
537        result: GenericTestResult<String>,
538    }
539
540    impl EcsUriTest {
541        async fn check(&self) {
542            let env = Env::from(self.env.clone());
543            let uri = Provider::uri(env, Some(TestDns::default().into_shared()))
544                .await
545                .map(|uri| uri.to_string());
546            self.result.assert_matches(uri.as_ref());
547        }
548    }
549
550    #[tokio::test]
551    async fn run_config_tests() -> Result<(), Box<dyn Error>> {
552        let test_cases = std::fs::read_to_string("test-data/ecs-tests.json")?;
553        #[derive(Deserialize)]
554        struct TestCases {
555            tests: Vec<EcsUriTest>,
556        }
557
558        let test_cases: TestCases = serde_json::from_str(&test_cases)?;
559        let test_cases = test_cases.tests;
560        for test in test_cases {
561            test.check().await
562        }
563        Ok(())
564    }
565
566    #[test]
567    fn validate_uri_https() {
568        // over HTTPs, any URI is fine
569        let dns = Some(NeverDns.into_shared());
570        assert_eq!(
571            validate_full_uri("https://amazon.com", None)
572                .now_or_never()
573                .unwrap()
574                .expect("valid"),
575            Uri::from_static("https://amazon.com")
576        );
577        // over HTTP, it will try to lookup
578        assert!(
579            validate_full_uri("http://amazon.com", dns)
580                .now_or_never()
581                .is_none(),
582            "DNS lookup should occur, but it will never return"
583        );
584
585        let no_dns_error = validate_full_uri("http://amazon.com", None)
586            .now_or_never()
587            .unwrap()
588            .expect_err("DNS service is required");
589        assert!(
590            matches!(
591                no_dns_error,
592                InvalidFullUriError {
593                    kind: InvalidFullUriErrorKind::NoDnsResolver
594                }
595            ),
596            "expected no dns service, got: {}",
597            no_dns_error
598        );
599    }
600
601    #[test]
602    fn valid_uri_loopback() {
603        assert_eq!(
604            validate_full_uri("http://127.0.0.1:8080/get-credentials", None)
605                .now_or_never()
606                .unwrap()
607                .expect("valid uri"),
608            Uri::from_static("http://127.0.0.1:8080/get-credentials")
609        );
610
611        let err = validate_full_uri("http://192.168.10.120/creds", None)
612            .now_or_never()
613            .unwrap()
614            .expect_err("not a loopback");
615        assert!(matches!(
616            err,
617            InvalidFullUriError {
618                kind: InvalidFullUriErrorKind::DisallowedIP
619            }
620        ));
621    }
622
623    #[test]
624    fn valid_uri_ecs_eks() {
625        assert_eq!(
626            validate_full_uri("http://169.254.170.2:8080/get-credentials", None)
627                .now_or_never()
628                .unwrap()
629                .expect("valid uri"),
630            Uri::from_static("http://169.254.170.2:8080/get-credentials")
631        );
632        assert_eq!(
633            validate_full_uri("http://169.254.170.23:8080/get-credentials", None)
634                .now_or_never()
635                .unwrap()
636                .expect("valid uri"),
637            Uri::from_static("http://169.254.170.23:8080/get-credentials")
638        );
639        assert_eq!(
640            validate_full_uri("http://[fd00:ec2::23]:8080/get-credentials", None)
641                .now_or_never()
642                .unwrap()
643                .expect("valid uri"),
644            Uri::from_static("http://[fd00:ec2::23]:8080/get-credentials")
645        );
646
647        let err = validate_full_uri("http://169.254.171.23/creds", None)
648            .now_or_never()
649            .unwrap()
650            .expect_err("not an ecs/eks container address");
651        assert!(matches!(
652            err,
653            InvalidFullUriError {
654                kind: InvalidFullUriErrorKind::DisallowedIP
655            }
656        ));
657
658        let err = validate_full_uri("http://[fd00:ec2::2]/creds", None)
659            .now_or_never()
660            .unwrap()
661            .expect_err("not an ecs/eks container address");
662        assert!(matches!(
663            err,
664            InvalidFullUriError {
665                kind: InvalidFullUriErrorKind::DisallowedIP
666            }
667        ));
668    }
669
670    #[test]
671    fn all_addrs_local() {
672        let dns = Some(
673            TestDns::with_fallback(vec![
674                "127.0.0.1".parse().unwrap(),
675                "127.0.0.2".parse().unwrap(),
676                "169.254.170.23".parse().unwrap(),
677                "fd00:ec2::23".parse().unwrap(),
678            ])
679            .into_shared(),
680        );
681        let resp = validate_full_uri("http://localhost:8888", dns)
682            .now_or_never()
683            .unwrap();
684        assert!(resp.is_ok(), "Should be valid: {:?}", resp);
685    }
686
687    #[test]
688    fn all_addrs_not_local() {
689        let dns = Some(
690            TestDns::with_fallback(vec![
691                "127.0.0.1".parse().unwrap(),
692                "192.168.0.1".parse().unwrap(),
693            ])
694            .into_shared(),
695        );
696        let resp = validate_full_uri("http://localhost:8888", dns)
697            .now_or_never()
698            .unwrap();
699        assert!(
700            matches!(
701                resp,
702                Err(InvalidFullUriError {
703                    kind: InvalidFullUriErrorKind::DisallowedIP
704                })
705            ),
706            "Should be invalid: {:?}",
707            resp
708        );
709    }
710
711    fn creds_request(uri: &str, auth: Option<&str>) -> http::Request<SdkBody> {
712        let mut builder = http::Request::builder();
713        if let Some(auth) = auth {
714            builder = builder.header(AUTHORIZATION, auth);
715        }
716        builder.uri(uri).body(SdkBody::empty()).unwrap()
717    }
718
719    fn ok_creds_response() -> http::Response<SdkBody> {
720        http::Response::builder()
721            .status(200)
722            .body(SdkBody::from(
723                r#" {
724                       "AccessKeyId" : "AKID",
725                       "SecretAccessKey" : "SECRET",
726                       "Token" : "TOKEN....=",
727                       "Expiration" : "2009-02-13T23:31:30Z"
728                     }"#,
729            ))
730            .unwrap()
731    }
732
733    #[track_caller]
734    fn assert_correct(creds: Credentials) {
735        assert_eq!(creds.access_key_id(), "AKID");
736        assert_eq!(creds.secret_access_key(), "SECRET");
737        assert_eq!(creds.session_token().unwrap(), "TOKEN....=");
738        assert_eq!(
739            creds.expiry().unwrap(),
740            UNIX_EPOCH + Duration::from_secs(1234567890)
741        );
742    }
743
744    #[tokio::test]
745    async fn load_valid_creds_auth() {
746        let env = Env::from_slice(&[
747            ("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", "/credentials"),
748            ("AWS_CONTAINER_AUTHORIZATION_TOKEN", "Basic password"),
749        ]);
750        let http_client = StaticReplayClient::new(vec![ReplayEvent::new(
751            creds_request("http://169.254.170.2/credentials", Some("Basic password")),
752            ok_creds_response(),
753        )]);
754        let provider = provider(env, Fs::default(), http_client.clone());
755        let creds = provider
756            .provide_credentials()
757            .await
758            .expect("valid credentials");
759        assert_correct(creds);
760        http_client.assert_requests_match(&[]);
761    }
762
763    #[tokio::test]
764    async fn load_valid_creds_auth_file() {
765        let env = Env::from_slice(&[
766            (
767                "AWS_CONTAINER_CREDENTIALS_FULL_URI",
768                "http://169.254.170.23/v1/credentials",
769            ),
770            (
771                "AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE",
772                "/var/run/secrets/pods.eks.amazonaws.com/serviceaccount/eks-pod-identity-token",
773            ),
774        ]);
775        let fs = Fs::from_raw_map(HashMap::from([(
776            OsString::from(
777                "/var/run/secrets/pods.eks.amazonaws.com/serviceaccount/eks-pod-identity-token",
778            ),
779            "Basic password".into(),
780        )]));
781
782        let http_client = StaticReplayClient::new(vec![ReplayEvent::new(
783            creds_request(
784                "http://169.254.170.23/v1/credentials",
785                Some("Basic password"),
786            ),
787            ok_creds_response(),
788        )]);
789        let provider = provider(env, fs, http_client.clone());
790        let creds = provider
791            .provide_credentials()
792            .await
793            .expect("valid credentials");
794        assert_correct(creds);
795        http_client.assert_requests_match(&[]);
796    }
797
798    #[tokio::test]
799    async fn auth_file_precedence_over_env() {
800        let env = Env::from_slice(&[
801            (
802                "AWS_CONTAINER_CREDENTIALS_FULL_URI",
803                "http://169.254.170.23/v1/credentials",
804            ),
805            (
806                "AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE",
807                "/var/run/secrets/pods.eks.amazonaws.com/serviceaccount/eks-pod-identity-token",
808            ),
809            ("AWS_CONTAINER_AUTHORIZATION_TOKEN", "unused"),
810        ]);
811        let fs = Fs::from_raw_map(HashMap::from([(
812            OsString::from(
813                "/var/run/secrets/pods.eks.amazonaws.com/serviceaccount/eks-pod-identity-token",
814            ),
815            "Basic password".into(),
816        )]));
817
818        let http_client = StaticReplayClient::new(vec![ReplayEvent::new(
819            creds_request(
820                "http://169.254.170.23/v1/credentials",
821                Some("Basic password"),
822            ),
823            ok_creds_response(),
824        )]);
825        let provider = provider(env, fs, http_client.clone());
826        let creds = provider
827            .provide_credentials()
828            .await
829            .expect("valid credentials");
830        assert_correct(creds);
831        http_client.assert_requests_match(&[]);
832    }
833
834    #[tokio::test]
835    async fn query_params_should_be_included_in_credentials_http_request() {
836        let env = Env::from_slice(&[
837            (
838                "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI",
839                "/my-credentials/?applicationName=test2024",
840            ),
841            (
842                "AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE",
843                "/var/run/secrets/pods.eks.amazonaws.com/serviceaccount/eks-pod-identity-token",
844            ),
845            ("AWS_CONTAINER_AUTHORIZATION_TOKEN", "unused"),
846        ]);
847        let fs = Fs::from_raw_map(HashMap::from([(
848            OsString::from(
849                "/var/run/secrets/pods.eks.amazonaws.com/serviceaccount/eks-pod-identity-token",
850            ),
851            "Basic password".into(),
852        )]));
853
854        let http_client = StaticReplayClient::new(vec![ReplayEvent::new(
855            creds_request(
856                "http://169.254.170.2/my-credentials/?applicationName=test2024",
857                Some("Basic password"),
858            ),
859            ok_creds_response(),
860        )]);
861        let provider = provider(env, fs, http_client.clone());
862        let creds = provider
863            .provide_credentials()
864            .await
865            .expect("valid credentials");
866        assert_correct(creds);
867        http_client.assert_requests_match(&[]);
868    }
869
870    #[tokio::test]
871    async fn fs_missing_file() {
872        let env = Env::from_slice(&[
873            (
874                "AWS_CONTAINER_CREDENTIALS_FULL_URI",
875                "http://169.254.170.23/v1/credentials",
876            ),
877            (
878                "AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE",
879                "/var/run/secrets/pods.eks.amazonaws.com/serviceaccount/eks-pod-identity-token",
880            ),
881        ]);
882        let fs = Fs::from_raw_map(HashMap::new());
883
884        let provider = provider(env, fs, no_traffic_client());
885        let err = provider.credentials().await.expect_err("no JWT token file");
886        match err {
887            CredentialsError::ProviderError { .. } => { /* ok */ }
888            _ => panic!("incorrect error variant"),
889        }
890    }
891
892    #[tokio::test]
893    async fn retry_5xx() {
894        let env = Env::from_slice(&[("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", "/credentials")]);
895        let http_client = StaticReplayClient::new(vec![
896            ReplayEvent::new(
897                creds_request("http://169.254.170.2/credentials", None),
898                http::Response::builder()
899                    .status(500)
900                    .body(SdkBody::empty())
901                    .unwrap(),
902            ),
903            ReplayEvent::new(
904                creds_request("http://169.254.170.2/credentials", None),
905                ok_creds_response(),
906            ),
907        ]);
908        tokio::time::pause();
909        let provider = provider(env, Fs::default(), http_client.clone());
910        let creds = provider
911            .provide_credentials()
912            .await
913            .expect("valid credentials");
914        assert_correct(creds);
915    }
916
917    #[tokio::test]
918    async fn load_valid_creds_no_auth() {
919        let env = Env::from_slice(&[("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", "/credentials")]);
920        let http_client = StaticReplayClient::new(vec![ReplayEvent::new(
921            creds_request("http://169.254.170.2/credentials", None),
922            ok_creds_response(),
923        )]);
924        let provider = provider(env, Fs::default(), http_client.clone());
925        let creds = provider
926            .provide_credentials()
927            .await
928            .expect("valid credentials");
929        assert_correct(creds);
930        http_client.assert_requests_match(&[]);
931    }
932
933    // ignored by default because it relies on actual DNS resolution
934    #[allow(unused_attributes)]
935    #[tokio::test]
936    #[traced_test]
937    #[ignore]
938    async fn real_dns_lookup() {
939        let dns = Some(
940            default_dns()
941                .expect("feature must be enabled")
942                .into_shared(),
943        );
944        let err = validate_full_uri("http://www.amazon.com/creds", dns.clone())
945            .await
946            .expect_err("not a valid IP");
947        assert!(
948            matches!(
949                err,
950                InvalidFullUriError {
951                    kind: InvalidFullUriErrorKind::DisallowedIP
952                }
953            ),
954            "{:?}",
955            err
956        );
957        assert!(logs_contain("Address does not resolve to an allowed IP"));
958        validate_full_uri("http://localhost:8888/creds", dns.clone())
959            .await
960            .expect("localhost is the loopback interface");
961        validate_full_uri("http://169.254.170.2.backname.io:8888/creds", dns.clone())
962            .await
963            .expect("169.254.170.2.backname.io is the ecs container address");
964        validate_full_uri("http://169.254.170.23.backname.io:8888/creds", dns.clone())
965            .await
966            .expect("169.254.170.23.backname.io is the eks pod identity address");
967        validate_full_uri("http://fd00-ec2--23.backname.io:8888/creds", dns)
968            .await
969            .expect("fd00-ec2--23.backname.io is the eks pod identity address");
970    }
971
972    /// Always returns the same IP addresses
973    #[derive(Clone, Debug)]
974    struct TestDns {
975        addrs: HashMap<String, Vec<IpAddr>>,
976        fallback: Vec<IpAddr>,
977    }
978
979    /// Default that returns a loopback for `localhost` and a non-loopback for all other hostnames
980    impl Default for TestDns {
981        fn default() -> Self {
982            let mut addrs = HashMap::new();
983            addrs.insert(
984                "localhost".into(),
985                vec!["127.0.0.1".parse().unwrap(), "127.0.0.2".parse().unwrap()],
986            );
987            TestDns {
988                addrs,
989                // non-loopback address
990                fallback: vec!["72.21.210.29".parse().unwrap()],
991            }
992        }
993    }
994
995    impl TestDns {
996        fn with_fallback(fallback: Vec<IpAddr>) -> Self {
997            TestDns {
998                addrs: Default::default(),
999                fallback,
1000            }
1001        }
1002    }
1003
1004    impl ResolveDns for TestDns {
1005        fn resolve_dns<'a>(&'a self, name: &'a str) -> DnsFuture<'a> {
1006            DnsFuture::ready(Ok(self.addrs.get(name).unwrap_or(&self.fallback).clone()))
1007        }
1008    }
1009
1010    #[derive(Debug)]
1011    struct NeverDns;
1012    impl ResolveDns for NeverDns {
1013        fn resolve_dns<'a>(&'a self, _name: &'a str) -> DnsFuture<'a> {
1014            DnsFuture::new(async {
1015                Never::new().await;
1016                unreachable!()
1017            })
1018        }
1019    }
1020}