aws_config/imds/
client.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6//! Raw IMDSv2 Client
7//!
8//! Client for direct access to IMDSv2.
9
10use crate::imds::client::error::{BuildError, ImdsError, InnerImdsError, InvalidEndpointMode};
11use crate::imds::client::token::TokenRuntimePlugin;
12use crate::provider_config::ProviderConfig;
13use crate::PKG_VERSION;
14use aws_runtime::user_agent::{ApiMetadata, AwsUserAgent, UserAgentInterceptor};
15use aws_smithy_runtime::client::orchestrator::operation::Operation;
16use aws_smithy_runtime::client::retries::strategy::StandardRetryStrategy;
17use aws_smithy_runtime_api::box_error::BoxError;
18use aws_smithy_runtime_api::client::auth::AuthSchemeOptionResolverParams;
19use aws_smithy_runtime_api::client::endpoint::{
20    EndpointFuture, EndpointResolverParams, ResolveEndpoint,
21};
22use aws_smithy_runtime_api::client::interceptors::context::InterceptorContext;
23use aws_smithy_runtime_api::client::orchestrator::{
24    HttpRequest, OrchestratorError, SensitiveOutput,
25};
26use aws_smithy_runtime_api::client::result::ConnectorError;
27use aws_smithy_runtime_api::client::result::SdkError;
28use aws_smithy_runtime_api::client::retries::classifiers::{
29    ClassifyRetry, RetryAction, SharedRetryClassifier,
30};
31use aws_smithy_runtime_api::client::runtime_components::RuntimeComponentsBuilder;
32use aws_smithy_runtime_api::client::runtime_plugin::{RuntimePlugin, SharedRuntimePlugin};
33use aws_smithy_types::body::SdkBody;
34use aws_smithy_types::config_bag::{FrozenLayer, Layer};
35use aws_smithy_types::endpoint::Endpoint;
36use aws_smithy_types::retry::RetryConfig;
37use aws_smithy_types::timeout::TimeoutConfig;
38use aws_types::os_shim_internal::Env;
39use http::Uri;
40use std::borrow::Cow;
41use std::error::Error as _;
42use std::fmt;
43use std::str::FromStr;
44use std::sync::Arc;
45use std::time::Duration;
46
47pub mod error;
48mod token;
49
50// 6 hours
51const DEFAULT_TOKEN_TTL: Duration = Duration::from_secs(21_600);
52const DEFAULT_ATTEMPTS: u32 = 4;
53const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(1);
54const DEFAULT_READ_TIMEOUT: Duration = Duration::from_secs(1);
55const DEFAULT_OPERATION_TIMEOUT: Duration = Duration::from_secs(30);
56const DEFAULT_OPERATION_ATTEMPT_TIMEOUT: Duration = Duration::from_secs(10);
57
58fn user_agent() -> AwsUserAgent {
59    AwsUserAgent::new_from_environment(Env::real(), ApiMetadata::new("imds", PKG_VERSION))
60}
61
62/// IMDSv2 Client
63///
64/// Client for IMDSv2. This client handles fetching tokens, retrying on failure, and token
65/// caching according to the specified token TTL.
66///
67/// _Note: This client ONLY supports IMDSv2. It will not fallback to IMDSv1. See
68/// [transitioning to IMDSv2](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/configuring-instance-metadata-service.html#instance-metadata-transition-to-version-2)
69/// for more information._
70///
71/// **Note**: When running in a Docker container, all network requests will incur an additional hop. When combined with the default IMDS hop limit of 1, this will cause requests to IMDS to timeout! To fix this issue, you'll need to set the following instance metadata settings :
72/// ```txt
73/// amazonec2-metadata-token=required
74/// amazonec2-metadata-token-response-hop-limit=2
75/// ```
76///
77/// On an instance that is already running, these can be set with [ModifyInstanceMetadataOptions](https://docs.aws.amazon.com/AWSEC2/latest/APIReference/API_ModifyInstanceMetadataOptions.html). On a new instance, these can be set with the `MetadataOptions` field on [RunInstances](https://docs.aws.amazon.com/AWSEC2/latest/APIReference/API_RunInstances.html).
78///
79/// For more information about IMDSv2 vs. IMDSv1 see [this guide](https://docs.aws.amazon.com/AWSEC2/latest/WindowsGuide/configuring-instance-metadata-service.html)
80///
81/// # Client Configuration
82/// The IMDS client can load configuration explicitly, via environment variables, or via
83/// `~/.aws/config`. It will first attempt to resolve an endpoint override. If no endpoint
84/// override exists, it will attempt to resolve an [`EndpointMode`]. If no
85/// [`EndpointMode`] override exists, it will fallback to [`IpV4`](EndpointMode::IpV4). An exhaustive
86/// list is below:
87///
88/// ## Endpoint configuration list
89/// 1. Explicit configuration of `Endpoint` via the [builder](Builder):
90/// ```no_run
91/// use aws_config::imds::client::Client;
92/// # async fn docs() {
93/// let client = Client::builder()
94///   .endpoint("http://customimds:456/").expect("valid URI")
95///   .build();
96/// # }
97/// ```
98///
99/// 2. The `AWS_EC2_METADATA_SERVICE_ENDPOINT` environment variable. Note: If this environment variable
100///    is set, it MUST contain a valid URI or client construction will fail.
101///
102/// 3. The `ec2_metadata_service_endpoint` field in `~/.aws/config`:
103/// ```ini
104/// [default]
105/// # ... other configuration
106/// ec2_metadata_service_endpoint = http://my-custom-endpoint:444
107/// ```
108///
109/// 4. An explicitly set endpoint mode:
110/// ```no_run
111/// use aws_config::imds::client::{Client, EndpointMode};
112/// # async fn docs() {
113/// let client = Client::builder().endpoint_mode(EndpointMode::IpV6).build();
114/// # }
115/// ```
116///
117/// 5. An [endpoint mode](EndpointMode) loaded from the `AWS_EC2_METADATA_SERVICE_ENDPOINT_MODE` environment
118///    variable. Valid values: `IPv4`, `IPv6`
119///
120/// 6. An [endpoint mode](EndpointMode) loaded from the `ec2_metadata_service_endpoint_mode` field in
121///    `~/.aws/config`:
122/// ```ini
123/// [default]
124/// # ... other configuration
125/// ec2_metadata_service_endpoint_mode = IPv4
126/// ```
127///
128/// 7. The default value of `http://169.254.169.254` will be used.
129///
130#[derive(Clone, Debug)]
131pub struct Client {
132    operation: Operation<String, SensitiveString, InnerImdsError>,
133}
134
135impl Client {
136    /// IMDS client builder
137    pub fn builder() -> Builder {
138        Builder::default()
139    }
140
141    /// Retrieve information from IMDS
142    ///
143    /// This method will handle loading and caching a session token, combining the `path` with the
144    /// configured IMDS endpoint, and retrying potential errors.
145    ///
146    /// For more information about IMDSv2 methods and functionality, see
147    /// [Instance metadata and user data](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/ec2-instance-metadata.html)
148    ///
149    /// # Examples
150    ///
151    /// ```no_run
152    /// use aws_config::imds::client::Client;
153    /// # async fn docs() {
154    /// let client = Client::builder().build();
155    /// let ami_id = client
156    ///   .get("/latest/meta-data/ami-id")
157    ///   .await
158    ///   .expect("failure communicating with IMDS");
159    /// # }
160    /// ```
161    pub async fn get(&self, path: impl Into<String>) -> Result<SensitiveString, ImdsError> {
162        self.operation
163            .invoke(path.into())
164            .await
165            .map_err(|err| match err {
166                SdkError::ConstructionFailure(_) if err.source().is_some() => {
167                    match err.into_source().map(|e| e.downcast::<ImdsError>()) {
168                        Ok(Ok(token_failure)) => *token_failure,
169                        Ok(Err(err)) => ImdsError::unexpected(err),
170                        Err(err) => ImdsError::unexpected(err),
171                    }
172                }
173                SdkError::ConstructionFailure(_) => ImdsError::unexpected(err),
174                SdkError::ServiceError(context) => match context.err() {
175                    InnerImdsError::InvalidUtf8 => {
176                        ImdsError::unexpected("IMDS returned invalid UTF-8")
177                    }
178                    InnerImdsError::BadStatus => ImdsError::error_response(context.into_raw()),
179                },
180                // If the error source is an ImdsError, then we need to directly return that source.
181                // That way, the IMDS token provider's errors can become the top-level ImdsError.
182                // There is a unit test that checks the correct error is being extracted.
183                err @ SdkError::DispatchFailure(_) => match err.into_source() {
184                    Ok(source) => match source.downcast::<ConnectorError>() {
185                        Ok(source) => match source.into_source().downcast::<ImdsError>() {
186                            Ok(source) => *source,
187                            Err(err) => ImdsError::unexpected(err),
188                        },
189                        Err(err) => ImdsError::unexpected(err),
190                    },
191                    Err(err) => ImdsError::unexpected(err),
192                },
193                SdkError::TimeoutError(_) | SdkError::ResponseError(_) => ImdsError::io_error(err),
194                _ => ImdsError::unexpected(err),
195            })
196    }
197}
198
199/// New-type around `String` that doesn't emit the string value in the `Debug` impl.
200#[derive(Clone)]
201pub struct SensitiveString(String);
202
203impl fmt::Debug for SensitiveString {
204    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
205        f.debug_tuple("SensitiveString")
206            .field(&"** redacted **")
207            .finish()
208    }
209}
210
211impl AsRef<str> for SensitiveString {
212    fn as_ref(&self) -> &str {
213        &self.0
214    }
215}
216
217impl From<String> for SensitiveString {
218    fn from(value: String) -> Self {
219        Self(value)
220    }
221}
222
223impl From<SensitiveString> for String {
224    fn from(value: SensitiveString) -> Self {
225        value.0
226    }
227}
228
229/// Runtime plugin that is used by both the IMDS client and the inner client that resolves
230/// the IMDS token and attaches it to requests. This runtime plugin marks the responses as
231/// sensitive, configures user agent headers, and sets up retries and timeouts.
232#[derive(Debug)]
233struct ImdsCommonRuntimePlugin {
234    config: FrozenLayer,
235    components: RuntimeComponentsBuilder,
236}
237
238impl ImdsCommonRuntimePlugin {
239    fn new(
240        config: &ProviderConfig,
241        endpoint_resolver: ImdsEndpointResolver,
242        retry_config: RetryConfig,
243        retry_classifier: SharedRetryClassifier,
244        timeout_config: TimeoutConfig,
245    ) -> Self {
246        let mut layer = Layer::new("ImdsCommonRuntimePlugin");
247        layer.store_put(AuthSchemeOptionResolverParams::new(()));
248        layer.store_put(EndpointResolverParams::new(()));
249        layer.store_put(SensitiveOutput);
250        layer.store_put(retry_config);
251        layer.store_put(timeout_config);
252        layer.store_put(user_agent());
253
254        Self {
255            config: layer.freeze(),
256            components: RuntimeComponentsBuilder::new("ImdsCommonRuntimePlugin")
257                .with_http_client(config.http_client())
258                .with_endpoint_resolver(Some(endpoint_resolver))
259                .with_interceptor(UserAgentInterceptor::new())
260                .with_retry_classifier(retry_classifier)
261                .with_retry_strategy(Some(StandardRetryStrategy::new()))
262                .with_time_source(Some(config.time_source()))
263                .with_sleep_impl(config.sleep_impl()),
264        }
265    }
266}
267
268impl RuntimePlugin for ImdsCommonRuntimePlugin {
269    fn config(&self) -> Option<FrozenLayer> {
270        Some(self.config.clone())
271    }
272
273    fn runtime_components(
274        &self,
275        _current_components: &RuntimeComponentsBuilder,
276    ) -> Cow<'_, RuntimeComponentsBuilder> {
277        Cow::Borrowed(&self.components)
278    }
279}
280
281/// IMDSv2 Endpoint Mode
282///
283/// IMDS can be accessed in two ways:
284/// 1. Via the IpV4 endpoint: `http://169.254.169.254`
285/// 2. Via the Ipv6 endpoint: `http://[fd00:ec2::254]`
286#[derive(Debug, Clone)]
287#[non_exhaustive]
288pub enum EndpointMode {
289    /// IpV4 mode: `http://169.254.169.254`
290    ///
291    /// This mode is the default unless otherwise specified.
292    IpV4,
293    /// IpV6 mode: `http://[fd00:ec2::254]`
294    IpV6,
295}
296
297impl FromStr for EndpointMode {
298    type Err = InvalidEndpointMode;
299
300    fn from_str(value: &str) -> Result<Self, Self::Err> {
301        match value {
302            _ if value.eq_ignore_ascii_case("ipv4") => Ok(EndpointMode::IpV4),
303            _ if value.eq_ignore_ascii_case("ipv6") => Ok(EndpointMode::IpV6),
304            other => Err(InvalidEndpointMode::new(other.to_owned())),
305        }
306    }
307}
308
309impl EndpointMode {
310    /// IMDS URI for this endpoint mode
311    fn endpoint(&self) -> Uri {
312        match self {
313            EndpointMode::IpV4 => Uri::from_static("http://169.254.169.254"),
314            EndpointMode::IpV6 => Uri::from_static("http://[fd00:ec2::254]"),
315        }
316    }
317}
318
319/// IMDSv2 Client Builder
320#[derive(Default, Debug, Clone)]
321pub struct Builder {
322    max_attempts: Option<u32>,
323    endpoint: Option<EndpointSource>,
324    mode_override: Option<EndpointMode>,
325    token_ttl: Option<Duration>,
326    connect_timeout: Option<Duration>,
327    read_timeout: Option<Duration>,
328    operation_timeout: Option<Duration>,
329    operation_attempt_timeout: Option<Duration>,
330    config: Option<ProviderConfig>,
331    retry_classifier: Option<SharedRetryClassifier>,
332}
333
334impl Builder {
335    /// Override the number of retries for fetching tokens & metadata
336    ///
337    /// By default, 4 attempts will be made.
338    pub fn max_attempts(mut self, max_attempts: u32) -> Self {
339        self.max_attempts = Some(max_attempts);
340        self
341    }
342
343    /// Configure generic options of the [`Client`]
344    ///
345    /// # Examples
346    /// ```no_run
347    /// # async fn test() {
348    /// use aws_config::imds::Client;
349    /// use aws_config::provider_config::ProviderConfig;
350    ///
351    /// let provider = Client::builder()
352    ///     .configure(&ProviderConfig::with_default_region().await)
353    ///     .build();
354    /// # }
355    /// ```
356    pub fn configure(mut self, provider_config: &ProviderConfig) -> Self {
357        self.config = Some(provider_config.clone());
358        self
359    }
360
361    /// Override the endpoint for the [`Client`]
362    ///
363    /// By default, the client will resolve an endpoint from the environment, AWS config, and endpoint mode.
364    ///
365    /// See [`Client`] for more information.
366    pub fn endpoint(mut self, endpoint: impl AsRef<str>) -> Result<Self, BoxError> {
367        let uri: Uri = endpoint.as_ref().parse()?;
368        self.endpoint = Some(EndpointSource::Explicit(uri));
369        Ok(self)
370    }
371
372    /// Override the endpoint mode for [`Client`]
373    ///
374    /// * When set to [`IpV4`](EndpointMode::IpV4), the endpoint will be `http://169.254.169.254`.
375    /// * When set to [`IpV6`](EndpointMode::IpV6), the endpoint will be `http://[fd00:ec2::254]`.
376    pub fn endpoint_mode(mut self, mode: EndpointMode) -> Self {
377        self.mode_override = Some(mode);
378        self
379    }
380
381    /// Override the time-to-live for the session token
382    ///
383    /// Requests to IMDS utilize a session token for authentication. By default, session tokens last
384    /// for 6 hours. When the TTL for the token expires, a new token must be retrieved from the
385    /// metadata service.
386    pub fn token_ttl(mut self, ttl: Duration) -> Self {
387        self.token_ttl = Some(ttl);
388        self
389    }
390
391    /// Override the connect timeout for IMDS
392    ///
393    /// This value defaults to 1 second
394    pub fn connect_timeout(mut self, timeout: Duration) -> Self {
395        self.connect_timeout = Some(timeout);
396        self
397    }
398
399    /// Override the read timeout for IMDS
400    ///
401    /// This value defaults to 1 second
402    pub fn read_timeout(mut self, timeout: Duration) -> Self {
403        self.read_timeout = Some(timeout);
404        self
405    }
406
407    /// Override the operation timeout for IMDS
408    ///
409    /// This value defaults to 1 second
410    pub fn operation_timeout(mut self, timeout: Duration) -> Self {
411        self.operation_timeout = Some(timeout);
412        self
413    }
414
415    /// Override the operation attempt timeout for IMDS
416    ///
417    /// This value defaults to 1 second
418    pub fn operation_attempt_timeout(mut self, timeout: Duration) -> Self {
419        self.operation_attempt_timeout = Some(timeout);
420        self
421    }
422
423    /// Override the retry classifier for IMDS
424    ///
425    /// This defaults to only retrying on server errors and 401s. The [ImdsResponseRetryClassifier] in this
426    /// module offers some configuration options and can be wrapped by[SharedRetryClassifier::new()] for use
427    /// here or you can create your own fully customized [SharedRetryClassifier].
428    pub fn retry_classifier(mut self, retry_classifier: SharedRetryClassifier) -> Self {
429        self.retry_classifier = Some(retry_classifier);
430        self
431    }
432
433    /* TODO(https://github.com/awslabs/aws-sdk-rust/issues/339): Support customizing the port explicitly */
434    /*
435    pub fn port(mut self, port: u32) -> Self {
436        self.port_override = Some(port);
437        self
438    }*/
439
440    /// Build an IMDSv2 Client
441    pub fn build(self) -> Client {
442        let config = self.config.unwrap_or_default();
443        let timeout_config = TimeoutConfig::builder()
444            .connect_timeout(self.connect_timeout.unwrap_or(DEFAULT_CONNECT_TIMEOUT))
445            .read_timeout(self.read_timeout.unwrap_or(DEFAULT_READ_TIMEOUT))
446            .operation_attempt_timeout(
447                self.operation_attempt_timeout
448                    .unwrap_or(DEFAULT_OPERATION_ATTEMPT_TIMEOUT),
449            )
450            .operation_timeout(self.operation_timeout.unwrap_or(DEFAULT_OPERATION_TIMEOUT))
451            .build();
452        let endpoint_source = self
453            .endpoint
454            .unwrap_or_else(|| EndpointSource::Env(config.clone()));
455        let endpoint_resolver = ImdsEndpointResolver {
456            endpoint_source: Arc::new(endpoint_source),
457            mode_override: self.mode_override,
458        };
459        let retry_config = RetryConfig::standard()
460            .with_max_attempts(self.max_attempts.unwrap_or(DEFAULT_ATTEMPTS));
461        let retry_classifier = self.retry_classifier.unwrap_or(SharedRetryClassifier::new(
462            ImdsResponseRetryClassifier::default(),
463        ));
464        let common_plugin = SharedRuntimePlugin::new(ImdsCommonRuntimePlugin::new(
465            &config,
466            endpoint_resolver,
467            retry_config,
468            retry_classifier,
469            timeout_config,
470        ));
471        let operation = Operation::builder()
472            .service_name("imds")
473            .operation_name("get")
474            .runtime_plugin(common_plugin.clone())
475            .runtime_plugin(TokenRuntimePlugin::new(
476                common_plugin,
477                self.token_ttl.unwrap_or(DEFAULT_TOKEN_TTL),
478            ))
479            .with_connection_poisoning()
480            .serializer(|path| {
481                Ok(HttpRequest::try_from(
482                    http::Request::builder()
483                        .uri(path)
484                        .body(SdkBody::empty())
485                        .expect("valid request"),
486                )
487                .unwrap())
488            })
489            .deserializer(|response| {
490                if response.status().is_success() {
491                    std::str::from_utf8(response.body().bytes().expect("non-streaming response"))
492                        .map(|data| SensitiveString::from(data.to_string()))
493                        .map_err(|_| OrchestratorError::operation(InnerImdsError::InvalidUtf8))
494                } else {
495                    Err(OrchestratorError::operation(InnerImdsError::BadStatus))
496                }
497            })
498            .build();
499        Client { operation }
500    }
501}
502
503mod env {
504    pub(super) const ENDPOINT: &str = "AWS_EC2_METADATA_SERVICE_ENDPOINT";
505    pub(super) const ENDPOINT_MODE: &str = "AWS_EC2_METADATA_SERVICE_ENDPOINT_MODE";
506}
507
508mod profile_keys {
509    pub(super) const ENDPOINT: &str = "ec2_metadata_service_endpoint";
510    pub(super) const ENDPOINT_MODE: &str = "ec2_metadata_service_endpoint_mode";
511}
512
513/// Endpoint Configuration Abstraction
514#[derive(Debug, Clone)]
515enum EndpointSource {
516    Explicit(Uri),
517    Env(ProviderConfig),
518}
519
520impl EndpointSource {
521    async fn endpoint(&self, mode_override: Option<EndpointMode>) -> Result<Uri, BuildError> {
522        match self {
523            EndpointSource::Explicit(uri) => {
524                if mode_override.is_some() {
525                    tracing::warn!(endpoint = ?uri, mode = ?mode_override,
526                        "Endpoint mode override was set in combination with an explicit endpoint. \
527                        The mode override will be ignored.")
528                }
529                Ok(uri.clone())
530            }
531            EndpointSource::Env(conf) => {
532                let env = conf.env();
533                // load an endpoint override from the environment
534                let profile = conf.profile().await;
535                let uri_override = if let Ok(uri) = env.get(env::ENDPOINT) {
536                    Some(Cow::Owned(uri))
537                } else {
538                    profile
539                        .and_then(|profile| profile.get(profile_keys::ENDPOINT))
540                        .map(Cow::Borrowed)
541                };
542                if let Some(uri) = uri_override {
543                    return Uri::try_from(uri.as_ref()).map_err(BuildError::invalid_endpoint_uri);
544                }
545
546                // if not, load a endpoint mode from the environment
547                let mode = if let Some(mode) = mode_override {
548                    mode
549                } else if let Ok(mode) = env.get(env::ENDPOINT_MODE) {
550                    mode.parse::<EndpointMode>()
551                        .map_err(BuildError::invalid_endpoint_mode)?
552                } else if let Some(mode) = profile.and_then(|p| p.get(profile_keys::ENDPOINT_MODE))
553                {
554                    mode.parse::<EndpointMode>()
555                        .map_err(BuildError::invalid_endpoint_mode)?
556                } else {
557                    EndpointMode::IpV4
558                };
559
560                Ok(mode.endpoint())
561            }
562        }
563    }
564}
565
566#[derive(Clone, Debug)]
567struct ImdsEndpointResolver {
568    endpoint_source: Arc<EndpointSource>,
569    mode_override: Option<EndpointMode>,
570}
571
572impl ResolveEndpoint for ImdsEndpointResolver {
573    fn resolve_endpoint<'a>(&'a self, _: &'a EndpointResolverParams) -> EndpointFuture<'a> {
574        EndpointFuture::new(async move {
575            self.endpoint_source
576                .endpoint(self.mode_override.clone())
577                .await
578                .map(|uri| Endpoint::builder().url(uri.to_string()).build())
579                .map_err(|err| err.into())
580        })
581    }
582}
583
584/// IMDS Response Retry Classifier
585///
586/// Possible status codes:
587/// - 200 (OK)
588/// - 400 (Missing or invalid parameters) **Not Retryable**
589/// - 401 (Unauthorized, expired token) **Retryable**
590/// - 403 (IMDS disabled): **Not Retryable**
591/// - 404 (Not found): **Not Retryable**
592/// - >=500 (server error): **Retryable**
593/// - Timeouts: Not retried by default, but this is configurable via [Self::with_retry_connect_timeouts()]
594#[derive(Clone, Debug, Default)]
595#[non_exhaustive]
596pub struct ImdsResponseRetryClassifier {
597    retry_connect_timeouts: bool,
598}
599
600impl ImdsResponseRetryClassifier {
601    /// Indicate whether the IMDS client should retry on connection timeouts
602    pub fn with_retry_connect_timeouts(mut self, retry_connect_timeouts: bool) -> Self {
603        self.retry_connect_timeouts = retry_connect_timeouts;
604        self
605    }
606}
607
608impl ClassifyRetry for ImdsResponseRetryClassifier {
609    fn name(&self) -> &'static str {
610        "ImdsResponseRetryClassifier"
611    }
612
613    fn classify_retry(&self, ctx: &InterceptorContext) -> RetryAction {
614        if let Some(response) = ctx.response() {
615            let status = response.status();
616            match status {
617                _ if status.is_server_error() => RetryAction::server_error(),
618                // 401 indicates that the token has expired, this is retryable
619                _ if status.as_u16() == 401 => RetryAction::server_error(),
620                // This catch-all includes successful responses that fail to parse. These should not be retried.
621                _ => RetryAction::NoActionIndicated,
622            }
623        } else if self.retry_connect_timeouts {
624            RetryAction::server_error()
625        } else {
626            // This is the default behavior.
627            // Don't retry timeouts for IMDS, or else it will take ~30 seconds for the default
628            // credentials provider chain to fail to provide credentials.
629            // Also don't retry non-responses.
630            RetryAction::NoActionIndicated
631        }
632    }
633}
634
635#[cfg(test)]
636pub(crate) mod test {
637    use crate::imds::client::{Client, EndpointMode, ImdsResponseRetryClassifier};
638    use crate::provider_config::ProviderConfig;
639    use aws_smithy_async::rt::sleep::TokioSleep;
640    use aws_smithy_async::test_util::{instant_time_and_sleep, InstantSleep};
641    use aws_smithy_runtime::client::http::test_util::{
642        capture_request, ReplayEvent, StaticReplayClient,
643    };
644    use aws_smithy_runtime::test_util::capture_test_logs::capture_test_logs;
645    use aws_smithy_runtime_api::client::interceptors::context::{
646        Input, InterceptorContext, Output,
647    };
648    use aws_smithy_runtime_api::client::orchestrator::{
649        HttpRequest, HttpResponse, OrchestratorError,
650    };
651    use aws_smithy_runtime_api::client::result::ConnectorError;
652    use aws_smithy_runtime_api::client::retries::classifiers::{
653        ClassifyRetry, RetryAction, SharedRetryClassifier,
654    };
655    use aws_smithy_types::body::SdkBody;
656    use aws_smithy_types::error::display::DisplayErrorContext;
657    use aws_types::os_shim_internal::{Env, Fs};
658    use http::header::USER_AGENT;
659    use http::Uri;
660    use serde::Deserialize;
661    use std::collections::HashMap;
662    use std::error::Error;
663    use std::io;
664    use std::time::SystemTime;
665    use std::time::{Duration, UNIX_EPOCH};
666    use tracing_test::traced_test;
667
668    macro_rules! assert_full_error_contains {
669        ($err:expr, $contains:expr) => {
670            let err = $err;
671            let message = format!(
672                "{}",
673                aws_smithy_types::error::display::DisplayErrorContext(&err)
674            );
675            assert!(
676                message.contains($contains),
677                "Error message '{message}' didn't contain text '{}'",
678                $contains
679            );
680        };
681    }
682
683    const TOKEN_A: &str = "AQAEAFTNrA4eEGx0AQgJ1arIq_Cc-t4tWt3fB0Hd8RKhXlKc5ccvhg==";
684    const TOKEN_B: &str = "alternatetoken==";
685
686    pub(crate) fn token_request(base: &str, ttl: u32) -> HttpRequest {
687        http::Request::builder()
688            .uri(format!("{}/latest/api/token", base))
689            .header("x-aws-ec2-metadata-token-ttl-seconds", ttl)
690            .method("PUT")
691            .body(SdkBody::empty())
692            .unwrap()
693            .try_into()
694            .unwrap()
695    }
696
697    pub(crate) fn token_response(ttl: u32, token: &'static str) -> HttpResponse {
698        HttpResponse::try_from(
699            http::Response::builder()
700                .status(200)
701                .header("X-aws-ec2-metadata-token-ttl-seconds", ttl)
702                .body(SdkBody::from(token))
703                .unwrap(),
704        )
705        .unwrap()
706    }
707
708    pub(crate) fn imds_request(path: &'static str, token: &str) -> HttpRequest {
709        http::Request::builder()
710            .uri(Uri::from_static(path))
711            .method("GET")
712            .header("x-aws-ec2-metadata-token", token)
713            .body(SdkBody::empty())
714            .unwrap()
715            .try_into()
716            .unwrap()
717    }
718
719    pub(crate) fn imds_response(body: &'static str) -> HttpResponse {
720        HttpResponse::try_from(
721            http::Response::builder()
722                .status(200)
723                .body(SdkBody::from(body))
724                .unwrap(),
725        )
726        .unwrap()
727    }
728
729    pub(crate) fn make_imds_client(http_client: &StaticReplayClient) -> super::Client {
730        tokio::time::pause();
731        super::Client::builder()
732            .configure(
733                &ProviderConfig::no_configuration()
734                    .with_sleep_impl(InstantSleep::unlogged())
735                    .with_http_client(http_client.clone()),
736            )
737            .build()
738    }
739
740    fn mock_imds_client(events: Vec<ReplayEvent>) -> (Client, StaticReplayClient) {
741        let http_client = StaticReplayClient::new(events);
742        let client = make_imds_client(&http_client);
743        (client, http_client)
744    }
745
746    #[tokio::test]
747    async fn client_caches_token() {
748        let (client, http_client) = mock_imds_client(vec![
749            ReplayEvent::new(
750                token_request("http://169.254.169.254", 21600),
751                token_response(21600, TOKEN_A),
752            ),
753            ReplayEvent::new(
754                imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
755                imds_response(r#"test-imds-output"#),
756            ),
757            ReplayEvent::new(
758                imds_request("http://169.254.169.254/latest/metadata2", TOKEN_A),
759                imds_response("output2"),
760            ),
761        ]);
762        // load once
763        let metadata = client.get("/latest/metadata").await.expect("failed");
764        assert_eq!("test-imds-output", metadata.as_ref());
765        // load again: the cached token should be used
766        let metadata = client.get("/latest/metadata2").await.expect("failed");
767        assert_eq!("output2", metadata.as_ref());
768        http_client.assert_requests_match(&[]);
769    }
770
771    #[tokio::test]
772    async fn token_can_expire() {
773        let (_, http_client) = mock_imds_client(vec![
774            ReplayEvent::new(
775                token_request("http://[fd00:ec2::254]", 600),
776                token_response(600, TOKEN_A),
777            ),
778            ReplayEvent::new(
779                imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_A),
780                imds_response(r#"test-imds-output1"#),
781            ),
782            ReplayEvent::new(
783                token_request("http://[fd00:ec2::254]", 600),
784                token_response(600, TOKEN_B),
785            ),
786            ReplayEvent::new(
787                imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_B),
788                imds_response(r#"test-imds-output2"#),
789            ),
790        ]);
791        let (time_source, sleep) = instant_time_and_sleep(UNIX_EPOCH);
792        let client = super::Client::builder()
793            .configure(
794                &ProviderConfig::no_configuration()
795                    .with_http_client(http_client.clone())
796                    .with_time_source(time_source.clone())
797                    .with_sleep_impl(sleep),
798            )
799            .endpoint_mode(EndpointMode::IpV6)
800            .token_ttl(Duration::from_secs(600))
801            .build();
802
803        let resp1 = client.get("/latest/metadata").await.expect("success");
804        // now the cached credential has expired
805        time_source.advance(Duration::from_secs(600));
806        let resp2 = client.get("/latest/metadata").await.expect("success");
807        http_client.assert_requests_match(&[]);
808        assert_eq!("test-imds-output1", resp1.as_ref());
809        assert_eq!("test-imds-output2", resp2.as_ref());
810    }
811
812    /// Tokens are refreshed up to 120 seconds early to avoid using an expired token.
813    #[tokio::test]
814    async fn token_refresh_buffer() {
815        let _logs = capture_test_logs();
816        let (_, http_client) = mock_imds_client(vec![
817            ReplayEvent::new(
818                token_request("http://[fd00:ec2::254]", 600),
819                token_response(600, TOKEN_A),
820            ),
821            // t = 0
822            ReplayEvent::new(
823                imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_A),
824                imds_response(r#"test-imds-output1"#),
825            ),
826            // t = 400 (no refresh)
827            ReplayEvent::new(
828                imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_A),
829                imds_response(r#"test-imds-output2"#),
830            ),
831            // t = 550 (within buffer)
832            ReplayEvent::new(
833                token_request("http://[fd00:ec2::254]", 600),
834                token_response(600, TOKEN_B),
835            ),
836            ReplayEvent::new(
837                imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_B),
838                imds_response(r#"test-imds-output3"#),
839            ),
840        ]);
841        let (time_source, sleep) = instant_time_and_sleep(UNIX_EPOCH);
842        let client = super::Client::builder()
843            .configure(
844                &ProviderConfig::no_configuration()
845                    .with_sleep_impl(sleep)
846                    .with_http_client(http_client.clone())
847                    .with_time_source(time_source.clone()),
848            )
849            .endpoint_mode(EndpointMode::IpV6)
850            .token_ttl(Duration::from_secs(600))
851            .build();
852
853        tracing::info!("resp1 -----------------------------------------------------------");
854        let resp1 = client.get("/latest/metadata").await.expect("success");
855        // now the cached credential has expired
856        time_source.advance(Duration::from_secs(400));
857        tracing::info!("resp2 -----------------------------------------------------------");
858        let resp2 = client.get("/latest/metadata").await.expect("success");
859        time_source.advance(Duration::from_secs(150));
860        tracing::info!("resp3 -----------------------------------------------------------");
861        let resp3 = client.get("/latest/metadata").await.expect("success");
862        http_client.assert_requests_match(&[]);
863        assert_eq!("test-imds-output1", resp1.as_ref());
864        assert_eq!("test-imds-output2", resp2.as_ref());
865        assert_eq!("test-imds-output3", resp3.as_ref());
866    }
867
868    /// 500 error during the GET should be retried
869    #[tokio::test]
870    #[traced_test]
871    async fn retry_500() {
872        let (client, http_client) = mock_imds_client(vec![
873            ReplayEvent::new(
874                token_request("http://169.254.169.254", 21600),
875                token_response(21600, TOKEN_A),
876            ),
877            ReplayEvent::new(
878                imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
879                http::Response::builder()
880                    .status(500)
881                    .body(SdkBody::empty())
882                    .unwrap(),
883            ),
884            ReplayEvent::new(
885                imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
886                imds_response("ok"),
887            ),
888        ]);
889        assert_eq!(
890            "ok",
891            client
892                .get("/latest/metadata")
893                .await
894                .expect("success")
895                .as_ref()
896        );
897        http_client.assert_requests_match(&[]);
898
899        // all requests should have a user agent header
900        for request in http_client.actual_requests() {
901            assert!(request.headers().get(USER_AGENT).is_some());
902        }
903    }
904
905    /// 500 error during token acquisition should be retried
906    #[tokio::test]
907    #[traced_test]
908    async fn retry_token_failure() {
909        let (client, http_client) = mock_imds_client(vec![
910            ReplayEvent::new(
911                token_request("http://169.254.169.254", 21600),
912                http::Response::builder()
913                    .status(500)
914                    .body(SdkBody::empty())
915                    .unwrap(),
916            ),
917            ReplayEvent::new(
918                token_request("http://169.254.169.254", 21600),
919                token_response(21600, TOKEN_A),
920            ),
921            ReplayEvent::new(
922                imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
923                imds_response("ok"),
924            ),
925        ]);
926        assert_eq!(
927            "ok",
928            client
929                .get("/latest/metadata")
930                .await
931                .expect("success")
932                .as_ref()
933        );
934        http_client.assert_requests_match(&[]);
935    }
936
937    /// 401 error during metadata retrieval must be retried
938    #[tokio::test]
939    #[traced_test]
940    async fn retry_metadata_401() {
941        let (client, http_client) = mock_imds_client(vec![
942            ReplayEvent::new(
943                token_request("http://169.254.169.254", 21600),
944                token_response(0, TOKEN_A),
945            ),
946            ReplayEvent::new(
947                imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
948                http::Response::builder()
949                    .status(401)
950                    .body(SdkBody::empty())
951                    .unwrap(),
952            ),
953            ReplayEvent::new(
954                token_request("http://169.254.169.254", 21600),
955                token_response(21600, TOKEN_B),
956            ),
957            ReplayEvent::new(
958                imds_request("http://169.254.169.254/latest/metadata", TOKEN_B),
959                imds_response("ok"),
960            ),
961        ]);
962        assert_eq!(
963            "ok",
964            client
965                .get("/latest/metadata")
966                .await
967                .expect("success")
968                .as_ref()
969        );
970        http_client.assert_requests_match(&[]);
971    }
972
973    /// 403 responses from IMDS during token acquisition MUST NOT be retried
974    #[tokio::test]
975    #[traced_test]
976    async fn no_403_retry() {
977        let (client, http_client) = mock_imds_client(vec![ReplayEvent::new(
978            token_request("http://169.254.169.254", 21600),
979            http::Response::builder()
980                .status(403)
981                .body(SdkBody::empty())
982                .unwrap(),
983        )]);
984        let err = client.get("/latest/metadata").await.expect_err("no token");
985        assert_full_error_contains!(err, "forbidden");
986        http_client.assert_requests_match(&[]);
987    }
988
989    /// The classifier should return `None` when classifying a successful response.
990    #[test]
991    fn successful_response_properly_classified() {
992        let mut ctx = InterceptorContext::new(Input::doesnt_matter());
993        ctx.set_output_or_error(Ok(Output::doesnt_matter()));
994        ctx.set_response(imds_response("").map(|_| SdkBody::empty()));
995        let classifier = ImdsResponseRetryClassifier::default();
996        assert_eq!(
997            RetryAction::NoActionIndicated,
998            classifier.classify_retry(&ctx)
999        );
1000
1001        // Emulate a failure to parse the response body (using an io error since it's easy to construct in a test)
1002        let mut ctx = InterceptorContext::new(Input::doesnt_matter());
1003        ctx.set_output_or_error(Err(OrchestratorError::connector(ConnectorError::io(
1004            io::Error::new(io::ErrorKind::BrokenPipe, "fail to parse").into(),
1005        ))));
1006        assert_eq!(
1007            RetryAction::NoActionIndicated,
1008            classifier.classify_retry(&ctx)
1009        );
1010    }
1011
1012    /// User provided retry classifier works
1013    #[tokio::test]
1014    async fn user_provided_retry_classifier() {
1015        #[derive(Clone, Debug)]
1016        struct UserProvidedRetryClassifier;
1017
1018        impl ClassifyRetry for UserProvidedRetryClassifier {
1019            fn name(&self) -> &'static str {
1020                "UserProvidedRetryClassifier"
1021            }
1022
1023            // Don't retry anything
1024            fn classify_retry(&self, _ctx: &InterceptorContext) -> RetryAction {
1025                RetryAction::RetryForbidden
1026            }
1027        }
1028
1029        let events = vec![
1030            ReplayEvent::new(
1031                token_request("http://169.254.169.254", 21600),
1032                token_response(0, TOKEN_A),
1033            ),
1034            ReplayEvent::new(
1035                imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
1036                http::Response::builder()
1037                    .status(401)
1038                    .body(SdkBody::empty())
1039                    .unwrap(),
1040            ),
1041            ReplayEvent::new(
1042                token_request("http://169.254.169.254", 21600),
1043                token_response(21600, TOKEN_B),
1044            ),
1045            ReplayEvent::new(
1046                imds_request("http://169.254.169.254/latest/metadata", TOKEN_B),
1047                imds_response("ok"),
1048            ),
1049        ];
1050        let http_client = StaticReplayClient::new(events);
1051
1052        let imds_client = super::Client::builder()
1053            .configure(
1054                &ProviderConfig::no_configuration()
1055                    .with_sleep_impl(InstantSleep::unlogged())
1056                    .with_http_client(http_client.clone()),
1057            )
1058            .retry_classifier(SharedRetryClassifier::new(UserProvidedRetryClassifier))
1059            .build();
1060
1061        let res = imds_client
1062            .get("/latest/metadata")
1063            .await
1064            .expect_err("Client should error");
1065
1066        // Assert that the operation errored on the initial 401 and did not retry and get
1067        // the 200 (since the user provided retry classifier never retries)
1068        assert_full_error_contains!(res, "401");
1069    }
1070
1071    // since tokens are sent as headers, the tokens need to be valid header values
1072    #[tokio::test]
1073    async fn invalid_token() {
1074        let (client, http_client) = mock_imds_client(vec![ReplayEvent::new(
1075            token_request("http://169.254.169.254", 21600),
1076            token_response(21600, "invalid\nheader\nvalue\0"),
1077        )]);
1078        let err = client.get("/latest/metadata").await.expect_err("no token");
1079        assert_full_error_contains!(err, "invalid token");
1080        http_client.assert_requests_match(&[]);
1081    }
1082
1083    #[tokio::test]
1084    async fn non_utf8_response() {
1085        let (client, http_client) = mock_imds_client(vec![
1086            ReplayEvent::new(
1087                token_request("http://169.254.169.254", 21600),
1088                token_response(21600, TOKEN_A).map(SdkBody::from),
1089            ),
1090            ReplayEvent::new(
1091                imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
1092                http::Response::builder()
1093                    .status(200)
1094                    .body(SdkBody::from(vec![0xA0, 0xA1]))
1095                    .unwrap(),
1096            ),
1097        ]);
1098        let err = client.get("/latest/metadata").await.expect_err("no token");
1099        assert_full_error_contains!(err, "invalid UTF-8");
1100        http_client.assert_requests_match(&[]);
1101    }
1102
1103    // TODO(https://github.com/awslabs/aws-sdk-rust/issues/1117) This test is ignored on Windows because it uses Unix-style paths
1104    #[cfg_attr(windows, ignore)]
1105    /// Verify that the end-to-end real client has a 1-second connect timeout
1106    #[tokio::test]
1107    #[cfg(feature = "rustls")]
1108    async fn one_second_connect_timeout() {
1109        use crate::imds::client::ImdsError;
1110        let client = Client::builder()
1111            // 240.* can never be resolved
1112            .endpoint("http://240.0.0.0")
1113            .expect("valid uri")
1114            .build();
1115        let now = SystemTime::now();
1116        let resp = client
1117            .get("/latest/metadata")
1118            .await
1119            .expect_err("240.0.0.0 will never resolve");
1120        match resp {
1121            err @ ImdsError::FailedToLoadToken(_)
1122                if format!("{}", DisplayErrorContext(&err)).contains("timeout") => {} // ok,
1123            other => panic!(
1124                "wrong error, expected construction failure with TimedOutError inside: {}",
1125                DisplayErrorContext(&other)
1126            ),
1127        }
1128        let time_elapsed = now.elapsed().unwrap();
1129        assert!(
1130            time_elapsed > Duration::from_secs(1),
1131            "time_elapsed should be greater than 1s but was {:?}",
1132            time_elapsed
1133        );
1134        assert!(
1135            time_elapsed < Duration::from_secs(2),
1136            "time_elapsed should be less than 2s but was {:?}",
1137            time_elapsed
1138        );
1139    }
1140
1141    /// Retry classifier properly retries timeouts when configured to (meaning it takes ~30s to fail)
1142    #[tokio::test]
1143    async fn retry_connect_timeouts() {
1144        let http_client = StaticReplayClient::new(vec![]);
1145        let imds_client = super::Client::builder()
1146            .retry_classifier(SharedRetryClassifier::new(
1147                ImdsResponseRetryClassifier::default().with_retry_connect_timeouts(true),
1148            ))
1149            .configure(&ProviderConfig::no_configuration().with_http_client(http_client.clone()))
1150            .operation_timeout(Duration::from_secs(1))
1151            .endpoint("http://240.0.0.0")
1152            .expect("valid uri")
1153            .build();
1154
1155        let now = SystemTime::now();
1156        let _res = imds_client
1157            .get("/latest/metadata")
1158            .await
1159            .expect_err("240.0.0.0 will never resolve");
1160        let time_elapsed: Duration = now.elapsed().unwrap();
1161
1162        assert!(
1163            time_elapsed > Duration::from_secs(1),
1164            "time_elapsed should be greater than 1s but was {:?}",
1165            time_elapsed
1166        );
1167
1168        assert!(
1169            time_elapsed < Duration::from_secs(2),
1170            "time_elapsed should be less than 2s but was {:?}",
1171            time_elapsed
1172        );
1173    }
1174
1175    #[derive(Debug, Deserialize)]
1176    struct ImdsConfigTest {
1177        env: HashMap<String, String>,
1178        fs: HashMap<String, String>,
1179        endpoint_override: Option<String>,
1180        mode_override: Option<String>,
1181        result: Result<String, String>,
1182        docs: String,
1183    }
1184
1185    #[tokio::test]
1186    async fn endpoint_config_tests() -> Result<(), Box<dyn Error>> {
1187        let _logs = capture_test_logs();
1188
1189        let test_cases = std::fs::read_to_string("test-data/imds-config/imds-endpoint-tests.json")?;
1190        #[derive(Deserialize)]
1191        struct TestCases {
1192            tests: Vec<ImdsConfigTest>,
1193        }
1194
1195        let test_cases: TestCases = serde_json::from_str(&test_cases)?;
1196        let test_cases = test_cases.tests;
1197        for test in test_cases {
1198            check(test).await;
1199        }
1200        Ok(())
1201    }
1202
1203    async fn check(test_case: ImdsConfigTest) {
1204        let (http_client, watcher) = capture_request(None);
1205        let provider_config = ProviderConfig::no_configuration()
1206            .with_sleep_impl(TokioSleep::new())
1207            .with_env(Env::from(test_case.env))
1208            .with_fs(Fs::from_map(test_case.fs))
1209            .with_http_client(http_client);
1210        let mut imds_client = Client::builder().configure(&provider_config);
1211        if let Some(endpoint_override) = test_case.endpoint_override {
1212            imds_client = imds_client
1213                .endpoint(endpoint_override)
1214                .expect("invalid URI");
1215        }
1216
1217        if let Some(mode_override) = test_case.mode_override {
1218            imds_client = imds_client.endpoint_mode(mode_override.parse().unwrap());
1219        }
1220
1221        let imds_client = imds_client.build();
1222        match &test_case.result {
1223            Ok(uri) => {
1224                // this request will fail, we just want to capture the endpoint configuration
1225                let _ = imds_client.get("/hello").await;
1226                assert_eq!(&watcher.expect_request().uri().to_string(), uri);
1227            }
1228            Err(expected) => {
1229                let err = imds_client.get("/hello").await.expect_err("it should fail");
1230                let message = format!("{}", DisplayErrorContext(&err));
1231                assert!(
1232                    message.contains(expected),
1233                    "{}\nexpected error: {expected}\nactual error: {message}",
1234                    test_case.docs
1235                );
1236            }
1237        };
1238    }
1239}