aws_config/imds/
client.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6//! Raw IMDSv2 Client
7//!
8//! Client for direct access to IMDSv2.
9
10use crate::imds::client::error::{BuildError, ImdsError, InnerImdsError, InvalidEndpointMode};
11use crate::imds::client::token::TokenRuntimePlugin;
12use crate::provider_config::ProviderConfig;
13use crate::PKG_VERSION;
14use aws_runtime::user_agent::{ApiMetadata, AwsUserAgent, UserAgentInterceptor};
15use aws_smithy_runtime::client::metrics::MetricsRuntimePlugin;
16use aws_smithy_runtime::client::orchestrator::operation::Operation;
17use aws_smithy_runtime::client::retries::strategy::StandardRetryStrategy;
18use aws_smithy_runtime_api::box_error::BoxError;
19use aws_smithy_runtime_api::client::auth::AuthSchemeOptionResolverParams;
20use aws_smithy_runtime_api::client::endpoint::{
21    EndpointFuture, EndpointResolverParams, ResolveEndpoint,
22};
23use aws_smithy_runtime_api::client::interceptors::context::InterceptorContext;
24use aws_smithy_runtime_api::client::orchestrator::{
25    HttpRequest, Metadata, OrchestratorError, SensitiveOutput,
26};
27use aws_smithy_runtime_api::client::result::ConnectorError;
28use aws_smithy_runtime_api::client::result::SdkError;
29use aws_smithy_runtime_api::client::retries::classifiers::{
30    ClassifyRetry, RetryAction, SharedRetryClassifier,
31};
32use aws_smithy_runtime_api::client::runtime_components::RuntimeComponentsBuilder;
33use aws_smithy_runtime_api::client::runtime_plugin::{RuntimePlugin, SharedRuntimePlugin};
34use aws_smithy_types::body::SdkBody;
35use aws_smithy_types::config_bag::{FrozenLayer, Layer};
36use aws_smithy_types::endpoint::Endpoint;
37use aws_smithy_types::retry::RetryConfig;
38use aws_smithy_types::timeout::TimeoutConfig;
39use aws_types::os_shim_internal::Env;
40use http::Uri;
41use std::borrow::Cow;
42use std::error::Error as _;
43use std::fmt;
44use std::str::FromStr;
45use std::sync::Arc;
46use std::time::Duration;
47
48pub mod error;
49mod token;
50
51// 6 hours
52const DEFAULT_TOKEN_TTL: Duration = Duration::from_secs(21_600);
53const DEFAULT_ATTEMPTS: u32 = 4;
54const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(1);
55const DEFAULT_READ_TIMEOUT: Duration = Duration::from_secs(1);
56const DEFAULT_OPERATION_TIMEOUT: Duration = Duration::from_secs(30);
57const DEFAULT_OPERATION_ATTEMPT_TIMEOUT: Duration = Duration::from_secs(10);
58
59fn user_agent() -> AwsUserAgent {
60    AwsUserAgent::new_from_environment(Env::real(), ApiMetadata::new("imds", PKG_VERSION))
61}
62
63/// IMDSv2 Client
64///
65/// Client for IMDSv2. This client handles fetching tokens, retrying on failure, and token
66/// caching according to the specified token TTL.
67///
68/// _Note: This client ONLY supports IMDSv2. It will not fallback to IMDSv1. See
69/// [transitioning to IMDSv2](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/configuring-instance-metadata-service.html#instance-metadata-transition-to-version-2)
70/// for more information._
71///
72/// **Note**: When running in a Docker container, all network requests will incur an additional hop. When combined with the default IMDS hop limit of 1, this will cause requests to IMDS to timeout! To fix this issue, you'll need to set the following instance metadata settings :
73/// ```txt
74/// amazonec2-metadata-token=required
75/// amazonec2-metadata-token-response-hop-limit=2
76/// ```
77///
78/// On an instance that is already running, these can be set with [ModifyInstanceMetadataOptions](https://docs.aws.amazon.com/AWSEC2/latest/APIReference/API_ModifyInstanceMetadataOptions.html). On a new instance, these can be set with the `MetadataOptions` field on [RunInstances](https://docs.aws.amazon.com/AWSEC2/latest/APIReference/API_RunInstances.html).
79///
80/// For more information about IMDSv2 vs. IMDSv1 see [this guide](https://docs.aws.amazon.com/AWSEC2/latest/WindowsGuide/configuring-instance-metadata-service.html)
81///
82/// # Client Configuration
83/// The IMDS client can load configuration explicitly, via environment variables, or via
84/// `~/.aws/config`. It will first attempt to resolve an endpoint override. If no endpoint
85/// override exists, it will attempt to resolve an [`EndpointMode`]. If no
86/// [`EndpointMode`] override exists, it will fallback to [`IpV4`](EndpointMode::IpV4). An exhaustive
87/// list is below:
88///
89/// ## Endpoint configuration list
90/// 1. Explicit configuration of `Endpoint` via the [builder](Builder):
91/// ```no_run
92/// use aws_config::imds::client::Client;
93/// # async fn docs() {
94/// let client = Client::builder()
95///   .endpoint("http://customimds:456/").expect("valid URI")
96///   .build();
97/// # }
98/// ```
99///
100/// 2. The `AWS_EC2_METADATA_SERVICE_ENDPOINT` environment variable. Note: If this environment variable
101///    is set, it MUST contain a valid URI or client construction will fail.
102///
103/// 3. The `ec2_metadata_service_endpoint` field in `~/.aws/config`:
104/// ```ini
105/// [default]
106/// # ... other configuration
107/// ec2_metadata_service_endpoint = http://my-custom-endpoint:444
108/// ```
109///
110/// 4. An explicitly set endpoint mode:
111/// ```no_run
112/// use aws_config::imds::client::{Client, EndpointMode};
113/// # async fn docs() {
114/// let client = Client::builder().endpoint_mode(EndpointMode::IpV6).build();
115/// # }
116/// ```
117///
118/// 5. An [endpoint mode](EndpointMode) loaded from the `AWS_EC2_METADATA_SERVICE_ENDPOINT_MODE` environment
119///    variable. Valid values: `IPv4`, `IPv6`
120///
121/// 6. An [endpoint mode](EndpointMode) loaded from the `ec2_metadata_service_endpoint_mode` field in
122///    `~/.aws/config`:
123/// ```ini
124/// [default]
125/// # ... other configuration
126/// ec2_metadata_service_endpoint_mode = IPv4
127/// ```
128///
129/// 7. The default value of `http://169.254.169.254` will be used.
130///
131#[derive(Clone, Debug)]
132pub struct Client {
133    operation: Operation<String, SensitiveString, InnerImdsError>,
134}
135
136impl Client {
137    /// IMDS client builder
138    pub fn builder() -> Builder {
139        Builder::default()
140    }
141
142    /// Retrieve information from IMDS
143    ///
144    /// This method will handle loading and caching a session token, combining the `path` with the
145    /// configured IMDS endpoint, and retrying potential errors.
146    ///
147    /// For more information about IMDSv2 methods and functionality, see
148    /// [Instance metadata and user data](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/ec2-instance-metadata.html)
149    ///
150    /// # Examples
151    ///
152    /// ```no_run
153    /// use aws_config::imds::client::Client;
154    /// # async fn docs() {
155    /// let client = Client::builder().build();
156    /// let ami_id = client
157    ///   .get("/latest/meta-data/ami-id")
158    ///   .await
159    ///   .expect("failure communicating with IMDS");
160    /// # }
161    /// ```
162    pub async fn get(&self, path: impl Into<String>) -> Result<SensitiveString, ImdsError> {
163        self.operation
164            .invoke(path.into())
165            .await
166            .map_err(|err| match err {
167                SdkError::ConstructionFailure(_) if err.source().is_some() => {
168                    match err.into_source().map(|e| e.downcast::<ImdsError>()) {
169                        Ok(Ok(token_failure)) => *token_failure,
170                        Ok(Err(err)) => ImdsError::unexpected(err),
171                        Err(err) => ImdsError::unexpected(err),
172                    }
173                }
174                SdkError::ConstructionFailure(_) => ImdsError::unexpected(err),
175                SdkError::ServiceError(context) => match context.err() {
176                    InnerImdsError::InvalidUtf8 => {
177                        ImdsError::unexpected("IMDS returned invalid UTF-8")
178                    }
179                    InnerImdsError::BadStatus => ImdsError::error_response(context.into_raw()),
180                },
181                // If the error source is an ImdsError, then we need to directly return that source.
182                // That way, the IMDS token provider's errors can become the top-level ImdsError.
183                // There is a unit test that checks the correct error is being extracted.
184                err @ SdkError::DispatchFailure(_) => match err.into_source() {
185                    Ok(source) => match source.downcast::<ConnectorError>() {
186                        Ok(source) => match source.into_source().downcast::<ImdsError>() {
187                            Ok(source) => *source,
188                            Err(err) => ImdsError::unexpected(err),
189                        },
190                        Err(err) => ImdsError::unexpected(err),
191                    },
192                    Err(err) => ImdsError::unexpected(err),
193                },
194                SdkError::TimeoutError(_) | SdkError::ResponseError(_) => ImdsError::io_error(err),
195                _ => ImdsError::unexpected(err),
196            })
197    }
198}
199
200/// New-type around `String` that doesn't emit the string value in the `Debug` impl.
201#[derive(Clone)]
202pub struct SensitiveString(String);
203
204impl fmt::Debug for SensitiveString {
205    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
206        f.debug_tuple("SensitiveString")
207            .field(&"** redacted **")
208            .finish()
209    }
210}
211
212impl AsRef<str> for SensitiveString {
213    fn as_ref(&self) -> &str {
214        &self.0
215    }
216}
217
218impl From<String> for SensitiveString {
219    fn from(value: String) -> Self {
220        Self(value)
221    }
222}
223
224impl From<SensitiveString> for String {
225    fn from(value: SensitiveString) -> Self {
226        value.0
227    }
228}
229
230/// Runtime plugin that is used by both the IMDS client and the inner client that resolves
231/// the IMDS token and attaches it to requests. This runtime plugin marks the responses as
232/// sensitive, configures user agent headers, and sets up retries and timeouts.
233#[derive(Debug)]
234struct ImdsCommonRuntimePlugin {
235    config: FrozenLayer,
236    components: RuntimeComponentsBuilder,
237}
238
239impl ImdsCommonRuntimePlugin {
240    fn new(
241        config: &ProviderConfig,
242        endpoint_resolver: ImdsEndpointResolver,
243        retry_config: RetryConfig,
244        retry_classifier: SharedRetryClassifier,
245        timeout_config: TimeoutConfig,
246    ) -> Self {
247        let mut layer = Layer::new("ImdsCommonRuntimePlugin");
248        layer.store_put(AuthSchemeOptionResolverParams::new(()));
249        layer.store_put(EndpointResolverParams::new(()));
250        layer.store_put(SensitiveOutput);
251        layer.store_put(retry_config);
252        layer.store_put(timeout_config);
253        layer.store_put(user_agent());
254
255        Self {
256            config: layer.freeze(),
257            components: RuntimeComponentsBuilder::new("ImdsCommonRuntimePlugin")
258                .with_http_client(config.http_client())
259                .with_endpoint_resolver(Some(endpoint_resolver))
260                .with_interceptor(UserAgentInterceptor::new())
261                .with_retry_classifier(retry_classifier)
262                .with_retry_strategy(Some(StandardRetryStrategy::new()))
263                .with_time_source(Some(config.time_source()))
264                .with_sleep_impl(config.sleep_impl()),
265        }
266    }
267}
268
269impl RuntimePlugin for ImdsCommonRuntimePlugin {
270    fn config(&self) -> Option<FrozenLayer> {
271        Some(self.config.clone())
272    }
273
274    fn runtime_components(
275        &self,
276        _current_components: &RuntimeComponentsBuilder,
277    ) -> Cow<'_, RuntimeComponentsBuilder> {
278        Cow::Borrowed(&self.components)
279    }
280}
281
282/// IMDSv2 Endpoint Mode
283///
284/// IMDS can be accessed in two ways:
285/// 1. Via the IpV4 endpoint: `http://169.254.169.254`
286/// 2. Via the Ipv6 endpoint: `http://[fd00:ec2::254]`
287#[derive(Debug, Clone)]
288#[non_exhaustive]
289pub enum EndpointMode {
290    /// IpV4 mode: `http://169.254.169.254`
291    ///
292    /// This mode is the default unless otherwise specified.
293    IpV4,
294    /// IpV6 mode: `http://[fd00:ec2::254]`
295    IpV6,
296}
297
298impl FromStr for EndpointMode {
299    type Err = InvalidEndpointMode;
300
301    fn from_str(value: &str) -> Result<Self, Self::Err> {
302        match value {
303            _ if value.eq_ignore_ascii_case("ipv4") => Ok(EndpointMode::IpV4),
304            _ if value.eq_ignore_ascii_case("ipv6") => Ok(EndpointMode::IpV6),
305            other => Err(InvalidEndpointMode::new(other.to_owned())),
306        }
307    }
308}
309
310impl EndpointMode {
311    /// IMDS URI for this endpoint mode
312    fn endpoint(&self) -> Uri {
313        match self {
314            EndpointMode::IpV4 => Uri::from_static("http://169.254.169.254"),
315            EndpointMode::IpV6 => Uri::from_static("http://[fd00:ec2::254]"),
316        }
317    }
318}
319
320/// IMDSv2 Client Builder
321#[derive(Default, Debug, Clone)]
322pub struct Builder {
323    max_attempts: Option<u32>,
324    endpoint: Option<EndpointSource>,
325    mode_override: Option<EndpointMode>,
326    token_ttl: Option<Duration>,
327    connect_timeout: Option<Duration>,
328    read_timeout: Option<Duration>,
329    operation_timeout: Option<Duration>,
330    operation_attempt_timeout: Option<Duration>,
331    config: Option<ProviderConfig>,
332    retry_classifier: Option<SharedRetryClassifier>,
333}
334
335impl Builder {
336    /// Override the number of retries for fetching tokens & metadata
337    ///
338    /// By default, 4 attempts will be made.
339    pub fn max_attempts(mut self, max_attempts: u32) -> Self {
340        self.max_attempts = Some(max_attempts);
341        self
342    }
343
344    /// Configure generic options of the [`Client`]
345    ///
346    /// # Examples
347    /// ```no_run
348    /// # async fn test() {
349    /// use aws_config::imds::Client;
350    /// use aws_config::provider_config::ProviderConfig;
351    ///
352    /// let provider = Client::builder()
353    ///     .configure(&ProviderConfig::with_default_region().await)
354    ///     .build();
355    /// # }
356    /// ```
357    pub fn configure(mut self, provider_config: &ProviderConfig) -> Self {
358        self.config = Some(provider_config.clone());
359        self
360    }
361
362    /// Override the endpoint for the [`Client`]
363    ///
364    /// By default, the client will resolve an endpoint from the environment, AWS config, and endpoint mode.
365    ///
366    /// See [`Client`] for more information.
367    pub fn endpoint(mut self, endpoint: impl AsRef<str>) -> Result<Self, BoxError> {
368        let uri: Uri = endpoint.as_ref().parse()?;
369        self.endpoint = Some(EndpointSource::Explicit(uri));
370        Ok(self)
371    }
372
373    /// Override the endpoint mode for [`Client`]
374    ///
375    /// * When set to [`IpV4`](EndpointMode::IpV4), the endpoint will be `http://169.254.169.254`.
376    /// * When set to [`IpV6`](EndpointMode::IpV6), the endpoint will be `http://[fd00:ec2::254]`.
377    pub fn endpoint_mode(mut self, mode: EndpointMode) -> Self {
378        self.mode_override = Some(mode);
379        self
380    }
381
382    /// Override the time-to-live for the session token
383    ///
384    /// Requests to IMDS utilize a session token for authentication. By default, session tokens last
385    /// for 6 hours. When the TTL for the token expires, a new token must be retrieved from the
386    /// metadata service.
387    pub fn token_ttl(mut self, ttl: Duration) -> Self {
388        self.token_ttl = Some(ttl);
389        self
390    }
391
392    /// Override the connect timeout for IMDS
393    ///
394    /// This value defaults to 1 second
395    pub fn connect_timeout(mut self, timeout: Duration) -> Self {
396        self.connect_timeout = Some(timeout);
397        self
398    }
399
400    /// Override the read timeout for IMDS
401    ///
402    /// This value defaults to 1 second
403    pub fn read_timeout(mut self, timeout: Duration) -> Self {
404        self.read_timeout = Some(timeout);
405        self
406    }
407
408    /// Override the operation timeout for IMDS
409    ///
410    /// This value defaults to 1 second
411    pub fn operation_timeout(mut self, timeout: Duration) -> Self {
412        self.operation_timeout = Some(timeout);
413        self
414    }
415
416    /// Override the operation attempt timeout for IMDS
417    ///
418    /// This value defaults to 1 second
419    pub fn operation_attempt_timeout(mut self, timeout: Duration) -> Self {
420        self.operation_attempt_timeout = Some(timeout);
421        self
422    }
423
424    /// Override the retry classifier for IMDS
425    ///
426    /// This defaults to only retrying on server errors and 401s. The [ImdsResponseRetryClassifier] in this
427    /// module offers some configuration options and can be wrapped by[SharedRetryClassifier::new()] for use
428    /// here or you can create your own fully customized [SharedRetryClassifier].
429    pub fn retry_classifier(mut self, retry_classifier: SharedRetryClassifier) -> Self {
430        self.retry_classifier = Some(retry_classifier);
431        self
432    }
433
434    /* TODO(https://github.com/awslabs/aws-sdk-rust/issues/339): Support customizing the port explicitly */
435    /*
436    pub fn port(mut self, port: u32) -> Self {
437        self.port_override = Some(port);
438        self
439    }*/
440
441    /// Build an IMDSv2 Client
442    pub fn build(self) -> Client {
443        let config = self.config.unwrap_or_default();
444        let timeout_config = TimeoutConfig::builder()
445            .connect_timeout(self.connect_timeout.unwrap_or(DEFAULT_CONNECT_TIMEOUT))
446            .read_timeout(self.read_timeout.unwrap_or(DEFAULT_READ_TIMEOUT))
447            .operation_attempt_timeout(
448                self.operation_attempt_timeout
449                    .unwrap_or(DEFAULT_OPERATION_ATTEMPT_TIMEOUT),
450            )
451            .operation_timeout(self.operation_timeout.unwrap_or(DEFAULT_OPERATION_TIMEOUT))
452            .build();
453        let endpoint_source = self
454            .endpoint
455            .unwrap_or_else(|| EndpointSource::Env(config.clone()));
456        let endpoint_resolver = ImdsEndpointResolver {
457            endpoint_source: Arc::new(endpoint_source),
458            mode_override: self.mode_override,
459        };
460        let retry_config = RetryConfig::standard()
461            .with_max_attempts(self.max_attempts.unwrap_or(DEFAULT_ATTEMPTS));
462        let retry_classifier = self.retry_classifier.unwrap_or(SharedRetryClassifier::new(
463            ImdsResponseRetryClassifier::default(),
464        ));
465        let common_plugin = SharedRuntimePlugin::new(ImdsCommonRuntimePlugin::new(
466            &config,
467            endpoint_resolver,
468            retry_config,
469            retry_classifier,
470            timeout_config,
471        ));
472        let operation = Operation::builder()
473            .service_name("imds")
474            .operation_name("get")
475            .runtime_plugin(common_plugin.clone())
476            .runtime_plugin(TokenRuntimePlugin::new(
477                common_plugin,
478                self.token_ttl.unwrap_or(DEFAULT_TOKEN_TTL),
479            ))
480            .runtime_plugin(
481                MetricsRuntimePlugin::builder()
482                    .with_scope("aws_config::imds_credentials")
483                    .with_time_source(config.time_source())
484                    .with_metadata(Metadata::new("get_credentials", "imds"))
485                    .build()
486                    .expect("All required fields have been set"),
487            )
488            .with_connection_poisoning()
489            .serializer(|path| {
490                Ok(HttpRequest::try_from(
491                    http::Request::builder()
492                        .uri(path)
493                        .body(SdkBody::empty())
494                        .expect("valid request"),
495                )
496                .unwrap())
497            })
498            .deserializer(|response| {
499                if response.status().is_success() {
500                    std::str::from_utf8(response.body().bytes().expect("non-streaming response"))
501                        .map(|data| SensitiveString::from(data.to_string()))
502                        .map_err(|_| OrchestratorError::operation(InnerImdsError::InvalidUtf8))
503                } else {
504                    Err(OrchestratorError::operation(InnerImdsError::BadStatus))
505                }
506            })
507            .build();
508        Client { operation }
509    }
510}
511
512mod env {
513    pub(super) const ENDPOINT: &str = "AWS_EC2_METADATA_SERVICE_ENDPOINT";
514    pub(super) const ENDPOINT_MODE: &str = "AWS_EC2_METADATA_SERVICE_ENDPOINT_MODE";
515}
516
517mod profile_keys {
518    pub(super) const ENDPOINT: &str = "ec2_metadata_service_endpoint";
519    pub(super) const ENDPOINT_MODE: &str = "ec2_metadata_service_endpoint_mode";
520}
521
522/// Endpoint Configuration Abstraction
523#[derive(Debug, Clone)]
524enum EndpointSource {
525    Explicit(Uri),
526    Env(ProviderConfig),
527}
528
529impl EndpointSource {
530    async fn endpoint(&self, mode_override: Option<EndpointMode>) -> Result<Uri, BuildError> {
531        match self {
532            EndpointSource::Explicit(uri) => {
533                if mode_override.is_some() {
534                    tracing::warn!(endpoint = ?uri, mode = ?mode_override,
535                        "Endpoint mode override was set in combination with an explicit endpoint. \
536                        The mode override will be ignored.")
537                }
538                Ok(uri.clone())
539            }
540            EndpointSource::Env(conf) => {
541                let env = conf.env();
542                // load an endpoint override from the environment
543                let profile = conf.profile().await;
544                let uri_override = if let Ok(uri) = env.get(env::ENDPOINT) {
545                    Some(Cow::Owned(uri))
546                } else {
547                    profile
548                        .and_then(|profile| profile.get(profile_keys::ENDPOINT))
549                        .map(Cow::Borrowed)
550                };
551                if let Some(uri) = uri_override {
552                    return Uri::try_from(uri.as_ref()).map_err(BuildError::invalid_endpoint_uri);
553                }
554
555                // if not, load a endpoint mode from the environment
556                let mode = if let Some(mode) = mode_override {
557                    mode
558                } else if let Ok(mode) = env.get(env::ENDPOINT_MODE) {
559                    mode.parse::<EndpointMode>()
560                        .map_err(BuildError::invalid_endpoint_mode)?
561                } else if let Some(mode) = profile.and_then(|p| p.get(profile_keys::ENDPOINT_MODE))
562                {
563                    mode.parse::<EndpointMode>()
564                        .map_err(BuildError::invalid_endpoint_mode)?
565                } else {
566                    EndpointMode::IpV4
567                };
568
569                Ok(mode.endpoint())
570            }
571        }
572    }
573}
574
575#[derive(Clone, Debug)]
576struct ImdsEndpointResolver {
577    endpoint_source: Arc<EndpointSource>,
578    mode_override: Option<EndpointMode>,
579}
580
581impl ResolveEndpoint for ImdsEndpointResolver {
582    fn resolve_endpoint<'a>(&'a self, _: &'a EndpointResolverParams) -> EndpointFuture<'a> {
583        EndpointFuture::new(async move {
584            self.endpoint_source
585                .endpoint(self.mode_override.clone())
586                .await
587                .map(|uri| Endpoint::builder().url(uri.to_string()).build())
588                .map_err(|err| err.into())
589        })
590    }
591}
592
593/// IMDS Response Retry Classifier
594///
595/// Possible status codes:
596/// - 200 (OK)
597/// - 400 (Missing or invalid parameters) **Not Retryable**
598/// - 401 (Unauthorized, expired token) **Retryable**
599/// - 403 (IMDS disabled): **Not Retryable**
600/// - 404 (Not found): **Not Retryable**
601/// - >=500 (server error): **Retryable**
602/// - Timeouts: Not retried by default, but this is configurable via [Self::with_retry_connect_timeouts()]
603#[derive(Clone, Debug, Default)]
604#[non_exhaustive]
605pub struct ImdsResponseRetryClassifier {
606    retry_connect_timeouts: bool,
607}
608
609impl ImdsResponseRetryClassifier {
610    /// Indicate whether the IMDS client should retry on connection timeouts
611    pub fn with_retry_connect_timeouts(mut self, retry_connect_timeouts: bool) -> Self {
612        self.retry_connect_timeouts = retry_connect_timeouts;
613        self
614    }
615}
616
617impl ClassifyRetry for ImdsResponseRetryClassifier {
618    fn name(&self) -> &'static str {
619        "ImdsResponseRetryClassifier"
620    }
621
622    fn classify_retry(&self, ctx: &InterceptorContext) -> RetryAction {
623        if let Some(response) = ctx.response() {
624            let status = response.status();
625            match status {
626                _ if status.is_server_error() => RetryAction::server_error(),
627                // 401 indicates that the token has expired, this is retryable
628                _ if status.as_u16() == 401 => RetryAction::server_error(),
629                // This catch-all includes successful responses that fail to parse. These should not be retried.
630                _ => RetryAction::NoActionIndicated,
631            }
632        } else if self.retry_connect_timeouts {
633            RetryAction::server_error()
634        } else {
635            // This is the default behavior.
636            // Don't retry timeouts for IMDS, or else it will take ~30 seconds for the default
637            // credentials provider chain to fail to provide credentials.
638            // Also don't retry non-responses.
639            RetryAction::NoActionIndicated
640        }
641    }
642}
643
644#[cfg(test)]
645pub(crate) mod test {
646    use crate::imds::client::{Client, EndpointMode, ImdsResponseRetryClassifier};
647    use crate::provider_config::ProviderConfig;
648    use aws_smithy_async::rt::sleep::TokioSleep;
649    use aws_smithy_async::test_util::{instant_time_and_sleep, InstantSleep};
650    use aws_smithy_http_client::test_util::{capture_request, ReplayEvent, StaticReplayClient};
651    use aws_smithy_runtime::test_util::capture_test_logs::capture_test_logs;
652    use aws_smithy_runtime_api::client::interceptors::context::{
653        Input, InterceptorContext, Output,
654    };
655    use aws_smithy_runtime_api::client::orchestrator::{
656        HttpRequest, HttpResponse, OrchestratorError,
657    };
658    use aws_smithy_runtime_api::client::result::ConnectorError;
659    use aws_smithy_runtime_api::client::retries::classifiers::{
660        ClassifyRetry, RetryAction, SharedRetryClassifier,
661    };
662    use aws_smithy_types::body::SdkBody;
663    use aws_smithy_types::error::display::DisplayErrorContext;
664    use aws_types::os_shim_internal::{Env, Fs};
665    use http::header::USER_AGENT;
666    use http::Uri;
667    use serde::Deserialize;
668    use std::collections::HashMap;
669    use std::error::Error;
670    use std::io;
671    use std::time::SystemTime;
672    use std::time::{Duration, UNIX_EPOCH};
673    use tracing_test::traced_test;
674
675    macro_rules! assert_full_error_contains {
676        ($err:expr, $contains:expr) => {
677            let err = $err;
678            let message = format!(
679                "{}",
680                aws_smithy_types::error::display::DisplayErrorContext(&err)
681            );
682            assert!(
683                message.contains($contains),
684                "Error message '{message}' didn't contain text '{}'",
685                $contains
686            );
687        };
688    }
689
690    const TOKEN_A: &str = "AQAEAFTNrA4eEGx0AQgJ1arIq_Cc-t4tWt3fB0Hd8RKhXlKc5ccvhg==";
691    const TOKEN_B: &str = "alternatetoken==";
692
693    pub(crate) fn token_request(base: &str, ttl: u32) -> HttpRequest {
694        http::Request::builder()
695            .uri(format!("{}/latest/api/token", base))
696            .header("x-aws-ec2-metadata-token-ttl-seconds", ttl)
697            .method("PUT")
698            .body(SdkBody::empty())
699            .unwrap()
700            .try_into()
701            .unwrap()
702    }
703
704    pub(crate) fn token_response(ttl: u32, token: &'static str) -> HttpResponse {
705        HttpResponse::try_from(
706            http::Response::builder()
707                .status(200)
708                .header("X-aws-ec2-metadata-token-ttl-seconds", ttl)
709                .body(SdkBody::from(token))
710                .unwrap(),
711        )
712        .unwrap()
713    }
714
715    pub(crate) fn imds_request(path: &'static str, token: &str) -> HttpRequest {
716        http::Request::builder()
717            .uri(Uri::from_static(path))
718            .method("GET")
719            .header("x-aws-ec2-metadata-token", token)
720            .body(SdkBody::empty())
721            .unwrap()
722            .try_into()
723            .unwrap()
724    }
725
726    pub(crate) fn imds_response(body: &'static str) -> HttpResponse {
727        HttpResponse::try_from(
728            http::Response::builder()
729                .status(200)
730                .body(SdkBody::from(body))
731                .unwrap(),
732        )
733        .unwrap()
734    }
735
736    pub(crate) fn make_imds_client(http_client: &StaticReplayClient) -> super::Client {
737        tokio::time::pause();
738        super::Client::builder()
739            .configure(
740                &ProviderConfig::no_configuration()
741                    .with_sleep_impl(InstantSleep::unlogged())
742                    .with_http_client(http_client.clone()),
743            )
744            .build()
745    }
746
747    fn mock_imds_client(events: Vec<ReplayEvent>) -> (Client, StaticReplayClient) {
748        let http_client = StaticReplayClient::new(events);
749        let client = make_imds_client(&http_client);
750        (client, http_client)
751    }
752
753    #[tokio::test]
754    async fn client_caches_token() {
755        let (client, http_client) = mock_imds_client(vec![
756            ReplayEvent::new(
757                token_request("http://169.254.169.254", 21600),
758                token_response(21600, TOKEN_A),
759            ),
760            ReplayEvent::new(
761                imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
762                imds_response(r#"test-imds-output"#),
763            ),
764            ReplayEvent::new(
765                imds_request("http://169.254.169.254/latest/metadata2", TOKEN_A),
766                imds_response("output2"),
767            ),
768        ]);
769        // load once
770        let metadata = client.get("/latest/metadata").await.expect("failed");
771        assert_eq!("test-imds-output", metadata.as_ref());
772        // load again: the cached token should be used
773        let metadata = client.get("/latest/metadata2").await.expect("failed");
774        assert_eq!("output2", metadata.as_ref());
775        http_client.assert_requests_match(&[]);
776    }
777
778    #[tokio::test]
779    async fn token_can_expire() {
780        let (_, http_client) = mock_imds_client(vec![
781            ReplayEvent::new(
782                token_request("http://[fd00:ec2::254]", 600),
783                token_response(600, TOKEN_A),
784            ),
785            ReplayEvent::new(
786                imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_A),
787                imds_response(r#"test-imds-output1"#),
788            ),
789            ReplayEvent::new(
790                token_request("http://[fd00:ec2::254]", 600),
791                token_response(600, TOKEN_B),
792            ),
793            ReplayEvent::new(
794                imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_B),
795                imds_response(r#"test-imds-output2"#),
796            ),
797        ]);
798        let (time_source, sleep) = instant_time_and_sleep(UNIX_EPOCH);
799        let client = super::Client::builder()
800            .configure(
801                &ProviderConfig::no_configuration()
802                    .with_http_client(http_client.clone())
803                    .with_time_source(time_source.clone())
804                    .with_sleep_impl(sleep),
805            )
806            .endpoint_mode(EndpointMode::IpV6)
807            .token_ttl(Duration::from_secs(600))
808            .build();
809
810        let resp1 = client.get("/latest/metadata").await.expect("success");
811        // now the cached credential has expired
812        time_source.advance(Duration::from_secs(600));
813        let resp2 = client.get("/latest/metadata").await.expect("success");
814        http_client.assert_requests_match(&[]);
815        assert_eq!("test-imds-output1", resp1.as_ref());
816        assert_eq!("test-imds-output2", resp2.as_ref());
817    }
818
819    /// Tokens are refreshed up to 120 seconds early to avoid using an expired token.
820    #[tokio::test]
821    async fn token_refresh_buffer() {
822        let _logs = capture_test_logs();
823        let (_, http_client) = mock_imds_client(vec![
824            ReplayEvent::new(
825                token_request("http://[fd00:ec2::254]", 600),
826                token_response(600, TOKEN_A),
827            ),
828            // t = 0
829            ReplayEvent::new(
830                imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_A),
831                imds_response(r#"test-imds-output1"#),
832            ),
833            // t = 400 (no refresh)
834            ReplayEvent::new(
835                imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_A),
836                imds_response(r#"test-imds-output2"#),
837            ),
838            // t = 550 (within buffer)
839            ReplayEvent::new(
840                token_request("http://[fd00:ec2::254]", 600),
841                token_response(600, TOKEN_B),
842            ),
843            ReplayEvent::new(
844                imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_B),
845                imds_response(r#"test-imds-output3"#),
846            ),
847        ]);
848        let (time_source, sleep) = instant_time_and_sleep(UNIX_EPOCH);
849        let client = super::Client::builder()
850            .configure(
851                &ProviderConfig::no_configuration()
852                    .with_sleep_impl(sleep)
853                    .with_http_client(http_client.clone())
854                    .with_time_source(time_source.clone()),
855            )
856            .endpoint_mode(EndpointMode::IpV6)
857            .token_ttl(Duration::from_secs(600))
858            .build();
859
860        tracing::info!("resp1 -----------------------------------------------------------");
861        let resp1 = client.get("/latest/metadata").await.expect("success");
862        // now the cached credential has expired
863        time_source.advance(Duration::from_secs(400));
864        tracing::info!("resp2 -----------------------------------------------------------");
865        let resp2 = client.get("/latest/metadata").await.expect("success");
866        time_source.advance(Duration::from_secs(150));
867        tracing::info!("resp3 -----------------------------------------------------------");
868        let resp3 = client.get("/latest/metadata").await.expect("success");
869        http_client.assert_requests_match(&[]);
870        assert_eq!("test-imds-output1", resp1.as_ref());
871        assert_eq!("test-imds-output2", resp2.as_ref());
872        assert_eq!("test-imds-output3", resp3.as_ref());
873    }
874
875    /// 500 error during the GET should be retried
876    #[tokio::test]
877    #[traced_test]
878    async fn retry_500() {
879        let (client, http_client) = mock_imds_client(vec![
880            ReplayEvent::new(
881                token_request("http://169.254.169.254", 21600),
882                token_response(21600, TOKEN_A),
883            ),
884            ReplayEvent::new(
885                imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
886                http::Response::builder()
887                    .status(500)
888                    .body(SdkBody::empty())
889                    .unwrap(),
890            ),
891            ReplayEvent::new(
892                imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
893                imds_response("ok"),
894            ),
895        ]);
896        assert_eq!(
897            "ok",
898            client
899                .get("/latest/metadata")
900                .await
901                .expect("success")
902                .as_ref()
903        );
904        http_client.assert_requests_match(&[]);
905
906        // all requests should have a user agent header
907        for request in http_client.actual_requests() {
908            assert!(request.headers().get(USER_AGENT).is_some());
909        }
910    }
911
912    /// 500 error during token acquisition should be retried
913    #[tokio::test]
914    #[traced_test]
915    async fn retry_token_failure() {
916        let (client, http_client) = mock_imds_client(vec![
917            ReplayEvent::new(
918                token_request("http://169.254.169.254", 21600),
919                http::Response::builder()
920                    .status(500)
921                    .body(SdkBody::empty())
922                    .unwrap(),
923            ),
924            ReplayEvent::new(
925                token_request("http://169.254.169.254", 21600),
926                token_response(21600, TOKEN_A),
927            ),
928            ReplayEvent::new(
929                imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
930                imds_response("ok"),
931            ),
932        ]);
933        assert_eq!(
934            "ok",
935            client
936                .get("/latest/metadata")
937                .await
938                .expect("success")
939                .as_ref()
940        );
941        http_client.assert_requests_match(&[]);
942    }
943
944    /// 401 error during metadata retrieval must be retried
945    #[tokio::test]
946    #[traced_test]
947    async fn retry_metadata_401() {
948        let (client, http_client) = mock_imds_client(vec![
949            ReplayEvent::new(
950                token_request("http://169.254.169.254", 21600),
951                token_response(0, TOKEN_A),
952            ),
953            ReplayEvent::new(
954                imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
955                http::Response::builder()
956                    .status(401)
957                    .body(SdkBody::empty())
958                    .unwrap(),
959            ),
960            ReplayEvent::new(
961                token_request("http://169.254.169.254", 21600),
962                token_response(21600, TOKEN_B),
963            ),
964            ReplayEvent::new(
965                imds_request("http://169.254.169.254/latest/metadata", TOKEN_B),
966                imds_response("ok"),
967            ),
968        ]);
969        assert_eq!(
970            "ok",
971            client
972                .get("/latest/metadata")
973                .await
974                .expect("success")
975                .as_ref()
976        );
977        http_client.assert_requests_match(&[]);
978    }
979
980    /// 403 responses from IMDS during token acquisition MUST NOT be retried
981    #[tokio::test]
982    #[traced_test]
983    async fn no_403_retry() {
984        let (client, http_client) = mock_imds_client(vec![ReplayEvent::new(
985            token_request("http://169.254.169.254", 21600),
986            http::Response::builder()
987                .status(403)
988                .body(SdkBody::empty())
989                .unwrap(),
990        )]);
991        let err = client.get("/latest/metadata").await.expect_err("no token");
992        assert_full_error_contains!(err, "forbidden");
993        http_client.assert_requests_match(&[]);
994    }
995
996    /// The classifier should return `None` when classifying a successful response.
997    #[test]
998    fn successful_response_properly_classified() {
999        let mut ctx = InterceptorContext::new(Input::doesnt_matter());
1000        ctx.set_output_or_error(Ok(Output::doesnt_matter()));
1001        ctx.set_response(imds_response("").map(|_| SdkBody::empty()));
1002        let classifier = ImdsResponseRetryClassifier::default();
1003        assert_eq!(
1004            RetryAction::NoActionIndicated,
1005            classifier.classify_retry(&ctx)
1006        );
1007
1008        // Emulate a failure to parse the response body (using an io error since it's easy to construct in a test)
1009        let mut ctx = InterceptorContext::new(Input::doesnt_matter());
1010        ctx.set_output_or_error(Err(OrchestratorError::connector(ConnectorError::io(
1011            io::Error::new(io::ErrorKind::BrokenPipe, "fail to parse").into(),
1012        ))));
1013        assert_eq!(
1014            RetryAction::NoActionIndicated,
1015            classifier.classify_retry(&ctx)
1016        );
1017    }
1018
1019    /// User provided retry classifier works
1020    #[tokio::test]
1021    async fn user_provided_retry_classifier() {
1022        #[derive(Clone, Debug)]
1023        struct UserProvidedRetryClassifier;
1024
1025        impl ClassifyRetry for UserProvidedRetryClassifier {
1026            fn name(&self) -> &'static str {
1027                "UserProvidedRetryClassifier"
1028            }
1029
1030            // Don't retry anything
1031            fn classify_retry(&self, _ctx: &InterceptorContext) -> RetryAction {
1032                RetryAction::RetryForbidden
1033            }
1034        }
1035
1036        let events = vec![
1037            ReplayEvent::new(
1038                token_request("http://169.254.169.254", 21600),
1039                token_response(0, TOKEN_A),
1040            ),
1041            ReplayEvent::new(
1042                imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
1043                http::Response::builder()
1044                    .status(401)
1045                    .body(SdkBody::empty())
1046                    .unwrap(),
1047            ),
1048            ReplayEvent::new(
1049                token_request("http://169.254.169.254", 21600),
1050                token_response(21600, TOKEN_B),
1051            ),
1052            ReplayEvent::new(
1053                imds_request("http://169.254.169.254/latest/metadata", TOKEN_B),
1054                imds_response("ok"),
1055            ),
1056        ];
1057        let http_client = StaticReplayClient::new(events);
1058
1059        let imds_client = super::Client::builder()
1060            .configure(
1061                &ProviderConfig::no_configuration()
1062                    .with_sleep_impl(InstantSleep::unlogged())
1063                    .with_http_client(http_client.clone()),
1064            )
1065            .retry_classifier(SharedRetryClassifier::new(UserProvidedRetryClassifier))
1066            .build();
1067
1068        let res = imds_client
1069            .get("/latest/metadata")
1070            .await
1071            .expect_err("Client should error");
1072
1073        // Assert that the operation errored on the initial 401 and did not retry and get
1074        // the 200 (since the user provided retry classifier never retries)
1075        assert_full_error_contains!(res, "401");
1076    }
1077
1078    // since tokens are sent as headers, the tokens need to be valid header values
1079    #[tokio::test]
1080    async fn invalid_token() {
1081        let (client, http_client) = mock_imds_client(vec![ReplayEvent::new(
1082            token_request("http://169.254.169.254", 21600),
1083            token_response(21600, "invalid\nheader\nvalue\0"),
1084        )]);
1085        let err = client.get("/latest/metadata").await.expect_err("no token");
1086        assert_full_error_contains!(err, "invalid token");
1087        http_client.assert_requests_match(&[]);
1088    }
1089
1090    #[tokio::test]
1091    async fn non_utf8_response() {
1092        let (client, http_client) = mock_imds_client(vec![
1093            ReplayEvent::new(
1094                token_request("http://169.254.169.254", 21600),
1095                token_response(21600, TOKEN_A).map(SdkBody::from),
1096            ),
1097            ReplayEvent::new(
1098                imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
1099                http::Response::builder()
1100                    .status(200)
1101                    .body(SdkBody::from(vec![0xA0, 0xA1]))
1102                    .unwrap(),
1103            ),
1104        ]);
1105        let err = client.get("/latest/metadata").await.expect_err("no token");
1106        assert_full_error_contains!(err, "invalid UTF-8");
1107        http_client.assert_requests_match(&[]);
1108    }
1109
1110    // TODO(https://github.com/awslabs/aws-sdk-rust/issues/1117) This test is ignored on Windows because it uses Unix-style paths
1111    #[cfg_attr(windows, ignore)]
1112    /// Verify that the end-to-end real client has a 1-second connect timeout
1113    #[tokio::test]
1114    #[cfg(feature = "default-https-client")]
1115    async fn one_second_connect_timeout() {
1116        use crate::imds::client::ImdsError;
1117        let client = Client::builder()
1118            // 240.* can never be resolved
1119            .endpoint("http://240.0.0.0")
1120            .expect("valid uri")
1121            .build();
1122        let now = SystemTime::now();
1123        let resp = client
1124            .get("/latest/metadata")
1125            .await
1126            .expect_err("240.0.0.0 will never resolve");
1127        match resp {
1128            err @ ImdsError::FailedToLoadToken(_)
1129                if format!("{}", DisplayErrorContext(&err)).contains("timeout") => {} // ok,
1130            other => panic!(
1131                "wrong error, expected construction failure with TimedOutError inside: {}",
1132                DisplayErrorContext(&other)
1133            ),
1134        }
1135        let time_elapsed = now.elapsed().unwrap();
1136        assert!(
1137            time_elapsed > Duration::from_secs(1),
1138            "time_elapsed should be greater than 1s but was {:?}",
1139            time_elapsed
1140        );
1141        assert!(
1142            time_elapsed < Duration::from_secs(2),
1143            "time_elapsed should be less than 2s but was {:?}",
1144            time_elapsed
1145        );
1146    }
1147
1148    /// Retry classifier properly retries timeouts when configured to (meaning it takes ~30s to fail)
1149    #[tokio::test]
1150    async fn retry_connect_timeouts() {
1151        let http_client = StaticReplayClient::new(vec![]);
1152        let imds_client = super::Client::builder()
1153            .retry_classifier(SharedRetryClassifier::new(
1154                ImdsResponseRetryClassifier::default().with_retry_connect_timeouts(true),
1155            ))
1156            .configure(&ProviderConfig::no_configuration().with_http_client(http_client.clone()))
1157            .operation_timeout(Duration::from_secs(1))
1158            .endpoint("http://240.0.0.0")
1159            .expect("valid uri")
1160            .build();
1161
1162        let now = SystemTime::now();
1163        let _res = imds_client
1164            .get("/latest/metadata")
1165            .await
1166            .expect_err("240.0.0.0 will never resolve");
1167        let time_elapsed: Duration = now.elapsed().unwrap();
1168
1169        assert!(
1170            time_elapsed > Duration::from_secs(1),
1171            "time_elapsed should be greater than 1s but was {:?}",
1172            time_elapsed
1173        );
1174
1175        assert!(
1176            time_elapsed < Duration::from_secs(2),
1177            "time_elapsed should be less than 2s but was {:?}",
1178            time_elapsed
1179        );
1180    }
1181
1182    #[derive(Debug, Deserialize)]
1183    struct ImdsConfigTest {
1184        env: HashMap<String, String>,
1185        fs: HashMap<String, String>,
1186        endpoint_override: Option<String>,
1187        mode_override: Option<String>,
1188        result: Result<String, String>,
1189        docs: String,
1190    }
1191
1192    #[tokio::test]
1193    async fn endpoint_config_tests() -> Result<(), Box<dyn Error>> {
1194        let _logs = capture_test_logs();
1195
1196        let test_cases = std::fs::read_to_string("test-data/imds-config/imds-endpoint-tests.json")?;
1197        #[derive(Deserialize)]
1198        struct TestCases {
1199            tests: Vec<ImdsConfigTest>,
1200        }
1201
1202        let test_cases: TestCases = serde_json::from_str(&test_cases)?;
1203        let test_cases = test_cases.tests;
1204        for test in test_cases {
1205            check(test).await;
1206        }
1207        Ok(())
1208    }
1209
1210    async fn check(test_case: ImdsConfigTest) {
1211        let (http_client, watcher) = capture_request(None);
1212        let provider_config = ProviderConfig::no_configuration()
1213            .with_sleep_impl(TokioSleep::new())
1214            .with_env(Env::from(test_case.env))
1215            .with_fs(Fs::from_map(test_case.fs))
1216            .with_http_client(http_client);
1217        let mut imds_client = Client::builder().configure(&provider_config);
1218        if let Some(endpoint_override) = test_case.endpoint_override {
1219            imds_client = imds_client
1220                .endpoint(endpoint_override)
1221                .expect("invalid URI");
1222        }
1223
1224        if let Some(mode_override) = test_case.mode_override {
1225            imds_client = imds_client.endpoint_mode(mode_override.parse().unwrap());
1226        }
1227
1228        let imds_client = imds_client.build();
1229        match &test_case.result {
1230            Ok(uri) => {
1231                // this request will fail, we just want to capture the endpoint configuration
1232                let _ = imds_client.get("/hello").await;
1233                assert_eq!(&watcher.expect_request().uri().to_string(), uri);
1234            }
1235            Err(expected) => {
1236                let err = imds_client.get("/hello").await.expect_err("it should fail");
1237                let message = format!("{}", DisplayErrorContext(&err));
1238                assert!(
1239                    message.contains(expected),
1240                    "{}\nexpected error: {expected}\nactual error: {message}",
1241                    test_case.docs
1242                );
1243            }
1244        };
1245    }
1246}